Compare commits
131 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 38bb5f3fab | |||
| 129a1d1e98 | |||
| 0f6dbfdc44 | |||
| 76d89f59af | |||
| 1a3a92bd1b | |||
| 4018320614 | |||
| d9700baa8d | |||
| 82456ff62e | |||
| 83fd4cf5c2 | |||
| 38d4da82b9 | |||
| 19e0e0e8e6 | |||
| 1d0653cdab | |||
| 95cff8c5fb | |||
| ad2415ede7 | |||
| 1e40cea199 | |||
| 9d6557d173 | |||
| 224db483d7 | |||
| 8237822441 | |||
| 65bf7c3b18 | |||
| 76cbc580f0 | |||
| 391b22aef7 | |||
| f8e8f979a2 | |||
| fb0ed1162b | |||
| 3f519744aa | |||
| 2505f6245f | |||
| 29ad2c6201 | |||
| 27e5ff0a8e | |||
| 128a7c23e6 | |||
| efb19eb748 | |||
| 2c499484b7 | |||
| 33d9d0d875 | |||
| f219834f5c | |||
| 7a94a683c4 | |||
| 2e6fdf2344 | |||
| 3d139c1a24 | |||
| f957981c8b | |||
| 584c61acb5 | |||
| f95a5202bf | |||
| d954460380 | |||
| f4240bb8c1 | |||
| 7caef4987f | |||
| 9b91af8ab7 | |||
| 506fba9ebf | |||
| 461a31e5d8 | |||
| e3a0dcd6fc | |||
| 12ada0115f | |||
| 7b0421d8c6 | |||
| 477d6d0cde | |||
| de61ac529d | |||
| 7f496c2f18 | |||
| 590235138f | |||
| 543c448b72 | |||
| 35c26ce22a | |||
| c2592c9f12 | |||
| b969d66978 | |||
| 1f808cdc62 | |||
| 497f637f58 | |||
| be686a8d0d | |||
| 7b7baea851 | |||
| a3de0fc78d | |||
| ab77154975 | |||
| c5d720f73d | |||
| 983819860f | |||
| f820945d9f | |||
| da5395a8ae | |||
| 86b919e4f7 | |||
| 233343c010 | |||
| 3a612898c6 | |||
| 3f7a3e3354 | |||
| 17a71aea72 | |||
| 7d3c5ac78c | |||
| d87c5ef439 | |||
| ef3e17317c | |||
| 1187b84c54 | |||
| 45336bd9ce | |||
| 36cf7debce | |||
| 027c222e82 | |||
| d00f148b76 | |||
| 48bc215f20 | |||
| 08bd9e672a | |||
| c5f1a2fccf | |||
| 655d647d40 | |||
| f3f0a2c553 | |||
| 5453a6c6d6 | |||
| 21c08a37d7 | |||
| 2bd261fbbf | |||
| cffc68df58 | |||
| 6e5335df1e | |||
| 16265e834e | |||
| 565a15bc9b | |||
| 76a2cb1af5 | |||
| 684f21740d | |||
| 86ca61d6ca | |||
| f0521cfa3c | |||
| 0c5d189aff | |||
| d7c8213eee | |||
| 63924ac687 | |||
| 6c47e9ea23 | |||
| aede045549 | |||
| 2ea08aa168 | |||
| d4b9248202 | |||
| fd6c623560 | |||
| 99da498679 | |||
| a20b817c28 | |||
| d5a1792f07 | |||
| beb99c17de | |||
| 8913f9f5c1 | |||
| acd5f01b4b | |||
| 6c62d8f5e6 | |||
| 5000f15021 | |||
| 44be5a0d1e | |||
| 3ca2aae9ca | |||
| 01080302a5 | |||
| 61d6c728b9 | |||
| 648787e739 | |||
| d2950e7615 | |||
| df8f695e84 | |||
| 8bb48ffdda | |||
| 4cfbf544a0 | |||
| a2ce74f398 | |||
| 0060dee222 | |||
| 5ff1058f30 | |||
| 500fc5e2a4 | |||
| baba9e6ede | |||
| b36619b905 | |||
| 937f50f0ae | |||
| a16755dd66 | |||
| 8bdc35f91f | |||
| 5b32c4d79d | |||
| 8625543413 | |||
| e18094825a |
@@ -1,2 +0,0 @@
|
||||
enabled: true
|
||||
preservePullRequestTitle: true
|
||||
@@ -0,0 +1,178 @@
|
||||
# Automatically backport merged PRs to the last N release branches when the
|
||||
# "backport" label is applied. Works whether the label is added before or
|
||||
# after the PR is merged.
|
||||
#
|
||||
# Usage:
|
||||
# 1. Add the "backport" label to a PR targeting main.
|
||||
# 2. When the PR merges (or if already merged), the workflow detects the
|
||||
# latest release/* branches and opens one cherry-pick PR per branch.
|
||||
#
|
||||
# The created backport PRs follow existing repo conventions:
|
||||
# - Branch: backport/<pr>-to-<version>
|
||||
# - Title: <original PR title> (#<pr>)
|
||||
# - Body: links back to the original PR and merge commit
|
||||
|
||||
name: Backport
|
||||
on:
|
||||
pull_request_target:
|
||||
branches:
|
||||
- main
|
||||
types:
|
||||
- closed
|
||||
- labeled
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
# Prevent duplicate runs for the same PR when both 'closed' and 'labeled'
|
||||
# fire in quick succession.
|
||||
concurrency:
|
||||
group: backport-${{ github.event.pull_request.number }}
|
||||
|
||||
jobs:
|
||||
detect:
|
||||
name: Detect target branches
|
||||
if: >
|
||||
github.event.pull_request.merged == true &&
|
||||
contains(github.event.pull_request.labels.*.name, 'backport')
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
branches: ${{ steps.find.outputs.branches }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
# Need all refs to discover release branches.
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Find latest release branches
|
||||
id: find
|
||||
run: |
|
||||
# List remote release branches matching the exact release/2.X
|
||||
# pattern (no suffixes like release/2.31_hotfix), sort by minor
|
||||
# version descending, and take the top 3.
|
||||
BRANCHES=$(
|
||||
git branch -r \
|
||||
| grep -E '^\s*origin/release/2\.[0-9]+$' \
|
||||
| sed 's|.*origin/||' \
|
||||
| sort -t. -k2 -n -r \
|
||||
| head -3
|
||||
)
|
||||
|
||||
if [ -z "$BRANCHES" ]; then
|
||||
echo "No release branches found."
|
||||
echo "branches=[]" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Convert to JSON array for the matrix.
|
||||
JSON=$(echo "$BRANCHES" | jq -Rnc '[inputs | select(length > 0)]')
|
||||
echo "branches=$JSON" >> "$GITHUB_OUTPUT"
|
||||
echo "Will backport to: $JSON"
|
||||
|
||||
backport:
|
||||
name: "Backport to ${{ matrix.branch }}"
|
||||
needs: detect
|
||||
if: needs.detect.outputs.branches != '[]'
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
branch: ${{ fromJson(needs.detect.outputs.branches) }}
|
||||
fail-fast: false
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
SENDER: ${{ github.event.sender.login }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
# Full history required for cherry-pick.
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cherry-pick and open PR
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
RELEASE_VERSION="${{ matrix.branch }}"
|
||||
# Strip the release/ prefix for naming.
|
||||
VERSION="${RELEASE_VERSION#release/}"
|
||||
BACKPORT_BRANCH="backport/${PR_NUMBER}-to-${VERSION}"
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
# Check if backport branch already exists (idempotency for re-runs).
|
||||
if git ls-remote --exit-code origin "refs/heads/${BACKPORT_BRANCH}" >/dev/null 2>&1; then
|
||||
echo "Backport branch ${BACKPORT_BRANCH} already exists, skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Create the backport branch from the target release branch.
|
||||
git checkout -b "$BACKPORT_BRANCH" "origin/${RELEASE_VERSION}"
|
||||
|
||||
# Cherry-pick the merge commit. Use -x to record provenance and
|
||||
# -m1 to pick the first parent (the main branch side).
|
||||
CONFLICTS=false
|
||||
if ! git cherry-pick -x -m1 "$MERGE_SHA"; then
|
||||
echo "::warning::Cherry-pick to ${RELEASE_VERSION} had conflicts."
|
||||
CONFLICTS=true
|
||||
|
||||
# Abort the failed cherry-pick and create an empty commit
|
||||
# explaining the situation.
|
||||
git cherry-pick --abort
|
||||
git commit --allow-empty -m "Cherry-pick of #${PR_NUMBER} requires manual resolution
|
||||
|
||||
The automatic cherry-pick of ${MERGE_SHA} to ${RELEASE_VERSION} had conflicts.
|
||||
Please cherry-pick manually:
|
||||
|
||||
git cherry-pick -x -m1 ${MERGE_SHA}"
|
||||
fi
|
||||
|
||||
git push origin "$BACKPORT_BRANCH"
|
||||
|
||||
TITLE="${PR_TITLE} (#${PR_NUMBER})"
|
||||
BODY=$(cat <<EOF
|
||||
Backport of ${PR_URL}
|
||||
|
||||
Original PR: #${PR_NUMBER} — ${PR_TITLE}
|
||||
Merge commit: ${MERGE_SHA}
|
||||
Requested by: @${SENDER}
|
||||
EOF
|
||||
)
|
||||
|
||||
if [ "$CONFLICTS" = true ]; then
|
||||
TITLE="${TITLE} (conflicts)"
|
||||
BODY="${BODY}
|
||||
|
||||
> [!WARNING]
|
||||
> The automatic cherry-pick had conflicts.
|
||||
> Please resolve manually by cherry-picking the original merge commit:
|
||||
>
|
||||
> \`\`\`
|
||||
> git fetch origin ${BACKPORT_BRANCH}
|
||||
> git checkout ${BACKPORT_BRANCH}
|
||||
> git reset --hard origin/${RELEASE_VERSION}
|
||||
> git cherry-pick -x -m1 ${MERGE_SHA}
|
||||
> # resolve conflicts, then push
|
||||
> \`\`\`"
|
||||
fi
|
||||
|
||||
# Check if a PR already exists for this branch (idempotency
|
||||
# for re-runs).
|
||||
EXISTING_PR=$(gh pr list --head "$BACKPORT_BRANCH" --base "$RELEASE_VERSION" --state all --json number --jq '.[0].number // empty')
|
||||
if [ -n "$EXISTING_PR" ]; then
|
||||
echo "PR #${EXISTING_PR} already exists for ${BACKPORT_BRANCH}, skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
gh pr create \
|
||||
--base "$RELEASE_VERSION" \
|
||||
--head "$BACKPORT_BRANCH" \
|
||||
--title "$TITLE" \
|
||||
--body "$BODY" \
|
||||
--assignee "$SENDER" \
|
||||
--reviewer "$SENDER"
|
||||
@@ -0,0 +1,143 @@
|
||||
# Automatically cherry-pick merged PRs to the latest release branch when the
|
||||
# "cherry-pick" label is applied. Works whether the label is added before or
|
||||
# after the PR is merged.
|
||||
#
|
||||
# Usage:
|
||||
# 1. Add the "cherry-pick" label to a PR targeting main.
|
||||
# 2. When the PR merges (or if already merged), the workflow detects the
|
||||
# latest release/* branch and opens a cherry-pick PR against it.
|
||||
#
|
||||
# The created PRs follow existing repo conventions:
|
||||
# - Branch: backport/<pr>-to-<version>
|
||||
# - Title: <original PR title> (#<pr>)
|
||||
# - Body: links back to the original PR and merge commit
|
||||
|
||||
name: Cherry-pick to release
|
||||
on:
|
||||
pull_request_target:
|
||||
branches:
|
||||
- main
|
||||
types:
|
||||
- closed
|
||||
- labeled
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
# Prevent duplicate runs for the same PR when both 'closed' and 'labeled'
|
||||
# fire in quick succession.
|
||||
concurrency:
|
||||
group: cherry-pick-${{ github.event.pull_request.number }}
|
||||
|
||||
jobs:
|
||||
cherry-pick:
|
||||
name: Cherry-pick to latest release
|
||||
if: >
|
||||
github.event.pull_request.merged == true &&
|
||||
contains(github.event.pull_request.labels.*.name, 'cherry-pick')
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
SENDER: ${{ github.event.sender.login }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
# Full history required for cherry-pick and branch discovery.
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cherry-pick and open PR
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Find the latest release branch matching the exact release/2.X
|
||||
# pattern (no suffixes like release/2.31_hotfix).
|
||||
RELEASE_BRANCH=$(
|
||||
git branch -r \
|
||||
| grep -E '^\s*origin/release/2\.[0-9]+$' \
|
||||
| sed 's|.*origin/||' \
|
||||
| sort -t. -k2 -n -r \
|
||||
| head -1
|
||||
)
|
||||
|
||||
if [ -z "$RELEASE_BRANCH" ]; then
|
||||
echo "::error::No release branch found."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Strip the release/ prefix for naming.
|
||||
VERSION="${RELEASE_BRANCH#release/}"
|
||||
BACKPORT_BRANCH="backport/${PR_NUMBER}-to-${VERSION}"
|
||||
|
||||
echo "Target branch: $RELEASE_BRANCH"
|
||||
echo "Backport branch: $BACKPORT_BRANCH"
|
||||
|
||||
# Check if backport branch already exists (idempotency for re-runs).
|
||||
if git ls-remote --exit-code origin "refs/heads/${BACKPORT_BRANCH}" >/dev/null 2>&1; then
|
||||
echo "Branch ${BACKPORT_BRANCH} already exists, skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
# Create the backport branch from the target release branch.
|
||||
git checkout -b "$BACKPORT_BRANCH" "origin/${RELEASE_BRANCH}"
|
||||
|
||||
# Cherry-pick the merge commit. Use -x to record provenance and
|
||||
# -m1 to pick the first parent (the main branch side).
|
||||
CONFLICT=false
|
||||
if ! git cherry-pick -x -m1 "$MERGE_SHA"; then
|
||||
CONFLICT=true
|
||||
echo "::warning::Cherry-pick to ${RELEASE_BRANCH} had conflicts."
|
||||
|
||||
# Abort the failed cherry-pick and create an empty commit with
|
||||
# instructions so the PR can still be opened.
|
||||
git cherry-pick --abort
|
||||
git commit --allow-empty -m "cherry-pick of #${PR_NUMBER} failed — resolve conflicts manually
|
||||
|
||||
Cherry-pick of ${MERGE_SHA} onto ${RELEASE_BRANCH} had conflicts.
|
||||
To resolve:
|
||||
git fetch origin ${BACKPORT_BRANCH}
|
||||
git checkout ${BACKPORT_BRANCH}
|
||||
git cherry-pick -x -m1 ${MERGE_SHA}
|
||||
# resolve conflicts
|
||||
git push origin ${BACKPORT_BRANCH}"
|
||||
fi
|
||||
|
||||
git push origin "$BACKPORT_BRANCH"
|
||||
|
||||
BODY=$(cat <<EOF
|
||||
Cherry-pick of ${PR_URL}
|
||||
|
||||
Original PR: #${PR_NUMBER} — ${PR_TITLE}
|
||||
Merge commit: ${MERGE_SHA}
|
||||
Requested by: @${SENDER}
|
||||
EOF
|
||||
)
|
||||
|
||||
TITLE="${PR_TITLE} (#${PR_NUMBER})"
|
||||
if [ "$CONFLICT" = true ]; then
|
||||
TITLE="[CONFLICT] ${TITLE}"
|
||||
fi
|
||||
|
||||
# Check if a PR already exists for this branch (idempotency
|
||||
# for re-runs). Use --state all to catch closed/merged PRs too.
|
||||
EXISTING_PR=$(gh pr list --head "$BACKPORT_BRANCH" --base "$RELEASE_BRANCH" --state all --json number --jq '.[0].number // empty')
|
||||
if [ -n "$EXISTING_PR" ]; then
|
||||
echo "PR #${EXISTING_PR} already exists for ${BACKPORT_BRANCH}, skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
gh pr create \
|
||||
--base "$RELEASE_BRANCH" \
|
||||
--head "$BACKPORT_BRANCH" \
|
||||
--title "$TITLE" \
|
||||
--body "$BODY" \
|
||||
--assignee "$SENDER" \
|
||||
--reviewer "$SENDER"
|
||||
+17
-17
@@ -35,7 +35,7 @@ jobs:
|
||||
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -157,7 +157,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -247,7 +247,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -272,7 +272,7 @@ jobs:
|
||||
if: ${{ !cancelled() }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -327,7 +327,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -379,7 +379,7 @@ jobs:
|
||||
- windows-2022
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -575,7 +575,7 @@ jobs:
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -637,7 +637,7 @@ jobs:
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -709,7 +709,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -736,7 +736,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -769,7 +769,7 @@ jobs:
|
||||
name: ${{ matrix.variant.name }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -849,7 +849,7 @@ jobs:
|
||||
if: needs.changes.outputs.site == 'true' || needs.changes.outputs.ci == 'true'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -930,7 +930,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1005,7 +1005,7 @@ jobs:
|
||||
if: always()
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1043,7 +1043,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1097,7 +1097,7 @@ jobs:
|
||||
IMAGE: ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1479,7 +1479,7 @@ jobs:
|
||||
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ jobs:
|
||||
steps:
|
||||
- name: Dependabot metadata
|
||||
id: metadata
|
||||
uses: dependabot/fetch-metadata@21025c705c08248db411dc16f3619e6b5f9ea21a # v2.5.0
|
||||
uses: dependabot/fetch-metadata@ffa630c65fa7e0ecfa0625b5ceda64399aea1b36 # v3.0.0
|
||||
with:
|
||||
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
verdict: ${{ steps.check.outputs.verdict }} # DEPLOY or NOOP
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -65,7 +65,7 @@ jobs:
|
||||
packages: write # to retag image as dogfood
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -142,7 +142,7 @@ jobs:
|
||||
needs: deploy
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
if: github.repository_owner == 'coder'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -125,7 +125,7 @@ jobs:
|
||||
id-token: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
- windows-2022
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
# Ensures that only bug fixes are cherry-picked to release branches.
|
||||
# PRs targeting release/* must have a title starting with "fix:" or "fix(scope):".
|
||||
name: PR Cherry-Pick Check
|
||||
|
||||
on:
|
||||
# zizmor: ignore[dangerous-triggers] Only reads PR metadata and comments; does not checkout PR code.
|
||||
pull_request_target:
|
||||
types: [opened, reopened, edited]
|
||||
branches:
|
||||
- "release/*"
|
||||
|
||||
permissions:
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
check-cherry-pick:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Check PR title for bug fix
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
with:
|
||||
script: |
|
||||
const title = context.payload.pull_request.title;
|
||||
const prNumber = context.payload.pull_request.number;
|
||||
const baseBranch = context.payload.pull_request.base.ref;
|
||||
const author = context.payload.pull_request.user.login;
|
||||
|
||||
console.log(`PR #${prNumber}: "${title}" -> ${baseBranch}`);
|
||||
|
||||
// Match conventional commit "fix:" or "fix(scope):" prefix.
|
||||
const isBugFix = /^fix(\(.+\))?:/.test(title);
|
||||
|
||||
if (isBugFix) {
|
||||
console.log("PR title indicates a bug fix. No action needed.");
|
||||
return;
|
||||
}
|
||||
|
||||
console.log("PR title does not indicate a bug fix. Commenting.");
|
||||
|
||||
// Check for an existing comment from this bot to avoid duplicates
|
||||
// on title edits.
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
});
|
||||
|
||||
const marker = "<!-- cherry-pick-check -->";
|
||||
const existingComment = comments.find(
|
||||
(c) => c.body && c.body.includes(marker),
|
||||
);
|
||||
|
||||
const body = [
|
||||
marker,
|
||||
`👋 Hey @${author}!`,
|
||||
"",
|
||||
`This PR is targeting the \`${baseBranch}\` release branch, but its title does not start with \`fix:\` or \`fix(scope):\`.`,
|
||||
"",
|
||||
"Only **bug fixes** should be cherry-picked to release branches. If this is a bug fix, please update the PR title to match the conventional commit format:",
|
||||
"",
|
||||
"```",
|
||||
"fix: description of the bug fix",
|
||||
"fix(scope): description of the bug fix",
|
||||
"```",
|
||||
"",
|
||||
"If this is **not** a bug fix, it likely should not target a release branch.",
|
||||
].join("\n");
|
||||
|
||||
if (existingComment) {
|
||||
console.log(`Updating existing comment ${existingComment.id}.`);
|
||||
await github.rest.issues.updateComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
comment_id: existingComment.id,
|
||||
body,
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body,
|
||||
});
|
||||
}
|
||||
|
||||
core.warning(
|
||||
`PR #${prNumber} targets ${baseBranch} but is not a bug fix. Title must start with "fix:" or "fix(scope):".`,
|
||||
);
|
||||
@@ -19,7 +19,7 @@ jobs:
|
||||
packages: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ jobs:
|
||||
PR_OPEN: ${{ steps.check_pr.outputs.pr_open }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -76,7 +76,7 @@ jobs:
|
||||
runs-on: "ubuntu-latest"
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -184,7 +184,7 @@ jobs:
|
||||
pull-requests: write # needed for commenting on PRs
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -228,7 +228,7 @@ jobs:
|
||||
CODER_IMAGE_TAG: ${{ needs.get_info.outputs.CODER_IMAGE_TAG }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -288,7 +288,7 @@ jobs:
|
||||
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ jobs:
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -121,22 +121,22 @@ jobs:
|
||||
fi
|
||||
|
||||
# Derive the release branch from the version tag.
|
||||
# Standard: 2.10.2 -> release/2.10
|
||||
# RC: 2.32.0-rc.0 -> release/2.32-rc.0
|
||||
# Non-RC releases must be on a release/X.Y branch.
|
||||
# RC tags are allowed on any branch (typically main).
|
||||
version="$(./scripts/version.sh)"
|
||||
# Strip any pre-release suffix first (e.g. 2.32.0-rc.0 -> 2.32.0)
|
||||
base_version="${version%%-*}"
|
||||
# Then strip patch to get major.minor (e.g. 2.32.0 -> 2.32)
|
||||
release_branch="release/${base_version%.*}"
|
||||
|
||||
if [[ "$version" == *-rc.* ]]; then
|
||||
# Extract major.minor and rc suffix from e.g. 2.32.0-rc.0
|
||||
base_version="${version%%-rc.*}" # 2.32.0
|
||||
major_minor="${base_version%.*}" # 2.32
|
||||
rc_suffix="${version##*-rc.}" # 0
|
||||
release_branch="release/${major_minor}-rc.${rc_suffix}"
|
||||
echo "RC release detected — skipping release branch check (RC tags are cut from main)."
|
||||
else
|
||||
release_branch=release/${version%.*}
|
||||
fi
|
||||
branch_contains_tag=$(git branch --remotes --contains "${GITHUB_REF}" --list "*/${release_branch}" --format='%(refname)')
|
||||
if [[ -z "${branch_contains_tag}" ]]; then
|
||||
echo "Ref tag must exist in a branch named ${release_branch} when creating a release, did you use scripts/release.sh?"
|
||||
exit 1
|
||||
branch_contains_tag=$(git branch --remotes --contains "${GITHUB_REF}" --list "*/${release_branch}" --format='%(refname)')
|
||||
if [[ -z "${branch_contains_tag}" ]]; then
|
||||
echo "Ref tag must exist in a branch named ${release_branch} when creating a non-RC release, did you use scripts/release.sh?"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ -z "${CODER_RELEASE_NOTES}" ]]; then
|
||||
@@ -673,7 +673,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -749,7 +749,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -47,6 +47,6 @@ jobs:
|
||||
|
||||
# Upload the results to GitHub's code scanning dashboard.
|
||||
- name: "Upload to code-scanning"
|
||||
uses: github/codeql-action/upload-sarif@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
|
||||
uses: github/codeql-action/upload-sarif@c10b8064de6f491fea524254123dbe5e09572f13 # v3.29.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
@@ -27,7 +27,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-go
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
|
||||
uses: github/codeql-action/init@c10b8064de6f491fea524254123dbe5e09572f13 # v3.29.5
|
||||
with:
|
||||
languages: go, javascript
|
||||
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
rm Makefile
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
|
||||
uses: github/codeql-action/analyze@c10b8064de6f491fea524254123dbe5e09572f13 # v3.29.5
|
||||
|
||||
- name: Send Slack notification on failure
|
||||
if: ${{ failure() }}
|
||||
|
||||
@@ -18,7 +18,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -96,7 +96,7 @@ jobs:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -120,7 +120,7 @@ jobs:
|
||||
actions: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ typ = "typ"
|
||||
styl = "styl"
|
||||
edn = "edn"
|
||||
Inferrable = "Inferrable"
|
||||
IIF = "IIF"
|
||||
|
||||
[files]
|
||||
extend-exclude = [
|
||||
|
||||
@@ -21,7 +21,7 @@ jobs:
|
||||
pull-requests: write # required to post PR review comments by the action
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -103,3 +103,6 @@ PLAN.md
|
||||
|
||||
# Ignore any dev licenses
|
||||
license.txt
|
||||
-e
|
||||
# Agent planning documents (local working files).
|
||||
docs/plans/
|
||||
|
||||
@@ -91,6 +91,59 @@ define atomic_write
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
endef
|
||||
|
||||
# Helper binary targets. Built with go build -o to avoid caching
|
||||
# link-stage executables in GOCACHE. Each binary is a real Make
|
||||
# target so parallel -j builds serialize correctly instead of
|
||||
# racing on the same output path.
|
||||
|
||||
_gen/bin/apitypings: $(wildcard scripts/apitypings/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/apitypings
|
||||
|
||||
_gen/bin/auditdocgen: $(wildcard scripts/auditdocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/auditdocgen
|
||||
|
||||
_gen/bin/check-scopes: $(wildcard scripts/check-scopes/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/check-scopes
|
||||
|
||||
_gen/bin/clidocgen: $(wildcard scripts/clidocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/clidocgen
|
||||
|
||||
_gen/bin/dbdump: $(wildcard coderd/database/gen/dump/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./coderd/database/gen/dump
|
||||
|
||||
_gen/bin/examplegen: $(wildcard scripts/examplegen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/examplegen
|
||||
|
||||
_gen/bin/gensite: $(wildcard scripts/gensite/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/gensite
|
||||
|
||||
_gen/bin/apikeyscopesgen: $(wildcard scripts/apikeyscopesgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/apikeyscopesgen
|
||||
|
||||
_gen/bin/metricsdocgen: $(wildcard scripts/metricsdocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/metricsdocgen
|
||||
|
||||
_gen/bin/metricsdocgen-scanner: $(wildcard scripts/metricsdocgen/scanner/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/metricsdocgen/scanner
|
||||
|
||||
_gen/bin/modeloptionsgen: $(wildcard scripts/modeloptionsgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/modeloptionsgen
|
||||
|
||||
_gen/bin/typegen: $(wildcard scripts/typegen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/typegen
|
||||
|
||||
# Shared temp directory for atomic writes. Lives at the project root
|
||||
# so all targets share the same filesystem, and is gitignored.
|
||||
# Order-only prerequisite: recipes that need it depend on | _gen
|
||||
@@ -201,6 +254,7 @@ endif
|
||||
|
||||
clean:
|
||||
rm -rf build/ site/build/ site/out/
|
||||
rm -rf _gen/bin
|
||||
mkdir -p build/
|
||||
git restore site/out/
|
||||
.PHONY: clean
|
||||
@@ -654,8 +708,8 @@ lint/go:
|
||||
go tool github.com/coder/paralleltestctx/cmd/paralleltestctx -custom-funcs="testutil.Context" ./...
|
||||
.PHONY: lint/go
|
||||
|
||||
lint/examples:
|
||||
go run ./scripts/examplegen/main.go -lint
|
||||
lint/examples: | _gen/bin/examplegen
|
||||
_gen/bin/examplegen -lint
|
||||
.PHONY: lint/examples
|
||||
|
||||
# Use shfmt to determine the shell files, takes editorconfig into consideration.
|
||||
@@ -693,8 +747,8 @@ lint/actions/zizmor:
|
||||
.PHONY: lint/actions/zizmor
|
||||
|
||||
# Verify api_key_scope enum contains all RBAC <resource>:<action> values.
|
||||
lint/check-scopes: coderd/database/dump.sql
|
||||
go run ./scripts/check-scopes
|
||||
lint/check-scopes: coderd/database/dump.sql | _gen/bin/check-scopes
|
||||
_gen/bin/check-scopes
|
||||
.PHONY: lint/check-scopes
|
||||
|
||||
# Verify migrations do not hardcode the public schema.
|
||||
@@ -734,8 +788,8 @@ lint/typos: build/typos-$(TYPOS_VERSION)
|
||||
# The pre-push hook is allowlisted, see scripts/githooks/pre-push.
|
||||
#
|
||||
# pre-commit uses two phases: gen+fmt first, then lint+build. This
|
||||
# avoids races where gen's `go run` creates temporary .go files that
|
||||
# lint's find-based checks pick up. Within each phase, targets run in
|
||||
# avoids races where gen creates temporary .go files that lint's
|
||||
# find-based checks pick up. Within each phase, targets run in
|
||||
# parallel via -j. It fails if any tracked files have unstaged
|
||||
# changes afterward.
|
||||
|
||||
@@ -949,8 +1003,8 @@ gen/mark-fresh:
|
||||
|
||||
# Runs migrations to output a dump of the database schema after migrations are
|
||||
# applied.
|
||||
coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/database/migrations/*.sql)
|
||||
go run ./coderd/database/gen/dump/main.go
|
||||
coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/database/migrations/*.sql) | _gen/bin/dbdump
|
||||
_gen/bin/dbdump
|
||||
touch "$@"
|
||||
|
||||
# Generates Go code for querying the database.
|
||||
@@ -1067,88 +1121,88 @@ enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./enterprise/aibridged/proto/aibridged.proto
|
||||
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') | _gen
|
||||
$(call atomic_write,go run -C ./scripts/apitypings main.go,./scripts/biome_format.sh)
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') | _gen _gen/bin/apitypings
|
||||
$(call atomic_write,_gen/bin/apitypings,./scripts/biome_format.sh)
|
||||
|
||||
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
|
||||
(cd site/ && pnpm run gen:provisioner)
|
||||
touch "$@"
|
||||
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*) | _gen
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*) | _gen _gen/bin/gensite
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && \
|
||||
go run ./scripts/gensite/ -icons "$$tmpfile" && \
|
||||
_gen/bin/gensite -icons "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates) | _gen
|
||||
$(call atomic_write,go run ./scripts/examplegen/main.go)
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates) | _gen _gen/bin/examplegen
|
||||
$(call atomic_write,_gen/bin/examplegen)
|
||||
|
||||
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go | _gen
|
||||
$(call atomic_write,go run ./scripts/typegen/main.go rbac object)
|
||||
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen rbac object)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go because `go run` compiles
|
||||
# coderd/rbac which includes it.
|
||||
# NOTE: depends on object_gen.go because the generator build
|
||||
# compiles coderd/rbac which includes it.
|
||||
coderd/rbac/scopes_constants_gen.go: scripts/typegen/scopenames.gotmpl scripts/typegen/main.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go | _gen
|
||||
coderd/rbac/object_gen.go | _gen _gen/bin/typegen
|
||||
# Write to a temp file first to avoid truncating the package
|
||||
# during build since the generator imports the rbac package.
|
||||
$(call atomic_write,go run ./scripts/typegen/main.go rbac scopenames)
|
||||
$(call atomic_write,_gen/bin/typegen rbac scopenames)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
codersdk/rbacresources_gen.go: scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/typegen
|
||||
# Write to a temp file to avoid truncating the target, which
|
||||
# would break the codersdk package and any parallel build targets.
|
||||
$(call atomic_write,go run scripts/typegen/main.go rbac codersdk)
|
||||
$(call atomic_write,_gen/bin/typegen rbac codersdk)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scopes_catalog.go coderd/rbac/scopes.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/apikeyscopesgen
|
||||
# Generate SDK constants for external API key scopes.
|
||||
$(call atomic_write,go run ./scripts/apikeyscopesgen)
|
||||
$(call atomic_write,_gen/bin/apikeyscopesgen)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
$(call atomic_write,go run scripts/typegen/main.go rbac typescript,./scripts/biome_format.sh)
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen rbac typescript,./scripts/biome_format.sh)
|
||||
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go | _gen
|
||||
$(call atomic_write,go run scripts/typegen/main.go countries,./scripts/biome_format.sh)
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen countries,./scripts/biome_format.sh)
|
||||
|
||||
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go | _gen
|
||||
$(call atomic_write,go run ./scripts/modeloptionsgen/main.go | tail -n +2,./scripts/biome_format.sh)
|
||||
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go | _gen _gen/bin/modeloptionsgen
|
||||
$(call atomic_write,_gen/bin/modeloptionsgen | tail -n +2,./scripts/biome_format.sh)
|
||||
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES) | _gen
|
||||
$(call atomic_write,go run ./scripts/metricsdocgen/scanner)
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES) | _gen _gen/bin/metricsdocgen-scanner
|
||||
$(call atomic_write,_gen/bin/metricsdocgen-scanner)
|
||||
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics | _gen
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics | _gen _gen/bin/metricsdocgen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
go run scripts/metricsdocgen/main.go --prometheus-doc-file="$$tmpfile" && \
|
||||
_gen/bin/metricsdocgen --prometheus-doc-file="$$tmpfile" && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES) | _gen
|
||||
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES) | _gen _gen/bin/clidocgen
|
||||
tmpdir=$$(mktemp -d -p _gen) && \
|
||||
tmpdir=$$(realpath "$$tmpdir") && \
|
||||
mkdir -p "$$tmpdir/docs/reference/cli" && \
|
||||
cp docs/manifest.json "$$tmpdir/docs/manifest.json" && \
|
||||
CI=true DOCS_DIR="$$tmpdir/docs" go run ./scripts/clidocgen && \
|
||||
CI=true DOCS_DIR="$$tmpdir/docs" _gen/bin/clidocgen && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpdir/docs/reference/cli/*.md" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpdir/docs/reference/cli/*.md" && \
|
||||
for f in "$$tmpdir/docs/reference/cli/"*.md; do mv "$$f" "docs/reference/cli/$$(basename "$$f")"; done && \
|
||||
rm -rf "$$tmpdir"
|
||||
|
||||
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go | _gen
|
||||
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go | _gen _gen/bin/auditdocgen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
go run scripts/auditdocgen/main.go --audit-doc-file="$$tmpfile" && \
|
||||
_gen/bin/auditdocgen --audit-doc-file="$$tmpfile" && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
+14
-6
@@ -102,6 +102,8 @@ type Options struct {
|
||||
ReportMetadataInterval time.Duration
|
||||
ServiceBannerRefreshInterval time.Duration
|
||||
BlockFileTransfer bool
|
||||
BlockReversePortForwarding bool
|
||||
BlockLocalPortForwarding bool
|
||||
Execer agentexec.Execer
|
||||
Devcontainers bool
|
||||
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
|
||||
@@ -214,6 +216,8 @@ func New(options Options) Agent {
|
||||
subsystems: options.Subsystems,
|
||||
logSender: agentsdk.NewLogSender(options.Logger),
|
||||
blockFileTransfer: options.BlockFileTransfer,
|
||||
blockReversePortForwarding: options.BlockReversePortForwarding,
|
||||
blockLocalPortForwarding: options.BlockLocalPortForwarding,
|
||||
|
||||
prometheusRegistry: prometheusRegistry,
|
||||
metrics: newAgentMetrics(prometheusRegistry),
|
||||
@@ -280,6 +284,8 @@ type agent struct {
|
||||
sshServer *agentssh.Server
|
||||
sshMaxTimeout time.Duration
|
||||
blockFileTransfer bool
|
||||
blockReversePortForwarding bool
|
||||
blockLocalPortForwarding bool
|
||||
|
||||
lifecycleUpdate chan struct{}
|
||||
lifecycleReported chan codersdk.WorkspaceAgentLifecycle
|
||||
@@ -331,12 +337,14 @@ func (a *agent) TailnetConn() *tailnet.Conn {
|
||||
func (a *agent) init() {
|
||||
// pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown.
|
||||
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, a.execer, &agentssh.Config{
|
||||
MaxTimeout: a.sshMaxTimeout,
|
||||
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
|
||||
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
|
||||
UpdateEnv: a.updateCommandEnv,
|
||||
WorkingDirectory: func() string { return a.manifest.Load().Directory },
|
||||
BlockFileTransfer: a.blockFileTransfer,
|
||||
MaxTimeout: a.sshMaxTimeout,
|
||||
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
|
||||
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
|
||||
UpdateEnv: a.updateCommandEnv,
|
||||
WorkingDirectory: func() string { return a.manifest.Load().Directory },
|
||||
BlockFileTransfer: a.blockFileTransfer,
|
||||
BlockReversePortForwarding: a.blockReversePortForwarding,
|
||||
BlockLocalPortForwarding: a.blockLocalPortForwarding,
|
||||
ReportConnection: func(id uuid.UUID, magicType agentssh.MagicSessionType, ip string) func(code int, reason string) {
|
||||
var connectionType proto.Connection_Type
|
||||
switch magicType {
|
||||
|
||||
@@ -986,6 +986,161 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) {
|
||||
requireEcho(t, conn)
|
||||
}
|
||||
|
||||
func TestAgent_TCPLocalForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
rl, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer rl.Close()
|
||||
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
|
||||
require.True(t, valid)
|
||||
remotePort := tcpAddr.Port
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockLocalPortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
_, err = sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
|
||||
require.ErrorContains(t, err, "administratively prohibited")
|
||||
}
|
||||
|
||||
func TestAgent_TCPRemoteForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockReversePortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
localhost := netip.MustParseAddr("127.0.0.1")
|
||||
randomPort := testutil.RandomPortNoListen(t)
|
||||
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
|
||||
_, err = sshClient.ListenTCP(addr)
|
||||
require.ErrorContains(t, err, "tcpip-forward request denied by peer")
|
||||
}
|
||||
|
||||
func TestAgent_UnixLocalForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix domain sockets are not fully supported on Windows")
|
||||
}
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
tmpdir := testutil.TempDirUnixSocket(t)
|
||||
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
||||
|
||||
l, err := net.Listen("unix", remoteSocketPath)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockLocalPortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
_, err = sshClient.Dial("unix", remoteSocketPath)
|
||||
require.ErrorContains(t, err, "administratively prohibited")
|
||||
}
|
||||
|
||||
func TestAgent_UnixRemoteForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix domain sockets are not fully supported on Windows")
|
||||
}
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
tmpdir := testutil.TempDirUnixSocket(t)
|
||||
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockReversePortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
_, err = sshClient.ListenUnix(remoteSocketPath)
|
||||
require.ErrorContains(t, err, "streamlocal-forward@openssh.com request denied by peer")
|
||||
}
|
||||
|
||||
// TestAgent_LocalBlockedDoesNotAffectReverse verifies that blocking
|
||||
// local port forwarding does not prevent reverse port forwarding from
|
||||
// working. A field-name transposition at any plumbing hop would cause
|
||||
// both directions to be blocked when only one flag is set.
|
||||
func TestAgent_LocalBlockedDoesNotAffectReverse(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockLocalPortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
// Reverse forwarding must still work.
|
||||
localhost := netip.MustParseAddr("127.0.0.1")
|
||||
var ll net.Listener
|
||||
for {
|
||||
randomPort := testutil.RandomPortNoListen(t)
|
||||
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
|
||||
ll, err = sshClient.ListenTCP(addr)
|
||||
if err != nil {
|
||||
t.Logf("error remote forwarding: %s", err.Error())
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out getting random listener")
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
_ = ll.Close()
|
||||
}
|
||||
|
||||
// TestAgent_ReverseBlockedDoesNotAffectLocal verifies that blocking
|
||||
// reverse port forwarding does not prevent local port forwarding from
|
||||
// working.
|
||||
func TestAgent_ReverseBlockedDoesNotAffectLocal(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
rl, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer rl.Close()
|
||||
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
|
||||
require.True(t, valid)
|
||||
remotePort := tcpAddr.Port
|
||||
go echoOnce(t, rl)
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockReversePortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
// Local forwarding must still work.
|
||||
conn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
requireEcho(t, conn)
|
||||
}
|
||||
|
||||
func TestAgent_UnixLocalForwarding(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
|
||||
@@ -134,6 +134,33 @@ func Config(workingDir string) (workspacesdk.ContextConfigResponse, []string) {
|
||||
}, ResolvePaths(mcpConfigFile, workingDir)
|
||||
}
|
||||
|
||||
// ContextPartsFromDir reads instruction files and discovers skills
|
||||
// from a specific directory, using default file names. This is used
|
||||
// by the CLI chat context commands to read context from an arbitrary
|
||||
// directory without consulting agent env vars.
|
||||
func ContextPartsFromDir(dir string) []codersdk.ChatMessagePart {
|
||||
var parts []codersdk.ChatMessagePart
|
||||
|
||||
if entry, found := readInstructionFileFromDir(dir, DefaultInstructionsFile); found {
|
||||
parts = append(parts, entry)
|
||||
}
|
||||
|
||||
// Reuse ResolvePaths so CLI skill discovery follows the same
|
||||
// project-relative path handling as agent config resolution.
|
||||
skillParts := discoverSkills(
|
||||
ResolvePaths(strings.Join([]string{DefaultSkillsDir, "skills"}, ","), dir),
|
||||
DefaultSkillMetaFile,
|
||||
)
|
||||
parts = append(parts, skillParts...)
|
||||
|
||||
// Guarantee non-nil slice.
|
||||
if parts == nil {
|
||||
parts = []codersdk.ChatMessagePart{}
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
// MCPConfigFiles returns the resolved MCP configuration file
|
||||
// paths for the agent's MCP manager.
|
||||
func (api *API) MCPConfigFiles() []string {
|
||||
|
||||
@@ -23,18 +23,144 @@ func filterParts(parts []codersdk.ChatMessagePart, t codersdk.ChatMessagePartTyp
|
||||
return out
|
||||
}
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
t.Run("Defaults", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
func writeSkillMetaFileInRoot(t *testing.T, skillsRoot, name, description string) string {
|
||||
t.Helper()
|
||||
|
||||
// Clear all env vars so defaults are used.
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
skillDir := filepath.Join(skillsRoot, name)
|
||||
require.NoError(t, os.MkdirAll(skillDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(skillDir, "SKILL.md"),
|
||||
[]byte("---\nname: "+name+"\ndescription: "+description+"\n---\nSkill body"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
return skillDir
|
||||
}
|
||||
|
||||
func writeSkillMetaFile(t *testing.T, dir, name, description string) string {
|
||||
t.Helper()
|
||||
return writeSkillMetaFileInRoot(t, filepath.Join(dir, ".agents", "skills"), name, description)
|
||||
}
|
||||
|
||||
func TestContextPartsFromDir(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ReturnsInstructionFilePart", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
instructionPath := filepath.Join(dir, "AGENTS.md")
|
||||
require.NoError(t, os.WriteFile(instructionPath, []byte("project instructions"), 0o600))
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Len(t, contextParts, 1)
|
||||
require.Empty(t, skillParts)
|
||||
require.Equal(t, instructionPath, contextParts[0].ContextFilePath)
|
||||
require.Equal(t, "project instructions", contextParts[0].ContextFileContent)
|
||||
require.False(t, contextParts[0].ContextFileTruncated)
|
||||
})
|
||||
|
||||
t.Run("ReturnsSkillParts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
skillDir := writeSkillMetaFile(t, dir, "my-skill", "A test skill")
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Empty(t, contextParts)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, "my-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
t.Run("ReturnsSkillPartsFromSkillsDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
skillDir := writeSkillMetaFileInRoot(
|
||||
t,
|
||||
filepath.Join(dir, "skills"),
|
||||
"my-skill",
|
||||
"A test skill",
|
||||
)
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Empty(t, contextParts)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, "my-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
t.Run("ReturnsEmptyForEmptyDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(t.TempDir())
|
||||
|
||||
require.NotNil(t, parts)
|
||||
require.Empty(t, parts)
|
||||
})
|
||||
|
||||
t.Run("ReturnsCombinedResults", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
instructionPath := filepath.Join(dir, "AGENTS.md")
|
||||
require.NoError(t, os.WriteFile(instructionPath, []byte("combined instructions"), 0o600))
|
||||
skillDir := writeSkillMetaFile(t, dir, "combined-skill", "Combined test skill")
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 2)
|
||||
require.Len(t, contextParts, 1)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, instructionPath, contextParts[0].ContextFilePath)
|
||||
require.Equal(t, "combined instructions", contextParts[0].ContextFileContent)
|
||||
require.Equal(t, "combined-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
})
|
||||
}
|
||||
|
||||
func setupConfigTestEnv(t *testing.T, overrides map[string]string) string {
|
||||
t.Helper()
|
||||
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
for key, value := range overrides {
|
||||
t.Setenv(key, value)
|
||||
}
|
||||
|
||||
return fakeHome
|
||||
}
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("Defaults", func(t *testing.T) {
|
||||
setupConfigTestEnv(t, nil)
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg, mcpFiles := agentcontextconfig.Config(workDir)
|
||||
@@ -46,20 +172,18 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, []string{filepath.Join(workDir, ".mcp.json")}, mcpFiles)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("CustomEnvVars", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
optInstructions := t.TempDir()
|
||||
optSkills := t.TempDir()
|
||||
optMCP := platformAbsPath("opt", "mcp.json")
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, optInstructions)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "CUSTOM.md")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, optSkills)
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "META.yaml")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: optInstructions,
|
||||
agentcontextconfig.EnvInstructionsFile: "CUSTOM.md",
|
||||
agentcontextconfig.EnvSkillsDirs: optSkills,
|
||||
agentcontextconfig.EnvSkillMetaFile: "META.yaml",
|
||||
agentcontextconfig.EnvMCPConfigFiles: optMCP,
|
||||
})
|
||||
|
||||
// Create files matching the custom names so we can
|
||||
// verify the env vars actually change lookup behavior.
|
||||
@@ -85,15 +209,12 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "META.yaml", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("WhitespaceInFileNames", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsFile: " CLAUDE.md ",
|
||||
})
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, " CLAUDE.md ")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
// Create a file matching the trimmed name.
|
||||
@@ -106,19 +227,13 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "hello", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("CommaSeparatedDirs", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
a := t.TempDir()
|
||||
b := t.TempDir()
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, a+","+b)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: a + "," + b,
|
||||
})
|
||||
|
||||
// Put instruction files in both dirs.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(a, "AGENTS.md"), []byte("from a"), 0o600))
|
||||
@@ -133,17 +248,10 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "from b", ctxFiles[1].ContextFileContent)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("ReadsInstructionFiles", func(t *testing.T) {
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
|
||||
// Create ~/.coder/AGENTS.md
|
||||
coderDir := filepath.Join(fakeHome, ".coder")
|
||||
@@ -164,16 +272,9 @@ func TestConfig(t *testing.T) {
|
||||
require.False(t, ctxFiles[0].ContextFileTruncated)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("ReadsWorkingDirInstructionFile", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
setupConfigTestEnv(t, nil)
|
||||
workDir := t.TempDir()
|
||||
|
||||
// Create AGENTS.md in the working directory.
|
||||
@@ -193,16 +294,9 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, filepath.Join(workDir, "AGENTS.md"), ctxFiles[0].ContextFilePath)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("TruncatesLargeInstructionFile", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
setupConfigTestEnv(t, nil)
|
||||
workDir := t.TempDir()
|
||||
largeContent := strings.Repeat("a", 64*1024+100)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(workDir, "AGENTS.md"), []byte(largeContent), 0o600))
|
||||
@@ -215,79 +309,47 @@ func TestConfig(t *testing.T) {
|
||||
require.Len(t, ctxFiles[0].ContextFileContent, 64*1024)
|
||||
})
|
||||
|
||||
t.Run("SanitizesHTMLComments", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
sanitizationTests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "SanitizesHTMLComments",
|
||||
input: "visible\n<!-- hidden -->content",
|
||||
expected: "visible\ncontent",
|
||||
},
|
||||
{
|
||||
name: "SanitizesInvisibleUnicode",
|
||||
input: "before\u200bafter",
|
||||
expected: "beforeafter",
|
||||
},
|
||||
{
|
||||
name: "NormalizesCRLF",
|
||||
input: "line1\r\nline2\rline3",
|
||||
expected: "line1\nline2\nline3",
|
||||
},
|
||||
}
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
for _, tt := range sanitizationTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setupConfigTestEnv(t, nil)
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte(tt.input),
|
||||
0o600,
|
||||
))
|
||||
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("visible\n<!-- hidden -->content"),
|
||||
0o600,
|
||||
))
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "visible\ncontent", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
t.Run("SanitizesInvisibleUnicode", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
// U+200B (zero-width space) should be stripped.
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("before\u200bafter"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "beforeafter", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
t.Run("NormalizesCRLF", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("line1\r\nline2\rline3"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "line1\nline2\nline3", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, tt.expected, ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("DiscoversSkills", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
@@ -320,17 +382,13 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("SkipsMissingDirs", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
nonExistent := filepath.Join(t.TempDir(), "does-not-exist")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, nonExistent)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, nonExistent)
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: nonExistent,
|
||||
agentcontextconfig.EnvSkillsDirs: nonExistent,
|
||||
})
|
||||
|
||||
workDir := t.TempDir()
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
@@ -340,17 +398,13 @@ func TestConfig(t *testing.T) {
|
||||
require.Empty(t, cfg.Parts)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("MCPConfigFilesResolvedSeparately", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
|
||||
optMCP := platformAbsPath("opt", "custom.json")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
|
||||
fakeHome := setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvMCPConfigFiles: optMCP,
|
||||
})
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
|
||||
workDir := t.TempDir()
|
||||
_, mcpFiles := agentcontextconfig.Config(workDir)
|
||||
@@ -358,14 +412,10 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, []string{optMCP}, mcpFiles)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("SkillNameMustMatchDir", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
skillsDir := filepath.Join(workDir, "skills")
|
||||
@@ -385,14 +435,10 @@ func TestConfig(t *testing.T) {
|
||||
require.Empty(t, skillParts)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("DuplicateSkillsFirstWins", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
skillsDir1 := filepath.Join(workDir, "skills1")
|
||||
|
||||
@@ -117,6 +117,10 @@ type Config struct {
|
||||
X11MaxPort *int
|
||||
// BlockFileTransfer restricts use of file transfer applications.
|
||||
BlockFileTransfer bool
|
||||
// BlockReversePortForwarding disables reverse port forwarding (ssh -R).
|
||||
BlockReversePortForwarding bool
|
||||
// BlockLocalPortForwarding disables local port forwarding (ssh -L).
|
||||
BlockLocalPortForwarding bool
|
||||
// ReportConnection.
|
||||
ReportConnection reportConnectionFunc
|
||||
// Experimental: allow connecting to running containers via Docker exec.
|
||||
@@ -190,7 +194,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
}
|
||||
|
||||
forwardHandler := &ssh.ForwardedTCPHandler{}
|
||||
unixForwardHandler := newForwardedUnixHandler(logger)
|
||||
unixForwardHandler := newForwardedUnixHandler(logger, config.BlockReversePortForwarding)
|
||||
|
||||
metrics := newSSHServerMetrics(prometheusRegistry)
|
||||
s := &Server{
|
||||
@@ -229,8 +233,15 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, s.config.ReportConnection, newChan, &s.connCountJetBrains)
|
||||
ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx)
|
||||
},
|
||||
"direct-streamlocal@openssh.com": directStreamLocalHandler,
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
"direct-streamlocal@openssh.com": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
|
||||
if s.config.BlockLocalPortForwarding {
|
||||
s.logger.Warn(ctx, "unix local port forward blocked")
|
||||
_ = newChan.Reject(gossh.Prohibited, "local port forwarding is disabled")
|
||||
return
|
||||
}
|
||||
directStreamLocalHandler(srv, conn, newChan, ctx)
|
||||
},
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
},
|
||||
ConnectionFailedCallback: func(conn net.Conn, err error) {
|
||||
s.logger.Warn(ctx, "ssh connection failed",
|
||||
@@ -250,6 +261,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
// be set before we start listening.
|
||||
HostSigners: []ssh.Signer{},
|
||||
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
|
||||
if s.config.BlockLocalPortForwarding {
|
||||
s.logger.Warn(ctx, "local port forward blocked",
|
||||
slog.F("destination_host", destinationHost),
|
||||
slog.F("destination_port", destinationPort))
|
||||
return false
|
||||
}
|
||||
// Allow local port forwarding all!
|
||||
s.logger.Debug(ctx, "local port forward",
|
||||
slog.F("destination_host", destinationHost),
|
||||
@@ -260,6 +277,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
return true
|
||||
},
|
||||
ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
|
||||
if s.config.BlockReversePortForwarding {
|
||||
s.logger.Warn(ctx, "reverse port forward blocked",
|
||||
slog.F("bind_host", bindHost),
|
||||
slog.F("bind_port", bindPort))
|
||||
return false
|
||||
}
|
||||
// Allow reverse port forwarding all!
|
||||
s.logger.Debug(ctx, "reverse port forward",
|
||||
slog.F("bind_host", bindHost),
|
||||
|
||||
@@ -35,8 +35,9 @@ type forwardedStreamLocalPayload struct {
|
||||
// streamlocal forwarding (aka. unix forwarding) instead of TCP forwarding.
|
||||
type forwardedUnixHandler struct {
|
||||
sync.Mutex
|
||||
log slog.Logger
|
||||
forwards map[forwardKey]net.Listener
|
||||
log slog.Logger
|
||||
forwards map[forwardKey]net.Listener
|
||||
blockReversePortForwarding bool
|
||||
}
|
||||
|
||||
type forwardKey struct {
|
||||
@@ -44,10 +45,11 @@ type forwardKey struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func newForwardedUnixHandler(log slog.Logger) *forwardedUnixHandler {
|
||||
func newForwardedUnixHandler(log slog.Logger, blockReversePortForwarding bool) *forwardedUnixHandler {
|
||||
return &forwardedUnixHandler{
|
||||
log: log,
|
||||
forwards: make(map[forwardKey]net.Listener),
|
||||
log: log,
|
||||
forwards: make(map[forwardKey]net.Listener),
|
||||
blockReversePortForwarding: blockReversePortForwarding,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,6 +64,10 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
||||
|
||||
switch req.Type {
|
||||
case "streamlocal-forward@openssh.com":
|
||||
if h.blockReversePortForwarding {
|
||||
log.Warn(ctx, "unix reverse port forward blocked")
|
||||
return false, nil
|
||||
}
|
||||
var reqPayload streamLocalForwardPayload
|
||||
err := gossh.Unmarshal(req.Payload, &reqPayload)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -620,6 +622,11 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
defer artifact.Reader.Close()
|
||||
defer func() {
|
||||
if artifact.ThumbnailReader != nil {
|
||||
_ = artifact.ThumbnailReader.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if artifact.Size > workspacesdk.MaxRecordingSize {
|
||||
a.logger.Warn(ctx, "recording file exceeds maximum size",
|
||||
@@ -633,10 +640,60 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "video/mp4")
|
||||
rw.Header().Set("Content-Length", strconv.FormatInt(artifact.Size, 10))
|
||||
// Discard the thumbnail if it exceeds the maximum size.
|
||||
// The server-side consumer also enforces this per-part, but
|
||||
// rejecting it here avoids streaming a large thumbnail over
|
||||
// the wire for nothing.
|
||||
if artifact.ThumbnailReader != nil && artifact.ThumbnailSize > workspacesdk.MaxThumbnailSize {
|
||||
a.logger.Warn(ctx, "thumbnail file exceeds maximum size, omitting",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("size", artifact.ThumbnailSize),
|
||||
slog.F("max_size", workspacesdk.MaxThumbnailSize),
|
||||
)
|
||||
_ = artifact.ThumbnailReader.Close()
|
||||
artifact.ThumbnailReader = nil
|
||||
artifact.ThumbnailSize = 0
|
||||
}
|
||||
|
||||
// The multipart response is best-effort: once WriteHeader(200) is
|
||||
// called, CreatePart failures produce a truncated response without
|
||||
// the closing boundary. The server-side consumer handles this
|
||||
// gracefully, preserving any parts read before the error.
|
||||
mw := multipart.NewWriter(rw)
|
||||
defer mw.Close()
|
||||
rw.Header().Set("Content-Type", "multipart/mixed; boundary="+mw.Boundary())
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = io.Copy(rw, artifact.Reader)
|
||||
|
||||
// Part 1: video/mp4 (always present).
|
||||
videoPart, err := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"video/mp4"},
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "failed to create video multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
if _, err := io.Copy(videoPart, artifact.Reader); err != nil {
|
||||
a.logger.Warn(ctx, "failed to write video multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Part 2: image/jpeg (present only when thumbnail was extracted).
|
||||
if artifact.ThumbnailReader != nil {
|
||||
thumbPart, err := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"image/jpeg"},
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "failed to create thumbnail multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
_, _ = io.Copy(thumbPart, artifact.ThumbnailReader)
|
||||
}
|
||||
}
|
||||
|
||||
// coordFromAction extracts the coordinate pair from a DesktopAction,
|
||||
|
||||
@@ -4,12 +4,17 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -59,6 +64,8 @@ type fakeDesktop struct {
|
||||
lastKeyDown string
|
||||
lastKeyUp string
|
||||
|
||||
thumbnailData []byte // if set, StopRecording includes a thumbnail
|
||||
|
||||
// Recording tracking (guarded by recMu).
|
||||
recMu sync.Mutex
|
||||
recordings map[string]string // ID → file path
|
||||
@@ -187,10 +194,15 @@ func (f *fakeDesktop) StopRecording(_ context.Context, recordingID string) (*age
|
||||
_ = file.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &agentdesktop.RecordingArtifact{
|
||||
artifact := &agentdesktop.RecordingArtifact{
|
||||
Reader: file,
|
||||
Size: info.Size(),
|
||||
}, nil
|
||||
}
|
||||
if f.thumbnailData != nil {
|
||||
artifact.ThumbnailReader = io.NopCloser(bytes.NewReader(f.thumbnailData))
|
||||
artifact.ThumbnailSize = int64(len(f.thumbnailData))
|
||||
}
|
||||
return artifact, nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) RecordActivity() {
|
||||
@@ -785,8 +797,8 @@ func TestRecordingStartStop(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), rr.Body.Bytes())
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), parts["video/mp4"])
|
||||
}
|
||||
|
||||
func TestRecordingStartFails(t *testing.T) {
|
||||
@@ -847,8 +859,8 @@ func TestRecordingStartIdempotent(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), rr.Body.Bytes())
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), parts["video/mp4"])
|
||||
}
|
||||
|
||||
func TestRecordingStopIdempotent(t *testing.T) {
|
||||
@@ -872,7 +884,7 @@ func TestRecordingStopIdempotent(t *testing.T) {
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop twice - both should succeed with identical data.
|
||||
var bodies [2][]byte
|
||||
var videoParts [2][]byte
|
||||
for i := range 2 {
|
||||
body, err := json.Marshal(map[string]string{"recording_id": testRecIDStopIdempotent})
|
||||
require.NoError(t, err)
|
||||
@@ -880,10 +892,10 @@ func TestRecordingStopIdempotent(t *testing.T) {
|
||||
request := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(recorder, request)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "video/mp4", recorder.Header().Get("Content-Type"))
|
||||
bodies[i] = recorder.Body.Bytes()
|
||||
parts := parseMultipartParts(t, recorder.Header().Get("Content-Type"), recorder.Body.Bytes())
|
||||
videoParts[i] = parts["video/mp4"]
|
||||
}
|
||||
assert.Equal(t, bodies[0], bodies[1])
|
||||
assert.Equal(t, videoParts[0], videoParts[1])
|
||||
}
|
||||
|
||||
func TestRecordingStopInvalidIDFormat(t *testing.T) {
|
||||
@@ -1004,8 +1016,8 @@ func TestRecordingMultipleSimultaneous(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, expected[id], rr.Body.Bytes())
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, expected[id], parts["video/mp4"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1112,8 +1124,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
firstData := rr.Body.Bytes()
|
||||
firstParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
firstData := firstParts["video/mp4"]
|
||||
require.NotEmpty(t, firstData)
|
||||
|
||||
// Step 3: Start again with the same ID - should succeed
|
||||
@@ -1128,8 +1140,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
secondData := rr.Body.Bytes()
|
||||
secondParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
secondData := secondParts["video/mp4"]
|
||||
require.NotEmpty(t, secondData)
|
||||
|
||||
// The two recordings should have different data because the
|
||||
@@ -1235,3 +1247,166 @@ func TestRecordingStopCorrupted(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Recording is corrupted.", respStop.Message)
|
||||
}
|
||||
|
||||
// parseMultipartParts parses a multipart/mixed response and returns
|
||||
// a map from Content-Type to body bytes.
|
||||
func parseMultipartParts(t *testing.T, contentType string, body []byte) map[string][]byte {
|
||||
t.Helper()
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
require.NoError(t, err, "parse Content-Type")
|
||||
boundary := params["boundary"]
|
||||
require.NotEmpty(t, boundary, "missing boundary")
|
||||
mr := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
parts := make(map[string][]byte)
|
||||
for {
|
||||
part, err := mr.NextPart()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
require.NoError(t, err, "unexpected multipart parse error")
|
||||
ct := part.Header.Get("Content-Type")
|
||||
data, readErr := io.ReadAll(part)
|
||||
require.NoError(t, readErr)
|
||||
parts[ct] = data
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_WithThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// Create a fake JPEG header: 0xFF 0xD8 0xFF followed by 509 zero bytes.
|
||||
thumbnail := make([]byte, 512)
|
||||
thumbnail[0] = 0xff
|
||||
thumbnail[1] = 0xd8
|
||||
thumbnail[2] = 0xff
|
||||
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
thumbnailData: thumbnail,
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 2, "expected exactly 2 parts (video + thumbnail)")
|
||||
|
||||
// The fake writes "fake-mp4-data-<id>-<counter>" as the MP4 content.
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
assert.Equal(t, thumbnail, parts["image/jpeg"])
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_NoThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 1, "expected exactly 1 part (video only)")
|
||||
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_OversizedThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// Create thumbnail data that exceeds MaxThumbnailSize.
|
||||
oversizedThumb := make([]byte, workspacesdk.MaxThumbnailSize+1)
|
||||
oversizedThumb[0] = 0xff
|
||||
oversizedThumb[1] = 0xd8
|
||||
oversizedThumb[2] = 0xff
|
||||
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
thumbnailData: oversizedThumb,
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response contains only the video part.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 1, "expected exactly 1 part (video only, oversized thumbnail discarded)")
|
||||
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
}
|
||||
|
||||
@@ -105,6 +105,11 @@ type RecordingArtifact struct {
|
||||
Reader io.ReadCloser
|
||||
// Size is the byte length of the MP4 content.
|
||||
Size int64
|
||||
// ThumbnailReader is the JPEG thumbnail. May be nil if no
|
||||
// thumbnail was produced. Callers must close it when done.
|
||||
ThumbnailReader io.ReadCloser
|
||||
// ThumbnailSize is the byte length of the thumbnail.
|
||||
ThumbnailSize int64
|
||||
}
|
||||
|
||||
// DisplayConfig describes a running desktop session.
|
||||
|
||||
@@ -56,6 +56,7 @@ type screenshotOutput struct {
|
||||
type recordingProcess struct {
|
||||
cmd *exec.Cmd
|
||||
filePath string
|
||||
thumbPath string
|
||||
stopped bool
|
||||
killed bool // true when the process was SIGKILLed
|
||||
done chan struct{} // closed when cmd.Wait() returns
|
||||
@@ -383,13 +384,20 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
}
|
||||
// Completed recording - discard old file, start fresh.
|
||||
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove old recording file",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("file_path", rec.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove old thumbnail file",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, recordingID)
|
||||
}
|
||||
|
||||
@@ -406,6 +414,7 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
|
||||
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".mp4")
|
||||
thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".thumb.jpg")
|
||||
|
||||
// Use a background context so the process outlives the HTTP
|
||||
// request that triggered it.
|
||||
@@ -419,6 +428,7 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
"--idle-speedup", "20",
|
||||
"--idle-min-duration", "0.35",
|
||||
"--idle-noise-tolerance", "-38dB",
|
||||
"--thumbnail", thumbPath,
|
||||
filePath)
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
@@ -427,9 +437,10 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
|
||||
rec := &recordingProcess{
|
||||
cmd: cmd,
|
||||
filePath: filePath,
|
||||
done: make(chan struct{}),
|
||||
cmd: cmd,
|
||||
filePath: filePath,
|
||||
thumbPath: thumbPath,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go func() {
|
||||
rec.waitErr = cmd.Wait()
|
||||
@@ -499,10 +510,35 @@ func (p *portableDesktop) StopRecording(ctx context.Context, recordingID string)
|
||||
_ = f.Close()
|
||||
return nil, xerrors.Errorf("stat recording artifact: %w", err)
|
||||
}
|
||||
return &RecordingArtifact{
|
||||
artifact := &RecordingArtifact{
|
||||
Reader: f,
|
||||
Size: info.Size(),
|
||||
}, nil
|
||||
}
|
||||
// Attach thumbnail if the subprocess wrote one.
|
||||
thumbFile, err := os.Open(rec.thumbPath)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "thumbnail not available",
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err))
|
||||
return artifact, nil
|
||||
}
|
||||
thumbInfo, err := thumbFile.Stat()
|
||||
if err != nil {
|
||||
_ = thumbFile.Close()
|
||||
p.logger.Warn(ctx, "thumbnail stat failed",
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err))
|
||||
return artifact, nil
|
||||
}
|
||||
if thumbInfo.Size() == 0 {
|
||||
_ = thumbFile.Close()
|
||||
p.logger.Warn(ctx, "thumbnail file is empty",
|
||||
slog.F("thumbnail_path", rec.thumbPath))
|
||||
return artifact, nil
|
||||
}
|
||||
artifact.ThumbnailReader = thumbFile
|
||||
artifact.ThumbnailSize = thumbInfo.Size()
|
||||
return artifact, nil
|
||||
}
|
||||
|
||||
// lockedStopRecordingProcess stops a single recording via stopOnce.
|
||||
@@ -571,18 +607,33 @@ func (p *portableDesktop) lockedCleanStaleRecordings(ctx context.Context) {
|
||||
}
|
||||
info, err := os.Stat(rec.filePath)
|
||||
if err != nil {
|
||||
// File already removed or inaccessible; drop entry.
|
||||
// File already removed or inaccessible; clean up
|
||||
// any leftover thumbnail and drop the entry.
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove stale thumbnail file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, id)
|
||||
continue
|
||||
}
|
||||
if p.clock.Since(info.ModTime()) > time.Hour {
|
||||
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove stale recording file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("file_path", rec.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove stale thumbnail file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, id)
|
||||
}
|
||||
}
|
||||
@@ -603,13 +654,14 @@ func (p *portableDesktop) Close() error {
|
||||
// Snapshot recording file paths and idle goroutine channels
|
||||
// for cleanup, then clear the map.
|
||||
type recEntry struct {
|
||||
id string
|
||||
filePath string
|
||||
idleDone chan struct{}
|
||||
id string
|
||||
filePath string
|
||||
thumbPath string
|
||||
idleDone chan struct{}
|
||||
}
|
||||
var allRecs []recEntry
|
||||
for id, rec := range p.recordings {
|
||||
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, idleDone: rec.idleDone})
|
||||
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, thumbPath: rec.thumbPath, idleDone: rec.idleDone})
|
||||
delete(p.recordings, id)
|
||||
}
|
||||
session := p.session
|
||||
@@ -630,13 +682,20 @@ func (p *portableDesktop) Close() error {
|
||||
go func() {
|
||||
defer close(cleanupDone)
|
||||
for _, entry := range allRecs {
|
||||
if err := os.Remove(entry.filePath); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(entry.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(context.Background(), "failed to remove recording file on close",
|
||||
slog.F("recording_id", entry.id),
|
||||
slog.F("file_path", entry.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(entry.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(context.Background(), "failed to remove thumbnail file on close",
|
||||
slog.F("recording_id", entry.id),
|
||||
slog.F("thumbnail_path", entry.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
if session != nil {
|
||||
session.cancel()
|
||||
|
||||
@@ -2,6 +2,7 @@ package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -584,6 +585,7 @@ func TestPortableDesktop_StartRecording(t *testing.T) {
|
||||
joined := strings.Join(cmd, " ")
|
||||
if strings.Contains(joined, "record") && strings.Contains(joined, "coder-recording-"+recID) {
|
||||
found = true
|
||||
assert.Contains(t, joined, "--thumbnail", "record command should include --thumbnail flag")
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -666,6 +668,66 @@ func TestPortableDesktop_StopRecording_ReturnsArtifact(t *testing.T) {
|
||||
defer artifact.Reader.Close()
|
||||
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
|
||||
|
||||
// No thumbnail file exists, so ThumbnailReader should be nil.
|
||||
assert.Nil(t, artifact.ThumbnailReader, "ThumbnailReader should be nil when no thumbnail file exists")
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_StopRecording_WithThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewReal()
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
recID := uuid.New().String()
|
||||
err := pd.StartRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write a dummy MP4 file at the expected path.
|
||||
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".mp4")
|
||||
require.NoError(t, os.WriteFile(filePath, []byte("fake-mp4-data"), 0o600))
|
||||
t.Cleanup(func() { _ = os.Remove(filePath) })
|
||||
|
||||
// Write a thumbnail file at the expected path.
|
||||
thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".thumb.jpg")
|
||||
thumbContent := []byte("fake-jpeg-thumbnail")
|
||||
require.NoError(t, os.WriteFile(thumbPath, thumbContent, 0o600))
|
||||
t.Cleanup(func() { _ = os.Remove(thumbPath) })
|
||||
|
||||
artifact, err := pd.StopRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
defer artifact.Reader.Close()
|
||||
|
||||
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
|
||||
|
||||
// Thumbnail should be attached.
|
||||
require.NotNil(t, artifact.ThumbnailReader, "ThumbnailReader should be non-nil when thumbnail file exists")
|
||||
defer artifact.ThumbnailReader.Close()
|
||||
assert.Equal(t, int64(len(thumbContent)), artifact.ThumbnailSize)
|
||||
|
||||
// Read and verify thumbnail content.
|
||||
thumbData, err := io.ReadAll(artifact.ThumbnailReader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, thumbContent, thumbData)
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
|
||||
@@ -187,7 +187,11 @@ func (*Manager) connectServer(ctx context.Context, cfg ServerConfig) (*client.Cl
|
||||
connectCtx, cancel := context.WithTimeout(ctx, connectTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := c.Start(connectCtx); err != nil {
|
||||
// Use the parent ctx (not connectCtx) so the subprocess outlives
|
||||
// the connect/initialize handshake. connectCtx bounds only the
|
||||
// Initialize call below. The subprocess is cleaned up when the
|
||||
// Manager is closed or ctx is canceled.
|
||||
if err := c.Start(ctx); err != nil {
|
||||
_ = c.Close()
|
||||
return nil, xerrors.Errorf("start %q: %w", cfg.Name, err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package agentmcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
@@ -8,6 +13,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestSplitToolName(t *testing.T) {
|
||||
@@ -193,3 +199,118 @@ func TestConvertResult(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectServer_StdioProcessSurvivesConnect verifies that a stdio MCP
|
||||
// server subprocess remains alive after connectServer returns. This is a
|
||||
// regression test for a bug where the subprocess was tied to a short-lived
|
||||
// connectCtx and killed as soon as the context was canceled.
|
||||
func TestConnectServer_StdioProcessSurvivesConnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" {
|
||||
// Child process: act as a minimal MCP server over stdio.
|
||||
runFakeMCPServer()
|
||||
return
|
||||
}
|
||||
|
||||
// Get the path to the test binary so we can re-exec ourselves
|
||||
// as a fake MCP server subprocess.
|
||||
testBin, err := os.Executable()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := ServerConfig{
|
||||
Name: "fake",
|
||||
Transport: "stdio",
|
||||
Command: testBin,
|
||||
Args: []string{"-test.run=^TestConnectServer_StdioProcessSurvivesConnect$"},
|
||||
Env: map[string]string{"TEST_MCP_FAKE_SERVER": "1"},
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
m := &Manager{}
|
||||
client, err := m.connectServer(ctx, cfg)
|
||||
require.NoError(t, err, "connectServer should succeed")
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
// At this point connectServer has returned and its internal
|
||||
// connectCtx has been canceled. The subprocess must still be
|
||||
// alive. Verify by listing tools (requires a live server).
|
||||
listCtx, listCancel := context.WithTimeout(ctx, testutil.WaitShort)
|
||||
defer listCancel()
|
||||
result, err := client.ListTools(listCtx, mcp.ListToolsRequest{})
|
||||
require.NoError(t, err, "ListTools should succeed — server must be alive after connect")
|
||||
require.Len(t, result.Tools, 1)
|
||||
assert.Equal(t, "echo", result.Tools[0].Name)
|
||||
}
|
||||
|
||||
// runFakeMCPServer implements a minimal JSON-RPC / MCP server over
|
||||
// stdin/stdout, just enough for initialize + tools/list.
|
||||
func runFakeMCPServer() {
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
|
||||
var req struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.RawMessage `json:"id"`
|
||||
Method string `json:"method"`
|
||||
}
|
||||
if err := json.Unmarshal(line, &req); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var resp any
|
||||
switch req.Method {
|
||||
case "initialize":
|
||||
resp = map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": req.ID,
|
||||
"result": map[string]any{
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": map[string]any{
|
||||
"tools": map[string]any{},
|
||||
},
|
||||
"serverInfo": map[string]any{
|
||||
"name": "fake-server",
|
||||
"version": "0.0.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
case "notifications/initialized":
|
||||
// No response needed for notifications.
|
||||
continue
|
||||
case "tools/list":
|
||||
resp = map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": req.ID,
|
||||
"result": map[string]any{
|
||||
"tools": []map[string]any{
|
||||
{
|
||||
"name": "echo",
|
||||
"description": "echoes input",
|
||||
"inputSchema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
default:
|
||||
resp = map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": req.ID,
|
||||
"error": map[string]any{
|
||||
"code": -32601,
|
||||
"message": "method not found",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
out, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
_, _ = fmt.Fprintf(os.Stdout, "%s\n", out)
|
||||
}
|
||||
}
|
||||
|
||||
+28
-19
@@ -3,11 +3,13 @@
|
||||
"enabled": true,
|
||||
"clientKind": "git",
|
||||
"useIgnoreFile": true,
|
||||
"defaultBranch": "main",
|
||||
"defaultBranch": "main"
|
||||
},
|
||||
"files": {
|
||||
"includes": ["**", "!**/pnpm-lock.yaml"],
|
||||
"ignoreUnknown": true,
|
||||
// static/*.html are Go templates with {{ }} directives that
|
||||
// Biome's HTML parser does not support.
|
||||
"includes": ["**", "!**/pnpm-lock.yaml", "!**/static/*.html"],
|
||||
"ignoreUnknown": true
|
||||
},
|
||||
"linter": {
|
||||
"rules": {
|
||||
@@ -15,7 +17,7 @@
|
||||
"noSvgWithoutTitle": "off",
|
||||
"useButtonType": "off",
|
||||
"useSemanticElements": "off",
|
||||
"noStaticElementInteractions": "off",
|
||||
"noStaticElementInteractions": "off"
|
||||
},
|
||||
"correctness": {
|
||||
"noUnusedImports": "warn",
|
||||
@@ -24,9 +26,9 @@
|
||||
"noUnusedVariables": {
|
||||
"level": "warn",
|
||||
"options": {
|
||||
"ignoreRestSiblings": true,
|
||||
},
|
||||
},
|
||||
"ignoreRestSiblings": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"style": {
|
||||
"noNonNullAssertion": "off",
|
||||
@@ -47,7 +49,7 @@
|
||||
"paths": {
|
||||
"react": {
|
||||
"message": "React 19 no longer requires forwardRef. Use ref as a prop instead.",
|
||||
"importNames": ["forwardRef"],
|
||||
"importNames": ["forwardRef"]
|
||||
},
|
||||
// "@mui/material/Alert": "Use components/Alert/Alert instead.",
|
||||
// "@mui/material/AlertTitle": "Use components/Alert/Alert instead.",
|
||||
@@ -115,10 +117,10 @@
|
||||
"@emotion/styled": "Use Tailwind CSS instead.",
|
||||
// "@emotion/cache": "Use Tailwind CSS instead.",
|
||||
// "components/Stack/Stack": "Use Tailwind flex utilities instead (e.g., <div className='flex flex-col gap-4'>).",
|
||||
"lodash": "Use lodash/<name> instead.",
|
||||
},
|
||||
},
|
||||
},
|
||||
"lodash": "Use lodash/<name> instead."
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"suspicious": {
|
||||
"noArrayIndexKey": "off",
|
||||
@@ -129,14 +131,21 @@
|
||||
"noConsole": {
|
||||
"level": "error",
|
||||
"options": {
|
||||
"allow": ["error", "info", "warn"],
|
||||
},
|
||||
},
|
||||
"allow": ["error", "info", "warn"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"complexity": {
|
||||
"noImportantStyles": "off", // TODO: check and fix !important styles
|
||||
},
|
||||
},
|
||||
"noImportantStyles": "off" // TODO: check and fix !important styles
|
||||
}
|
||||
}
|
||||
},
|
||||
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json",
|
||||
"css": {
|
||||
"parser": {
|
||||
// Biome 2.3+ requires opt-in for @apply and other
|
||||
// Tailwind directives.
|
||||
"tailwindDirectives": true
|
||||
}
|
||||
},
|
||||
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json"
|
||||
}
|
||||
|
||||
@@ -87,6 +87,12 @@ func IsDevVersion(v string) bool {
|
||||
return strings.Contains(v, "-"+develPreRelease)
|
||||
}
|
||||
|
||||
// IsRCVersion returns true if the version has a release candidate
|
||||
// pre-release tag, e.g. "v2.31.0-rc.0".
|
||||
func IsRCVersion(v string) bool {
|
||||
return strings.Contains(v, "-rc.")
|
||||
}
|
||||
|
||||
// IsDev returns true if this is a development build.
|
||||
// CI builds are also considered development builds.
|
||||
func IsDev() bool {
|
||||
|
||||
@@ -102,3 +102,29 @@ func TestBuildInfo(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsRCVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
expected bool
|
||||
}{
|
||||
{"RC0", "v2.31.0-rc.0", true},
|
||||
{"RC1WithBuild", "v2.31.0-rc.1+abc123", true},
|
||||
{"RC10", "v2.31.0-rc.10", true},
|
||||
{"RCDevel", "v2.33.0-rc.1-devel+727ec00f7", true},
|
||||
{"DevelVersion", "v2.31.0-devel+abc123", false},
|
||||
{"StableVersion", "v2.31.0", false},
|
||||
{"DevNoVersion", "v0.0.0-devel+abc123", false},
|
||||
{"BetaVersion", "v2.31.0-beta.1", false},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, c.expected, buildinfo.IsRCVersion(c.version))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+22
-4
@@ -53,6 +53,8 @@ func workspaceAgent() *serpent.Command {
|
||||
slogJSONPath string
|
||||
slogStackdriverPath string
|
||||
blockFileTransfer bool
|
||||
blockReversePortForwarding bool
|
||||
blockLocalPortForwarding bool
|
||||
agentHeaderCommand string
|
||||
agentHeader []string
|
||||
devcontainers bool
|
||||
@@ -319,10 +321,12 @@ func workspaceAgent() *serpent.Command {
|
||||
SSHMaxTimeout: sshMaxTimeout,
|
||||
Subsystems: subsystems,
|
||||
|
||||
PrometheusRegistry: prometheusRegistry,
|
||||
BlockFileTransfer: blockFileTransfer,
|
||||
Execer: execer,
|
||||
Devcontainers: devcontainers,
|
||||
PrometheusRegistry: prometheusRegistry,
|
||||
BlockFileTransfer: blockFileTransfer,
|
||||
BlockReversePortForwarding: blockReversePortForwarding,
|
||||
BlockLocalPortForwarding: blockLocalPortForwarding,
|
||||
Execer: execer,
|
||||
Devcontainers: devcontainers,
|
||||
DevcontainerAPIOptions: []agentcontainers.Option{
|
||||
agentcontainers.WithSubAgentURL(agentAuth.agentURL.String()),
|
||||
agentcontainers.WithProjectDiscovery(devcontainerProjectDiscovery),
|
||||
@@ -493,6 +497,20 @@ func workspaceAgent() *serpent.Command {
|
||||
Description: fmt.Sprintf("Block file transfer using known applications: %s.", strings.Join(agentssh.BlockedFileTransferCommands, ",")),
|
||||
Value: serpent.BoolOf(&blockFileTransfer),
|
||||
},
|
||||
{
|
||||
Flag: "block-reverse-port-forwarding",
|
||||
Default: "false",
|
||||
Env: "CODER_AGENT_BLOCK_REVERSE_PORT_FORWARDING",
|
||||
Description: "Block reverse port forwarding through the SSH server (ssh -R).",
|
||||
Value: serpent.BoolOf(&blockReversePortForwarding),
|
||||
},
|
||||
{
|
||||
Flag: "block-local-port-forwarding",
|
||||
Default: "false",
|
||||
Env: "CODER_AGENT_BLOCK_LOCAL_PORT_FORWARDING",
|
||||
Description: "Block local port forwarding through the SSH server (ssh -L).",
|
||||
Value: serpent.BoolOf(&blockLocalPortForwarding),
|
||||
},
|
||||
{
|
||||
Flag: "devcontainers-enable",
|
||||
Default: "true",
|
||||
|
||||
+194
@@ -0,0 +1,194 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontextconfig"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) chatCommand() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "chat",
|
||||
Short: "Manage agent chats",
|
||||
Long: "Commands for interacting with chats from within a workspace.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.chatContextCommand(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RootCmd) chatContextCommand() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "context",
|
||||
Short: "Manage chat context",
|
||||
Long: "Add or clear context files and skills for an active chat session.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.chatContextAddCommand(),
|
||||
r.chatContextClearCommand(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (*RootCmd) chatContextAddCommand() *serpent.Command {
|
||||
var (
|
||||
dir string
|
||||
chatID string
|
||||
)
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
Use: "add",
|
||||
Short: "Add context to an active chat",
|
||||
Long: "Read instruction files and discover skills from a directory, then add " +
|
||||
"them as context to an active chat session. Multiple calls " +
|
||||
"are additive.",
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
|
||||
if dir == "" && inv.Environ.Get("CODER") != "true" {
|
||||
return xerrors.New("this command must be run inside a Coder workspace (set --dir to override)")
|
||||
}
|
||||
|
||||
client, err := agentAuth.CreateClient()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
|
||||
resolvedDir := dir
|
||||
if resolvedDir == "" {
|
||||
resolvedDir, err = os.Getwd()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get working directory: %w", err)
|
||||
}
|
||||
}
|
||||
resolvedDir, err = filepath.Abs(resolvedDir)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve directory: %w", err)
|
||||
}
|
||||
info, err := os.Stat(resolvedDir)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("cannot read directory %q: %w", resolvedDir, err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return xerrors.Errorf("%q is not a directory", resolvedDir)
|
||||
}
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(resolvedDir)
|
||||
if len(parts) == 0 {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "No context files or skills found in "+resolvedDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve chat ID from flag or auto-detect.
|
||||
resolvedChatID, err := parseChatID(chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.AddChatContext(ctx, agentsdk.AddChatContextRequest{
|
||||
ChatID: resolvedChatID,
|
||||
Parts: parts,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("add chat context: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Added %d context part(s) to chat %s\n", resp.Count, resp.ChatID)
|
||||
return nil
|
||||
},
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
Name: "Directory",
|
||||
Flag: "dir",
|
||||
Description: "Directory to read context files and skills from. Defaults to the current working directory.",
|
||||
Value: serpent.StringOf(&dir),
|
||||
},
|
||||
{
|
||||
Name: "Chat ID",
|
||||
Flag: "chat",
|
||||
Env: "CODER_CHAT_ID",
|
||||
Description: "Chat ID to add context to. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.",
|
||||
Value: serpent.StringOf(&chatID),
|
||||
},
|
||||
},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (*RootCmd) chatContextClearCommand() *serpent.Command {
|
||||
var chatID string
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
Use: "clear",
|
||||
Short: "Clear context from an active chat",
|
||||
Long: "Soft-delete all context-file and skill messages from an active chat. " +
|
||||
"The next turn will re-fetch default context from the agent.",
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
|
||||
client, err := agentAuth.CreateClient()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
|
||||
resolvedChatID, err := parseChatID(chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{
|
||||
ChatID: resolvedChatID,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("clear chat context: %w", err)
|
||||
}
|
||||
|
||||
if resp.ChatID == uuid.Nil {
|
||||
_, _ = fmt.Fprintln(inv.Stdout, "No active chats to clear.")
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Cleared context from chat %s\n", resp.ChatID)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Options: serpent.OptionSet{{
|
||||
Name: "Chat ID",
|
||||
Flag: "chat",
|
||||
Env: "CODER_CHAT_ID",
|
||||
Description: "Chat ID to clear context from. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.",
|
||||
Value: serpent.StringOf(&chatID),
|
||||
}},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// parseChatID returns the chat UUID from the flag value (which
|
||||
// serpent already populates from --chat or CODER_CHAT_ID). Returns
|
||||
// uuid.Nil if empty (the server will auto-detect).
|
||||
func parseChatID(flagValue string) (uuid.UUID, error) {
|
||||
if flagValue == "" {
|
||||
return uuid.Nil, nil
|
||||
}
|
||||
parsed, err := uuid.Parse(flagValue)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("invalid chat ID %q: %w", flagValue, err)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
)
|
||||
|
||||
func TestExpChatContextAdd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("RequiresWorkspaceOrDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add")
|
||||
|
||||
err := inv.Run()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
})
|
||||
|
||||
t.Run("AllowsExplicitDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add", "--dir", t.TempDir())
|
||||
|
||||
err := inv.Run()
|
||||
if err != nil {
|
||||
require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AllowsWorkspaceEnv", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add")
|
||||
inv.Environ.Set("CODER", "true")
|
||||
|
||||
err := inv.Run()
|
||||
if err != nil {
|
||||
require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
}
|
||||
})
|
||||
}
|
||||
+29
-5
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -148,6 +149,7 @@ func (r *RootCmd) AGPLExperimental() []*serpent.Command {
|
||||
return []*serpent.Command{
|
||||
r.scaletestCmd(),
|
||||
r.errorExample(),
|
||||
r.chatCommand(),
|
||||
r.mcpCommand(),
|
||||
r.promptExample(),
|
||||
r.rptyCommand(),
|
||||
@@ -710,7 +712,7 @@ func (r *RootCmd) createHTTPClient(ctx context.Context, serverURL *url.URL, inv
|
||||
transport = wrapTransportWithTelemetryHeader(transport, inv)
|
||||
transport = wrapTransportWithUserAgentHeader(transport, inv)
|
||||
if !r.noVersionCheck {
|
||||
transport = wrapTransportWithVersionMismatchCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
||||
transport = wrapTransportWithVersionCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
||||
// Create a new client without any wrapped transport
|
||||
// otherwise it creates an infinite loop!
|
||||
basicClient := codersdk.New(serverURL)
|
||||
@@ -1434,6 +1436,21 @@ func defaultUpgradeMessage(version string) string {
|
||||
return fmt.Sprintf("download the server version with: 'curl -L https://coder.com/install.sh | sh -s -- --version %s'", version)
|
||||
}
|
||||
|
||||
// serverVersionMessage returns a warning message if the server version
|
||||
// is a release candidate or development build. Returns empty string
|
||||
// for stable versions. RC is checked before devel because RC dev
|
||||
// builds (e.g. v2.33.0-rc.1-devel+hash) contain both tags.
|
||||
func serverVersionMessage(serverVersion string) string {
|
||||
switch {
|
||||
case buildinfo.IsRCVersion(serverVersion):
|
||||
return fmt.Sprintf("the server is running a release candidate of Coder (%s)", serverVersion)
|
||||
case buildinfo.IsDevVersion(serverVersion):
|
||||
return fmt.Sprintf("the server is running a development version of Coder (%s)", serverVersion)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// wrapTransportWithEntitlementsCheck adds a middleware to the HTTP transport
|
||||
// that checks for entitlement warnings and prints them to the user.
|
||||
func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.RoundTripper {
|
||||
@@ -1452,10 +1469,10 @@ func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.
|
||||
})
|
||||
}
|
||||
|
||||
// wrapTransportWithVersionMismatchCheck adds a middleware to the HTTP transport
|
||||
// that checks for version mismatches between the client and server. If a mismatch
|
||||
// is detected, a warning is printed to the user.
|
||||
func wrapTransportWithVersionMismatchCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper {
|
||||
// wrapTransportWithVersionCheck adds a middleware to the HTTP transport
|
||||
// that checks the server version and warns about development builds,
|
||||
// release candidates, and client/server version mismatches.
|
||||
func wrapTransportWithVersionCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper {
|
||||
var once sync.Once
|
||||
return roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
res, err := rt.RoundTrip(req)
|
||||
@@ -1467,9 +1484,16 @@ func wrapTransportWithVersionMismatchCheck(rt http.RoundTripper, inv *serpent.In
|
||||
if serverVersion == "" {
|
||||
return
|
||||
}
|
||||
// Warn about non-stable server versions. Skip
|
||||
// during tests to avoid polluting golden files.
|
||||
if msg := serverVersionMessage(serverVersion); msg != "" && flag.Lookup("test.v") == nil {
|
||||
warning := pretty.Sprint(cliui.DefaultStyles.Warn, msg)
|
||||
_, _ = fmt.Fprintln(inv.Stderr, warning)
|
||||
}
|
||||
if buildinfo.VersionsMatch(clientVersion, serverVersion) {
|
||||
return
|
||||
}
|
||||
|
||||
upgradeMessage := defaultUpgradeMessage(semver.Canonical(serverVersion))
|
||||
if serverInfo, err := getBuildInfo(inv.Context()); err == nil {
|
||||
switch {
|
||||
|
||||
@@ -91,7 +91,7 @@ func Test_formatExamples(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
func Test_wrapTransportWithVersionCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NoOutput", func(t *testing.T) {
|
||||
@@ -102,7 +102,7 @@ func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
@@ -131,7 +131,7 @@ func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
expectedUpgradeMessage := "My custom upgrade message"
|
||||
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
@@ -159,6 +159,53 @@ func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput))
|
||||
require.Equal(t, expectedOutput, buf.String())
|
||||
})
|
||||
|
||||
t.Run("ServerStableVersion", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &RootCmd{}
|
||||
cmd, err := r.Command(nil)
|
||||
require.NoError(t, err)
|
||||
var buf bytes.Buffer
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
codersdk.BuildVersionHeader: []string{"v2.31.0"},
|
||||
},
|
||||
Body: io.NopCloser(nil),
|
||||
}, nil
|
||||
}), inv, "v2.31.0", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Empty(t, buf.String())
|
||||
})
|
||||
}
|
||||
|
||||
func Test_serverVersionMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
expected string
|
||||
}{
|
||||
{"Stable", "v2.31.0", ""},
|
||||
{"Dev", "v0.0.0-devel+abc123", "the server is running a development version of Coder (v0.0.0-devel+abc123)"},
|
||||
{"RC", "v2.31.0-rc.1", "the server is running a release candidate of Coder (v2.31.0-rc.1)"},
|
||||
{"RCDevel", "v2.33.0-rc.1-devel+727ec00f7", "the server is running a release candidate of Coder (v2.33.0-rc.1-devel+727ec00f7)"},
|
||||
{"Empty", "", ""},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, c.expected, serverVersionMessage(c.version))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wrapTransportWithTelemetryHeader(t *testing.T) {
|
||||
|
||||
+9
-1
@@ -79,6 +79,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/coderd/notifications/reports"
|
||||
"github.com/coder/coder/v2/coderd/oauthpki"
|
||||
"github.com/coder/coder/v2/coderd/objstore"
|
||||
"github.com/coder/coder/v2/coderd/pproflabel"
|
||||
"github.com/coder/coder/v2/coderd/prometheusmetrics"
|
||||
"github.com/coder/coder/v2/coderd/prometheusmetrics/insights"
|
||||
@@ -638,12 +639,19 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
vals.WorkspaceHostnameSuffix.String())
|
||||
}
|
||||
|
||||
objStore, err := objstore.FromConfig(ctx, vals.ObjectStore, r.globalConfig)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("initialize object store: %w", err)
|
||||
}
|
||||
defer objStore.Close()
|
||||
|
||||
options := &coderd.Options{
|
||||
AccessURL: vals.AccessURL.Value(),
|
||||
AppHostname: appHostname,
|
||||
AppHostnameRegex: appHostnameRegex,
|
||||
Logger: logger.Named("coderd"),
|
||||
Database: nil,
|
||||
ObjectStore: objStore,
|
||||
BaseDERPMap: derpMap,
|
||||
Pubsub: nil,
|
||||
CacheDir: cacheDir,
|
||||
@@ -1075,7 +1083,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
defer shutdownConns()
|
||||
|
||||
// Ensures that old database entries are cleaned up over time!
|
||||
purger := dbpurge.New(ctx, logger.Named("dbpurge"), options.Database, options.DeploymentValues, quartz.NewReal(), options.PrometheusRegistry)
|
||||
purger := dbpurge.New(ctx, logger.Named("dbpurge"), options.Database, options.DeploymentValues, quartz.NewReal(), options.PrometheusRegistry, objStore)
|
||||
defer purger.Close()
|
||||
|
||||
// Updates workspace usage
|
||||
|
||||
+99
-17
@@ -52,6 +52,10 @@ import (
|
||||
|
||||
const (
|
||||
disableUsageApp = "disable"
|
||||
|
||||
// Retry transient errors during SSH connection establishment.
|
||||
sshRetryInterval = 2 * time.Second
|
||||
sshMaxAttempts = 10 // initial + retries per step
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -62,6 +66,53 @@ var (
|
||||
workspaceNameRe = regexp.MustCompile(`[/.]+|--`)
|
||||
)
|
||||
|
||||
// isRetryableError checks for transient connection errors worth
|
||||
// retrying: DNS failures, connection refused, and server 5xx.
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil || xerrors.Is(err, context.Canceled) {
|
||||
return false
|
||||
}
|
||||
// Check connection errors before context.DeadlineExceeded because
|
||||
// net.Dialer.Timeout produces *net.OpError that matches both.
|
||||
if codersdk.IsConnectionError(err) {
|
||||
return true
|
||||
}
|
||||
if xerrors.Is(err, context.DeadlineExceeded) {
|
||||
return false
|
||||
}
|
||||
var sdkErr *codersdk.Error
|
||||
if xerrors.As(err, &sdkErr) {
|
||||
return sdkErr.StatusCode() >= 500
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// retryWithInterval calls fn up to maxAttempts times, waiting
|
||||
// interval between attempts. Stops on success, non-retryable
|
||||
// error, or context cancellation.
|
||||
func retryWithInterval(ctx context.Context, logger slog.Logger, interval time.Duration, maxAttempts int, fn func() error) error {
|
||||
var lastErr error
|
||||
attempt := 0
|
||||
for r := retry.New(interval, interval); r.Wait(ctx); {
|
||||
lastErr = fn()
|
||||
if lastErr == nil || !isRetryableError(lastErr) {
|
||||
return lastErr
|
||||
}
|
||||
attempt++
|
||||
if attempt >= maxAttempts {
|
||||
break
|
||||
}
|
||||
logger.Warn(ctx, "transient error, retrying",
|
||||
slog.Error(lastErr),
|
||||
slog.F("attempt", attempt),
|
||||
)
|
||||
}
|
||||
if lastErr != nil {
|
||||
return lastErr
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
func (r *RootCmd) ssh() *serpent.Command {
|
||||
var (
|
||||
stdio bool
|
||||
@@ -277,10 +328,17 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
HostnameSuffix: hostnameSuffix,
|
||||
}
|
||||
|
||||
workspace, workspaceAgent, err := findWorkspaceAndAgentByHostname(
|
||||
ctx, inv, client,
|
||||
inv.Args[0], cliConfig, disableAutostart)
|
||||
if err != nil {
|
||||
// Populated by the closure below.
|
||||
var workspace codersdk.Workspace
|
||||
var workspaceAgent codersdk.WorkspaceAgent
|
||||
resolveWorkspace := func() error {
|
||||
var err error
|
||||
workspace, workspaceAgent, err = findWorkspaceAndAgentByHostname(
|
||||
ctx, inv, client,
|
||||
inv.Args[0], cliConfig, disableAutostart)
|
||||
return err
|
||||
}
|
||||
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, resolveWorkspace); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -306,8 +364,13 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
wait = false
|
||||
}
|
||||
|
||||
templateVersion, err := client.TemplateVersion(ctx, workspace.LatestBuild.TemplateVersionID)
|
||||
if err != nil {
|
||||
var templateVersion codersdk.TemplateVersion
|
||||
fetchVersion := func() error {
|
||||
var err error
|
||||
templateVersion, err = client.TemplateVersion(ctx, workspace.LatestBuild.TemplateVersionID)
|
||||
return err
|
||||
}
|
||||
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, fetchVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -347,8 +410,12 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
// If we're in stdio mode, check to see if we can use Coder Connect.
|
||||
// We don't support Coder Connect over non-stdio coder ssh yet.
|
||||
if stdio && !forceNewTunnel {
|
||||
connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx)
|
||||
if err != nil {
|
||||
var connInfo workspacesdk.AgentConnectionInfo
|
||||
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, func() error {
|
||||
var err error
|
||||
connInfo, err = wsClient.AgentConnectionInfoGeneric(ctx)
|
||||
return err
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("get agent connection info: %w", err)
|
||||
}
|
||||
coderConnectHost := fmt.Sprintf("%s.%s.%s.%s",
|
||||
@@ -384,23 +451,27 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
})
|
||||
defer closeUsage()
|
||||
}
|
||||
return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack)
|
||||
return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack, logger)
|
||||
}
|
||||
}
|
||||
|
||||
if r.disableDirect {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Direct connections disabled.")
|
||||
}
|
||||
conn, err := wsClient.
|
||||
DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
|
||||
var conn workspacesdk.AgentConn
|
||||
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, func() error {
|
||||
var err error
|
||||
conn, err = wsClient.DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
|
||||
Logger: logger,
|
||||
BlockEndpoints: r.disableDirect,
|
||||
EnableTelemetry: !r.disableNetworkTelemetry,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("dial agent: %w", err)
|
||||
}
|
||||
if err = stack.push("agent conn", conn); err != nil {
|
||||
_ = conn.Close()
|
||||
return err
|
||||
}
|
||||
conn.AwaitReachable(ctx)
|
||||
@@ -1578,16 +1649,27 @@ func WithTestOnlyCoderConnectDialer(ctx context.Context, dialer coderConnectDial
|
||||
func testOrDefaultDialer(ctx context.Context) coderConnectDialer {
|
||||
dialer, ok := ctx.Value(coderConnectDialerContextKey{}).(coderConnectDialer)
|
||||
if !ok || dialer == nil {
|
||||
return &net.Dialer{}
|
||||
// Timeout prevents hanging on broken tunnels (OS default is very long).
|
||||
return &net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
return dialer
|
||||
}
|
||||
|
||||
func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack) error {
|
||||
func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack, logger slog.Logger) error {
|
||||
dialer := testOrDefaultDialer(ctx)
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("dial coder connect host: %w", err)
|
||||
var conn net.Conn
|
||||
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, func() error {
|
||||
var err error
|
||||
conn, err = dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("dial coder connect host %q over tcp: %w", addr, err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := stack.push("tcp conn", conn); err != nil {
|
||||
return err
|
||||
|
||||
+166
-1
@@ -5,7 +5,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -226,6 +228,41 @@ func TestCloserStack_Timeout(t *testing.T) {
|
||||
testutil.TryReceive(ctx, t, closed)
|
||||
}
|
||||
|
||||
func TestCloserStack_PushAfterClose_ConnClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
|
||||
|
||||
uut.close(xerrors.New("canceled"))
|
||||
|
||||
closes := new([]*fakeCloser)
|
||||
fc := &fakeCloser{closes: closes}
|
||||
err := uut.push("conn", fc)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, []*fakeCloser{fc}, *closes, "should close conn on failed push")
|
||||
}
|
||||
|
||||
func TestCoderConnectDialer_DefaultTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
dialer := testOrDefaultDialer(ctx)
|
||||
d, ok := dialer.(*net.Dialer)
|
||||
require.True(t, ok, "expected *net.Dialer")
|
||||
assert.Equal(t, 5*time.Second, d.Timeout)
|
||||
assert.Equal(t, 30*time.Second, d.KeepAlive)
|
||||
}
|
||||
|
||||
func TestCoderConnectDialer_Overridden(t *testing.T) {
|
||||
t.Parallel()
|
||||
custom := &net.Dialer{Timeout: 99 * time.Second}
|
||||
ctx := WithTestOnlyCoderConnectDialer(context.Background(), custom)
|
||||
|
||||
dialer := testOrDefaultDialer(ctx)
|
||||
assert.Equal(t, custom, dialer)
|
||||
}
|
||||
|
||||
func TestCoderConnectStdio(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -254,7 +291,7 @@ func TestCoderConnectStdio(t *testing.T) {
|
||||
|
||||
stdioDone := make(chan struct{})
|
||||
go func() {
|
||||
err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack)
|
||||
err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack, logger)
|
||||
assert.NoError(t, err)
|
||||
close(stdioDone)
|
||||
}()
|
||||
@@ -448,3 +485,131 @@ func Test_getWorkspaceAgent(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "available agents: [clark krypton zod]")
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsRetryableError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
retryable bool
|
||||
}{
|
||||
{"Nil", nil, false},
|
||||
{"ContextCanceled", context.Canceled, false},
|
||||
{"ContextDeadlineExceeded", context.DeadlineExceeded, false},
|
||||
{"WrappedContextCanceled", xerrors.Errorf("wrapped: %w", context.Canceled), false},
|
||||
{"DNSError", &net.DNSError{Err: "no such host", Name: "example.com", IsNotFound: true}, true},
|
||||
{"OpError", &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{}}, true},
|
||||
{"WrappedDNSError", xerrors.Errorf("connect: %w", &net.DNSError{Err: "no such host", Name: "example.com"}), true},
|
||||
{"SDKError_500", codersdk.NewTestError(http.StatusInternalServerError, "GET", "/api"), true},
|
||||
{"SDKError_502", codersdk.NewTestError(http.StatusBadGateway, "GET", "/api"), true},
|
||||
{"SDKError_503", codersdk.NewTestError(http.StatusServiceUnavailable, "GET", "/api"), true},
|
||||
{"SDKError_401", codersdk.NewTestError(http.StatusUnauthorized, "GET", "/api"), false},
|
||||
{"SDKError_403", codersdk.NewTestError(http.StatusForbidden, "GET", "/api"), false},
|
||||
{"SDKError_404", codersdk.NewTestError(http.StatusNotFound, "GET", "/api"), false},
|
||||
{"GenericError", xerrors.New("something went wrong"), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, tt.retryable, isRetryableError(tt.err))
|
||||
})
|
||||
}
|
||||
|
||||
// net.Dialer.Timeout produces *net.OpError that matches both
|
||||
// IsConnectionError and context.DeadlineExceeded. Verify it is retryable.
|
||||
t.Run("DialTimeout", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now())
|
||||
defer cancel()
|
||||
<-ctx.Done() // ensure deadline has fired
|
||||
_, err := (&net.Dialer{}).DialContext(ctx, "tcp", "127.0.0.1:1")
|
||||
require.Error(t, err)
|
||||
// Proves the ambiguity: this error matches BOTH checks.
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.ErrorAs(t, err, new(*net.OpError))
|
||||
assert.True(t, isRetryableError(err))
|
||||
// Also when wrapped, as runCoderConnectStdio does.
|
||||
assert.True(t, isRetryableError(xerrors.Errorf("dial coder connect: %w", err)))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRetryWithInterval(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const interval = time.Millisecond
|
||||
const maxAttempts = 3
|
||||
|
||||
dnsErr := &net.DNSError{Err: "no such host", Name: "example.com", IsNotFound: true}
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
|
||||
t.Run("Succeeds_FirstTry", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
attempts := 0
|
||||
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
|
||||
attempts++
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, attempts)
|
||||
})
|
||||
|
||||
t.Run("Succeeds_AfterTransientFailures", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
attempts := 0
|
||||
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
|
||||
attempts++
|
||||
if attempts < 3 {
|
||||
return dnsErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, attempts)
|
||||
})
|
||||
|
||||
t.Run("Stops_NonRetryableError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
attempts := 0
|
||||
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
|
||||
attempts++
|
||||
return xerrors.New("permanent failure")
|
||||
})
|
||||
require.ErrorContains(t, err, "permanent failure")
|
||||
assert.Equal(t, 1, attempts)
|
||||
})
|
||||
|
||||
t.Run("Stops_MaxAttemptsExhausted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
attempts := 0
|
||||
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
|
||||
attempts++
|
||||
return dnsErr
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, maxAttempts, attempts)
|
||||
})
|
||||
|
||||
t.Run("Stops_ContextCanceled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
attempts := 0
|
||||
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
|
||||
attempts++
|
||||
cancel()
|
||||
return dnsErr
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, 1, attempts)
|
||||
})
|
||||
}
|
||||
|
||||
+6
@@ -39,6 +39,12 @@ OPTIONS:
|
||||
--block-file-transfer bool, $CODER_AGENT_BLOCK_FILE_TRANSFER (default: false)
|
||||
Block file transfer using known applications: nc,rsync,scp,sftp.
|
||||
|
||||
--block-local-port-forwarding bool, $CODER_AGENT_BLOCK_LOCAL_PORT_FORWARDING (default: false)
|
||||
Block local port forwarding through the SSH server (ssh -L).
|
||||
|
||||
--block-reverse-port-forwarding bool, $CODER_AGENT_BLOCK_REVERSE_PORT_FORWARDING (default: false)
|
||||
Block reverse port forwarding through the SSH server (ssh -R).
|
||||
|
||||
--boundary-log-proxy-socket-path string, $CODER_AGENT_BOUNDARY_LOG_PROXY_SOCKET_PATH (default: /tmp/boundary-audit.sock)
|
||||
The path for the boundary log proxy server Unix socket. Boundary
|
||||
should write audit logs to this socket.
|
||||
|
||||
+35
@@ -773,6 +773,41 @@ OIDC OPTIONS:
|
||||
requirement, and can lead to an insecure OIDC configuration. It is not
|
||||
recommended to use this flag.
|
||||
|
||||
OBJECT STORE OPTIONS:
|
||||
Configure the object storage backend for binary data (chat files, transcripts,
|
||||
etc.). Defaults to local filesystem storage.
|
||||
|
||||
--objectstore-backend string, $CODER_OBJECTSTORE_BACKEND (default: local)
|
||||
The storage backend for binary data such as chat files. Valid values:
|
||||
local, s3, gcs.
|
||||
|
||||
--objectstore-gcs-bucket string, $CODER_OBJECTSTORE_GCS_BUCKET
|
||||
GCS bucket name. Required when the backend is "gcs".
|
||||
|
||||
--objectstore-gcs-credentials-file string, $CODER_OBJECTSTORE_GCS_CREDENTIALS_FILE
|
||||
Path to a GCS service account key file. If empty, Application Default
|
||||
Credentials are used.
|
||||
|
||||
--objectstore-gcs-prefix string, $CODER_OBJECTSTORE_GCS_PREFIX
|
||||
Optional key prefix within the GCS bucket.
|
||||
|
||||
--objectstore-local-dir string, $CODER_OBJECTSTORE_LOCAL_DIR
|
||||
Root directory for the local filesystem object store backend. Only
|
||||
used when the backend is "local".
|
||||
|
||||
--objectstore-s3-bucket string, $CODER_OBJECTSTORE_S3_BUCKET
|
||||
S3 bucket name. Required when the backend is "s3".
|
||||
|
||||
--objectstore-s3-endpoint string, $CODER_OBJECTSTORE_S3_ENDPOINT
|
||||
Custom S3-compatible endpoint URL (e.g. for MinIO, R2, Cloudflare).
|
||||
Leave empty for standard AWS S3.
|
||||
|
||||
--objectstore-s3-prefix string, $CODER_OBJECTSTORE_S3_PREFIX
|
||||
Optional key prefix within the S3 bucket.
|
||||
|
||||
--objectstore-s3-region string, $CODER_OBJECTSTORE_S3_REGION
|
||||
AWS region for the S3 bucket.
|
||||
|
||||
PROVISIONING OPTIONS:
|
||||
Tune the behavior of the provisioner, which is responsible for creating,
|
||||
updating, and deleting workspace resources.
|
||||
|
||||
+34
@@ -908,3 +908,37 @@ retention:
|
||||
# build are always retained. Set to 0 to disable automatic deletion.
|
||||
# (default: 7d, type: duration)
|
||||
workspace_agent_logs: 168h0m0s
|
||||
# Configure the object storage backend for binary data (chat files, transcripts,
|
||||
# etc.). Defaults to local filesystem storage.
|
||||
objectStore:
|
||||
# The storage backend for binary data such as chat files. Valid values: local, s3,
|
||||
# gcs.
|
||||
# (default: local, type: string)
|
||||
backend: local
|
||||
# Root directory for the local filesystem object store backend. Only used when the
|
||||
# backend is "local".
|
||||
# (default: <unset>, type: string)
|
||||
local_dir: ""
|
||||
# S3 bucket name. Required when the backend is "s3".
|
||||
# (default: <unset>, type: string)
|
||||
s3_bucket: ""
|
||||
# AWS region for the S3 bucket.
|
||||
# (default: <unset>, type: string)
|
||||
s3_region: ""
|
||||
# Optional key prefix within the S3 bucket.
|
||||
# (default: <unset>, type: string)
|
||||
s3_prefix: ""
|
||||
# Custom S3-compatible endpoint URL (e.g. for MinIO, R2, Cloudflare). Leave empty
|
||||
# for standard AWS S3.
|
||||
# (default: <unset>, type: string)
|
||||
s3_endpoint: ""
|
||||
# GCS bucket name. Required when the backend is "gcs".
|
||||
# (default: <unset>, type: string)
|
||||
gcs_bucket: ""
|
||||
# Optional key prefix within the GCS bucket.
|
||||
# (default: <unset>, type: string)
|
||||
gcs_prefix: ""
|
||||
# Path to a GCS service account key file. If empty, Application Default
|
||||
# Credentials are used.
|
||||
# (default: <unset>, type: string)
|
||||
gcs_credentials_file: ""
|
||||
|
||||
@@ -134,6 +134,7 @@ func TestUserCreate(t *testing.T) {
|
||||
{
|
||||
name: "ServiceAccount",
|
||||
args: []string{"--service-account", "-u", "dean"},
|
||||
err: "Premium feature",
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountLoginType",
|
||||
|
||||
@@ -77,8 +77,9 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
|
||||
level := make([]database.LogLevel, 0)
|
||||
outputLength := 0
|
||||
for _, logEntry := range req.Logs {
|
||||
output = append(output, logEntry.Output)
|
||||
outputLength += len(logEntry.Output)
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
|
||||
output = append(output, sanitizedOutput)
|
||||
outputLength += len(sanitizedOutput)
|
||||
|
||||
var dbLevel database.LogLevel
|
||||
switch logEntry.Level {
|
||||
|
||||
@@ -139,6 +139,59 @@ func TestBatchCreateLogs(t *testing.T) {
|
||||
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
|
||||
})
|
||||
|
||||
t.Run("SanitizesOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
now := dbtime.Now()
|
||||
api := &agentapi.LogsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
rawOutput := "before\x00middle\xc3\x28after"
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
|
||||
expectedOutputLength := int32(len(sanitizedOutput)) //nolint:gosec // Test-controlled string length is small.
|
||||
req := &agentproto.BatchCreateLogsRequest{
|
||||
LogSourceId: logSource.ID[:],
|
||||
Logs: []*agentproto.Log{
|
||||
{
|
||||
CreatedAt: timestamppb.New(now),
|
||||
Level: agentproto.Log_WARN,
|
||||
Output: rawOutput,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), database.InsertWorkspaceAgentLogsParams{
|
||||
AgentID: agent.ID,
|
||||
LogSourceID: logSource.ID,
|
||||
CreatedAt: now,
|
||||
Output: []string{sanitizedOutput},
|
||||
Level: []database.LogLevel{database.LogLevelWarn},
|
||||
OutputLength: expectedOutputLength,
|
||||
}).Return([]database.WorkspaceAgentLog{
|
||||
{
|
||||
AgentID: agent.ID,
|
||||
CreatedAt: now,
|
||||
ID: 1,
|
||||
Output: sanitizedOutput,
|
||||
Level: database.LogLevelWarn,
|
||||
LogSourceID: logSource.ID,
|
||||
},
|
||||
}, nil)
|
||||
|
||||
resp, err := api.BatchCreateLogs(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
|
||||
})
|
||||
|
||||
t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Generated
+406
@@ -1266,6 +1266,68 @@ const docTemplate = `{
|
||||
]
|
||||
}
|
||||
},
|
||||
"/experimental/chats/config/retention-days": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Get chat retention days",
|
||||
"operationId": "get-chat-retention-days",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.ChatRetentionDaysResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
},
|
||||
"put": {
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Update chat retention days",
|
||||
"operationId": "update-chat-retention-days",
|
||||
"parameters": [
|
||||
{
|
||||
"description": "Request body",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateChatRetentionDaysRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/experimental/watch-all-workspacebuilds": {
|
||||
"get": {
|
||||
"produces": [
|
||||
@@ -9452,6 +9514,212 @@ const docTemplate = `{
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "List user secrets",
|
||||
"operationId": "list-user-secrets",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"post": {
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Create a new user secret",
|
||||
"operationId": "create-a-new-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Create secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.CreateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Created",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets/{name}": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Get a user secret by name",
|
||||
"operationId": "get-a-user-secret-by-name",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"delete": {
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Delete a user secret",
|
||||
"operationId": "delete-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"patch": {
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Update a user secret",
|
||||
"operationId": "update-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Update secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/status/activate": {
|
||||
"put": {
|
||||
"produces": [
|
||||
@@ -13177,6 +13445,12 @@ const docTemplate = `{
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
|
||||
}
|
||||
},
|
||||
"credential_hint": {
|
||||
"type": "string"
|
||||
},
|
||||
"credential_kind": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -14175,6 +14449,9 @@ const docTemplate = `{
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -14417,6 +14694,14 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatRetentionDaysResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retention_days": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ConnectionLatency": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -14496,6 +14781,9 @@ const docTemplate = `{
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -15066,6 +15354,26 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateWorkspaceBuildReason": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -15617,6 +15925,9 @@ const docTemplate = `{
|
||||
"oauth2": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2Config"
|
||||
},
|
||||
"object_store": {
|
||||
"$ref": "#/definitions/codersdk.ObjectStoreConfig"
|
||||
},
|
||||
"oidc": {
|
||||
"$ref": "#/definitions/codersdk.OIDCConfig"
|
||||
},
|
||||
@@ -17633,6 +17944,47 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ObjectStoreConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"backend": {
|
||||
"description": "Backend selects the storage backend: \"local\" (default), \"s3\", or \"gcs\".",
|
||||
"type": "string"
|
||||
},
|
||||
"gcs_bucket": {
|
||||
"description": "GCSBucket is the GCS bucket name. Required when Backend is \"gcs\".",
|
||||
"type": "string"
|
||||
},
|
||||
"gcs_credentials_file": {
|
||||
"description": "GCSCredentialsFile is an optional path to a GCS service account\nkey file. If empty, Application Default Credentials are used.",
|
||||
"type": "string"
|
||||
},
|
||||
"gcs_prefix": {
|
||||
"description": "GCSPrefix is an optional key prefix within the GCS bucket.",
|
||||
"type": "string"
|
||||
},
|
||||
"local_dir": {
|
||||
"description": "LocalDir is the root directory for the local filesystem backend.\nOnly used when Backend is \"local\". Defaults to \u003cconfig-dir\u003e/objectstore/.",
|
||||
"type": "string"
|
||||
},
|
||||
"s3_bucket": {
|
||||
"description": "S3Bucket is the S3 bucket name. Required when Backend is \"s3\".",
|
||||
"type": "string"
|
||||
},
|
||||
"s3_endpoint": {
|
||||
"description": "S3Endpoint is a custom S3-compatible endpoint URL (for MinIO, R2, etc.).",
|
||||
"type": "string"
|
||||
},
|
||||
"s3_prefix": {
|
||||
"description": "S3Prefix is an optional key prefix within the S3 bucket.",
|
||||
"type": "string"
|
||||
},
|
||||
"s3_region": {
|
||||
"description": "S3Region is the AWS region for the S3 bucket.",
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.OptionType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -20946,6 +21298,14 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateChatRetentionDaysRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retention_days": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateCheckResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -21187,6 +21547,23 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateWorkspaceACL": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -21642,6 +22019,35 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserSecret": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserStatus": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
|
||||
Generated
+376
@@ -1103,6 +1103,60 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"/experimental/chats/config/retention-days": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Chats"],
|
||||
"summary": "Get chat retention days",
|
||||
"operationId": "get-chat-retention-days",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.ChatRetentionDaysResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
},
|
||||
"put": {
|
||||
"consumes": ["application/json"],
|
||||
"tags": ["Chats"],
|
||||
"summary": "Update chat retention days",
|
||||
"operationId": "update-chat-retention-days",
|
||||
"parameters": [
|
||||
{
|
||||
"description": "Request body",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateChatRetentionDaysRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/experimental/watch-all-workspacebuilds": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
@@ -8377,6 +8431,190 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "List user secrets",
|
||||
"operationId": "list-user-secrets",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"post": {
|
||||
"consumes": ["application/json"],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Create a new user secret",
|
||||
"operationId": "create-a-new-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Create secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.CreateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Created",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets/{name}": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Get a user secret by name",
|
||||
"operationId": "get-a-user-secret-by-name",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"delete": {
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Delete a user secret",
|
||||
"operationId": "delete-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"patch": {
|
||||
"consumes": ["application/json"],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Update a user secret",
|
||||
"operationId": "update-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Update secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/status/activate": {
|
||||
"put": {
|
||||
"produces": ["application/json"],
|
||||
@@ -11755,6 +11993,12 @@
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
|
||||
}
|
||||
},
|
||||
"credential_hint": {
|
||||
"type": "string"
|
||||
},
|
||||
"credential_kind": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -12739,6 +12983,9 @@
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -12960,6 +13207,14 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatRetentionDaysResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retention_days": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ConnectionLatency": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -13039,6 +13294,9 @@
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -13575,6 +13833,26 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateWorkspaceBuildReason": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -14114,6 +14392,9 @@
|
||||
"oauth2": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2Config"
|
||||
},
|
||||
"object_store": {
|
||||
"$ref": "#/definitions/codersdk.ObjectStoreConfig"
|
||||
},
|
||||
"oidc": {
|
||||
"$ref": "#/definitions/codersdk.OIDCConfig"
|
||||
},
|
||||
@@ -16060,6 +16341,47 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ObjectStoreConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"backend": {
|
||||
"description": "Backend selects the storage backend: \"local\" (default), \"s3\", or \"gcs\".",
|
||||
"type": "string"
|
||||
},
|
||||
"gcs_bucket": {
|
||||
"description": "GCSBucket is the GCS bucket name. Required when Backend is \"gcs\".",
|
||||
"type": "string"
|
||||
},
|
||||
"gcs_credentials_file": {
|
||||
"description": "GCSCredentialsFile is an optional path to a GCS service account\nkey file. If empty, Application Default Credentials are used.",
|
||||
"type": "string"
|
||||
},
|
||||
"gcs_prefix": {
|
||||
"description": "GCSPrefix is an optional key prefix within the GCS bucket.",
|
||||
"type": "string"
|
||||
},
|
||||
"local_dir": {
|
||||
"description": "LocalDir is the root directory for the local filesystem backend.\nOnly used when Backend is \"local\". Defaults to \u003cconfig-dir\u003e/objectstore/.",
|
||||
"type": "string"
|
||||
},
|
||||
"s3_bucket": {
|
||||
"description": "S3Bucket is the S3 bucket name. Required when Backend is \"s3\".",
|
||||
"type": "string"
|
||||
},
|
||||
"s3_endpoint": {
|
||||
"description": "S3Endpoint is a custom S3-compatible endpoint URL (for MinIO, R2, etc.).",
|
||||
"type": "string"
|
||||
},
|
||||
"s3_prefix": {
|
||||
"description": "S3Prefix is an optional key prefix within the S3 bucket.",
|
||||
"type": "string"
|
||||
},
|
||||
"s3_region": {
|
||||
"description": "S3Region is the AWS region for the S3 bucket.",
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.OptionType": {
|
||||
"type": "string",
|
||||
"enum": ["string", "number", "bool", "list(string)"],
|
||||
@@ -19237,6 +19559,14 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateChatRetentionDaysRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retention_days": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateCheckResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -19469,6 +19799,23 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateWorkspaceACL": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -19899,6 +20246,35 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserSecret": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserStatus": {
|
||||
"type": "string",
|
||||
"enum": ["active", "dormant", "suspended"],
|
||||
|
||||
+8
-1
@@ -26,6 +26,11 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// Limit the count query to avoid a slow sequential scan due to joins
|
||||
// on a large table. Set to 0 to disable capping (but also see the note
|
||||
// in the SQL query).
|
||||
const auditLogCountCap = 2000
|
||||
|
||||
// @Summary Get audit logs
|
||||
// @ID get-audit-logs
|
||||
// @Security CoderSessionToken
|
||||
@@ -66,7 +71,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
countFilter.Username = ""
|
||||
}
|
||||
|
||||
// Use the same filters to count the number of audit logs
|
||||
countFilter.CountCap = auditLogCountCap
|
||||
count, err := api.Database.CountAuditLogs(ctx, countFilter)
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
@@ -81,6 +86,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
|
||||
AuditLogs: []codersdk.AuditLog{},
|
||||
Count: 0,
|
||||
CountCap: auditLogCountCap,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -98,6 +104,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
|
||||
AuditLogs: api.convertAuditLogs(ctx, dblogs),
|
||||
Count: count,
|
||||
CountCap: auditLogCountCap,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -71,6 +71,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/metricscache"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/coderd/oauth2provider"
|
||||
"github.com/coder/coder/v2/coderd/objstore"
|
||||
"github.com/coder/coder/v2/coderd/portsharing"
|
||||
"github.com/coder/coder/v2/coderd/pproflabel"
|
||||
"github.com/coder/coder/v2/coderd/prebuilds"
|
||||
@@ -158,6 +159,7 @@ type Options struct {
|
||||
AppHostnameRegex *regexp.Regexp
|
||||
Logger slog.Logger
|
||||
Database database.Store
|
||||
ObjectStore objstore.Store
|
||||
Pubsub pubsub.Pubsub
|
||||
RuntimeConfig *runtimeconfig.Manager
|
||||
|
||||
@@ -792,6 +794,7 @@ func New(options *Options) *API {
|
||||
Pubsub: options.Pubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
UsageTracker: options.WorkspaceUsageTracker,
|
||||
ObjectStore: options.ObjectStore,
|
||||
})
|
||||
gitSyncLogger := options.Logger.Named("gitsync")
|
||||
refresher := gitsync.NewRefresher(
|
||||
@@ -1189,6 +1192,8 @@ func New(options *Options) *API {
|
||||
r.Delete("/user-compaction-thresholds/{modelConfig}", api.deleteUserChatCompactionThreshold)
|
||||
r.Get("/workspace-ttl", api.getChatWorkspaceTTL)
|
||||
r.Put("/workspace-ttl", api.putChatWorkspaceTTL)
|
||||
r.Get("/retention-days", api.getChatRetentionDays)
|
||||
r.Put("/retention-days", api.putChatRetentionDays)
|
||||
r.Get("/template-allowlist", api.getChatTemplateAllowlist)
|
||||
r.Put("/template-allowlist", api.putChatTemplateAllowlist)
|
||||
})
|
||||
@@ -1243,6 +1248,7 @@ func New(options *Options) *API {
|
||||
r.Get("/git", api.watchChatGit)
|
||||
})
|
||||
r.Post("/interrupt", api.interruptChat)
|
||||
r.Post("/tool-results", api.postChatToolResults)
|
||||
r.Post("/title/regenerate", api.regenerateChatTitle)
|
||||
r.Get("/diff", api.getChatDiffContents)
|
||||
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
|
||||
@@ -1605,6 +1611,15 @@ func New(options *Options) *API {
|
||||
|
||||
r.Get("/gitsshkey", api.gitSSHKey)
|
||||
r.Put("/gitsshkey", api.regenerateGitSSHKey)
|
||||
r.Route("/secrets", func(r chi.Router) {
|
||||
r.Post("/", api.postUserSecret)
|
||||
r.Get("/", api.getUserSecrets)
|
||||
r.Route("/{name}", func(r chi.Router) {
|
||||
r.Get("/", api.getUserSecret)
|
||||
r.Patch("/", api.patchUserSecret)
|
||||
r.Delete("/", api.deleteUserSecret)
|
||||
})
|
||||
})
|
||||
r.Route("/notifications", func(r chi.Router) {
|
||||
r.Route("/preferences", func(r chi.Router) {
|
||||
r.Get("/", api.userNotificationPreferences)
|
||||
@@ -1650,6 +1665,10 @@ func New(options *Options) *API {
|
||||
r.Get("/gitsshkey", api.agentGitSSHKey)
|
||||
r.Post("/log-source", api.workspaceAgentPostLogSource)
|
||||
r.Get("/reinit", api.workspaceAgentReinit)
|
||||
r.Route("/experimental", func(r chi.Router) {
|
||||
r.Post("/chat-context", api.workspaceAgentAddChatContext)
|
||||
r.Delete("/chat-context", api.workspaceAgentClearChatContext)
|
||||
})
|
||||
r.Route("/tasks/{task}", func(r chi.Router) {
|
||||
r.Post("/log-snapshot", api.postWorkspaceAgentTaskLogSnapshot)
|
||||
})
|
||||
|
||||
@@ -147,6 +147,10 @@ func parseSwaggerComment(commentGroup *ast.CommentGroup) SwaggerComment {
|
||||
return c
|
||||
}
|
||||
|
||||
func isExperimentalEndpoint(route string) bool {
|
||||
return strings.HasPrefix(route, "/workspaceagents/me/experimental/")
|
||||
}
|
||||
|
||||
func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments []SwaggerComment) {
|
||||
assertUniqueRoutes(t, swaggerComments)
|
||||
assertSingleAnnotations(t, swaggerComments)
|
||||
@@ -165,6 +169,9 @@ func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments [
|
||||
if strings.HasSuffix(route, "/*") {
|
||||
return
|
||||
}
|
||||
if isExperimentalEndpoint(route) {
|
||||
return
|
||||
}
|
||||
|
||||
c := findSwaggerCommentByMethodAndRoute(swaggerComments, method, route)
|
||||
assert.NotNil(t, c, "Missing @Router annotation")
|
||||
|
||||
@@ -123,6 +123,10 @@ func UsersPagination(
|
||||
require.Contains(t, gotUsers[0].Name, "after")
|
||||
}
|
||||
|
||||
type UsersFilterOptions struct {
|
||||
CreateServiceAccounts bool
|
||||
}
|
||||
|
||||
// UsersFilter creates a set of users to run various filters against for
|
||||
// testing. It can be used to test filtering both users and group members.
|
||||
func UsersFilter(
|
||||
@@ -130,11 +134,16 @@ func UsersFilter(
|
||||
t *testing.T,
|
||||
client *codersdk.Client,
|
||||
db database.Store,
|
||||
options *UsersFilterOptions,
|
||||
setup func(users []codersdk.User),
|
||||
fetch func(ctx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
if options == nil {
|
||||
options = &UsersFilterOptions{}
|
||||
}
|
||||
|
||||
firstUser, err := client.User(setupCtx, codersdk.Me)
|
||||
require.NoError(t, err, "fetch me")
|
||||
|
||||
@@ -211,11 +220,13 @@ func UsersFilter(
|
||||
}
|
||||
|
||||
// Add some service accounts.
|
||||
for range 3 {
|
||||
_, user := CreateAnotherUserMutators(t, client, orgID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
|
||||
r.ServiceAccount = true
|
||||
})
|
||||
users = append(users, user)
|
||||
if options.CreateServiceAccounts {
|
||||
for range 3 {
|
||||
_, user := CreateAnotherUserMutators(t, client, orgID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
|
||||
r.ServiceAccount = true
|
||||
})
|
||||
users = append(users, user)
|
||||
}
|
||||
}
|
||||
|
||||
hashedPassword, err := userpassword.Hash("SomeStrongPassword!")
|
||||
|
||||
@@ -538,6 +538,12 @@ func WorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordinator,
|
||||
switch {
|
||||
case workspaceAgent.Status != codersdk.WorkspaceAgentConnected && workspaceAgent.LifecycleState == codersdk.WorkspaceAgentLifecycleOff:
|
||||
workspaceAgent.Health.Reason = "agent is not running"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentConnecting:
|
||||
// Note: the case above catches connecting+off as "not running".
|
||||
// This case handles connecting agents with a non-off lifecycle
|
||||
// (e.g. "created" or "starting"), where the agent binary has
|
||||
// not yet established a connection to coderd.
|
||||
workspaceAgent.Health.Reason = "agent has not yet connected"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentTimeout:
|
||||
workspaceAgent.Health.Reason = "agent is taking too long to connect"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentDisconnected:
|
||||
@@ -1234,6 +1240,8 @@ func buildAIBridgeThread(
|
||||
if rootIntc != nil {
|
||||
thread.Model = rootIntc.Model
|
||||
thread.Provider = rootIntc.Provider
|
||||
thread.CredentialKind = string(rootIntc.CredentialKind)
|
||||
thread.CredentialHint = rootIntc.CredentialHint
|
||||
// Get first user prompt from root interception.
|
||||
// A thread can only have one prompt, by definition, since we currently
|
||||
// only store the last prompt observed in an interception.
|
||||
@@ -1528,7 +1536,10 @@ func nullInt64Ptr(v sql.NullInt64) *int64 {
|
||||
// Chat converts a database.Chat to a codersdk.Chat. It coalesces
|
||||
// nil slices and maps to empty values for JSON serialization and
|
||||
// derives RootChatID from the parent chain when not explicitly set.
|
||||
func Chat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat {
|
||||
// When diffStatus is non-nil the response includes diff metadata.
|
||||
// When files is non-empty the response includes file metadata;
|
||||
// pass nil to omit the files field (e.g. list endpoints).
|
||||
func Chat(c database.Chat, diffStatus *database.ChatDiffStatus, files []database.GetChatFileMetadataByChatIDRow) codersdk.Chat {
|
||||
mcpServerIDs := c.MCPServerIDs
|
||||
if mcpServerIDs == nil {
|
||||
mcpServerIDs = []uuid.UUID{}
|
||||
@@ -1581,6 +1592,19 @@ func Chat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat {
|
||||
convertedDiffStatus := ChatDiffStatus(c.ID, diffStatus)
|
||||
chat.DiffStatus = &convertedDiffStatus
|
||||
}
|
||||
if len(files) > 0 {
|
||||
chat.Files = make([]codersdk.ChatFileMetadata, 0, len(files))
|
||||
for _, row := range files {
|
||||
chat.Files = append(chat.Files, codersdk.ChatFileMetadata{
|
||||
ID: row.ID,
|
||||
OwnerID: row.OwnerID,
|
||||
OrganizationID: row.OrganizationID,
|
||||
Name: row.Name,
|
||||
MimeType: row.Mimetype,
|
||||
CreatedAt: row.CreatedAt,
|
||||
})
|
||||
}
|
||||
}
|
||||
if c.LastInjectedContext.Valid {
|
||||
var parts []codersdk.ChatMessagePart
|
||||
// Internal fields are stripped at write time in
|
||||
@@ -1604,9 +1628,9 @@ func ChatRows(rows []database.GetChatsRow, diffStatusesByChatID map[uuid.UUID]da
|
||||
for i, row := range rows {
|
||||
diffStatus, ok := diffStatusesByChatID[row.Chat.ID]
|
||||
if ok {
|
||||
result[i] = Chat(row.Chat, &diffStatus)
|
||||
result[i] = Chat(row.Chat, &diffStatus, nil)
|
||||
} else {
|
||||
result[i] = Chat(row.Chat, nil)
|
||||
result[i] = Chat(row.Chat, nil, nil)
|
||||
if diffStatusesByChatID != nil {
|
||||
emptyDiffStatus := ChatDiffStatus(row.Chat.ID, nil)
|
||||
result[i].DiffStatus = &emptyDiffStatus
|
||||
@@ -1699,3 +1723,41 @@ func ChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) codersdk.
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// UserSecret converts a database ListUserSecretsRow (metadata only,
|
||||
// no value) to an SDK UserSecret.
|
||||
func UserSecret(secret database.ListUserSecretsRow) codersdk.UserSecret {
|
||||
return codersdk.UserSecret{
|
||||
ID: secret.ID,
|
||||
Name: secret.Name,
|
||||
Description: secret.Description,
|
||||
EnvName: secret.EnvName,
|
||||
FilePath: secret.FilePath,
|
||||
CreatedAt: secret.CreatedAt,
|
||||
UpdatedAt: secret.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// UserSecretFromFull converts a full database UserSecret row to an
|
||||
// SDK UserSecret, omitting the value and encryption key ID.
|
||||
func UserSecretFromFull(secret database.UserSecret) codersdk.UserSecret {
|
||||
return codersdk.UserSecret{
|
||||
ID: secret.ID,
|
||||
Name: secret.Name,
|
||||
Description: secret.Description,
|
||||
EnvName: secret.EnvName,
|
||||
FilePath: secret.FilePath,
|
||||
CreatedAt: secret.CreatedAt,
|
||||
UpdatedAt: secret.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// UserSecrets converts a slice of database ListUserSecretsRow to
|
||||
// SDK UserSecret values.
|
||||
func UserSecrets(secrets []database.ListUserSecretsRow) []codersdk.UserSecret {
|
||||
result := make([]codersdk.UserSecret, 0, len(secrets))
|
||||
for _, s := range secrets {
|
||||
result = append(result, UserSecret(s))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -552,6 +552,10 @@ func TestChat_AllFieldsPopulated(t *testing.T) {
|
||||
RawMessage: json.RawMessage(`[{"type":"context-file","context_file_path":"/AGENTS.md"}]`),
|
||||
Valid: true,
|
||||
},
|
||||
DynamicTools: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`[{"name":"tool1","description":"test tool","inputSchema":{"type":"object"}}]`),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
// Only ChatID is needed here. This test checks that
|
||||
// Chat.DiffStatus is non-nil, not that every DiffStatus
|
||||
@@ -561,14 +565,26 @@ func TestChat_AllFieldsPopulated(t *testing.T) {
|
||||
ChatID: input.ID,
|
||||
}
|
||||
|
||||
got := db2sdk.Chat(input, diffStatus)
|
||||
fileRows := []database.GetChatFileMetadataByChatIDRow{
|
||||
{
|
||||
ID: uuid.New(),
|
||||
OwnerID: input.OwnerID,
|
||||
OrganizationID: uuid.New(),
|
||||
Name: "test.png",
|
||||
Mimetype: "image/png",
|
||||
CreatedAt: now,
|
||||
},
|
||||
}
|
||||
|
||||
got := db2sdk.Chat(input, diffStatus, fileRows)
|
||||
|
||||
v := reflect.ValueOf(got)
|
||||
typ := v.Type()
|
||||
// HasUnread is populated by ChatRows (which joins the
|
||||
// read-cursor query), not by Chat, so it is expected
|
||||
// to remain zero here.
|
||||
skip := map[string]bool{"HasUnread": true}
|
||||
// read-cursor query), not by Chat. Warnings is a transient
|
||||
// field populated by handlers, not the converter. Both are
|
||||
// expected to remain zero here.
|
||||
skip := map[string]bool{"HasUnread": true, "Warnings": true}
|
||||
for i := range typ.NumField() {
|
||||
field := typ.Field(i)
|
||||
if skip[field.Name] {
|
||||
@@ -581,6 +597,112 @@ func TestChat_AllFieldsPopulated(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_FileMetadataConversion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ownerID := uuid.New()
|
||||
orgID := uuid.New()
|
||||
fileID := uuid.New()
|
||||
now := dbtime.Now()
|
||||
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
OwnerID: ownerID,
|
||||
LastModelConfigID: uuid.New(),
|
||||
Title: "file metadata test",
|
||||
Status: database.ChatStatusWaiting,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
rows := []database.GetChatFileMetadataByChatIDRow{
|
||||
{
|
||||
ID: fileID,
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
Name: "screenshot.png",
|
||||
Mimetype: "image/png",
|
||||
CreatedAt: now,
|
||||
},
|
||||
}
|
||||
|
||||
result := db2sdk.Chat(chat, nil, rows)
|
||||
|
||||
require.Len(t, result.Files, 1)
|
||||
f := result.Files[0]
|
||||
require.Equal(t, fileID, f.ID)
|
||||
require.Equal(t, ownerID, f.OwnerID, "OwnerID must be mapped from DB row")
|
||||
require.Equal(t, orgID, f.OrganizationID, "OrganizationID must be mapped from DB row")
|
||||
require.Equal(t, "screenshot.png", f.Name)
|
||||
require.Equal(t, "image/png", f.MimeType)
|
||||
require.Equal(t, now, f.CreatedAt)
|
||||
|
||||
// Verify JSON serialization uses snake_case for mime_type.
|
||||
data, err := json.Marshal(f)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(data), `"mime_type"`)
|
||||
require.NotContains(t, string(data), `"mimetype"`)
|
||||
}
|
||||
|
||||
func TestChat_NilFilesOmitted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
OwnerID: uuid.New(),
|
||||
LastModelConfigID: uuid.New(),
|
||||
Title: "no files",
|
||||
Status: database.ChatStatusWaiting,
|
||||
CreatedAt: dbtime.Now(),
|
||||
UpdatedAt: dbtime.Now(),
|
||||
}
|
||||
|
||||
result := db2sdk.Chat(chat, nil, nil)
|
||||
require.Empty(t, result.Files)
|
||||
}
|
||||
|
||||
func TestChat_MultipleFiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := dbtime.Now()
|
||||
file1 := uuid.New()
|
||||
file2 := uuid.New()
|
||||
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
OwnerID: uuid.New(),
|
||||
LastModelConfigID: uuid.New(),
|
||||
Title: "multi file test",
|
||||
Status: database.ChatStatusWaiting,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
rows := []database.GetChatFileMetadataByChatIDRow{
|
||||
{
|
||||
ID: file1,
|
||||
OwnerID: chat.OwnerID,
|
||||
OrganizationID: uuid.New(),
|
||||
Name: "a.png",
|
||||
Mimetype: "image/png",
|
||||
CreatedAt: now,
|
||||
},
|
||||
{
|
||||
ID: file2,
|
||||
OwnerID: chat.OwnerID,
|
||||
OrganizationID: uuid.New(),
|
||||
Name: "b.txt",
|
||||
Mimetype: "text/plain",
|
||||
CreatedAt: now,
|
||||
},
|
||||
}
|
||||
|
||||
result := db2sdk.Chat(chat, nil, rows)
|
||||
require.Len(t, result.Files, 2)
|
||||
require.Equal(t, "a.png", result.Files[0].Name)
|
||||
require.Equal(t, "b.txt", result.Files[1].Name)
|
||||
}
|
||||
|
||||
func TestChatQueuedMessage_MalformedContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -1708,6 +1708,17 @@ func (q *querier) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
|
||||
return q.db.CleanupDeletedMCPServerIDsFromChats(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
@@ -2031,6 +2042,20 @@ func (q *querier) DeleteOldAuditLogs(ctx context.Context, arg database.DeleteOld
|
||||
return q.db.DeleteOldAuditLogs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) ([]database.DeleteOldChatFilesRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.DeleteOldChatFiles(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.DeleteOldChats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
|
||||
return 0, err
|
||||
@@ -2155,17 +2180,12 @@ func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.De
|
||||
return q.db.DeleteUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.GetUserSecret(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, secret); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteUserSecret(ctx, id)
|
||||
return q.db.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error {
|
||||
@@ -2404,6 +2424,10 @@ func (q *querier) GetActiveAISeatCount(ctx context.Context) (int64, error) {
|
||||
return q.db.GetActiveAISeatCount(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetActiveChatsByAgentID)(ctx, agentID)
|
||||
}
|
||||
|
||||
func (q *querier) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTemplate.All()); err != nil {
|
||||
return nil, err
|
||||
@@ -2583,6 +2607,10 @@ func (q *querier) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.C
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]database.GetChatFileMetadataByChatIDRow, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatFileMetadataByChatID)(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
|
||||
files, err := q.db.GetChatFilesByIDs(ctx, ids)
|
||||
if err != nil {
|
||||
@@ -2623,6 +2651,14 @@ func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.Ch
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
|
||||
// Telemetry queries are called from system contexts only.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatMessageSummariesPerChat(ctx, createdAfter)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, arg.ChatID)
|
||||
@@ -2671,6 +2707,14 @@ func (q *querier) GetChatModelConfigs(ctx context.Context) ([]database.ChatModel
|
||||
return q.db.GetChatModelConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
|
||||
// Telemetry queries are called from system contexts only.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatModelConfigsForTelemetry(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
@@ -2700,6 +2744,15 @@ func (q *querier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) (
|
||||
return q.db.GetChatQueuedMessages(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatRetentionDays(ctx context.Context) (int32, error) {
|
||||
// Chat retention is a deployment-wide config read by dbpurge.
|
||||
// Only requires a valid actor in context.
|
||||
if _, ok := ActorFromContext(ctx); !ok {
|
||||
return 0, ErrNoActor
|
||||
}
|
||||
return q.db.GetChatRetentionDays(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
// The system prompt is a deployment-wide setting read during chat
|
||||
// creation by every authenticated user, so no RBAC policy check
|
||||
@@ -2778,6 +2831,14 @@ func (q *querier) GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) (
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByWorkspaceIDs)(ctx, ids)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
|
||||
// Telemetry queries are called from system contexts only.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatsUpdatedAfter(ctx, updatedAfter)
|
||||
}
|
||||
|
||||
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
// Just like with the audit logs query, shortcut if the user is an owner.
|
||||
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
|
||||
@@ -4124,19 +4185,6 @@ func (q *querier) GetUserNotificationPreferences(ctx context.Context, userID uui
|
||||
return q.db.GetUserNotificationPreferences(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.db.GetUserSecret(ctx, id)
|
||||
if err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, secret); err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
|
||||
@@ -5393,6 +5441,17 @@ func (q *querier) InsertWorkspaceResourceMetadata(ctx context.Context, arg datab
|
||||
return q.db.InsertWorkspaceResourceMetadata(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) LinkChatFiles(ctx context.Context, arg database.LinkChatFilesParams) (int32, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.LinkChatFiles(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
@@ -5509,7 +5568,7 @@ func (q *querier) ListUserChatCompactionThresholds(ctx context.Context, userID u
|
||||
return q.db.ListUserChatCompactionThresholds(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(userID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
|
||||
return nil, err
|
||||
@@ -5517,6 +5576,16 @@ func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]data
|
||||
return q.db.ListUserSecrets(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
// This query returns decrypted secret values and must only be called
|
||||
// from system contexts (provisioner, agent manifest). REST API
|
||||
// handlers should use ListUserSecrets (metadata only).
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ListUserSecretsWithValues(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
workspace, err := q.db.GetWorkspaceByID(ctx, workspaceID)
|
||||
if err != nil {
|
||||
@@ -5674,6 +5743,17 @@ func (q *querier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg databas
|
||||
return q.db.SoftDeleteChatMessagesAfterID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.SoftDeleteContextFileMessages(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
|
||||
return q.db.TryAcquireLock(ctx, id)
|
||||
}
|
||||
@@ -5767,15 +5847,15 @@ func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByI
|
||||
return q.db.UpdateChatByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
func (q *querier) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
// The batch heartbeat is a system-level operation filtered by
|
||||
// worker_id. Authorization is enforced by the AsChatd context
|
||||
// at the call site rather than per-row, because checking each
|
||||
// row individually would defeat the purpose of batching.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.UpdateChatHeartbeat(ctx, arg)
|
||||
return q.db.UpdateChatHeartbeats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateChatLabelsByIDParams) (database.Chat, error) {
|
||||
@@ -6617,17 +6697,12 @@ func (q *querier) UpdateUserRoles(ctx context.Context, arg database.UpdateUserRo
|
||||
return q.db.UpdateUserRoles(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.db.GetUserSecret(ctx, arg.ID)
|
||||
if err != nil {
|
||||
func (q *querier) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, obj); err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, secret); err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
return q.db.UpdateUserSecret(ctx, arg)
|
||||
return q.db.UpdateUserSecretByUserIDAndName(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserStatus(ctx context.Context, arg database.UpdateUserStatusParams) (database.User, error) {
|
||||
@@ -7029,6 +7104,13 @@ func (q *querier) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, incl
|
||||
return q.db.UpsertChatIncludeDefaultSystemPrompt(ctx, includeDefaultSystemPrompt)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertChatRetentionDays(ctx, retentionDays)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
|
||||
@@ -400,6 +400,17 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return([]database.Chat{chat}, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns([]database.Chat{chat})
|
||||
}))
|
||||
s.Run("LinkChatFiles", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.LinkChatFilesParams{
|
||||
ChatID: chat.ID,
|
||||
MaxFileLinks: int32(codersdk.MaxChatFileIDs),
|
||||
FileIds: []uuid.UUID{uuid.New()},
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().LinkChatFiles(gomock.Any(), arg).Return(int32(0), nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int32(0))
|
||||
}))
|
||||
s.Run("PinChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
@@ -467,6 +478,24 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatsByWorkspaceIDs(gomock.Any(), arg).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
|
||||
}))
|
||||
s.Run("GetActiveChatsByAgentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
agentID := uuid.New()
|
||||
dbm.EXPECT().GetActiveChatsByAgentID(gomock.Any(), agentID).Return([]database.Chat{chat}, nil).AnyTimes()
|
||||
check.Args(agentID).Asserts(chat, policy.ActionRead).Returns([]database.Chat{chat})
|
||||
}))
|
||||
s.Run("SoftDeleteContextFileMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().SoftDeleteContextFileMessages(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("ClearChatMessageProviderResponseIDsByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().ClearChatMessageProviderResponseIDsByChatID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("GetChatCostPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetChatCostPerChatParams{
|
||||
OwnerID: uuid.New(),
|
||||
@@ -576,6 +605,35 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatFilesByIDs(gomock.Any(), []uuid.UUID{file.ID}).Return([]database.ChatFile{file}, nil).AnyTimes()
|
||||
check.Args([]uuid.UUID{file.ID}).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns([]database.ChatFile{file})
|
||||
}))
|
||||
s.Run("GetChatFileMetadataByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
file := testutil.Fake(s.T(), faker, database.ChatFile{})
|
||||
rows := []database.GetChatFileMetadataByChatIDRow{{
|
||||
ID: file.ID,
|
||||
Name: file.Name,
|
||||
Mimetype: file.Mimetype,
|
||||
CreatedAt: file.CreatedAt,
|
||||
OwnerID: file.OwnerID,
|
||||
OrganizationID: file.OrganizationID,
|
||||
}}
|
||||
dbm.EXPECT().GetChatFileMetadataByChatID(gomock.Any(), file.ID).Return(rows, nil).AnyTimes()
|
||||
check.Args(file.ID).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns(rows)
|
||||
}))
|
||||
s.Run("DeleteOldChatFiles", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().DeleteOldChatFiles(gomock.Any(), database.DeleteOldChatFilesParams{}).Return(int64(0), nil).AnyTimes()
|
||||
check.Args(database.DeleteOldChatFilesParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
|
||||
}))
|
||||
s.Run("DeleteOldChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().DeleteOldChats(gomock.Any(), database.DeleteOldChatsParams{}).Return(int64(0), nil).AnyTimes()
|
||||
check.Args(database.DeleteOldChatsParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
|
||||
}))
|
||||
s.Run("GetChatRetentionDays", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(30), nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("UpsertChatRetentionDays", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatRetentionDays(gomock.Any(), int32(30)).Return(nil).AnyTimes()
|
||||
check.Args(int32(30)).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
@@ -818,15 +876,15 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
s.Run("UpdateChatHeartbeats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
resultID := uuid.New()
|
||||
arg := database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{resultID},
|
||||
WorkerID: uuid.New(),
|
||||
Now: time.Now(),
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatHeartbeat(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
|
||||
dbm.EXPECT().UpdateChatHeartbeats(gomock.Any(), arg).Return([]uuid.UUID{resultID}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]uuid.UUID{resultID})
|
||||
}))
|
||||
s.Run("UpdateChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
@@ -3972,6 +4030,20 @@ func (s *MethodTestSuite) TestSystemFunctions() {
|
||||
dbm.EXPECT().GetWorkspaceAgentsCreatedAfter(gomock.Any(), ts).Return([]database.WorkspaceAgent{}, nil).AnyTimes()
|
||||
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetChatsUpdatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
ts := dbtime.Now()
|
||||
dbm.EXPECT().GetChatsUpdatedAfter(gomock.Any(), ts).Return([]database.GetChatsUpdatedAfterRow{}, nil).AnyTimes()
|
||||
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetChatMessageSummariesPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
ts := dbtime.Now()
|
||||
dbm.EXPECT().GetChatMessageSummariesPerChat(gomock.Any(), ts).Return([]database.GetChatMessageSummariesPerChatRow{}, nil).AnyTimes()
|
||||
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetChatModelConfigsForTelemetry", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatModelConfigsForTelemetry(gomock.Any()).Return([]database.GetChatModelConfigsForTelemetryRow{}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetWorkspaceAppsCreatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
ts := dbtime.Now()
|
||||
dbm.EXPECT().GetWorkspaceAppsCreatedAfter(gomock.Any(), ts).Return([]database.WorkspaceApp{}, nil).AnyTimes()
|
||||
@@ -5322,19 +5394,20 @@ func (s *MethodTestSuite) TestUserSecrets() {
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
|
||||
Returns(secret)
|
||||
}))
|
||||
s.Run("GetUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
check.Args(secret.ID).
|
||||
Asserts(secret, policy.ActionRead).
|
||||
Returns(secret)
|
||||
}))
|
||||
s.Run("ListUserSecrets", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
|
||||
row := testutil.Fake(s.T(), faker, database.ListUserSecretsRow{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.ListUserSecretsRow{row}, nil).AnyTimes()
|
||||
check.Args(user.ID).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
|
||||
Returns([]database.ListUserSecretsRow{row})
|
||||
}))
|
||||
s.Run("ListUserSecretsWithValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecretsWithValues(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
|
||||
check.Args(user.ID).
|
||||
Asserts(rbac.ResourceSystem, policy.ActionRead).
|
||||
Returns([]database.UserSecret{secret})
|
||||
}))
|
||||
s.Run("CreateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
@@ -5346,23 +5419,22 @@ func (s *MethodTestSuite) TestUserSecrets() {
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionCreate).
|
||||
Returns(ret)
|
||||
}))
|
||||
s.Run("UpdateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
updated := testutil.Fake(s.T(), faker, database.UserSecret{ID: secret.ID})
|
||||
arg := database.UpdateUserSecretParams{ID: secret.ID}
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateUserSecret(gomock.Any(), arg).Return(updated, nil).AnyTimes()
|
||||
s.Run("UpdateUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
updated := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
arg := database.UpdateUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
|
||||
dbm.EXPECT().UpdateUserSecretByUserIDAndName(gomock.Any(), arg).Return(updated, nil).AnyTimes()
|
||||
check.Args(arg).
|
||||
Asserts(secret, policy.ActionUpdate).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionUpdate).
|
||||
Returns(updated)
|
||||
}))
|
||||
s.Run("DeleteUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteUserSecret(gomock.Any(), secret.ID).Return(nil).AnyTimes()
|
||||
check.Args(secret.ID).
|
||||
Asserts(secret, policy.ActionRead, secret, policy.ActionDelete).
|
||||
Returns()
|
||||
s.Run("DeleteUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.DeleteUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
|
||||
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
|
||||
check.Args(arg).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionDelete).
|
||||
Returns(int64(1))
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -1597,6 +1597,7 @@ func UserSecret(t testing.TB, db database.Store, seed database.UserSecret) datab
|
||||
Name: takeFirst(seed.Name, "secret-name"),
|
||||
Description: takeFirst(seed.Description, "secret description"),
|
||||
Value: takeFirst(seed.Value, "secret value"),
|
||||
ValueKeyID: seed.ValueKeyID,
|
||||
EnvName: takeFirst(seed.EnvName, "SECRET_ENV_NAME"),
|
||||
FilePath: takeFirst(seed.FilePath, "~/secret/file/path"),
|
||||
})
|
||||
@@ -1643,6 +1644,8 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
|
||||
ThreadParentInterceptionID: seed.ThreadParentInterceptionID,
|
||||
ThreadRootInterceptionID: seed.ThreadRootInterceptionID,
|
||||
ClientSessionID: seed.ClientSessionID,
|
||||
CredentialKind: takeFirst(seed.CredentialKind, database.CredentialKindCentralized),
|
||||
CredentialHint: takeFirst(seed.CredentialHint, ""),
|
||||
})
|
||||
if endedAt != nil {
|
||||
interception, err = db.UpdateAIBridgeInterceptionEnded(genCtx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
|
||||
@@ -280,6 +280,14 @@ func (m queryMetricsStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Conte
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("ClearChatMessageProviderResponseIDsByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ClearChatMessageProviderResponseIDsByChatID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAIBridgeInterceptions(ctx, arg)
|
||||
@@ -592,6 +600,22 @@ func (m queryMetricsStore) DeleteOldAuditLogs(ctx context.Context, arg database.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) ([]database.DeleteOldChatFilesRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteOldChatFiles(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteOldChatFiles").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteOldChatFiles").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteOldChats(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteOldChats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteOldChats").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteOldConnectionLogs(ctx, arg)
|
||||
@@ -712,12 +736,12 @@ func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg da
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserSecret(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecret").Inc()
|
||||
return r0
|
||||
r0, r1 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecretByUserIDAndName").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error {
|
||||
@@ -952,6 +976,14 @@ func (m queryMetricsStore) GetActiveAISeatCount(ctx context.Context) (int64, err
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetActiveChatsByAgentID(ctx, agentID)
|
||||
m.queryLatencies.WithLabelValues("GetActiveChatsByAgentID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetActiveChatsByAgentID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetActivePresetPrebuildSchedules(ctx)
|
||||
@@ -1128,6 +1160,14 @@ func (m queryMetricsStore) GetChatFileByID(ctx context.Context, id uuid.UUID) (d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]database.GetChatFileMetadataByChatIDRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatFileMetadataByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatFileMetadataByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatFileMetadataByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatFilesByIDs(ctx, ids)
|
||||
@@ -1152,6 +1192,14 @@ func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (da
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessageSummariesPerChat(ctx, createdAfter)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessageSummariesPerChat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessageSummariesPerChat").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesByChatID(ctx, chatID)
|
||||
@@ -1200,6 +1248,14 @@ func (m queryMetricsStore) GetChatModelConfigs(ctx context.Context) ([]database.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatModelConfigsForTelemetry(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatModelConfigsForTelemetry").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigsForTelemetry").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviderByID(ctx, id)
|
||||
@@ -1232,6 +1288,14 @@ func (m queryMetricsStore) GetChatQueuedMessages(ctx context.Context, chatID uui
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatRetentionDays(ctx context.Context) (int32, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatRetentionDays(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatRetentionDays").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatRetentionDays").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatSystemPrompt(ctx)
|
||||
@@ -1304,6 +1368,14 @@ func (m queryMetricsStore) GetChatsByWorkspaceIDs(ctx context.Context, ids []uui
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatsUpdatedAfter(ctx, updatedAfter)
|
||||
m.queryLatencies.WithLabelValues("GetChatsUpdatedAfter").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsUpdatedAfter").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetConnectionLogsOffset(ctx, arg)
|
||||
@@ -2616,14 +2688,6 @@ func (m queryMetricsStore) GetUserNotificationPreferences(ctx context.Context, u
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserSecret(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserSecret").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserSecretByUserIDAndName(ctx, arg)
|
||||
@@ -3776,6 +3840,14 @@ func (m queryMetricsStore) InsertWorkspaceResourceMetadata(ctx context.Context,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) LinkChatFiles(ctx context.Context, arg database.LinkChatFilesParams) (int32, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.LinkChatFiles(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("LinkChatFiles").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "LinkChatFiles").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeClients(ctx, arg)
|
||||
@@ -3904,7 +3976,7 @@ func (m queryMetricsStore) ListUserChatCompactionThresholds(ctx context.Context,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListUserSecrets(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("ListUserSecrets").Observe(time.Since(start).Seconds())
|
||||
@@ -3912,6 +3984,14 @@ func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListUserSecretsWithValues(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("ListUserSecretsWithValues").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListUserSecretsWithValues").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListWorkspaceAgentPortShares(ctx, workspaceID)
|
||||
@@ -4040,6 +4120,14 @@ func (m queryMetricsStore) SoftDeleteChatMessagesAfterID(ctx context.Context, ar
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.SoftDeleteContextFileMessages(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("SoftDeleteContextFileMessages").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "SoftDeleteContextFileMessages").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.TryAcquireLock(ctx, pgTryAdvisoryXactLock)
|
||||
@@ -4120,11 +4208,11 @@ func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.Upda
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
func (m queryMetricsStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatHeartbeat(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatHeartbeat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeat").Inc()
|
||||
r0, r1 := m.s.UpdateChatHeartbeats(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatHeartbeats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeats").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -4680,11 +4768,11 @@ func (m queryMetricsStore) UpdateUserRoles(ctx context.Context, arg database.Upd
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
func (m queryMetricsStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateUserSecret(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecret").Inc()
|
||||
r0, r1 := m.s.UpdateUserSecretByUserIDAndName(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecretByUserIDAndName").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -4984,6 +5072,14 @@ func (m queryMetricsStore) UpsertChatIncludeDefaultSystemPrompt(ctx context.Cont
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatRetentionDays(ctx, retentionDays)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatRetentionDays").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatRetentionDays").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatSystemPrompt(ctx context.Context, value string) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatSystemPrompt(ctx, value)
|
||||
|
||||
@@ -363,6 +363,20 @@ func (mr *MockStoreMockRecorder) CleanupDeletedMCPServerIDsFromChats(ctx any) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupDeletedMCPServerIDsFromChats", reflect.TypeOf((*MockStore)(nil).CleanupDeletedMCPServerIDsFromChats), ctx)
|
||||
}
|
||||
|
||||
// ClearChatMessageProviderResponseIDsByChatID mocks base method.
|
||||
func (m *MockStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClearChatMessageProviderResponseIDsByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClearChatMessageProviderResponseIDsByChatID indicates an expected call of ClearChatMessageProviderResponseIDsByChatID.
|
||||
func (mr *MockStoreMockRecorder) ClearChatMessageProviderResponseIDsByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearChatMessageProviderResponseIDsByChatID", reflect.TypeOf((*MockStore)(nil).ClearChatMessageProviderResponseIDsByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// CountAIBridgeInterceptions mocks base method.
|
||||
func (m *MockStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -984,6 +998,36 @@ func (mr *MockStoreMockRecorder) DeleteOldAuditLogs(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldAuditLogs", reflect.TypeOf((*MockStore)(nil).DeleteOldAuditLogs), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteOldChatFiles mocks base method.
|
||||
func (m *MockStore) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) ([]database.DeleteOldChatFilesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteOldChatFiles", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.DeleteOldChatFilesRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteOldChatFiles indicates an expected call of DeleteOldChatFiles.
|
||||
func (mr *MockStoreMockRecorder) DeleteOldChatFiles(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldChatFiles", reflect.TypeOf((*MockStore)(nil).DeleteOldChatFiles), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteOldChats mocks base method.
|
||||
func (m *MockStore) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteOldChats", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteOldChats indicates an expected call of DeleteOldChats.
|
||||
func (mr *MockStoreMockRecorder) DeleteOldChats(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldChats", reflect.TypeOf((*MockStore)(nil).DeleteOldChats), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteOldConnectionLogs mocks base method.
|
||||
func (m *MockStore) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1199,18 +1243,19 @@ func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserSecret mocks base method.
|
||||
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
// DeleteUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteUserSecret", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret := m.ctrl.Call(m, "DeleteUserSecretByUserIDAndName", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteUserSecret indicates an expected call of DeleteUserSecret.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserSecret(ctx, id any) *gomock.Call {
|
||||
// DeleteUserSecretByUserIDAndName indicates an expected call of DeleteUserSecretByUserIDAndName.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecret", reflect.TypeOf((*MockStore)(nil).DeleteUserSecret), ctx, id)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).DeleteUserSecretByUserIDAndName), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteWebpushSubscriptionByUserIDAndEndpoint mocks base method.
|
||||
@@ -1637,6 +1682,21 @@ func (mr *MockStoreMockRecorder) GetActiveAISeatCount(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveAISeatCount", reflect.TypeOf((*MockStore)(nil).GetActiveAISeatCount), ctx)
|
||||
}
|
||||
|
||||
// GetActiveChatsByAgentID mocks base method.
|
||||
func (m *MockStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveChatsByAgentID", ctx, agentID)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveChatsByAgentID indicates an expected call of GetActiveChatsByAgentID.
|
||||
func (mr *MockStoreMockRecorder) GetActiveChatsByAgentID(ctx, agentID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveChatsByAgentID", reflect.TypeOf((*MockStore)(nil).GetActiveChatsByAgentID), ctx, agentID)
|
||||
}
|
||||
|
||||
// GetActivePresetPrebuildSchedules mocks base method.
|
||||
func (m *MockStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2072,6 +2132,21 @@ func (mr *MockStoreMockRecorder) GetChatFileByID(ctx, id any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFileByID", reflect.TypeOf((*MockStore)(nil).GetChatFileByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatFileMetadataByChatID mocks base method.
|
||||
func (m *MockStore) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]database.GetChatFileMetadataByChatIDRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatFileMetadataByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].([]database.GetChatFileMetadataByChatIDRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatFileMetadataByChatID indicates an expected call of GetChatFileMetadataByChatID.
|
||||
func (mr *MockStoreMockRecorder) GetChatFileMetadataByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFileMetadataByChatID", reflect.TypeOf((*MockStore)(nil).GetChatFileMetadataByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// GetChatFilesByIDs mocks base method.
|
||||
func (m *MockStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2117,6 +2192,21 @@ func (mr *MockStoreMockRecorder) GetChatMessageByID(ctx, id any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessageByID", reflect.TypeOf((*MockStore)(nil).GetChatMessageByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatMessageSummariesPerChat mocks base method.
|
||||
func (m *MockStore) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatMessageSummariesPerChat", ctx, createdAfter)
|
||||
ret0, _ := ret[0].([]database.GetChatMessageSummariesPerChatRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatMessageSummariesPerChat indicates an expected call of GetChatMessageSummariesPerChat.
|
||||
func (mr *MockStoreMockRecorder) GetChatMessageSummariesPerChat(ctx, createdAfter any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessageSummariesPerChat", reflect.TypeOf((*MockStore)(nil).GetChatMessageSummariesPerChat), ctx, createdAfter)
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatID mocks base method.
|
||||
func (m *MockStore) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2207,6 +2297,21 @@ func (mr *MockStoreMockRecorder) GetChatModelConfigs(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigs", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigs), ctx)
|
||||
}
|
||||
|
||||
// GetChatModelConfigsForTelemetry mocks base method.
|
||||
func (m *MockStore) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatModelConfigsForTelemetry", ctx)
|
||||
ret0, _ := ret[0].([]database.GetChatModelConfigsForTelemetryRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatModelConfigsForTelemetry indicates an expected call of GetChatModelConfigsForTelemetry.
|
||||
func (mr *MockStoreMockRecorder) GetChatModelConfigsForTelemetry(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigsForTelemetry", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigsForTelemetry), ctx)
|
||||
}
|
||||
|
||||
// GetChatProviderByID mocks base method.
|
||||
func (m *MockStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2267,6 +2372,21 @@ func (mr *MockStoreMockRecorder) GetChatQueuedMessages(ctx, chatID any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatQueuedMessages", reflect.TypeOf((*MockStore)(nil).GetChatQueuedMessages), ctx, chatID)
|
||||
}
|
||||
|
||||
// GetChatRetentionDays mocks base method.
|
||||
func (m *MockStore) GetChatRetentionDays(ctx context.Context) (int32, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatRetentionDays", ctx)
|
||||
ret0, _ := ret[0].(int32)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatRetentionDays indicates an expected call of GetChatRetentionDays.
|
||||
func (mr *MockStoreMockRecorder) GetChatRetentionDays(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatRetentionDays", reflect.TypeOf((*MockStore)(nil).GetChatRetentionDays), ctx)
|
||||
}
|
||||
|
||||
// GetChatSystemPrompt mocks base method.
|
||||
func (m *MockStore) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2402,6 +2522,21 @@ func (mr *MockStoreMockRecorder) GetChatsByWorkspaceIDs(ctx, ids any) *gomock.Ca
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByWorkspaceIDs", reflect.TypeOf((*MockStore)(nil).GetChatsByWorkspaceIDs), ctx, ids)
|
||||
}
|
||||
|
||||
// GetChatsUpdatedAfter mocks base method.
|
||||
func (m *MockStore) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatsUpdatedAfter", ctx, updatedAfter)
|
||||
ret0, _ := ret[0].([]database.GetChatsUpdatedAfterRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatsUpdatedAfter indicates an expected call of GetChatsUpdatedAfter.
|
||||
func (mr *MockStoreMockRecorder) GetChatsUpdatedAfter(ctx, updatedAfter any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsUpdatedAfter", reflect.TypeOf((*MockStore)(nil).GetChatsUpdatedAfter), ctx, updatedAfter)
|
||||
}
|
||||
|
||||
// GetConnectionLogsOffset mocks base method.
|
||||
func (m *MockStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4892,21 +5027,6 @@ func (mr *MockStoreMockRecorder) GetUserNotificationPreferences(ctx, userID any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserNotificationPreferences", reflect.TypeOf((*MockStore)(nil).GetUserNotificationPreferences), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserSecret mocks base method.
|
||||
func (m *MockStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserSecret", ctx, id)
|
||||
ret0, _ := ret[0].(database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserSecret indicates an expected call of GetUserSecret.
|
||||
func (mr *MockStoreMockRecorder) GetUserSecret(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserSecret", reflect.TypeOf((*MockStore)(nil).GetUserSecret), ctx, id)
|
||||
}
|
||||
|
||||
// GetUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7066,6 +7186,21 @@ func (mr *MockStoreMockRecorder) InsertWorkspaceResourceMetadata(ctx, arg any) *
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertWorkspaceResourceMetadata", reflect.TypeOf((*MockStore)(nil).InsertWorkspaceResourceMetadata), ctx, arg)
|
||||
}
|
||||
|
||||
// LinkChatFiles mocks base method.
|
||||
func (m *MockStore) LinkChatFiles(ctx context.Context, arg database.LinkChatFilesParams) (int32, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkChatFiles", ctx, arg)
|
||||
ret0, _ := ret[0].(int32)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LinkChatFiles indicates an expected call of LinkChatFiles.
|
||||
func (mr *MockStoreMockRecorder) LinkChatFiles(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkChatFiles", reflect.TypeOf((*MockStore)(nil).LinkChatFiles), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeClients mocks base method.
|
||||
func (m *MockStore) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7382,10 +7517,10 @@ func (mr *MockStoreMockRecorder) ListUserChatCompactionThresholds(ctx, userID an
|
||||
}
|
||||
|
||||
// ListUserSecrets mocks base method.
|
||||
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListUserSecrets", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.UserSecret)
|
||||
ret0, _ := ret[0].([]database.ListUserSecretsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -7396,6 +7531,21 @@ func (mr *MockStoreMockRecorder) ListUserSecrets(ctx, userID any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecrets", reflect.TypeOf((*MockStore)(nil).ListUserSecrets), ctx, userID)
|
||||
}
|
||||
|
||||
// ListUserSecretsWithValues mocks base method.
|
||||
func (m *MockStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListUserSecretsWithValues", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListUserSecretsWithValues indicates an expected call of ListUserSecretsWithValues.
|
||||
func (mr *MockStoreMockRecorder) ListUserSecretsWithValues(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecretsWithValues", reflect.TypeOf((*MockStore)(nil).ListUserSecretsWithValues), ctx, userID)
|
||||
}
|
||||
|
||||
// ListWorkspaceAgentPortShares mocks base method.
|
||||
func (m *MockStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7660,6 +7810,20 @@ func (mr *MockStoreMockRecorder) SoftDeleteChatMessagesAfterID(ctx, arg any) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).SoftDeleteChatMessagesAfterID), ctx, arg)
|
||||
}
|
||||
|
||||
// SoftDeleteContextFileMessages mocks base method.
|
||||
func (m *MockStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SoftDeleteContextFileMessages", ctx, chatID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SoftDeleteContextFileMessages indicates an expected call of SoftDeleteContextFileMessages.
|
||||
func (mr *MockStoreMockRecorder) SoftDeleteContextFileMessages(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteContextFileMessages", reflect.TypeOf((*MockStore)(nil).SoftDeleteContextFileMessages), ctx, chatID)
|
||||
}
|
||||
|
||||
// TryAcquireLock mocks base method.
|
||||
func (m *MockStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7805,19 +7969,19 @@ func (mr *MockStoreMockRecorder) UpdateChatByID(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatByID", reflect.TypeOf((*MockStore)(nil).UpdateChatByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatHeartbeat mocks base method.
|
||||
func (m *MockStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
// UpdateChatHeartbeats mocks base method.
|
||||
func (m *MockStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatHeartbeat", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret := m.ctrl.Call(m, "UpdateChatHeartbeats", ctx, arg)
|
||||
ret0, _ := ret[0].([]uuid.UUID)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatHeartbeat indicates an expected call of UpdateChatHeartbeat.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call {
|
||||
// UpdateChatHeartbeats indicates an expected call of UpdateChatHeartbeats.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatHeartbeats(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeats", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeats), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatLabelsByID mocks base method.
|
||||
@@ -8824,19 +8988,19 @@ func (mr *MockStoreMockRecorder) UpdateUserRoles(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserRoles", reflect.TypeOf((*MockStore)(nil).UpdateUserRoles), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserSecret mocks base method.
|
||||
func (m *MockStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
// UpdateUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateUserSecret", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "UpdateUserSecretByUserIDAndName", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateUserSecret indicates an expected call of UpdateUserSecret.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserSecret(ctx, arg any) *gomock.Call {
|
||||
// UpdateUserSecretByUserIDAndName indicates an expected call of UpdateUserSecretByUserIDAndName.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecret", reflect.TypeOf((*MockStore)(nil).UpdateUserSecret), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).UpdateUserSecretByUserIDAndName), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserStatus mocks base method.
|
||||
@@ -9369,6 +9533,20 @@ func (mr *MockStoreMockRecorder) UpsertChatIncludeDefaultSystemPrompt(ctx, inclu
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatIncludeDefaultSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatIncludeDefaultSystemPrompt), ctx, includeDefaultSystemPrompt)
|
||||
}
|
||||
|
||||
// UpsertChatRetentionDays mocks base method.
|
||||
func (m *MockStore) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatRetentionDays", ctx, retentionDays)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertChatRetentionDays indicates an expected call of UpsertChatRetentionDays.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatRetentionDays(ctx, retentionDays any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatRetentionDays", reflect.TypeOf((*MockStore)(nil).UpsertChatRetentionDays), ctx, retentionDays)
|
||||
}
|
||||
|
||||
// UpsertChatSystemPrompt mocks base method.
|
||||
func (m *MockStore) UpsertChatSystemPrompt(ctx context.Context, value string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -3,6 +3,7 @@ package dbpurge
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/objstore"
|
||||
"github.com/coder/coder/v2/coderd/pproflabel"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
@@ -34,13 +36,22 @@ const (
|
||||
// long enough to cover the maximum interval of a heartbeat event (currently
|
||||
// 1 hour) plus some buffer.
|
||||
maxTelemetryHeartbeatAge = 24 * time.Hour
|
||||
// Batch sizes for chat purging. Both use 1000, which is smaller
|
||||
// than audit/connection log batches (10000), because chat_files
|
||||
// rows contain bytea blob data that make large batches heavier.
|
||||
chatsBatchSize = 1000
|
||||
chatFilesBatchSize = 1000
|
||||
)
|
||||
|
||||
// chatFilesNamespace is the object store namespace under which chat
|
||||
// files are stored.
|
||||
const chatFilesNamespace = "chatfiles"
|
||||
|
||||
// New creates a new periodically purging database instance.
|
||||
// It is the caller's responsibility to call Close on the returned instance.
|
||||
//
|
||||
// This is for cleaning up old, unused resources from the database that take up space.
|
||||
func New(ctx context.Context, logger slog.Logger, db database.Store, vals *codersdk.DeploymentValues, clk quartz.Clock, reg prometheus.Registerer) io.Closer {
|
||||
func New(ctx context.Context, logger slog.Logger, db database.Store, vals *codersdk.DeploymentValues, clk quartz.Clock, reg prometheus.Registerer, objStore objstore.Store) io.Closer {
|
||||
closed := make(chan struct{})
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(ctx)
|
||||
@@ -64,6 +75,22 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, vals *coder
|
||||
}, []string{"record_type"})
|
||||
reg.MustRegister(recordsPurged)
|
||||
|
||||
objStoreInflight := prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "dbpurge",
|
||||
Name: "objstore_delete_inflight",
|
||||
Help: "Number of object store files currently enqueued for deletion.",
|
||||
})
|
||||
reg.MustRegister(objStoreInflight)
|
||||
|
||||
objStoreDeleted := prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "dbpurge",
|
||||
Name: "objstore_files_deleted_total",
|
||||
Help: "Total number of object store files successfully deleted.",
|
||||
})
|
||||
reg.MustRegister(objStoreDeleted)
|
||||
|
||||
inst := &instance{
|
||||
cancel: cancelFunc,
|
||||
closed: closed,
|
||||
@@ -72,6 +99,9 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, vals *coder
|
||||
clk: clk,
|
||||
iterationDuration: iterationDuration,
|
||||
recordsPurged: recordsPurged,
|
||||
objStore: objStore,
|
||||
objStoreInflight: objStoreInflight,
|
||||
objStoreDeleted: objStoreDeleted,
|
||||
}
|
||||
|
||||
// Start the ticker with the initial delay.
|
||||
@@ -109,6 +139,17 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, vals *coder
|
||||
// purgeTick performs a single purge iteration. It returns an error if the
|
||||
// purge fails.
|
||||
func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.Time) error {
|
||||
// Read chat retention config outside the transaction to
|
||||
// avoid poisoning the tx if the stored value is corrupt.
|
||||
// A SQL-level cast error (e.g. non-numeric text) puts PG
|
||||
// into error state, failing all subsequent queries in the
|
||||
// same transaction.
|
||||
chatRetentionDays, err := db.GetChatRetentionDays(ctx)
|
||||
if err != nil {
|
||||
i.logger.Warn(ctx, "failed to read chat retention config, skipping chat purge", slog.Error(err))
|
||||
chatRetentionDays = 0
|
||||
}
|
||||
|
||||
// Start a transaction to grab advisory lock, we don't want to run
|
||||
// multiple purges at the same time (multiple replicas).
|
||||
return db.InTx(func(tx database.Store) error {
|
||||
@@ -213,12 +254,50 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.
|
||||
}
|
||||
}
|
||||
|
||||
// Chat retention is configured via site_configs. When
|
||||
// enabled, old archived chats are deleted first, then
|
||||
// orphaned chat files. Deleting a chat cascades to
|
||||
// chat_file_links (removing references) but not to
|
||||
// chat_files directly, so files from deleted chats
|
||||
// become orphaned and are caught by DeleteOldChatFiles
|
||||
// in the same tick.
|
||||
var purgedChats int64
|
||||
var purgedChatFiles int64
|
||||
if chatRetentionDays > 0 {
|
||||
chatRetention := time.Duration(chatRetentionDays) * 24 * time.Hour
|
||||
deleteChatsBefore := start.Add(-chatRetention)
|
||||
|
||||
purgedChats, err = tx.DeleteOldChats(ctx, database.DeleteOldChatsParams{
|
||||
BeforeTime: deleteChatsBefore,
|
||||
LimitCount: chatsBatchSize,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to delete old chats: %w", err)
|
||||
}
|
||||
|
||||
deletedFiles, err := tx.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
|
||||
BeforeTime: deleteChatsBefore,
|
||||
LimitCount: chatFilesBatchSize,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to delete old chat files: %w", err)
|
||||
}
|
||||
purgedChatFiles = int64(len(deletedFiles))
|
||||
|
||||
// Collect object store keys from the deleted rows
|
||||
// and delete them in a background goroutine so
|
||||
// slow object store I/O does not hold the
|
||||
// advisory lock or block the next tick.
|
||||
i.deleteObjStoreKeys(ctx, deletedFiles)
|
||||
}
|
||||
i.logger.Debug(ctx, "purged old database entries",
|
||||
slog.F("workspace_agent_logs", purgedWorkspaceAgentLogs),
|
||||
slog.F("expired_api_keys", expiredAPIKeys),
|
||||
slog.F("aibridge_records", purgedAIBridgeRecords),
|
||||
slog.F("connection_logs", purgedConnectionLogs),
|
||||
slog.F("audit_logs", purgedAuditLogs),
|
||||
slog.F("chats", purgedChats),
|
||||
slog.F("chat_files", purgedChatFiles),
|
||||
slog.F("duration", i.clk.Since(start)),
|
||||
)
|
||||
|
||||
@@ -232,6 +311,8 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.
|
||||
i.recordsPurged.WithLabelValues("aibridge_records").Add(float64(purgedAIBridgeRecords))
|
||||
i.recordsPurged.WithLabelValues("connection_logs").Add(float64(purgedConnectionLogs))
|
||||
i.recordsPurged.WithLabelValues("audit_logs").Add(float64(purgedAuditLogs))
|
||||
i.recordsPurged.WithLabelValues("chats").Add(float64(purgedChats))
|
||||
i.recordsPurged.WithLabelValues("chat_files").Add(float64(purgedChatFiles))
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -246,6 +327,13 @@ type instance struct {
|
||||
clk quartz.Clock
|
||||
iterationDuration *prometheus.HistogramVec
|
||||
recordsPurged *prometheus.CounterVec
|
||||
objStore objstore.Store
|
||||
objStoreInflight prometheus.Gauge
|
||||
objStoreDeleted prometheus.Counter
|
||||
|
||||
// objDeleteMu serializes background object store delete batches
|
||||
// so at most one goroutine is deleting at a time.
|
||||
objDeleteMu sync.Mutex
|
||||
}
|
||||
|
||||
func (i *instance) Close() error {
|
||||
@@ -253,3 +341,62 @@ func (i *instance) Close() error {
|
||||
<-i.closed
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteObjStoreKeys removes object store entries for the given
|
||||
// deleted chat file rows. The work runs in a background goroutine
|
||||
// guarded by a mutex so that slow object store I/O never blocks
|
||||
// the purge transaction or the next tick. At most one delete batch
|
||||
// runs at a time; if a batch is already in flight the new keys are
|
||||
// silently dropped (they will be orphan-collected on a future tick
|
||||
// if needed).
|
||||
func (i *instance) deleteObjStoreKeys(ctx context.Context, rows []database.DeleteOldChatFilesRow) {
|
||||
// Collect non-empty object store keys.
|
||||
var keys []string
|
||||
for _, r := range rows {
|
||||
if r.ObjectStoreKey.Valid && r.ObjectStoreKey.String != "" {
|
||||
keys = append(keys, r.ObjectStoreKey.String)
|
||||
}
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Try to acquire the mutex without blocking. If another
|
||||
// delete batch is already running, skip this one.
|
||||
if !i.objDeleteMu.TryLock() {
|
||||
i.logger.Debug(ctx, "object store delete already in progress, skipping batch",
|
||||
slog.F("skipped_keys", len(keys)))
|
||||
return
|
||||
}
|
||||
|
||||
i.objStoreInflight.Add(float64(len(keys)))
|
||||
|
||||
go func() {
|
||||
defer i.objDeleteMu.Unlock()
|
||||
|
||||
var deleted int
|
||||
for _, key := range keys {
|
||||
if ctx.Err() != nil {
|
||||
remaining := len(keys) - deleted
|
||||
i.objStoreInflight.Sub(float64(remaining))
|
||||
i.logger.Debug(ctx, "context canceled during object store cleanup",
|
||||
slog.F("deleted", deleted),
|
||||
slog.F("remaining", remaining))
|
||||
return
|
||||
}
|
||||
if err := i.objStore.Delete(ctx, chatFilesNamespace, key); err != nil {
|
||||
i.logger.Warn(ctx, "failed to delete chat file from object store",
|
||||
slog.F("key", key),
|
||||
slog.Error(err))
|
||||
} else {
|
||||
deleted++
|
||||
}
|
||||
i.objStoreInflight.Dec()
|
||||
}
|
||||
|
||||
i.objStoreDeleted.Add(float64(deleted))
|
||||
i.logger.Debug(ctx, "deleted chat files from object store",
|
||||
slog.F("deleted", deleted),
|
||||
slog.F("failed", len(keys)-deleted))
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -53,8 +54,9 @@ func TestPurge(t *testing.T) {
|
||||
clk := quartz.NewMock(t)
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
mDB := dbmock.NewMockStore(gomock.NewController(t))
|
||||
mDB.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(0), nil).AnyTimes()
|
||||
mDB.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("db_purge")).Return(nil).Times(2)
|
||||
purger := dbpurge.New(context.Background(), testutil.Logger(t), mDB, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
purger := dbpurge.New(context.Background(), testutil.Logger(t), mDB, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry(), nil)
|
||||
<-done // wait for doTick() to run.
|
||||
require.NoError(t, purger.Close())
|
||||
}
|
||||
@@ -88,7 +90,7 @@ func TestMetrics(t *testing.T) {
|
||||
Retention: codersdk.RetentionConfig{
|
||||
APIKeys: serpent.Duration(7 * 24 * time.Hour), // 7 days retention
|
||||
},
|
||||
}, clk, reg)
|
||||
}, clk, reg, nil)
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
@@ -125,6 +127,16 @@ func TestMetrics(t *testing.T) {
|
||||
"record_type": "audit_logs",
|
||||
})
|
||||
require.GreaterOrEqual(t, auditLogs, 0)
|
||||
|
||||
chats := promhelp.CounterValue(t, reg, "coderd_dbpurge_records_purged_total", prometheus.Labels{
|
||||
"record_type": "chats",
|
||||
})
|
||||
require.GreaterOrEqual(t, chats, 0)
|
||||
|
||||
chatFiles := promhelp.CounterValue(t, reg, "coderd_dbpurge_records_purged_total", prometheus.Labels{
|
||||
"record_type": "chat_files",
|
||||
})
|
||||
require.GreaterOrEqual(t, chatFiles, 0)
|
||||
})
|
||||
|
||||
t.Run("FailedIteration", func(t *testing.T) {
|
||||
@@ -138,6 +150,7 @@ func TestMetrics(t *testing.T) {
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
mDB.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(0), nil).AnyTimes()
|
||||
mDB.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("db_purge")).
|
||||
Return(xerrors.New("simulated database error")).
|
||||
MinTimes(1)
|
||||
@@ -145,7 +158,7 @@ func TestMetrics(t *testing.T) {
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, mDB, &codersdk.DeploymentValues{}, clk, reg)
|
||||
closer := dbpurge.New(ctx, logger, mDB, &codersdk.DeploymentValues{}, clk, reg, nil)
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
@@ -235,7 +248,7 @@ func TestDeleteOldWorkspaceAgentStats(t *testing.T) {
|
||||
})
|
||||
|
||||
// when
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry(), nil)
|
||||
defer closer.Close()
|
||||
|
||||
// then
|
||||
@@ -260,7 +273,7 @@ func TestDeleteOldWorkspaceAgentStats(t *testing.T) {
|
||||
|
||||
// Start a new purger to immediately trigger delete after rollup.
|
||||
_ = closer.Close()
|
||||
closer = dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
closer = dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry(), nil)
|
||||
defer closer.Close()
|
||||
|
||||
// then
|
||||
@@ -355,7 +368,7 @@ func TestDeleteOldWorkspaceAgentLogs(t *testing.T) {
|
||||
Retention: codersdk.RetentionConfig{
|
||||
WorkspaceAgentLogs: serpent.Duration(7 * 24 * time.Hour),
|
||||
},
|
||||
}, clk, prometheus.NewRegistry())
|
||||
}, clk, prometheus.NewRegistry(), nil)
|
||||
|
||||
defer closer.Close()
|
||||
<-done // doTick() has now run.
|
||||
@@ -570,7 +583,7 @@ func TestDeleteOldWorkspaceAgentLogsRetention(t *testing.T) {
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{
|
||||
Retention: tc.retentionConfig,
|
||||
}, clk, prometheus.NewRegistry())
|
||||
}, clk, prometheus.NewRegistry(), nil)
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
@@ -661,7 +674,7 @@ func TestDeleteOldProvisionerDaemons(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// when
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry(), nil)
|
||||
defer closer.Close()
|
||||
|
||||
// then
|
||||
@@ -765,7 +778,7 @@ func TestDeleteOldAuditLogConnectionEvents(t *testing.T) {
|
||||
|
||||
// Run the purge
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry(), nil)
|
||||
defer closer.Close()
|
||||
// Wait for tick
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
@@ -928,7 +941,7 @@ func TestDeleteOldTelemetryHeartbeats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry(), nil)
|
||||
defer closer.Close()
|
||||
<-done // doTick() has now run.
|
||||
|
||||
@@ -1047,7 +1060,7 @@ func TestDeleteOldConnectionLogs(t *testing.T) {
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{
|
||||
Retention: tc.retentionConfig,
|
||||
}, clk, prometheus.NewRegistry())
|
||||
}, clk, prometheus.NewRegistry(), nil)
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
@@ -1303,7 +1316,7 @@ func TestDeleteOldAIBridgeRecords(t *testing.T) {
|
||||
Retention: serpent.Duration(tc.retention),
|
||||
},
|
||||
},
|
||||
}, clk, prometheus.NewRegistry())
|
||||
}, clk, prometheus.NewRegistry(), nil)
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
@@ -1390,7 +1403,7 @@ func TestDeleteOldAuditLogs(t *testing.T) {
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{
|
||||
Retention: tc.retentionConfig,
|
||||
}, clk, prometheus.NewRegistry())
|
||||
}, clk, prometheus.NewRegistry(), nil)
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
@@ -1480,7 +1493,7 @@ func TestDeleteOldAuditLogs(t *testing.T) {
|
||||
Retention: codersdk.RetentionConfig{
|
||||
AuditLogs: serpent.Duration(retentionPeriod),
|
||||
},
|
||||
}, clk, prometheus.NewRegistry())
|
||||
}, clk, prometheus.NewRegistry(), nil)
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
@@ -1600,7 +1613,7 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{
|
||||
Retention: tc.retentionConfig,
|
||||
}, clk, prometheus.NewRegistry())
|
||||
}, clk, prometheus.NewRegistry(), nil)
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
@@ -1634,3 +1647,488 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
//nolint:paralleltest // It uses LockIDDBPurge.
|
||||
func TestDeleteOldChatFiles(t *testing.T) {
|
||||
now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// createChatFile inserts a chat file and backdates created_at.
|
||||
createChatFile := func(ctx context.Context, t *testing.T, db database.Store, rawDB *sql.DB, ownerID, orgID uuid.UUID, createdAt time.Time) uuid.UUID {
|
||||
t.Helper()
|
||||
row, err := db.InsertChatFile(ctx, database.InsertChatFileParams{
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
Name: "test.png",
|
||||
Mimetype: "image/png",
|
||||
Data: []byte("fake-image-data"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chat_files SET created_at = $1 WHERE id = $2", createdAt, row.ID)
|
||||
require.NoError(t, err)
|
||||
return row.ID
|
||||
}
|
||||
|
||||
// createChat inserts a chat and optionally archives it, then
|
||||
// backdates updated_at to control the "archived since" window.
|
||||
createChat := func(ctx context.Context, t *testing.T, db database.Store, rawDB *sql.DB, ownerID, modelConfigID uuid.UUID, archived bool, updatedAt time.Time) database.Chat {
|
||||
t.Helper()
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: ownerID,
|
||||
LastModelConfigID: modelConfigID,
|
||||
Title: "test-chat",
|
||||
Status: database.ChatStatusWaiting,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
if archived {
|
||||
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2", updatedAt, chat.ID)
|
||||
require.NoError(t, err)
|
||||
return chat
|
||||
}
|
||||
// setupChatDeps creates the common dependencies needed for
|
||||
// chat-related tests: user, org, org member, provider, model config.
|
||||
type chatDeps struct {
|
||||
user database.User
|
||||
org database.Organization
|
||||
modelConfig database.ChatModelConfig
|
||||
}
|
||||
setupChatDeps := func(ctx context.Context, t *testing.T, db database.Store) chatDeps {
|
||||
t.Helper()
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
mc, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "test-model",
|
||||
ContextLimit: 8192,
|
||||
Options: json.RawMessage("{}"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return chatDeps{user: user, org: org, modelConfig: mc}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(t *testing.T)
|
||||
}{
|
||||
{
|
||||
name: "ChatRetentionDisabled",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
clk := quartz.NewMock(t)
|
||||
clk.Set(now).MustWait(ctx)
|
||||
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
// Disable retention.
|
||||
err := db.UpsertChatRetentionDays(ctx, int32(0))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an old archived chat and an orphaned old file.
|
||||
oldChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
oldFileID := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry(), nil)
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
// Both should still exist.
|
||||
_, err = db.GetChatByID(ctx, oldChat.ID)
|
||||
require.NoError(t, err, "chat should not be deleted when retention is disabled")
|
||||
_, err = db.GetChatFileByID(ctx, oldFileID)
|
||||
require.NoError(t, err, "chat file should not be deleted when retention is disabled")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OldArchivedChatsDeleted",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
clk := quartz.NewMock(t)
|
||||
clk.Set(now).MustWait(ctx)
|
||||
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
err := db.UpsertChatRetentionDays(ctx, int32(30))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Old archived chat (31 days) — should be deleted.
|
||||
oldChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
// Insert a message so we can verify CASCADE.
|
||||
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: oldChat.ID,
|
||||
CreatedBy: []uuid.UUID{deps.user.ID},
|
||||
ModelConfigID: []uuid.UUID{deps.modelConfig.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleUser},
|
||||
Content: []string{`[{"type":"text","text":"hello"}]`},
|
||||
ContentVersion: []int16{0},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0},
|
||||
OutputTokens: []int64{0},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{0},
|
||||
RuntimeMs: []int64{0},
|
||||
ProviderResponseID: []string{""},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Recently archived chat (10 days) — should be retained.
|
||||
recentChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-10*24*time.Hour))
|
||||
|
||||
// Active chat — should be retained.
|
||||
activeChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry(), nil)
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
// Old archived chat should be gone.
|
||||
_, err = db.GetChatByID(ctx, oldChat.ID)
|
||||
require.Error(t, err, "old archived chat should be deleted")
|
||||
|
||||
// Its messages should be gone too (CASCADE).
|
||||
msgs, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: oldChat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, msgs, "messages should be cascade-deleted")
|
||||
|
||||
// Recent archived and active chats should remain.
|
||||
_, err = db.GetChatByID(ctx, recentChat.ID)
|
||||
require.NoError(t, err, "recently archived chat should be retained")
|
||||
_, err = db.GetChatByID(ctx, activeChat.ID)
|
||||
require.NoError(t, err, "active chat should be retained")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OrphanedOldFilesDeleted",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
clk := quartz.NewMock(t)
|
||||
clk.Set(now).MustWait(ctx)
|
||||
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
err := db.UpsertChatRetentionDays(ctx, int32(30))
|
||||
require.NoError(t, err)
|
||||
|
||||
// File A: 31 days old, NOT in any chat -> should be deleted.
|
||||
fileA := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
|
||||
// File B: 31 days old, in an active chat -> should be retained.
|
||||
fileB := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
activeChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: activeChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileB},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// File C: 10 days old, NOT in any chat -> should be retained (too young).
|
||||
fileC := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-10*24*time.Hour))
|
||||
|
||||
// File near boundary: 29d23h old — close to threshold.
|
||||
fileBoundary := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-30*24*time.Hour).Add(time.Hour))
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry(), nil)
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileA)
|
||||
require.Error(t, err, "orphaned old file A should be deleted")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileB)
|
||||
require.NoError(t, err, "file B in active chat should be retained")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileC)
|
||||
require.NoError(t, err, "young file C should be retained")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileBoundary)
|
||||
require.NoError(t, err, "file near 30d boundary should be retained")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ArchivedChatFilesDeleted",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
clk := quartz.NewMock(t)
|
||||
clk.Set(now).MustWait(ctx)
|
||||
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
err := db.UpsertChatRetentionDays(ctx, int32(30))
|
||||
require.NoError(t, err)
|
||||
|
||||
// File D: 31 days old, in a chat archived 31 days ago -> should be deleted.
|
||||
fileD := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
oldArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: oldArchivedChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileD},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// LinkChatFiles does not update chats.updated_at, so backdate.
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
|
||||
now.Add(-31*24*time.Hour), oldArchivedChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// File E: 31 days old, in a chat archived 10 days ago -> should be retained.
|
||||
fileE := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
recentArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-10*24*time.Hour))
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: recentArchivedChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileE},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
|
||||
now.Add(-10*24*time.Hour), recentArchivedChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// File F: 31 days old, in BOTH an active chat AND an old archived chat -> should be retained.
|
||||
fileF := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
anotherOldArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: anotherOldArchivedChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileF},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
|
||||
now.Add(-31*24*time.Hour), anotherOldArchivedChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
activeChatForF := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: activeChatForF.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileF},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry(), nil)
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileD)
|
||||
require.Error(t, err, "file D in old archived chat should be deleted")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileE)
|
||||
require.NoError(t, err, "file E in recently archived chat should be retained")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileF)
|
||||
require.NoError(t, err, "file F in active + old archived chat should be retained")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UnarchiveAfterFilePurge",
|
||||
run: func(t *testing.T) {
|
||||
// Validates that when dbpurge deletes chat_files rows,
|
||||
// the FK cascade on chat_file_links automatically
|
||||
// removes the stale links. Unarchiving a chat after
|
||||
// file purge should show only surviving files.
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
// Create a chat with three attached files.
|
||||
fileA := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
fileB := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
fileC := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
|
||||
chat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
_, err := db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: chat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileA, fileB, fileC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive the chat.
|
||||
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate dbpurge deleting files A and B. The FK
|
||||
// cascade on chat_file_links_file_id_fkey should
|
||||
// automatically remove the corresponding link rows.
|
||||
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = ANY($1)", pq.Array([]uuid.UUID{fileA, fileB}))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Unarchive the chat.
|
||||
_, err = db.UnarchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Only file C should remain linked (FK cascade
|
||||
// removed the links for deleted files A and B).
|
||||
files, err := db.GetChatFileMetadataByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, files, 1, "only surviving file should be linked")
|
||||
require.Equal(t, fileC, files[0].ID)
|
||||
|
||||
// Edge case: delete the last file too. The chat
|
||||
// should have zero linked files, not an error.
|
||||
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = $1", fileC)
|
||||
require.NoError(t, err)
|
||||
_, err = db.UnarchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
files, err = db.GetChatFileMetadataByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, files, "all-files-deleted should yield empty result")
|
||||
|
||||
// Test parent+child cascade: deleting files should
|
||||
// clean up links for both parent and child chats
|
||||
// independently via FK cascade.
|
||||
parentChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
childChat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: deps.user.ID,
|
||||
LastModelConfigID: deps.modelConfig.ID,
|
||||
Title: "child-chat",
|
||||
Status: database.ChatStatusWaiting,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// Set root_chat_id to link child to parent.
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET root_chat_id = $1 WHERE id = $2", parentChat.ID, childChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attach different files to parent and child.
|
||||
parentFileKeep := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
parentFileStale := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
childFileKeep := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
childFileStale := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: parentChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{parentFileKeep, parentFileStale},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: childChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{childFileKeep, childFileStale},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive via parent (cascades to child).
|
||||
_, err = db.ArchiveChatByID(ctx, parentChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete one file from each chat.
|
||||
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = ANY($1)",
|
||||
pq.Array([]uuid.UUID{parentFileStale, childFileStale}))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Unarchive via parent.
|
||||
_, err = db.UnarchiveChatByID(ctx, parentChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
parentFiles, err := db.GetChatFileMetadataByChatID(ctx, parentChat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, parentFiles, 1)
|
||||
require.Equal(t, parentFileKeep, parentFiles[0].ID,
|
||||
"parent should retain only non-stale file")
|
||||
|
||||
childFiles, err := db.GetChatFileMetadataByChatID(ctx, childChat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, childFiles, 1)
|
||||
require.Equal(t, childFileKeep, childFiles[0].ID,
|
||||
"child should retain only non-stale file")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BatchLimitFiles",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
// Create 3 deletable orphaned files (all 31 days old).
|
||||
for range 3 {
|
||||
createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
}
|
||||
|
||||
// Delete with limit 2 — should delete 2, leave 1.
|
||||
deleted, err := db.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
|
||||
BeforeTime: now.Add(-30 * 24 * time.Hour),
|
||||
LimitCount: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), deleted, "should delete exactly 2 files")
|
||||
|
||||
// Delete again — should delete the remaining 1.
|
||||
deleted, err = db.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
|
||||
BeforeTime: now.Add(-30 * 24 * time.Hour),
|
||||
LimitCount: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), deleted, "should delete remaining 1 file")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BatchLimitChats",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
// Create 3 deletable old archived chats.
|
||||
for range 3 {
|
||||
createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
}
|
||||
|
||||
// Delete with limit 2 — should delete 2, leave 1.
|
||||
deleted, err := db.DeleteOldChats(ctx, database.DeleteOldChatsParams{
|
||||
BeforeTime: now.Add(-30 * 24 * time.Hour),
|
||||
LimitCount: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), deleted, "should delete exactly 2 chats")
|
||||
|
||||
// Delete again — should delete the remaining 1.
|
||||
deleted, err = db.DeleteOldChats(ctx, database.DeleteOldChatsParams{
|
||||
BeforeTime: now.Add(-30 * 24 * time.Hour),
|
||||
LimitCount: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), deleted, "should delete remaining 1 chat")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc.run(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Generated
+36
-4
@@ -293,7 +293,8 @@ CREATE TYPE chat_status AS ENUM (
|
||||
'running',
|
||||
'paused',
|
||||
'completed',
|
||||
'error'
|
||||
'error',
|
||||
'requires_action'
|
||||
);
|
||||
|
||||
CREATE TYPE connection_status AS ENUM (
|
||||
@@ -315,6 +316,11 @@ CREATE TYPE cors_behavior AS ENUM (
|
||||
'passthru'
|
||||
);
|
||||
|
||||
CREATE TYPE credential_kind AS ENUM (
|
||||
'centralized',
|
||||
'byok'
|
||||
);
|
||||
|
||||
CREATE TYPE crypto_key_feature AS ENUM (
|
||||
'workspace_apps_token',
|
||||
'workspace_apps_api_key',
|
||||
@@ -1101,7 +1107,9 @@ CREATE TABLE aibridge_interceptions (
|
||||
thread_root_id uuid,
|
||||
client_session_id character varying(256),
|
||||
session_id text GENERATED ALWAYS AS (COALESCE(client_session_id, ((thread_root_id)::text)::character varying, ((id)::text)::character varying)) STORED NOT NULL,
|
||||
provider_name text DEFAULT ''::text NOT NULL
|
||||
provider_name text DEFAULT ''::text NOT NULL,
|
||||
credential_kind credential_kind DEFAULT 'centralized'::credential_kind NOT NULL,
|
||||
credential_hint character varying(15) DEFAULT ''::character varying NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge';
|
||||
@@ -1118,6 +1126,10 @@ COMMENT ON COLUMN aibridge_interceptions.session_id IS 'Groups related intercept
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.provider_name IS 'The provider instance name which may differ from provider when multiple instances of the same provider type exist.';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.credential_kind IS 'How the request was authenticated: centralized or byok.';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.credential_hint IS 'Masked credential identifier for audit (e.g. sk-a***efgh).';
|
||||
|
||||
CREATE TABLE aibridge_model_thoughts (
|
||||
interception_id uuid NOT NULL,
|
||||
content text NOT NULL,
|
||||
@@ -1269,6 +1281,11 @@ CREATE TABLE chat_diff_statuses (
|
||||
head_branch text
|
||||
);
|
||||
|
||||
CREATE TABLE chat_file_links (
|
||||
chat_id uuid NOT NULL,
|
||||
file_id uuid NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE chat_files (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
owner_id uuid NOT NULL,
|
||||
@@ -1276,7 +1293,8 @@ CREATE TABLE chat_files (
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
name text DEFAULT ''::text NOT NULL,
|
||||
mimetype text NOT NULL,
|
||||
data bytea NOT NULL
|
||||
data bytea,
|
||||
object_store_key text
|
||||
);
|
||||
|
||||
CREATE TABLE chat_messages (
|
||||
@@ -1413,7 +1431,8 @@ CREATE TABLE chats (
|
||||
agent_id uuid,
|
||||
pin_order integer DEFAULT 0 NOT NULL,
|
||||
last_read_message_id bigint,
|
||||
last_injected_context jsonb
|
||||
last_injected_context jsonb,
|
||||
dynamic_tools jsonb
|
||||
);
|
||||
|
||||
CREATE TABLE connection_logs (
|
||||
@@ -3344,6 +3363,9 @@ ALTER TABLE ONLY boundary_usage_stats
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
|
||||
ALTER TABLE ONLY chat_file_links
|
||||
ADD CONSTRAINT chat_file_links_chat_id_file_id_key UNIQUE (chat_id, file_id);
|
||||
|
||||
ALTER TABLE ONLY chat_files
|
||||
ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
|
||||
|
||||
@@ -3734,6 +3756,8 @@ CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC);
|
||||
|
||||
CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses USING btree (stale_at);
|
||||
|
||||
CREATE INDEX idx_chat_file_links_chat_id ON chat_file_links USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chat_files_org ON chat_files USING btree (organization_id);
|
||||
|
||||
CREATE INDEX idx_chat_files_owner ON chat_files USING btree (owner_id);
|
||||
@@ -3760,6 +3784,8 @@ CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled);
|
||||
|
||||
CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_agent_id ON chats USING btree (agent_id) WHERE (agent_id IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_chats_labels ON chats USING gin (labels);
|
||||
|
||||
CREATE INDEX idx_chats_last_model_config_id ON chats USING btree (last_model_config_id);
|
||||
@@ -4036,6 +4062,12 @@ ALTER TABLE ONLY api_keys
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_file_links
|
||||
ADD CONSTRAINT chat_file_links_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_file_links
|
||||
ADD CONSTRAINT chat_file_links_file_id_fkey FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_files
|
||||
ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ const (
|
||||
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
|
||||
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFileLinksChatID ForeignKeyConstraint = "chat_file_links_chat_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFileLinksFileID ForeignKeyConstraint = "chat_file_links_file_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_file_id_fkey FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFilesOrganizationID ForeignKeyConstraint = "chat_files_organization_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFilesOwnerID ForeignKeyConstraint = "chat_files_owner_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatMessagesChatID ForeignKeyConstraint = "chat_messages_chat_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
ALTER TABLE chats ADD COLUMN file_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL;
|
||||
|
||||
UPDATE chats SET file_ids = (
|
||||
SELECT COALESCE(array_agg(cfl.file_id), '{}')
|
||||
FROM chat_file_links cfl
|
||||
WHERE cfl.chat_id = chats.id
|
||||
);
|
||||
|
||||
DROP TABLE chat_file_links;
|
||||
@@ -0,0 +1,17 @@
|
||||
CREATE TABLE chat_file_links (
|
||||
chat_id uuid NOT NULL,
|
||||
file_id uuid NOT NULL,
|
||||
UNIQUE (chat_id, file_id)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chat_file_links_chat_id ON chat_file_links (chat_id);
|
||||
|
||||
ALTER TABLE chat_file_links
|
||||
ADD CONSTRAINT chat_file_links_chat_id_fkey
|
||||
FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE chat_file_links
|
||||
ADD CONSTRAINT chat_file_links_file_id_fkey
|
||||
FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE chats DROP COLUMN IF EXISTS file_ids;
|
||||
@@ -0,0 +1,31 @@
|
||||
-- First update any rows using the value we're about to remove.
|
||||
-- The column type is still the original chat_status at this point.
|
||||
UPDATE chats SET status = 'error' WHERE status = 'requires_action';
|
||||
|
||||
-- Drop the column (this is independent of the enum).
|
||||
ALTER TABLE chats DROP COLUMN IF EXISTS dynamic_tools;
|
||||
|
||||
-- Drop the partial index that references the chat_status enum type.
|
||||
-- It must be removed before the rename-create-cast-drop cycle
|
||||
-- because the index's WHERE clause (status = 'pending'::chat_status)
|
||||
-- would otherwise cause a cross-type comparison failure.
|
||||
DROP INDEX IF EXISTS idx_chats_pending;
|
||||
|
||||
-- Now recreate the enum without requires_action.
|
||||
-- We must use the rename-create-cast-drop pattern.
|
||||
ALTER TYPE chat_status RENAME TO chat_status_old;
|
||||
CREATE TYPE chat_status AS ENUM (
|
||||
'waiting',
|
||||
'pending',
|
||||
'running',
|
||||
'paused',
|
||||
'completed',
|
||||
'error'
|
||||
);
|
||||
ALTER TABLE chats ALTER COLUMN status DROP DEFAULT;
|
||||
ALTER TABLE chats ALTER COLUMN status TYPE chat_status USING status::text::chat_status;
|
||||
ALTER TABLE chats ALTER COLUMN status SET DEFAULT 'waiting';
|
||||
DROP TYPE chat_status_old;
|
||||
|
||||
-- Recreate the partial index.
|
||||
CREATE INDEX idx_chats_pending ON chats USING btree (status) WHERE (status = 'pending'::chat_status);
|
||||
@@ -0,0 +1,3 @@
|
||||
ALTER TYPE chat_status ADD VALUE IF NOT EXISTS 'requires_action';
|
||||
|
||||
ALTER TABLE chats ADD COLUMN dynamic_tools JSONB DEFAULT NULL;
|
||||
@@ -0,0 +1,5 @@
|
||||
ALTER TABLE aibridge_interceptions
|
||||
DROP COLUMN IF EXISTS credential_kind,
|
||||
DROP COLUMN IF EXISTS credential_hint;
|
||||
|
||||
DROP TYPE IF EXISTS credential_kind;
|
||||
@@ -0,0 +1,12 @@
|
||||
CREATE TYPE credential_kind AS ENUM ('centralized', 'byok');
|
||||
|
||||
-- Records how each LLM request was authenticated and a masked credential
|
||||
-- identifier for audit purposes. Existing rows default to 'centralized'
|
||||
-- with an empty hint since we cannot retroactively determine their values.
|
||||
ALTER TABLE aibridge_interceptions
|
||||
ADD COLUMN credential_kind credential_kind NOT NULL DEFAULT 'centralized',
|
||||
-- Length capped as a safety measure to ensure only masked values are stored.
|
||||
ADD COLUMN credential_hint CHARACTER VARYING(15) NOT NULL DEFAULT '';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.credential_kind IS 'How the request was authenticated: centralized or byok.';
|
||||
COMMENT ON COLUMN aibridge_interceptions.credential_hint IS 'Masked credential identifier for audit (e.g. sk-a***efgh).';
|
||||
@@ -0,0 +1 @@
|
||||
DROP INDEX IF EXISTS idx_chats_agent_id;
|
||||
@@ -0,0 +1 @@
|
||||
CREATE INDEX idx_chats_agent_id ON chats(agent_id) WHERE agent_id IS NOT NULL;
|
||||
@@ -0,0 +1,7 @@
|
||||
-- Backfill any NULL data values before restoring NOT NULL would require
|
||||
-- reading from the object store, which is not possible in a migration.
|
||||
-- Instead, delete rows that only exist in the object store.
|
||||
DELETE FROM chat_files WHERE data IS NULL;
|
||||
|
||||
ALTER TABLE chat_files ALTER COLUMN data SET NOT NULL;
|
||||
ALTER TABLE chat_files DROP COLUMN object_store_key;
|
||||
@@ -0,0 +1,8 @@
|
||||
-- Add object_store_key to track files stored in external object storage.
|
||||
-- When non-NULL, the file data lives in the object store under this key
|
||||
-- and the data column may be NULL.
|
||||
ALTER TABLE chat_files ADD COLUMN object_store_key TEXT;
|
||||
|
||||
-- Make data nullable so new writes can skip the BYTEA column when
|
||||
-- storing in the object store.
|
||||
ALTER TABLE chat_files ALTER COLUMN data DROP NOT NULL;
|
||||
@@ -0,0 +1,5 @@
|
||||
INSERT INTO chat_file_links (chat_id, file_id)
|
||||
VALUES (
|
||||
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
|
||||
'00000000-0000-0000-0000-000000000099'
|
||||
);
|
||||
@@ -187,6 +187,10 @@ func (c ChatFile) RBACObject() rbac.Object {
|
||||
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()).InOrg(c.OrganizationID)
|
||||
}
|
||||
|
||||
func (c GetChatFileMetadataByChatIDRow) RBACObject() rbac.Object {
|
||||
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()).InOrg(c.OrganizationID)
|
||||
}
|
||||
|
||||
func (s APIKeyScope) ToRBAC() rbac.ScopeName {
|
||||
switch s {
|
||||
case ApiKeyScopeCoderAll:
|
||||
|
||||
@@ -584,6 +584,7 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi
|
||||
arg.DateTo,
|
||||
arg.BuildReason,
|
||||
arg.RequestID,
|
||||
arg.CountCap,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -720,6 +721,7 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
|
||||
arg.WorkspaceID,
|
||||
arg.ConnectionID,
|
||||
arg.Status,
|
||||
arg.CountCap,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -796,6 +798,7 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
|
||||
&i.Chat.PinOrder,
|
||||
&i.Chat.LastReadMessageID,
|
||||
&i.Chat.LastInjectedContext,
|
||||
&i.Chat.DynamicTools,
|
||||
&i.HasUnread); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -866,6 +869,8 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.AIBridgeInterception.SessionID,
|
||||
&i.AIBridgeInterception.ProviderName,
|
||||
&i.AIBridgeInterception.CredentialKind,
|
||||
&i.AIBridgeInterception.CredentialHint,
|
||||
&i.VisibleUser.ID,
|
||||
&i.VisibleUser.Username,
|
||||
&i.VisibleUser.Name,
|
||||
@@ -1129,6 +1134,8 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeSessionThreads(ctx context.Context, a
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.AIBridgeInterception.SessionID,
|
||||
&i.AIBridgeInterception.ProviderName,
|
||||
&i.AIBridgeInterception.CredentialKind,
|
||||
&i.AIBridgeInterception.CredentialHint,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -145,5 +145,13 @@ func extractWhereClause(query string) string {
|
||||
// Remove SQL comments
|
||||
whereClause = regexp.MustCompile(`(?m)--.*$`).ReplaceAllString(whereClause, "")
|
||||
|
||||
// Normalize indentation so subquery wrapping doesn't cause
|
||||
// mismatches.
|
||||
lines := strings.Split(whereClause, "\n")
|
||||
for i, line := range lines {
|
||||
lines[i] = strings.TrimLeft(line, " \t")
|
||||
}
|
||||
whereClause = strings.Join(lines, "\n")
|
||||
|
||||
return strings.TrimSpace(whereClause)
|
||||
}
|
||||
|
||||
+86
-14
@@ -1290,12 +1290,13 @@ func AllChatModeValues() []ChatMode {
|
||||
type ChatStatus string
|
||||
|
||||
const (
|
||||
ChatStatusWaiting ChatStatus = "waiting"
|
||||
ChatStatusPending ChatStatus = "pending"
|
||||
ChatStatusRunning ChatStatus = "running"
|
||||
ChatStatusPaused ChatStatus = "paused"
|
||||
ChatStatusCompleted ChatStatus = "completed"
|
||||
ChatStatusError ChatStatus = "error"
|
||||
ChatStatusWaiting ChatStatus = "waiting"
|
||||
ChatStatusPending ChatStatus = "pending"
|
||||
ChatStatusRunning ChatStatus = "running"
|
||||
ChatStatusPaused ChatStatus = "paused"
|
||||
ChatStatusCompleted ChatStatus = "completed"
|
||||
ChatStatusError ChatStatus = "error"
|
||||
ChatStatusRequiresAction ChatStatus = "requires_action"
|
||||
)
|
||||
|
||||
func (e *ChatStatus) Scan(src interface{}) error {
|
||||
@@ -1340,7 +1341,8 @@ func (e ChatStatus) Valid() bool {
|
||||
ChatStatusRunning,
|
||||
ChatStatusPaused,
|
||||
ChatStatusCompleted,
|
||||
ChatStatusError:
|
||||
ChatStatusError,
|
||||
ChatStatusRequiresAction:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -1354,6 +1356,7 @@ func AllChatStatusValues() []ChatStatus {
|
||||
ChatStatusPaused,
|
||||
ChatStatusCompleted,
|
||||
ChatStatusError,
|
||||
ChatStatusRequiresAction,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1543,6 +1546,64 @@ func AllCorsBehaviorValues() []CorsBehavior {
|
||||
}
|
||||
}
|
||||
|
||||
type CredentialKind string
|
||||
|
||||
const (
|
||||
CredentialKindCentralized CredentialKind = "centralized"
|
||||
CredentialKindByok CredentialKind = "byok"
|
||||
)
|
||||
|
||||
func (e *CredentialKind) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = CredentialKind(s)
|
||||
case string:
|
||||
*e = CredentialKind(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for CredentialKind: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type NullCredentialKind struct {
|
||||
CredentialKind CredentialKind `json:"credential_kind"`
|
||||
Valid bool `json:"valid"` // Valid is true if CredentialKind is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (ns *NullCredentialKind) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
ns.CredentialKind, ns.Valid = "", false
|
||||
return nil
|
||||
}
|
||||
ns.Valid = true
|
||||
return ns.CredentialKind.Scan(value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (ns NullCredentialKind) Value() (driver.Value, error) {
|
||||
if !ns.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return string(ns.CredentialKind), nil
|
||||
}
|
||||
|
||||
func (e CredentialKind) Valid() bool {
|
||||
switch e {
|
||||
case CredentialKindCentralized,
|
||||
CredentialKindByok:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AllCredentialKindValues() []CredentialKind {
|
||||
return []CredentialKind{
|
||||
CredentialKindCentralized,
|
||||
CredentialKindByok,
|
||||
}
|
||||
}
|
||||
|
||||
type CryptoKeyFeature string
|
||||
|
||||
const (
|
||||
@@ -4040,6 +4101,10 @@ type AIBridgeInterception struct {
|
||||
SessionID string `db:"session_id" json:"session_id"`
|
||||
// The provider instance name which may differ from provider when multiple instances of the same provider type exist.
|
||||
ProviderName string `db:"provider_name" json:"provider_name"`
|
||||
// How the request was authenticated: centralized or byok.
|
||||
CredentialKind CredentialKind `db:"credential_kind" json:"credential_kind"`
|
||||
// Masked credential identifier for audit (e.g. sk-a***efgh).
|
||||
CredentialHint string `db:"credential_hint" json:"credential_hint"`
|
||||
}
|
||||
|
||||
// Audit log of model thinking in intercepted requests in AI Bridge
|
||||
@@ -4180,6 +4245,7 @@ type Chat struct {
|
||||
PinOrder int32 `db:"pin_order" json:"pin_order"`
|
||||
LastReadMessageID sql.NullInt64 `db:"last_read_message_id" json:"last_read_message_id"`
|
||||
LastInjectedContext pqtype.NullRawMessage `db:"last_injected_context" json:"last_injected_context"`
|
||||
DynamicTools pqtype.NullRawMessage `db:"dynamic_tools" json:"dynamic_tools"`
|
||||
}
|
||||
|
||||
type ChatDiffStatus struct {
|
||||
@@ -4209,13 +4275,19 @@ type ChatDiffStatus struct {
|
||||
}
|
||||
|
||||
type ChatFile struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Mimetype string `db:"mimetype" json:"mimetype"`
|
||||
Data []byte `db:"data" json:"data"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Mimetype string `db:"mimetype" json:"mimetype"`
|
||||
Data []byte `db:"data" json:"data"`
|
||||
ObjectStoreKey sql.NullString `db:"object_store_key" json:"object_store_key"`
|
||||
}
|
||||
|
||||
type ChatFileLink struct {
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
FileID uuid.UUID `db:"file_id" json:"file_id"`
|
||||
}
|
||||
|
||||
type ChatMessage struct {
|
||||
|
||||
@@ -76,6 +76,7 @@ type sqlcQuerier interface {
|
||||
CleanTailnetLostPeers(ctx context.Context) error
|
||||
CleanTailnetTunnels(ctx context.Context) error
|
||||
CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
|
||||
ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error
|
||||
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
|
||||
CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error)
|
||||
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
|
||||
@@ -128,6 +129,24 @@ type sqlcQuerier interface {
|
||||
// connection events (connect, disconnect, open, close) which are handled
|
||||
// separately by DeleteOldAuditLogConnectionEvents.
|
||||
DeleteOldAuditLogs(ctx context.Context, arg DeleteOldAuditLogsParams) (int64, error)
|
||||
// TODO(cian): Add indexes on chats(archived, updated_at) and
|
||||
// chat_files(created_at) for purge query performance.
|
||||
// See: https://github.com/coder/internal/issues/1438
|
||||
// Deletes chat files that are older than the given threshold and are
|
||||
// not referenced by any chat that is still active or was archived
|
||||
// within the same threshold window. This covers two cases:
|
||||
// 1. Orphaned files not linked to any chat.
|
||||
// 2. Files whose every referencing chat has been archived for longer
|
||||
// than the retention period.
|
||||
// Returns the deleted rows so callers can clean up associated object
|
||||
// store entries.
|
||||
DeleteOldChatFiles(ctx context.Context, arg DeleteOldChatFilesParams) ([]DeleteOldChatFilesRow, error)
|
||||
// Deletes chats that have been archived for longer than the given
|
||||
// threshold. Active (non-archived) chats are never deleted.
|
||||
// Related chat_messages, chat_diff_statuses, and
|
||||
// chat_queued_messages are removed via ON DELETE CASCADE.
|
||||
// Parent/root references on child chats are SET NULL.
|
||||
DeleteOldChats(ctx context.Context, arg DeleteOldChatsParams) (int64, error)
|
||||
DeleteOldConnectionLogs(ctx context.Context, arg DeleteOldConnectionLogsParams) (int64, error)
|
||||
// Delete all notification messages which have not been updated for over a week.
|
||||
DeleteOldNotificationMessages(ctx context.Context) error
|
||||
@@ -152,7 +171,7 @@ type sqlcQuerier interface {
|
||||
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
|
||||
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
|
||||
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
|
||||
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
|
||||
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) (int64, error)
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
|
||||
@@ -199,6 +218,7 @@ type sqlcQuerier interface {
|
||||
GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error)
|
||||
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
|
||||
GetActiveAISeatCount(ctx context.Context) (int64, error)
|
||||
GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]Chat, error)
|
||||
GetActivePresetPrebuildSchedules(ctx context.Context) ([]TemplateVersionPresetPrebuildSchedule, error)
|
||||
GetActiveUserCount(ctx context.Context, includeSystem bool) (int64, error)
|
||||
GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceBuild, error)
|
||||
@@ -244,6 +264,10 @@ type sqlcQuerier interface {
|
||||
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
|
||||
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
|
||||
GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error)
|
||||
// GetChatFileMetadataByChatID returns lightweight file metadata for
|
||||
// all files linked to a chat. The data column is excluded to avoid
|
||||
// loading file content.
|
||||
GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]GetChatFileMetadataByChatIDRow, error)
|
||||
GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error)
|
||||
// GetChatIncludeDefaultSystemPrompt preserves the legacy default
|
||||
// for deployments created before the explicit include-default toggle.
|
||||
@@ -251,16 +275,27 @@ type sqlcQuerier interface {
|
||||
// otherwise the setting defaults to true.
|
||||
GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error)
|
||||
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
|
||||
// Aggregates message-level metrics per chat for messages created
|
||||
// after the given timestamp. Uses message created_at so that
|
||||
// ongoing activity in long-running chats is captured each window.
|
||||
GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]GetChatMessageSummariesPerChatRow, error)
|
||||
GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error)
|
||||
GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg GetChatMessagesByChatIDAscPaginatedParams) ([]ChatMessage, error)
|
||||
GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg GetChatMessagesByChatIDDescPaginatedParams) ([]ChatMessage, error)
|
||||
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
|
||||
GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
|
||||
GetChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
|
||||
// Returns all model configurations for telemetry snapshot collection.
|
||||
GetChatModelConfigsForTelemetry(ctx context.Context) ([]GetChatModelConfigsForTelemetryRow, error)
|
||||
GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error)
|
||||
GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error)
|
||||
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
|
||||
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
|
||||
// Returns the chat retention period in days. Chats archived longer
|
||||
// than this and orphaned chat files older than this are purged by
|
||||
// dbpurge. Returns 30 (days) when no value has been configured.
|
||||
// A value of 0 disables chat purging entirely.
|
||||
GetChatRetentionDays(ctx context.Context) (int32, error)
|
||||
GetChatSystemPrompt(ctx context.Context) (string, error)
|
||||
// GetChatSystemPromptConfig returns both chat system prompt settings in a
|
||||
// single read to avoid torn reads between separate site-config lookups.
|
||||
@@ -279,6 +314,10 @@ type sqlcQuerier interface {
|
||||
GetChatWorkspaceTTL(ctx context.Context) (string, error)
|
||||
GetChats(ctx context.Context, arg GetChatsParams) ([]GetChatsRow, error)
|
||||
GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]Chat, error)
|
||||
// Retrieves chats updated after the given timestamp for telemetry
|
||||
// snapshot collection. Uses updated_at so that long-running chats
|
||||
// still appear in each snapshot window while they are active.
|
||||
GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]GetChatsUpdatedAfterRow, error)
|
||||
GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error)
|
||||
GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error)
|
||||
GetCryptoKeys(ctx context.Context) ([]CryptoKey, error)
|
||||
@@ -474,8 +513,10 @@ type sqlcQuerier interface {
|
||||
GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error)
|
||||
GetRunningPrebuiltWorkspaces(ctx context.Context) ([]GetRunningPrebuiltWorkspacesRow, error)
|
||||
GetRuntimeConfig(ctx context.Context, key string) (string, error)
|
||||
// Find chats that appear stuck (running but heartbeat has expired).
|
||||
// Used for recovery after coderd crashes or long hangs.
|
||||
// Find chats that appear stuck and need recovery. This covers:
|
||||
// 1. Running chats whose heartbeat has expired (worker crash).
|
||||
// 2. Chats awaiting client action (requires_action) past the
|
||||
// timeout threshold (client disappeared).
|
||||
GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error)
|
||||
GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]TailnetPeer, error)
|
||||
GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]GetTailnetTunnelPeerBindingsBatchRow, error)
|
||||
@@ -594,7 +635,6 @@ type sqlcQuerier interface {
|
||||
GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error)
|
||||
GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error)
|
||||
GetUserNotificationPreferences(ctx context.Context, userID uuid.UUID) ([]NotificationPreference, error)
|
||||
GetUserSecret(ctx context.Context, id uuid.UUID) (UserSecret, error)
|
||||
GetUserSecretByUserIDAndName(ctx context.Context, arg GetUserSecretByUserIDAndNameParams) (UserSecret, error)
|
||||
// GetUserStatusCounts returns the count of users in each status over time.
|
||||
// The time range is inclusively defined by the start_time and end_time parameters.
|
||||
@@ -778,6 +818,15 @@ type sqlcQuerier interface {
|
||||
InsertWorkspaceProxy(ctx context.Context, arg InsertWorkspaceProxyParams) (WorkspaceProxy, error)
|
||||
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
|
||||
InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error)
|
||||
// LinkChatFiles inserts file associations into the chat_file_links
|
||||
// join table with deduplication (ON CONFLICT DO NOTHING). The INSERT
|
||||
// is conditional: it only proceeds when the total number of links
|
||||
// (existing + genuinely new) does not exceed max_file_links. Returns
|
||||
// the number of genuinely new file IDs that were NOT inserted due to
|
||||
// the cap. A return value of 0 means all files were linked (or were
|
||||
// already linked). A positive value means the cap blocked that many
|
||||
// new links.
|
||||
LinkChatFiles(ctx context.Context, arg LinkChatFilesParams) (int32, error)
|
||||
ListAIBridgeClients(ctx context.Context, arg ListAIBridgeClientsParams) ([]string, error)
|
||||
ListAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams) ([]ListAIBridgeInterceptionsRow, error)
|
||||
// Finds all unique AI Bridge interception telemetry summaries combinations
|
||||
@@ -805,7 +854,13 @@ type sqlcQuerier interface {
|
||||
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
|
||||
ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]UserConfig, error)
|
||||
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
|
||||
// Returns metadata only (no value or value_key_id) for the
|
||||
// REST API list and get endpoints.
|
||||
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]ListUserSecretsRow, error)
|
||||
// Returns all columns including the secret value. Used by the
|
||||
// provisioner (build-time injection) and the agent manifest
|
||||
// (runtime injection).
|
||||
ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
|
||||
ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error)
|
||||
MarkAllInboxNotificationsAsRead(ctx context.Context, arg MarkAllInboxNotificationsAsReadParams) error
|
||||
OIDCClaimFieldValues(ctx context.Context, arg OIDCClaimFieldValuesParams) ([]string, error)
|
||||
@@ -842,11 +897,16 @@ type sqlcQuerier interface {
|
||||
SelectUsageEventsForPublishing(ctx context.Context, now time.Time) ([]UsageEvent, error)
|
||||
SoftDeleteChatMessageByID(ctx context.Context, id int64) error
|
||||
SoftDeleteChatMessagesAfterID(ctx context.Context, arg SoftDeleteChatMessagesAfterIDParams) error
|
||||
SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error
|
||||
// Non blocking lock. Returns true if the lock was acquired, false otherwise.
|
||||
//
|
||||
// This must be called from within a transaction. The lock will be automatically
|
||||
// released when the transaction ends.
|
||||
TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error)
|
||||
// Unarchives a chat (and its children). Stale file references are
|
||||
// handled automatically by FK cascades on chat_file_links: when
|
||||
// dbpurge deletes a chat_files row, the corresponding
|
||||
// chat_file_links rows are cascade-deleted by PostgreSQL.
|
||||
UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
|
||||
// This will always work regardless of the current state of the template version.
|
||||
UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error
|
||||
@@ -857,9 +917,11 @@ type sqlcQuerier interface {
|
||||
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
|
||||
UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error)
|
||||
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
|
||||
// Bumps the heartbeat timestamp for a running chat so that other
|
||||
// replicas know the worker is still alive.
|
||||
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
|
||||
// Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
// provided they are still running and owned by the specified
|
||||
// worker. Returns the IDs that were actually updated so the
|
||||
// caller can detect stolen or completed chats via set-difference.
|
||||
UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error)
|
||||
UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error)
|
||||
// Updates the cached injected context parts (AGENTS.md +
|
||||
// skills) on the chat row. Called only when context changes
|
||||
@@ -942,7 +1004,7 @@ type sqlcQuerier interface {
|
||||
UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error)
|
||||
UpdateUserQuietHoursSchedule(ctx context.Context, arg UpdateUserQuietHoursScheduleParams) (User, error)
|
||||
UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error)
|
||||
UpdateUserSecret(ctx context.Context, arg UpdateUserSecretParams) (UserSecret, error)
|
||||
UpdateUserSecretByUserIDAndName(ctx context.Context, arg UpdateUserSecretByUserIDAndNameParams) (UserSecret, error)
|
||||
UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error)
|
||||
UpdateUserTaskNotificationAlertDismissed(ctx context.Context, arg UpdateUserTaskNotificationAlertDismissedParams) (bool, error)
|
||||
UpdateUserTerminalFont(ctx context.Context, arg UpdateUserTerminalFontParams) (UserConfig, error)
|
||||
@@ -986,6 +1048,7 @@ type sqlcQuerier interface {
|
||||
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
|
||||
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
|
||||
UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error
|
||||
UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error
|
||||
UpsertChatSystemPrompt(ctx context.Context, value string) error
|
||||
UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error
|
||||
UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error)
|
||||
|
||||
@@ -7339,13 +7339,7 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, secretID, createdSecret.ID)
|
||||
|
||||
// 2. READ by ID
|
||||
readSecret, err := db.GetUserSecret(ctx, createdSecret.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, createdSecret.ID, readSecret.ID)
|
||||
assert.Equal(t, "workflow-secret", readSecret.Name)
|
||||
|
||||
// 3. READ by UserID and Name
|
||||
// 2. READ by UserID and Name
|
||||
readByNameParams := database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
@@ -7353,33 +7347,43 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
readByNameSecret, err := db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, createdSecret.ID, readByNameSecret.ID)
|
||||
assert.Equal(t, "workflow-secret", readByNameSecret.Name)
|
||||
|
||||
// 4. LIST
|
||||
// 3. LIST (metadata only)
|
||||
secrets, err := db.ListUserSecrets(ctx, testUser.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secrets, 1)
|
||||
assert.Equal(t, createdSecret.ID, secrets[0].ID)
|
||||
|
||||
// 5. UPDATE
|
||||
updateParams := database.UpdateUserSecretParams{
|
||||
ID: createdSecret.ID,
|
||||
Description: "Updated workflow description",
|
||||
Value: "updated-workflow-value",
|
||||
EnvName: "UPDATED_WORKFLOW_ENV",
|
||||
FilePath: "/updated/workflow/path",
|
||||
// 4. LIST with values
|
||||
secretsWithValues, err := db.ListUserSecretsWithValues(ctx, testUser.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secretsWithValues, 1)
|
||||
assert.Equal(t, "workflow-value", secretsWithValues[0].Value)
|
||||
|
||||
// 5. UPDATE (partial - only description)
|
||||
updateParams := database.UpdateUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
UpdateDescription: true,
|
||||
Description: "Updated workflow description",
|
||||
}
|
||||
|
||||
updatedSecret, err := db.UpdateUserSecret(ctx, updateParams)
|
||||
updatedSecret, err := db.UpdateUserSecretByUserIDAndName(ctx, updateParams)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated workflow description", updatedSecret.Description)
|
||||
assert.Equal(t, "updated-workflow-value", updatedSecret.Value)
|
||||
assert.Equal(t, "workflow-value", updatedSecret.Value) // Value unchanged
|
||||
assert.Equal(t, "WORKFLOW_ENV", updatedSecret.EnvName) // EnvName unchanged
|
||||
|
||||
// 6. DELETE
|
||||
err = db.DeleteUserSecret(ctx, createdSecret.ID)
|
||||
_, err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deletion
|
||||
_, err = db.GetUserSecret(ctx, createdSecret.ID)
|
||||
_, err = db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no rows in result set")
|
||||
|
||||
@@ -7449,9 +7453,13 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
})
|
||||
|
||||
// Verify both secrets exist
|
||||
_, err = db.GetUserSecret(ctx, secret1.ID)
|
||||
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID, Name: secret1.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.GetUserSecret(ctx, secret2.ID)
|
||||
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID, Name: secret2.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@@ -7474,14 +7482,14 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
// Create secrets for users
|
||||
user1Secret := dbgen.UserSecret(t, db, database.UserSecret{
|
||||
_ = dbgen.UserSecret(t, db, database.UserSecret{
|
||||
UserID: user1.ID,
|
||||
Name: "user1-secret",
|
||||
Description: "User 1's secret",
|
||||
Value: "user1-value",
|
||||
})
|
||||
|
||||
user2Secret := dbgen.UserSecret(t, db, database.UserSecret{
|
||||
_ = dbgen.UserSecret(t, db, database.UserSecret{
|
||||
UserID: user2.ID,
|
||||
Name: "user2-secret",
|
||||
Description: "User 2's secret",
|
||||
@@ -7491,7 +7499,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
subject rbac.Subject
|
||||
secretID uuid.UUID
|
||||
lookupUserID uuid.UUID
|
||||
lookupName string
|
||||
expectedAccess bool
|
||||
}{
|
||||
{
|
||||
@@ -7501,7 +7510,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
secretID: user1Secret.ID,
|
||||
lookupUserID: user1.ID,
|
||||
lookupName: "user1-secret",
|
||||
expectedAccess: true,
|
||||
},
|
||||
{
|
||||
@@ -7511,7 +7521,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
secretID: user2Secret.ID,
|
||||
lookupUserID: user2.ID,
|
||||
lookupName: "user2-secret",
|
||||
expectedAccess: false,
|
||||
},
|
||||
{
|
||||
@@ -7521,7 +7532,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.RoleOwner()},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
secretID: user1Secret.ID,
|
||||
lookupUserID: user1.ID,
|
||||
lookupName: "user1-secret",
|
||||
expectedAccess: false,
|
||||
},
|
||||
{
|
||||
@@ -7531,7 +7543,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.ScopedRoleOrgAdmin(org.ID)},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
secretID: user1Secret.ID,
|
||||
lookupUserID: user1.ID,
|
||||
lookupName: "user1-secret",
|
||||
expectedAccess: false,
|
||||
},
|
||||
}
|
||||
@@ -7543,8 +7556,10 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
|
||||
authCtx := dbauthz.As(ctx, tc.subject)
|
||||
|
||||
// Test GetUserSecret
|
||||
_, err := authDB.GetUserSecret(authCtx, tc.secretID)
|
||||
_, err := authDB.GetUserSecretByUserIDAndName(authCtx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: tc.lookupUserID,
|
||||
Name: tc.lookupName,
|
||||
})
|
||||
|
||||
if tc.expectedAccess {
|
||||
require.NoError(t, err, "expected access to be granted")
|
||||
@@ -9070,10 +9085,11 @@ func TestUpdateAIBridgeInterceptionEnded(t *testing.T) {
|
||||
|
||||
for _, uid := range []uuid.UUID{{1}, {2}, {3}} {
|
||||
insertParams := database.InsertAIBridgeInterceptionParams{
|
||||
ID: uid,
|
||||
InitiatorID: user.ID,
|
||||
Metadata: json.RawMessage("{}"),
|
||||
Client: sql.NullString{String: "client", Valid: true},
|
||||
ID: uid,
|
||||
InitiatorID: user.ID,
|
||||
Metadata: json.RawMessage("{}"),
|
||||
Client: sql.NullString{String: "client", Valid: true},
|
||||
CredentialKind: database.CredentialKindCentralized,
|
||||
}
|
||||
|
||||
intc, err := db.InsertAIBridgeInterception(ctx, insertParams)
|
||||
|
||||
+987
-303
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,8 @@
|
||||
-- name: InsertAIBridgeInterception :one
|
||||
INSERT INTO aibridge_interceptions (
|
||||
id, api_key_id, initiator_id, provider, provider_name, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id
|
||||
id, api_key_id, initiator_id, provider, provider_name, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id, credential_kind, credential_hint
|
||||
) VALUES (
|
||||
@id, @api_key_id, @initiator_id, @provider, @provider_name, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid
|
||||
@id, @api_key_id, @initiator_id, @provider, @provider_name, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid, @credential_kind, @credential_hint
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
|
||||
@@ -149,94 +149,105 @@ VALUES (
|
||||
RETURNING *;
|
||||
|
||||
-- name: CountAuditLogs :one
|
||||
SELECT COUNT(*)
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN @resource_target::text != '' THEN resource_target = @resource_target
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN @action::text != '' THEN action = @action::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower(@username)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @email::text != '' THEN users.email = @email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
;
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT 1
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN @resource_target::text != '' THEN resource_target = @resource_target
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN @action::text != '' THEN action = @action::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower(@username)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @email::text != '' THEN users.email = @email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
-- Avoid a slow scan on a large table with joins. The caller
|
||||
-- passes the count cap and we add 1 so the frontend can detect
|
||||
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
|
||||
-- -> NULL + 1 = NULL).
|
||||
-- NOTE: Parameterizing this so that we can easily change from,
|
||||
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
|
||||
-- here if disabling the capping on a large table permanently.
|
||||
-- This way the PG planner can plan parallel execution for
|
||||
-- potential large wins.
|
||||
LIMIT NULLIF(@count_cap::int, 0) + 1
|
||||
) AS limited_count;
|
||||
|
||||
-- name: DeleteOldAuditLogConnectionEvents :exec
|
||||
DELETE FROM audit_logs
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
-- name: InsertChatFile :one
|
||||
INSERT INTO chat_files (owner_id, organization_id, name, mimetype, data)
|
||||
VALUES (@owner_id::uuid, @organization_id::uuid, @name::text, @mimetype::text, @data::bytea)
|
||||
INSERT INTO chat_files (owner_id, organization_id, name, mimetype, data, object_store_key)
|
||||
VALUES (@owner_id::uuid, @organization_id::uuid, @name::text, @mimetype::text, @data::bytea, @object_store_key::text)
|
||||
RETURNING id, owner_id, organization_id, created_at, name, mimetype;
|
||||
|
||||
-- name: GetChatFileByID :one
|
||||
@@ -8,3 +8,50 @@ SELECT * FROM chat_files WHERE id = @id::uuid;
|
||||
|
||||
-- name: GetChatFilesByIDs :many
|
||||
SELECT * FROM chat_files WHERE id = ANY(@ids::uuid[]);
|
||||
|
||||
-- name: GetChatFileMetadataByChatID :many
|
||||
-- GetChatFileMetadataByChatID returns lightweight file metadata for
|
||||
-- all files linked to a chat. The data column is excluded to avoid
|
||||
-- loading file content.
|
||||
SELECT cf.id, cf.owner_id, cf.organization_id, cf.name, cf.mimetype, cf.created_at
|
||||
FROM chat_files cf
|
||||
JOIN chat_file_links cfl ON cfl.file_id = cf.id
|
||||
WHERE cfl.chat_id = @chat_id::uuid
|
||||
ORDER BY cf.created_at ASC;
|
||||
|
||||
-- TODO(cian): Add indexes on chats(archived, updated_at) and
|
||||
-- chat_files(created_at) for purge query performance.
|
||||
-- See: https://github.com/coder/internal/issues/1438
|
||||
-- name: DeleteOldChatFiles :many
|
||||
-- Deletes chat files that are older than the given threshold and are
|
||||
-- not referenced by any chat that is still active or was archived
|
||||
-- within the same threshold window. This covers two cases:
|
||||
-- 1. Orphaned files not linked to any chat.
|
||||
-- 2. Files whose every referencing chat has been archived for longer
|
||||
-- than the retention period.
|
||||
-- Returns the deleted rows so callers can clean up associated object
|
||||
-- store entries.
|
||||
WITH kept_file_ids AS (
|
||||
-- NOTE: This uses updated_at as a proxy for archive time
|
||||
-- because there is no archived_at column. Correctness
|
||||
-- requires that updated_at is never backdated on archived
|
||||
-- chats. See ArchiveChatByID.
|
||||
SELECT DISTINCT cfl.file_id
|
||||
FROM chat_file_links cfl
|
||||
JOIN chats c ON c.id = cfl.chat_id
|
||||
WHERE c.archived = false
|
||||
OR c.updated_at >= @before_time::timestamptz
|
||||
),
|
||||
deletable AS (
|
||||
SELECT cf.id
|
||||
FROM chat_files cf
|
||||
LEFT JOIN kept_file_ids k ON cf.id = k.file_id
|
||||
WHERE cf.created_at < @before_time::timestamptz
|
||||
AND k.file_id IS NULL
|
||||
ORDER BY cf.created_at ASC
|
||||
LIMIT @limit_count
|
||||
)
|
||||
DELETE FROM chat_files
|
||||
USING deletable
|
||||
WHERE chat_files.id = deletable.id
|
||||
RETURNING chat_files.id, chat_files.object_store_key;
|
||||
|
||||
@@ -10,9 +10,14 @@ FROM chats
|
||||
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
|
||||
|
||||
-- name: UnarchiveChatByID :many
|
||||
-- Unarchives a chat (and its children). Stale file references are
|
||||
-- handled automatically by FK cascades on chat_file_links: when
|
||||
-- dbpurge deletes a chat_files row, the corresponding
|
||||
-- chat_file_links rows are cascade-deleted by PostgreSQL.
|
||||
WITH chats AS (
|
||||
UPDATE chats
|
||||
SET archived = false, updated_at = NOW()
|
||||
UPDATE chats SET
|
||||
archived = false,
|
||||
updated_at = NOW()
|
||||
WHERE id = @id::uuid OR root_chat_id = @id::uuid
|
||||
RETURNING *
|
||||
)
|
||||
@@ -394,7 +399,8 @@ INSERT INTO chats (
|
||||
mode,
|
||||
status,
|
||||
mcp_server_ids,
|
||||
labels
|
||||
labels,
|
||||
dynamic_tools
|
||||
) VALUES (
|
||||
@owner_id::uuid,
|
||||
sqlc.narg('workspace_id')::uuid,
|
||||
@@ -407,7 +413,8 @@ INSERT INTO chats (
|
||||
sqlc.narg('mode')::chat_mode,
|
||||
@status::chat_status,
|
||||
COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[]),
|
||||
COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb)
|
||||
COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb),
|
||||
sqlc.narg('dynamic_tools')::jsonb
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
@@ -567,6 +574,43 @@ WHERE
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: LinkChatFiles :one
|
||||
-- LinkChatFiles inserts file associations into the chat_file_links
|
||||
-- join table with deduplication (ON CONFLICT DO NOTHING). The INSERT
|
||||
-- is conditional: it only proceeds when the total number of links
|
||||
-- (existing + genuinely new) does not exceed max_file_links. Returns
|
||||
-- the number of genuinely new file IDs that were NOT inserted due to
|
||||
-- the cap. A return value of 0 means all files were linked (or were
|
||||
-- already linked). A positive value means the cap blocked that many
|
||||
-- new links.
|
||||
WITH current AS (
|
||||
SELECT COUNT(*) AS cnt
|
||||
FROM chat_file_links
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
),
|
||||
new_links AS (
|
||||
SELECT @chat_id::uuid AS chat_id, unnest(@file_ids::uuid[]) AS file_id
|
||||
),
|
||||
genuinely_new AS (
|
||||
SELECT nl.chat_id, nl.file_id
|
||||
FROM new_links nl
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM chat_file_links cfl
|
||||
WHERE cfl.chat_id = nl.chat_id AND cfl.file_id = nl.file_id
|
||||
)
|
||||
),
|
||||
inserted AS (
|
||||
INSERT INTO chat_file_links (chat_id, file_id)
|
||||
SELECT gn.chat_id, gn.file_id
|
||||
FROM genuinely_new gn, current c
|
||||
WHERE c.cnt + (SELECT COUNT(*) FROM genuinely_new) <= @max_file_links::int
|
||||
ON CONFLICT (chat_id, file_id) DO NOTHING
|
||||
RETURNING file_id
|
||||
)
|
||||
SELECT
|
||||
(SELECT COUNT(*)::int FROM genuinely_new) -
|
||||
(SELECT COUNT(*)::int FROM inserted) AS rejected_new_files;
|
||||
|
||||
-- name: AcquireChats :many
|
||||
-- Acquires up to @num_chats pending chats for processing. Uses SKIP LOCKED
|
||||
-- to prevent multiple replicas from acquiring the same chat.
|
||||
@@ -627,27 +671,34 @@ RETURNING
|
||||
*;
|
||||
|
||||
-- name: GetStaleChats :many
|
||||
-- Find chats that appear stuck (running but heartbeat has expired).
|
||||
-- Used for recovery after coderd crashes or long hangs.
|
||||
-- Find chats that appear stuck and need recovery. This covers:
|
||||
-- 1. Running chats whose heartbeat has expired (worker crash).
|
||||
-- 2. Chats awaiting client action (requires_action) past the
|
||||
-- timeout threshold (client disappeared).
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
status = 'running'::chat_status
|
||||
AND heartbeat_at < @stale_threshold::timestamptz;
|
||||
(status = 'running'::chat_status
|
||||
AND heartbeat_at < @stale_threshold::timestamptz)
|
||||
OR (status = 'requires_action'::chat_status
|
||||
AND updated_at < @stale_threshold::timestamptz);
|
||||
|
||||
-- name: UpdateChatHeartbeat :execrows
|
||||
-- Bumps the heartbeat timestamp for a running chat so that other
|
||||
-- replicas know the worker is still alive.
|
||||
-- name: UpdateChatHeartbeats :many
|
||||
-- Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
-- provided they are still running and owned by the specified
|
||||
-- worker. Returns the IDs that were actually updated so the
|
||||
-- caller can detect stolen or completed chats via set-difference.
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
heartbeat_at = NOW()
|
||||
heartbeat_at = @now::timestamptz
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
id = ANY(@ids::uuid[])
|
||||
AND worker_id = @worker_id::uuid
|
||||
AND status = 'running'::chat_status;
|
||||
AND status = 'running'::chat_status
|
||||
RETURNING id;
|
||||
|
||||
-- name: GetChatDiffStatusByChatID :one
|
||||
SELECT
|
||||
@@ -883,7 +934,8 @@ SELECT
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
@@ -913,7 +965,8 @@ SELECT
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
@@ -948,7 +1001,8 @@ WITH chat_costs AS (
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
|
||||
FROM chat_messages cm
|
||||
JOIN chats c ON c.id = cm.chat_id
|
||||
WHERE c.owner_id = @owner_id::uuid
|
||||
@@ -965,7 +1019,8 @@ SELECT
|
||||
cc.total_input_tokens,
|
||||
cc.total_output_tokens,
|
||||
cc.total_cache_read_tokens,
|
||||
cc.total_cache_creation_tokens
|
||||
cc.total_cache_creation_tokens,
|
||||
cc.total_runtime_ms
|
||||
FROM chat_costs cc
|
||||
LEFT JOIN chats rc ON rc.id = cc.root_chat_id
|
||||
ORDER BY cc.total_cost_micros DESC;
|
||||
@@ -991,7 +1046,8 @@ WITH chat_cost_users AS (
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
@@ -1025,6 +1081,7 @@ SELECT
|
||||
total_output_tokens,
|
||||
total_cache_read_tokens,
|
||||
total_cache_creation_tokens,
|
||||
total_runtime_ms,
|
||||
COUNT(*) OVER()::bigint AS total_count
|
||||
FROM
|
||||
chat_cost_users
|
||||
@@ -1174,3 +1231,88 @@ LIMIT 1;
|
||||
UPDATE chats
|
||||
SET last_read_message_id = @last_read_message_id::bigint
|
||||
WHERE id = @id::uuid;
|
||||
|
||||
-- name: DeleteOldChats :execrows
|
||||
-- Deletes chats that have been archived for longer than the given
|
||||
-- threshold. Active (non-archived) chats are never deleted.
|
||||
-- Related chat_messages, chat_diff_statuses, and
|
||||
-- chat_queued_messages are removed via ON DELETE CASCADE.
|
||||
-- Parent/root references on child chats are SET NULL.
|
||||
WITH deletable AS (
|
||||
SELECT id
|
||||
FROM chats
|
||||
WHERE archived = true
|
||||
AND updated_at < @before_time::timestamptz
|
||||
ORDER BY updated_at ASC
|
||||
LIMIT @limit_count
|
||||
)
|
||||
DELETE FROM chats
|
||||
USING deletable
|
||||
WHERE chats.id = deletable.id
|
||||
AND chats.archived = true;
|
||||
|
||||
-- name: GetChatsUpdatedAfter :many
|
||||
-- Retrieves chats updated after the given timestamp for telemetry
|
||||
-- snapshot collection. Uses updated_at so that long-running chats
|
||||
-- still appear in each snapshot window while they are active.
|
||||
SELECT
|
||||
id, owner_id, created_at, updated_at, status,
|
||||
(parent_chat_id IS NOT NULL)::bool AS has_parent,
|
||||
root_chat_id, workspace_id,
|
||||
mode, archived, last_model_config_id
|
||||
FROM chats
|
||||
WHERE updated_at > @updated_after;
|
||||
|
||||
-- name: GetChatMessageSummariesPerChat :many
|
||||
-- Aggregates message-level metrics per chat for messages created
|
||||
-- after the given timestamp. Uses message created_at so that
|
||||
-- ongoing activity in long-running chats is captured each window.
|
||||
SELECT
|
||||
cm.chat_id,
|
||||
COUNT(*)::bigint AS message_count,
|
||||
COUNT(*) FILTER (WHERE cm.role = 'user')::bigint AS user_message_count,
|
||||
COUNT(*) FILTER (WHERE cm.role = 'assistant')::bigint AS assistant_message_count,
|
||||
COUNT(*) FILTER (WHERE cm.role = 'tool')::bigint AS tool_message_count,
|
||||
COUNT(*) FILTER (WHERE cm.role = 'system')::bigint AS system_message_count,
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.reasoning_tokens), 0)::bigint AS total_reasoning_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms,
|
||||
COUNT(DISTINCT cm.model_config_id)::bigint AS distinct_model_count,
|
||||
COUNT(*) FILTER (WHERE cm.compressed)::bigint AS compressed_message_count
|
||||
FROM chat_messages cm
|
||||
WHERE cm.created_at > @created_after
|
||||
AND cm.deleted = false
|
||||
GROUP BY cm.chat_id;
|
||||
|
||||
-- name: GetChatModelConfigsForTelemetry :many
|
||||
-- Returns all model configurations for telemetry snapshot collection.
|
||||
SELECT id, provider, model, context_limit, enabled, is_default
|
||||
FROM chat_model_configs
|
||||
WHERE deleted = false;
|
||||
-- name: GetActiveChatsByAgentID :many
|
||||
SELECT *
|
||||
FROM chats
|
||||
WHERE agent_id = @agent_id::uuid
|
||||
AND archived = false
|
||||
-- Active statuses only: waiting, pending, running, paused,
|
||||
-- requires_action.
|
||||
-- Excludes completed and error (terminal states).
|
||||
AND status IN ('waiting', 'running', 'paused', 'pending', 'requires_action')
|
||||
ORDER BY updated_at DESC;
|
||||
|
||||
-- name: ClearChatMessageProviderResponseIDsByChatID :exec
|
||||
UPDATE chat_messages
|
||||
SET provider_response_id = NULL
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
AND deleted = false
|
||||
AND provider_response_id IS NOT NULL;
|
||||
|
||||
-- name: SoftDeleteContextFileMessages :exec
|
||||
UPDATE chat_messages SET deleted = true
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
AND deleted = false
|
||||
AND content::jsonb @> '[{"type": "context-file"}]';
|
||||
|
||||
@@ -133,111 +133,113 @@ OFFSET
|
||||
@offset_opt;
|
||||
|
||||
-- name: CountConnectionLogs :one
|
||||
SELECT
|
||||
COUNT(*) AS count
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN @workspace_owner :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = @workspace_owner_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN @workspace_owner_email :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = @workspace_owner_email AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN @type :: text != '' THEN
|
||||
type = @type :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@username) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @user_email :: text != '' THEN
|
||||
users.email = @user_email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= @connected_after
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= @connected_before
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = @workspace_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = @connection_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN @status :: text != '' THEN
|
||||
((@status = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
;
|
||||
SELECT COUNT(*) AS count FROM (
|
||||
SELECT 1
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN @workspace_owner :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = @workspace_owner_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN @workspace_owner_email :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = @workspace_owner_email AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN @type :: text != '' THEN
|
||||
type = @type :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@username) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @user_email :: text != '' THEN
|
||||
users.email = @user_email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= @connected_after
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= @connected_before
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = @workspace_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = @connection_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN @status :: text != '' THEN
|
||||
((@status = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
-- NOTE: See the CountAuditLogs LIMIT note.
|
||||
LIMIT NULLIF(@count_cap::int, 0) + 1
|
||||
) AS limited_count;
|
||||
|
||||
-- name: DeleteOldConnectionLogs :execrows
|
||||
WITH old_logs AS (
|
||||
|
||||
@@ -236,3 +236,20 @@ VALUES ('agents_workspace_ttl', @workspace_ttl::text)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET value = @workspace_ttl::text
|
||||
WHERE site_configs.key = 'agents_workspace_ttl';
|
||||
|
||||
-- name: GetChatRetentionDays :one
|
||||
-- Returns the chat retention period in days. Chats archived longer
|
||||
-- than this and orphaned chat files older than this are purged by
|
||||
-- dbpurge. Returns 30 (days) when no value has been configured.
|
||||
-- A value of 0 disables chat purging entirely.
|
||||
SELECT COALESCE(
|
||||
(SELECT value::integer FROM site_configs
|
||||
WHERE key = 'agents_chat_retention_days'),
|
||||
30
|
||||
) :: integer AS retention_days;
|
||||
|
||||
-- name: UpsertChatRetentionDays :exec
|
||||
INSERT INTO site_configs (key, value)
|
||||
VALUES ('agents_chat_retention_days', CAST(@retention_days AS integer)::text)
|
||||
ON CONFLICT (key) DO UPDATE SET value = CAST(@retention_days AS integer)::text
|
||||
WHERE site_configs.key = 'agents_chat_retention_days';
|
||||
|
||||
@@ -1,14 +1,26 @@
|
||||
-- name: GetUserSecretByUserIDAndName :one
|
||||
SELECT * FROM user_secrets
|
||||
WHERE user_id = $1 AND name = $2;
|
||||
|
||||
-- name: GetUserSecret :one
|
||||
SELECT * FROM user_secrets
|
||||
WHERE id = $1;
|
||||
SELECT *
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id AND name = @name;
|
||||
|
||||
-- name: ListUserSecrets :many
|
||||
SELECT * FROM user_secrets
|
||||
WHERE user_id = $1
|
||||
-- Returns metadata only (no value or value_key_id) for the
|
||||
-- REST API list and get endpoints.
|
||||
SELECT
|
||||
id, user_id, name, description,
|
||||
env_name, file_path,
|
||||
created_at, updated_at
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id
|
||||
ORDER BY name ASC;
|
||||
|
||||
-- name: ListUserSecretsWithValues :many
|
||||
-- Returns all columns including the secret value. Used by the
|
||||
-- provisioner (build-time injection) and the agent manifest
|
||||
-- (runtime injection).
|
||||
SELECT *
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id
|
||||
ORDER BY name ASC;
|
||||
|
||||
-- name: CreateUserSecret :one
|
||||
@@ -18,23 +30,32 @@ INSERT INTO user_secrets (
|
||||
name,
|
||||
description,
|
||||
value,
|
||||
value_key_id,
|
||||
env_name,
|
||||
file_path
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7
|
||||
@id,
|
||||
@user_id,
|
||||
@name,
|
||||
@description,
|
||||
@value,
|
||||
@value_key_id,
|
||||
@env_name,
|
||||
@file_path
|
||||
) RETURNING *;
|
||||
|
||||
-- name: UpdateUserSecret :one
|
||||
-- name: UpdateUserSecretByUserIDAndName :one
|
||||
UPDATE user_secrets
|
||||
SET
|
||||
description = $2,
|
||||
value = $3,
|
||||
env_name = $4,
|
||||
file_path = $5,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = $1
|
||||
value = CASE WHEN @update_value::bool THEN @value ELSE value END,
|
||||
value_key_id = CASE WHEN @update_value::bool THEN @value_key_id ELSE value_key_id END,
|
||||
description = CASE WHEN @update_description::bool THEN @description ELSE description END,
|
||||
env_name = CASE WHEN @update_env_name::bool THEN @env_name ELSE env_name END,
|
||||
file_path = CASE WHEN @update_file_path::bool THEN @file_path ELSE file_path END,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE user_id = @user_id AND name = @name
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteUserSecret :exec
|
||||
-- name: DeleteUserSecretByUserIDAndName :execrows
|
||||
DELETE FROM user_secrets
|
||||
WHERE id = $1;
|
||||
WHERE user_id = @user_id AND name = @name;
|
||||
|
||||
@@ -247,6 +247,7 @@ sql:
|
||||
mcp_server_tool_snapshots: MCPServerToolSnapshots
|
||||
mcp_server_config_id: MCPServerConfigID
|
||||
mcp_server_ids: MCPServerIDs
|
||||
max_file_links: MaxFileLinks
|
||||
icon_url: IconURL
|
||||
oauth2_client_id: OAuth2ClientID
|
||||
oauth2_client_secret: OAuth2ClientSecret
|
||||
|
||||
@@ -16,6 +16,7 @@ const (
|
||||
UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
|
||||
UniqueBoundaryUsageStatsPkey UniqueConstraint = "boundary_usage_stats_pkey" // ALTER TABLE ONLY boundary_usage_stats ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
|
||||
UniqueChatDiffStatusesPkey UniqueConstraint = "chat_diff_statuses_pkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
UniqueChatFileLinksChatIDFileIDKey UniqueConstraint = "chat_file_links_chat_id_file_id_key" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_file_id_key UNIQUE (chat_id, file_id);
|
||||
UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
|
||||
UniqueChatMessagesPkey UniqueConstraint = "chat_messages_pkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
|
||||
UniqueChatModelConfigsPkey UniqueConstraint = "chat_model_configs_pkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id);
|
||||
|
||||
+450
-118
@@ -137,8 +137,9 @@ func publishChatConfigEvent(logger slog.Logger, ps dbpubsub.Pubsub, kind pubsub.
|
||||
func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
logger := api.Logger.Named("chat_watcher")
|
||||
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to open chat watch stream.",
|
||||
@@ -146,54 +147,44 @@ func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
|
||||
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatEventChannel(apiKey.UserID),
|
||||
pubsub.HandleChatEvent(
|
||||
func(ctx context.Context, payload pubsub.ChatEvent, err error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
_ = conn.CloseRead(context.Background())
|
||||
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
// The encoder is only written from the SubscribeWithErr callback,
|
||||
// which delivers serially per subscription. Do not add a second
|
||||
// write path without introducing synchronization.
|
||||
encoder := json.NewEncoder(wsNetConn)
|
||||
|
||||
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatWatchEventChannel(apiKey.UserID),
|
||||
pubsub.HandleChatWatchEvent(
|
||||
func(ctx context.Context, payload codersdk.ChatWatchEvent, err error) {
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "chat event subscription error", slog.Error(err))
|
||||
logger.Error(ctx, "chat watch event subscription error", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: payload,
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat event", slog.Error(err))
|
||||
if err := encoder.Encode(payload); err != nil {
|
||||
logger.Debug(ctx, "failed to send chat watch event", slog.Error(err))
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
},
|
||||
))
|
||||
if err != nil {
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Internal error subscribing to chat events.",
|
||||
Detail: err.Error(),
|
||||
},
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat subscribe error event", slog.Error(err))
|
||||
}
|
||||
logger.Error(ctx, "failed to subscribe to chat watch events", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, "Failed to subscribe to chat events.")
|
||||
return
|
||||
}
|
||||
defer cancelSubscribe()
|
||||
|
||||
// Send initial ping to signal the connection is ready.
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypePing,
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat ping event", slog.Error(err))
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-senderClosed:
|
||||
return
|
||||
}
|
||||
}
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: chatsByWorkspace returns a mapping of workspace ID to
|
||||
@@ -398,12 +389,26 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Cap the raw request body to prevent excessive memory use
|
||||
// from large dynamic tool schemas.
|
||||
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
|
||||
|
||||
var req codersdk.CreateChatRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
contentBlocks, titleSource, inputError := createChatInputFromRequest(ctx, api.Database, req)
|
||||
// Validate per-chat system prompt length.
|
||||
const maxSystemPromptLen = 10000
|
||||
if len(req.SystemPrompt) > maxSystemPromptLen {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "System prompt exceeds maximum length.",
|
||||
Detail: fmt.Sprintf("System prompt must be at most %d characters, got %d.", maxSystemPromptLen, len(req.SystemPrompt)),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
contentBlocks, titleSource, fileIDs, inputError := createChatInputFromRequest(ctx, api.Database, req)
|
||||
if inputError != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, *inputError)
|
||||
return
|
||||
@@ -478,15 +483,60 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.UnsafeDynamicTools) > 250 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Too many dynamic tools.",
|
||||
Detail: "Maximum 250 dynamic tools per chat.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate that dynamic tool names are non-empty and unique
|
||||
// within the list. Name collision with built-in tools is
|
||||
// checked at chatloop time when the full tool set is known.
|
||||
if len(req.UnsafeDynamicTools) > 0 {
|
||||
seenNames := make(map[string]struct{}, len(req.UnsafeDynamicTools))
|
||||
for _, dt := range req.UnsafeDynamicTools {
|
||||
if dt.Name == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Dynamic tool name must not be empty.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if _, exists := seenNames[dt.Name]; exists {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Duplicate dynamic tool name.",
|
||||
Detail: fmt.Sprintf("Tool %q appears more than once.", dt.Name),
|
||||
})
|
||||
return
|
||||
}
|
||||
seenNames[dt.Name] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
var dynamicToolsJSON json.RawMessage
|
||||
if len(req.UnsafeDynamicTools) > 0 {
|
||||
var err error
|
||||
dynamicToolsJSON, err = json.Marshal(req.UnsafeDynamicTools)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to marshal dynamic tools.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
chat, err := api.chatDaemon.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: apiKey.UserID,
|
||||
WorkspaceID: workspaceSelection.WorkspaceID,
|
||||
Title: title,
|
||||
ModelConfigID: modelConfigID,
|
||||
SystemPrompt: api.resolvedChatSystemPrompt(ctx),
|
||||
SystemPrompt: req.SystemPrompt,
|
||||
InitialUserContent: contentBlocks,
|
||||
MCPServerIDs: mcpServerIDs,
|
||||
Labels: labels,
|
||||
DynamicTools: dynamicToolsJSON,
|
||||
})
|
||||
if err != nil {
|
||||
if maybeWriteLimitErr(ctx, rw, err) {
|
||||
@@ -514,7 +564,32 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, db2sdk.Chat(chat, nil))
|
||||
// Link any user-uploaded files referenced in the initial
|
||||
// message to this newly created chat (best-effort; cap
|
||||
// enforced in SQL).
|
||||
unlinked, capExceeded := api.linkFilesToChat(ctx, chat.ID, fileIDs)
|
||||
|
||||
// Re-read the chat so the response reflects the authoritative
|
||||
// database state (file links are deduped in the join table).
|
||||
chat, err = api.Database.GetChatByID(ctx, chat.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to read back chat after creation.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chatFiles := api.fetchChatFileMetadata(ctx, chat.ID)
|
||||
response := db2sdk.Chat(chat, nil, chatFiles)
|
||||
if len(unlinked) > 0 {
|
||||
if capExceeded {
|
||||
response.Warnings = append(response.Warnings, fileLinkCapWarning(len(unlinked)))
|
||||
} else {
|
||||
response.Warnings = append(response.Warnings, fileLinkErrorWarning(len(unlinked)))
|
||||
}
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, response)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
@@ -717,6 +792,7 @@ func (api *API) chatCostSummary(rw http.ResponseWriter, r *http.Request) {
|
||||
TotalOutputTokens: summary.TotalOutputTokens,
|
||||
TotalCacheReadTokens: summary.TotalCacheReadTokens,
|
||||
TotalCacheCreationTokens: summary.TotalCacheCreationTokens,
|
||||
TotalRuntimeMs: summary.TotalRuntimeMs,
|
||||
ByModel: modelBreakdowns,
|
||||
ByChat: chatBreakdowns,
|
||||
}
|
||||
@@ -1290,7 +1366,11 @@ func (api *API) getChat(rw http.ResponseWriter, r *http.Request) {
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, diffStatus))
|
||||
|
||||
// Hydrate file metadata for all files linked to this chat.
|
||||
chatFiles := api.fetchChatFileMetadata(ctx, chat.ID)
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, diffStatus, chatFiles))
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
@@ -1780,7 +1860,7 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
contentBlocks, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
|
||||
contentBlocks, _, fileIDs, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
|
||||
if inputError != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: inputError.Message,
|
||||
@@ -1819,6 +1899,20 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
busyBehavior := chatd.SendMessageBusyBehaviorQueue
|
||||
switch req.BusyBehavior {
|
||||
case codersdk.ChatBusyBehaviorInterrupt:
|
||||
busyBehavior = chatd.SendMessageBusyBehaviorInterrupt
|
||||
case codersdk.ChatBusyBehaviorQueue, "":
|
||||
// Default to queue.
|
||||
default:
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid busy_behavior value.",
|
||||
Detail: `Must be "queue" or "interrupt".`,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
sendResult, sendErr := api.chatDaemon.SendMessage(
|
||||
ctx,
|
||||
chatd.SendMessageOptions{
|
||||
@@ -1826,7 +1920,7 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
CreatedBy: apiKey.UserID,
|
||||
Content: contentBlocks,
|
||||
ModelConfigID: req.ModelConfigID,
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
||||
BusyBehavior: busyBehavior,
|
||||
MCPServerIDs: req.MCPServerIDs,
|
||||
},
|
||||
)
|
||||
@@ -1848,6 +1942,9 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Link any user-uploaded files referenced in this message
|
||||
// to the chat (best-effort; cap enforced in SQL).
|
||||
unlinked, capExceeded := api.linkFilesToChat(ctx, chatID, fileIDs)
|
||||
response := codersdk.CreateChatMessageResponse{Queued: sendResult.Queued}
|
||||
if sendResult.Queued {
|
||||
if sendResult.QueuedMessage != nil {
|
||||
@@ -1857,6 +1954,13 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
message := convertChatMessage(sendResult.Message)
|
||||
response.Message = &message
|
||||
}
|
||||
if len(unlinked) > 0 {
|
||||
if capExceeded {
|
||||
response.Warnings = append(response.Warnings, fileLinkCapWarning(len(unlinked)))
|
||||
} else {
|
||||
response.Warnings = append(response.Warnings, fileLinkErrorWarning(len(unlinked)))
|
||||
}
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, response)
|
||||
}
|
||||
@@ -1890,7 +1994,7 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
contentBlocks, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
|
||||
contentBlocks, _, fileIDs, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
|
||||
if inputError != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: inputError.Message,
|
||||
@@ -1929,8 +2033,20 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
message := convertChatMessage(editResult.Message)
|
||||
httpapi.Write(ctx, rw, http.StatusOK, message)
|
||||
// Link any user-uploaded files referenced in the edited
|
||||
// message to the chat (best-effort; cap enforced in SQL).
|
||||
unlinked, capExceeded := api.linkFilesToChat(ctx, chat.ID, fileIDs)
|
||||
response := codersdk.EditChatMessageResponse{
|
||||
Message: convertChatMessage(editResult.Message),
|
||||
}
|
||||
if len(unlinked) > 0 {
|
||||
if capExceeded {
|
||||
response.Warnings = append(response.Warnings, fileLinkCapWarning(len(unlinked)))
|
||||
} else {
|
||||
response.Warnings = append(response.Warnings, fileLinkErrorWarning(len(unlinked)))
|
||||
}
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
@@ -2051,6 +2167,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
chatID := chat.ID
|
||||
logger := api.Logger.Named("chat_streamer").With(slog.F("chat_id", chatID))
|
||||
|
||||
if api.chatDaemon == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
@@ -2073,7 +2190,22 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
// Subscribe before accepting the WebSocket so that failures
|
||||
// can still be reported as normal HTTP errors.
|
||||
snapshot, events, cancelSub, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
// Subscribe only fails today when the receiver is nil, which
|
||||
// the chatDaemon == nil guard above already catches. This is
|
||||
// defensive against future Subscribe failure modes.
|
||||
if !ok {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer cancelSub()
|
||||
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to open chat stream.",
|
||||
@@ -2081,41 +2213,30 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
if !ok {
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
},
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream unavailable event", slog.Error(err))
|
||||
}
|
||||
// Ensure the WebSocket is closed so senderClosed
|
||||
// completes and the handler can return.
|
||||
<-senderClosed
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
_ = conn.CloseRead(context.Background())
|
||||
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
// Mark the chat as read when the stream connects and again
|
||||
// when it disconnects so we avoid per-message API calls while
|
||||
// messages are actively streaming.
|
||||
api.markChatAsRead(ctx, chatID)
|
||||
defer api.markChatAsRead(context.WithoutCancel(ctx), chatID)
|
||||
|
||||
encoder := json.NewEncoder(wsNetConn)
|
||||
|
||||
sendChatStreamBatch := func(batch []codersdk.ChatStreamEvent) error {
|
||||
if len(batch) == 0 {
|
||||
return nil
|
||||
}
|
||||
return sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: batch,
|
||||
})
|
||||
return encoder.Encode(batch)
|
||||
}
|
||||
|
||||
drainChatStreamBatch := func(
|
||||
@@ -2148,7 +2269,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
end = len(snapshot)
|
||||
}
|
||||
if err := sendChatStreamBatch(snapshot[start:end]); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
|
||||
logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -2157,8 +2278,6 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-senderClosed:
|
||||
return
|
||||
case firstEvent, ok := <-events:
|
||||
if !ok {
|
||||
return
|
||||
@@ -2168,7 +2287,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
chatStreamBatchSize,
|
||||
)
|
||||
if err := sendChatStreamBatch(batch); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
|
||||
logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if streamClosed {
|
||||
@@ -2183,6 +2302,7 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
chatID := chat.ID
|
||||
logger := api.Logger.Named("chat_interrupt").With(slog.F("chat_id", chatID))
|
||||
|
||||
if api.chatDaemon != nil {
|
||||
chat = api.chatDaemon.InterruptChat(ctx, chat)
|
||||
@@ -2196,8 +2316,7 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
if updateErr != nil {
|
||||
api.Logger.Error(ctx, "failed to mark chat as waiting",
|
||||
slog.F("chat_id", chatID), slog.Error(updateErr))
|
||||
logger.Error(ctx, "failed to mark chat as waiting", slog.Error(updateErr))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to interrupt chat.",
|
||||
Detail: updateErr.Error(),
|
||||
@@ -2207,7 +2326,7 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
chat = updatedChat
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, nil))
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, nil, nil))
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
@@ -2251,7 +2370,7 @@ func (api *API) regenerateChatTitle(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(updatedChat, nil))
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(updatedChat, nil, nil))
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
@@ -2869,6 +2988,8 @@ const (
|
||||
maxChatFileSize = 10 << 20
|
||||
// maxChatFileName is the maximum length of an uploaded file name.
|
||||
maxChatFileName = 255
|
||||
// chatFilesNamespace is the object store namespace for chat files.
|
||||
chatFilesNamespace = "chatfiles"
|
||||
)
|
||||
|
||||
// allowedChatFileMIMETypes lists the content types accepted for chat
|
||||
@@ -3107,6 +3228,70 @@ func (api *API) putChatWorkspaceTTL(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// @Summary Get chat retention days
|
||||
// @ID get-chat-retention-days
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Chats
|
||||
// @Produce json
|
||||
// @Success 200 {object} codersdk.ChatRetentionDaysResponse
|
||||
// @Router /experimental/chats/config/retention-days [get]
|
||||
// @x-apidocgen {"skip": true}
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
func (api *API) getChatRetentionDays(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
retentionDays, err := api.Database.GetChatRetentionDays(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get chat retention days.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatRetentionDaysResponse{
|
||||
RetentionDays: retentionDays,
|
||||
})
|
||||
}
|
||||
|
||||
// Keep in sync with retentionDaysMaximum in
|
||||
// site/src/pages/AgentsPage/AgentSettingsBehaviorPageView.tsx.
|
||||
const retentionDaysMaximum = 3650 // ~10 years
|
||||
|
||||
// @Summary Update chat retention days
|
||||
// @ID update-chat-retention-days
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Chats
|
||||
// @Accept json
|
||||
// @Param request body codersdk.UpdateChatRetentionDaysRequest true "Request body"
|
||||
// @Success 204
|
||||
// @Router /experimental/chats/config/retention-days [put]
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) putChatRetentionDays(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
var req codersdk.UpdateChatRetentionDaysRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
if req.RetentionDays < 0 || req.RetentionDays > retentionDaysMaximum {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Retention days must be between 0 and %d.", retentionDaysMaximum),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := api.Database.UpsertChatRetentionDays(ctx, req.RetentionDays); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to update chat retention days.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
@@ -3476,35 +3661,6 @@ func (api *API) deleteUserChatCompactionThreshold(rw http.ResponseWriter, r *htt
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (api *API) resolvedChatSystemPrompt(ctx context.Context) string {
|
||||
config, err := api.Database.GetChatSystemPromptConfig(ctx)
|
||||
if err != nil {
|
||||
// We intentionally fail open here. When the prompt configuration
|
||||
// cannot be read, returning the built-in default keeps the chat
|
||||
// grounded instead of sending no system guidance at all.
|
||||
api.Logger.Error(ctx, "failed to fetch chat system prompt configuration, using default", slog.Error(err))
|
||||
return chatd.DefaultSystemPrompt
|
||||
}
|
||||
|
||||
sanitizedCustom := chatd.SanitizePromptText(config.ChatSystemPrompt)
|
||||
if sanitizedCustom == "" && strings.TrimSpace(config.ChatSystemPrompt) != "" {
|
||||
api.Logger.Warn(ctx, "custom system prompt became empty after sanitization, omitting custom portion")
|
||||
}
|
||||
|
||||
var parts []string
|
||||
if config.IncludeDefaultSystemPrompt {
|
||||
parts = append(parts, chatd.DefaultSystemPrompt)
|
||||
}
|
||||
if sanitizedCustom != "" {
|
||||
parts = append(parts, sanitizedCustom)
|
||||
}
|
||||
result := strings.Join(parts, "\n\n")
|
||||
if result == "" {
|
||||
api.Logger.Warn(ctx, "resolved system prompt is empty, no system prompt will be injected into chats")
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (api *API) postChatFile(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
@@ -3630,12 +3786,21 @@ func (api *API) postChatFile(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
key := uuid.New().String()
|
||||
if err := api.ObjectStore.Write(ctx, chatFilesNamespace, key, data); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to save chat file.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chatFile, err := api.Database.InsertChatFile(ctx, database.InsertChatFileParams{
|
||||
OwnerID: apiKey.UserID,
|
||||
OrganizationID: orgID,
|
||||
Name: filename,
|
||||
Mimetype: detected,
|
||||
Data: data,
|
||||
ObjectStoreKey: key,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
@@ -3682,6 +3847,27 @@ func (api *API) chatFileByID(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Set("Content-Disposition", "inline")
|
||||
}
|
||||
rw.Header().Set("Cache-Control", "private, max-age=31536000, immutable")
|
||||
|
||||
// Serve from object store, falling back to the database BYTEA
|
||||
// column for files that predate the migration.
|
||||
if chatFile.ObjectStoreKey.Valid && chatFile.ObjectStoreKey.String != "" {
|
||||
rc, info, err := api.ObjectStore.Read(ctx, chatFilesNamespace, chatFile.ObjectStoreKey.String)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to read chat file from storage.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer rc.Close()
|
||||
rw.Header().Set("Content-Length", strconv.FormatInt(info.Size, 10))
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
if _, err := io.Copy(rw, rc); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to stream chat file response", slog.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Length", strconv.Itoa(len(chatFile.Data)))
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
if _, err := rw.Write(chatFile.Data); err != nil {
|
||||
@@ -3692,6 +3878,7 @@ func (api *API) chatFileByID(rw http.ResponseWriter, r *http.Request) {
|
||||
func createChatInputFromRequest(ctx context.Context, db database.Store, req codersdk.CreateChatRequest) (
|
||||
[]codersdk.ChatMessagePart,
|
||||
string,
|
||||
[]uuid.UUID,
|
||||
*codersdk.Response,
|
||||
) {
|
||||
return createChatInputFromParts(ctx, db, req.Content, "content")
|
||||
@@ -3702,14 +3889,15 @@ func createChatInputFromParts(
|
||||
db database.Store,
|
||||
parts []codersdk.ChatInputPart,
|
||||
fieldName string,
|
||||
) ([]codersdk.ChatMessagePart, string, *codersdk.Response) {
|
||||
) ([]codersdk.ChatMessagePart, string, []uuid.UUID, *codersdk.Response) {
|
||||
if len(parts) == 0 {
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, "", nil, &codersdk.Response{
|
||||
Message: "Content is required.",
|
||||
Detail: "Content cannot be empty.",
|
||||
}
|
||||
}
|
||||
|
||||
var fileIDs []uuid.UUID
|
||||
content := make([]codersdk.ChatMessagePart, 0, len(parts))
|
||||
textParts := make([]string, 0, len(parts))
|
||||
for i, part := range parts {
|
||||
@@ -3717,7 +3905,7 @@ func createChatInputFromParts(
|
||||
case string(codersdk.ChatInputPartTypeText):
|
||||
text := strings.TrimSpace(part.Text)
|
||||
if text == "" {
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, "", nil, &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf("%s[%d].text cannot be empty.", fieldName, i),
|
||||
}
|
||||
@@ -3726,7 +3914,7 @@ func createChatInputFromParts(
|
||||
textParts = append(textParts, text)
|
||||
case string(codersdk.ChatInputPartTypeFile):
|
||||
if part.FileID == uuid.Nil {
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, "", nil, &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf("%s[%d].file_id is required for file parts.", fieldName, i),
|
||||
}
|
||||
@@ -3737,20 +3925,23 @@ func createChatInputFromParts(
|
||||
chatFile, err := db.GetChatFileByID(ctx, part.FileID)
|
||||
if err != nil {
|
||||
if httpapi.Is404Error(err) {
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, "", nil, &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf("%s[%d].file_id references a file that does not exist.", fieldName, i),
|
||||
}
|
||||
}
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, "", nil, &codersdk.Response{
|
||||
Message: "Internal error.",
|
||||
Detail: fmt.Sprintf("Failed to retrieve file for %s[%d].", fieldName, i),
|
||||
}
|
||||
}
|
||||
content = append(content, codersdk.ChatMessageFile(part.FileID, chatFile.Mimetype))
|
||||
fileIDs = append(fileIDs, part.FileID)
|
||||
// file-reference parts carry inline code snippets, not uploaded
|
||||
// files. They have no FileID and are excluded from file tracking.
|
||||
case string(codersdk.ChatInputPartTypeFileReference):
|
||||
if part.FileName == "" {
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, "", nil, &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf("%s[%d].file_name cannot be empty for file-reference.", fieldName, i),
|
||||
}
|
||||
@@ -3768,7 +3959,7 @@ func createChatInputFromParts(
|
||||
}
|
||||
textParts = append(textParts, sb.String())
|
||||
default:
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, "", nil, &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf(
|
||||
"%s[%d].type %q is not supported.",
|
||||
@@ -3783,13 +3974,13 @@ func createChatInputFromParts(
|
||||
// Allow file-only messages. The titleSource may be empty
|
||||
// when only file parts are provided, callers handle this.
|
||||
if len(content) == 0 {
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, "", nil, &codersdk.Response{
|
||||
Message: "Content is required.",
|
||||
Detail: fmt.Sprintf("%s must include at least one text or file part.", fieldName),
|
||||
}
|
||||
}
|
||||
titleSource := strings.TrimSpace(strings.Join(textParts, " "))
|
||||
return content, titleSource, nil
|
||||
return content, titleSource, fileIDs, nil
|
||||
}
|
||||
|
||||
func chatTitleFromMessage(message string) string {
|
||||
@@ -3824,6 +4015,70 @@ func truncateRunes(value string, maxLen int) string {
|
||||
return string(runes[:maxLen])
|
||||
}
|
||||
|
||||
// linkFilesToChat inserts file-link rows into the chat_file_links
|
||||
// join table. Cap enforcement and dedup are handled atomically in
|
||||
// SQL. On success returns (nil, false). On failure returns the full
|
||||
// input fileIDs slice — linking is all-or-nothing because the
|
||||
// SQL operates on the batch atomically. capExceeded indicates
|
||||
// whether the failure was due to the cap being exceeded (true)
|
||||
// or a database error (false).
|
||||
// Failures are logged but never block the caller.
|
||||
func (api *API) linkFilesToChat(ctx context.Context, chatID uuid.UUID, fileIDs []uuid.UUID) (unlinked []uuid.UUID, capExceeded bool) {
|
||||
if len(fileIDs) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
rejected, err := api.Database.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: chatID,
|
||||
MaxFileLinks: int32(codersdk.MaxChatFileIDs),
|
||||
FileIds: fileIDs,
|
||||
})
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "failed to link files to chat",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.F("file_ids", fileIDs),
|
||||
slog.Error(err),
|
||||
)
|
||||
return fileIDs, false
|
||||
}
|
||||
if rejected > 0 {
|
||||
api.Logger.Warn(ctx, "file cap reached, files not linked",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.F("file_ids", fileIDs),
|
||||
slog.F("max_file_links", codersdk.MaxChatFileIDs),
|
||||
)
|
||||
return fileIDs, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// fileLinkCapWarning builds a user-facing warning when a batch
|
||||
// of file IDs was atomically rejected because the resulting
|
||||
// array would exceed the per-chat file cap.
|
||||
func fileLinkCapWarning(count int) string {
|
||||
return fmt.Sprintf("file linking skipped: batch of %d file(s) would exceed limit of %d", count, codersdk.MaxChatFileIDs)
|
||||
}
|
||||
|
||||
// fileLinkErrorWarning builds a user-facing warning when a
|
||||
// database error prevented linking files to a chat.
|
||||
func fileLinkErrorWarning(count int) string {
|
||||
return fmt.Sprintf("%d file(s) could not be linked due to a server error", count)
|
||||
}
|
||||
|
||||
// fetchChatFileMetadata returns metadata for all files linked to
|
||||
// the given chat. Errors are logged and result in a nil return
|
||||
// (callers treat file metadata as best-effort).
|
||||
func (api *API) fetchChatFileMetadata(ctx context.Context, chatID uuid.UUID) []database.GetChatFileMetadataByChatIDRow {
|
||||
rows, err := api.Database.GetChatFileMetadataByChatID(ctx, chatID)
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "failed to fetch chat file metadata",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
return rows
|
||||
}
|
||||
|
||||
func convertChatCostModelBreakdown(model database.GetChatCostPerModelRow) codersdk.ChatCostModelBreakdown {
|
||||
displayName := strings.TrimSpace(model.DisplayName)
|
||||
if displayName == "" {
|
||||
@@ -3840,6 +4095,7 @@ func convertChatCostModelBreakdown(model database.GetChatCostPerModelRow) coders
|
||||
TotalOutputTokens: model.TotalOutputTokens,
|
||||
TotalCacheReadTokens: model.TotalCacheReadTokens,
|
||||
TotalCacheCreationTokens: model.TotalCacheCreationTokens,
|
||||
TotalRuntimeMs: model.TotalRuntimeMs,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3853,6 +4109,7 @@ func convertChatCostChatBreakdown(chat database.GetChatCostPerChatRow) codersdk.
|
||||
TotalOutputTokens: chat.TotalOutputTokens,
|
||||
TotalCacheReadTokens: chat.TotalCacheReadTokens,
|
||||
TotalCacheCreationTokens: chat.TotalCacheCreationTokens,
|
||||
TotalRuntimeMs: chat.TotalRuntimeMs,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3869,6 +4126,7 @@ func convertChatCostUserRollup(user database.GetChatCostPerUserRow) codersdk.Cha
|
||||
TotalOutputTokens: user.TotalOutputTokens,
|
||||
TotalCacheReadTokens: user.TotalCacheReadTokens,
|
||||
TotalCacheCreationTokens: user.TotalCacheCreationTokens,
|
||||
TotalRuntimeMs: user.TotalRuntimeMs,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5568,3 +5826,77 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
|
||||
RecentPRs: prEntries,
|
||||
})
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
func (api *API) postChatToolResults(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
// Cap the raw request body to prevent excessive memory use.
|
||||
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
|
||||
var req codersdk.SubmitToolResultsRequest
|
||||
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Results) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "At least one tool result is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Fast-path check outside the transaction. The authoritative
|
||||
// check happens inside SubmitToolResults under a row lock.
|
||||
if chat.Status != database.ChatStatusRequiresAction {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Chat is not waiting for tool results.",
|
||||
Detail: fmt.Sprintf("Chat status is %q, expected %q.", chat.Status, database.ChatStatusRequiresAction),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var dynamicTools json.RawMessage
|
||||
if chat.DynamicTools.Valid {
|
||||
dynamicTools = chat.DynamicTools.RawMessage
|
||||
}
|
||||
|
||||
err := api.chatDaemon.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
|
||||
ChatID: chat.ID,
|
||||
UserID: apiKey.UserID,
|
||||
ModelConfigID: chat.LastModelConfigID,
|
||||
Results: req.Results,
|
||||
DynamicTools: dynamicTools,
|
||||
})
|
||||
if err != nil {
|
||||
var validationErr *chatd.ToolResultValidationError
|
||||
var conflictErr *chatd.ToolResultStatusConflictError
|
||||
switch {
|
||||
case errors.As(err, &conflictErr):
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Chat is not waiting for tool results.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
case errors.As(err, &validationErr):
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: validationErr.Message,
|
||||
Detail: validationErr.Detail,
|
||||
})
|
||||
default:
|
||||
api.Logger.Error(ctx, "tool results submission failed",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error submitting tool results.",
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
+1008
-122
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,4 @@
|
||||
package coderd
|
||||
|
||||
// InsertAgentChatTestModelConfig exposes insertAgentChatTestModelConfig for external tests.
|
||||
var InsertAgentChatTestModelConfig = insertAgentChatTestModelConfig
|
||||
@@ -148,7 +148,7 @@ func TestGetOrgMembersFilter(t *testing.T) {
|
||||
setupCtx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
|
||||
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
|
||||
res, err := client.OrganizationMembersPaginated(testCtx, first.OrganizationID, req)
|
||||
require.NoError(t, err)
|
||||
reduced := make([]codersdk.ReducedUser, len(res.Members))
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user