Compare commits

..

1 Commits

Author SHA1 Message Date
Stephen Kirby 9589ebb15a chore(coderd/database): add pruning for orphaned chat file_ids entries
Add a new DeleteOrphanedChatFiles SQL query that removes chat_files rows
older than 24 hours that are not referenced by any non-deleted chat
message's content. Register this as a periodic job in the dbpurge
package, following the existing pattern used by other purge operations.

The query scans the JSONB content array of chat_messages for file_id
references and preserves any chat_files still in use.
2026-04-02 15:38:14 +00:00
313 changed files with 5311 additions and 21124 deletions
+3 -4
View File
@@ -31,8 +31,7 @@ updates:
patterns:
- "golang.org/x/*"
ignore:
# Patch updates are handled by the security-patch-prs workflow so this
# lane stays focused on broader dependency updates.
# Ignore patch updates for all dependencies
- dependency-name: "*"
update-types:
- version-update:semver-patch
@@ -57,7 +56,7 @@ updates:
labels: []
ignore:
# We need to coordinate terraform updates with the version hardcoded in
# our Go code. These are handled by the security-patch-prs workflow.
# our Go code.
- dependency-name: "terraform"
- package-ecosystem: "npm"
@@ -118,11 +117,11 @@ updates:
interval: "weekly"
commit-message:
prefix: "chore"
labels: []
groups:
coder-modules:
patterns:
- "coder/*/coder"
labels: []
ignore:
- dependency-name: "*"
update-types:
+17 -17
View File
@@ -35,7 +35,7 @@ jobs:
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -157,7 +157,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -247,7 +247,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -272,7 +272,7 @@ jobs:
if: ${{ !cancelled() }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -327,7 +327,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -379,7 +379,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -575,7 +575,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -637,7 +637,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -709,7 +709,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -736,7 +736,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -769,7 +769,7 @@ jobs:
name: ${{ matrix.variant.name }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -849,7 +849,7 @@ jobs:
if: needs.changes.outputs.site == 'true' || needs.changes.outputs.ci == 'true'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -930,7 +930,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -1005,7 +1005,7 @@ jobs:
if: always()
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -1043,7 +1043,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -1097,7 +1097,7 @@ jobs:
IMAGE: ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -1479,7 +1479,7 @@ jobs:
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+1 -1
View File
@@ -23,7 +23,7 @@ jobs:
steps:
- name: Dependabot metadata
id: metadata
uses: dependabot/fetch-metadata@ffa630c65fa7e0ecfa0625b5ceda64399aea1b36 # v3.0.0
uses: dependabot/fetch-metadata@21025c705c08248db411dc16f3619e6b5f9ea21a # v2.5.0
with:
github-token: "${{ secrets.GITHUB_TOKEN }}"
+3 -3
View File
@@ -36,7 +36,7 @@ jobs:
verdict: ${{ steps.check.outputs.verdict }} # DEPLOY or NOOP
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -65,7 +65,7 @@ jobs:
packages: write # to retag image as dogfood
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -142,7 +142,7 @@ jobs:
needs: deploy
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+1 -1
View File
@@ -38,7 +38,7 @@ jobs:
if: github.repository_owner == 'coder'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+2 -2
View File
@@ -26,7 +26,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -125,7 +125,7 @@ jobs:
id-token: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+1 -1
View File
@@ -28,7 +28,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+1 -1
View File
@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+1 -1
View File
@@ -19,7 +19,7 @@ jobs:
packages: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+5 -5
View File
@@ -39,7 +39,7 @@ jobs:
PR_OPEN: ${{ steps.check_pr.outputs.pr_open }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -76,7 +76,7 @@ jobs:
runs-on: "ubuntu-latest"
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -184,7 +184,7 @@ jobs:
pull-requests: write # needed for commenting on PRs
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -228,7 +228,7 @@ jobs:
CODER_IMAGE_TAG: ${{ needs.get_info.outputs.CODER_IMAGE_TAG }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -288,7 +288,7 @@ jobs:
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+1 -1
View File
@@ -14,7 +14,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+3 -3
View File
@@ -81,7 +81,7 @@ jobs:
version: ${{ steps.version.outputs.version }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -673,7 +673,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -749,7 +749,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+2 -2
View File
@@ -20,7 +20,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -47,6 +47,6 @@ jobs:
# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@c10b8064de6f491fea524254123dbe5e09572f13 # v3.29.5
uses: github/codeql-action/upload-sarif@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
with:
sarif_file: results.sarif
-354
View File
@@ -1,354 +0,0 @@
name: security-backport
on:
pull_request_target:
types:
- labeled
- unlabeled
- closed
workflow_dispatch:
inputs:
pull_request:
description: Pull request number to backport.
required: true
type: string
permissions:
contents: write
pull-requests: write
issues: write
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || inputs.pull_request }}
cancel-in-progress: false
env:
LATEST_BRANCH: release/2.31
STABLE_BRANCH: release/2.30
STABLE_1_BRANCH: release/2.29
jobs:
label-policy:
if: github.event_name == 'pull_request_target'
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
- name: Apply security backport label policy
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
pr_json="$(gh pr view "${pr_number}" --json number,title,url,baseRefName,labels)"
PR_JSON="${pr_json}" \
python3 - <<'PY'
import json
import os
import subprocess
import sys
pr = json.loads(os.environ["PR_JSON"])
pr_number = pr["number"]
labels = [label["name"] for label in pr.get("labels", [])]
def has(label: str) -> bool:
return label in labels
def ensure_label(label: str) -> None:
if not has(label):
subprocess.run(
["gh", "pr", "edit", str(pr_number), "--add-label", label],
check=False,
)
def remove_label(label: str) -> None:
if has(label):
subprocess.run(
["gh", "pr", "edit", str(pr_number), "--remove-label", label],
check=False,
)
def comment(body: str) -> None:
subprocess.run(
["gh", "pr", "comment", str(pr_number), "--body", body],
check=True,
)
if not has("security:patch"):
remove_label("status:needs-severity")
sys.exit(0)
severity_labels = [
label
for label in ("severity:medium", "severity:high", "severity:critical")
if has(label)
]
if len(severity_labels) == 0:
ensure_label("status:needs-severity")
comment(
"This PR is labeled `security:patch` but is missing a severity "
"label. Add one of `severity:medium`, `severity:high`, or "
"`severity:critical` before backport automation can proceed."
)
sys.exit(0)
if len(severity_labels) > 1:
comment(
"This PR has multiple severity labels. Keep exactly one of "
"`severity:medium`, `severity:high`, or `severity:critical`."
)
sys.exit(1)
remove_label("status:needs-severity")
target_labels = [
label
for label in ("backport:stable", "backport:stable-1")
if has(label)
]
has_none = has("backport:none")
if has_none and target_labels:
comment(
"`backport:none` cannot be combined with other backport labels. "
"Remove `backport:none` or remove the explicit backport targets."
)
sys.exit(1)
if not has_none and not target_labels:
ensure_label("backport:stable")
ensure_label("backport:stable-1")
comment(
"Applied default backport labels `backport:stable` and "
"`backport:stable-1` for a qualifying security patch."
)
PY
backport:
if: >
github.event_name == 'workflow_dispatch' ||
(
github.event_name == 'pull_request_target' &&
github.event.pull_request.merged == true
)
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Resolve PR metadata
id: metadata
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
INPUT_PR_NUMBER: ${{ inputs.pull_request }}
LATEST_BRANCH: ${{ env.LATEST_BRANCH }}
STABLE_BRANCH: ${{ env.STABLE_BRANCH }}
STABLE_1_BRANCH: ${{ env.STABLE_1_BRANCH }}
run: |
set -euo pipefail
if [ "${GITHUB_EVENT_NAME}" = "workflow_dispatch" ]; then
pr_number="${INPUT_PR_NUMBER}"
else
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
fi
case "${pr_number}" in
''|*[!0-9]*)
echo "A valid pull request number is required."
exit 1
;;
esac
pr_json="$(gh pr view "${pr_number}" --json number,title,url,mergeCommit,baseRefName,labels,mergedAt,author)"
PR_JSON="${pr_json}" \
python3 - <<'PY'
import json
import os
import sys
pr = json.loads(os.environ["PR_JSON"])
github_output = os.environ["GITHUB_OUTPUT"]
labels = [label["name"] for label in pr.get("labels", [])]
if "security:patch" not in labels:
print("Not a security patch PR; skipping.")
sys.exit(0)
severity_labels = [
label
for label in ("severity:medium", "severity:high", "severity:critical")
if label in labels
]
if len(severity_labels) != 1:
raise SystemExit(
"Merged security patch PR must have exactly one severity label."
)
if not pr.get("mergedAt"):
raise SystemExit(f"PR #{pr['number']} is not merged.")
if "backport:none" in labels:
target_pairs = []
else:
mapping = {
"backport:stable": os.environ["STABLE_BRANCH"],
"backport:stable-1": os.environ["STABLE_1_BRANCH"],
}
target_pairs = []
for label_name, branch in mapping.items():
if label_name in labels and branch and branch != pr["baseRefName"]:
target_pairs.append({"label": label_name, "branch": branch})
with open(github_output, "a", encoding="utf-8") as f:
f.write(f"pr_number={pr['number']}\n")
f.write(f"merge_sha={pr['mergeCommit']['oid']}\n")
f.write(f"title={pr['title']}\n")
f.write(f"url={pr['url']}\n")
f.write(f"author={pr['author']['login']}\n")
f.write(f"severity_label={severity_labels[0]}\n")
f.write(f"target_pairs={json.dumps(target_pairs)}\n")
PY
- name: Backport to release branches
if: ${{ steps.metadata.outputs.target_pairs != '[]' }}
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ steps.metadata.outputs.pr_number }}
MERGE_SHA: ${{ steps.metadata.outputs.merge_sha }}
PR_TITLE: ${{ steps.metadata.outputs.title }}
PR_URL: ${{ steps.metadata.outputs.url }}
PR_AUTHOR: ${{ steps.metadata.outputs.author }}
SEVERITY_LABEL: ${{ steps.metadata.outputs.severity_label }}
TARGET_PAIRS: ${{ steps.metadata.outputs.target_pairs }}
run: |
set -euo pipefail
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
git remote set-url origin "https://x-access-token:${GH_TOKEN}@github.com/${GITHUB_REPOSITORY}.git"
git fetch origin --prune
merge_parent_count="$(git rev-list --parents -n 1 "${MERGE_SHA}" | awk '{print NF-1}')"
failures=()
successes=()
while IFS=$'\t' read -r backport_label target_branch; do
[ -n "${target_branch}" ] || continue
safe_branch_name="${target_branch//\//-}"
head_branch="backport/${safe_branch_name}/pr-${PR_NUMBER}"
existing_pr="$(gh pr list \
--base "${target_branch}" \
--head "${head_branch}" \
--state all \
--json number,url \
--jq '.[0]')"
if [ -n "${existing_pr}" ] && [ "${existing_pr}" != "null" ]; then
pr_url="$(printf '%s' "${existing_pr}" | jq -r '.url')"
successes+=("${target_branch}:existing:${pr_url}")
continue
fi
git checkout -B "${head_branch}" "origin/${target_branch}"
if [ "${merge_parent_count}" -gt 1 ]; then
cherry_pick_args=(-m 1 "${MERGE_SHA}")
else
cherry_pick_args=("${MERGE_SHA}")
fi
if ! git cherry-pick -x "${cherry_pick_args[@]}"; then
git cherry-pick --abort || true
gh pr edit "${PR_NUMBER}" --add-label "backport:conflict" || true
gh pr comment "${PR_NUMBER}" --body \
"Automatic backport to \`${target_branch}\` conflicted. The original author or release manager should resolve it manually."
failures+=("${target_branch}:cherry-pick failed")
continue
fi
git push --force-with-lease origin "${head_branch}"
body_file="$(mktemp)"
printf '%s\n' \
"Automated backport of [#${PR_NUMBER}](${PR_URL})." \
"" \
"- Source PR: #${PR_NUMBER}" \
"- Source commit: ${MERGE_SHA}" \
"- Target branch: ${target_branch}" \
"- Severity: ${SEVERITY_LABEL}" \
> "${body_file}"
pr_url="$(gh pr create \
--base "${target_branch}" \
--head "${head_branch}" \
--title "${PR_TITLE} (backport to ${target_branch})" \
--body-file "${body_file}")"
backport_pr_number="$(gh pr list \
--base "${target_branch}" \
--head "${head_branch}" \
--state open \
--json number \
--jq '.[0].number')"
gh pr edit "${backport_pr_number}" \
--add-label "security:patch" \
--add-label "${SEVERITY_LABEL}" \
--add-label "${backport_label}" || true
successes+=("${target_branch}:created:${pr_url}")
done < <(
python3 - <<'PY'
import json
import os
for pair in json.loads(os.environ["TARGET_PAIRS"]):
print(f"{pair['label']}\t{pair['branch']}")
PY
)
summary_file="$(mktemp)"
{
echo "## Security backport summary"
echo
if [ "${#successes[@]}" -gt 0 ]; then
echo "### Created or existing"
for entry in "${successes[@]}"; do
echo "- ${entry}"
done
echo
fi
if [ "${#failures[@]}" -gt 0 ]; then
echo "### Failures"
for entry in "${failures[@]}"; do
echo "- ${entry}"
done
fi
} | tee -a "${GITHUB_STEP_SUMMARY}" > "${summary_file}"
gh pr comment "${PR_NUMBER}" --body-file "${summary_file}"
if [ "${#failures[@]}" -gt 0 ]; then
printf 'Backport failures:\n%s\n' "${failures[@]}" >&2
exit 1
fi
-214
View File
@@ -1,214 +0,0 @@
name: security-patch-prs
on:
workflow_dispatch:
schedule:
- cron: "0 3 * * 1-5"
permissions:
contents: write
pull-requests: write
jobs:
patch:
strategy:
fail-fast: false
matrix:
lane:
- gomod
- terraform
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Setup Go
uses: ./.github/actions/setup-go
- name: Patch Go dependencies
if: matrix.lane == 'gomod'
shell: bash
run: |
set -euo pipefail
go get -u=patch ./...
go mod tidy
# Guardrail: do not auto-edit replace directives.
if git diff --unified=0 -- go.mod | grep -E '^[+-]replace '; then
echo "Refusing to auto-edit go.mod replace directives"
exit 1
fi
# Guardrail: only go.mod / go.sum may change.
extra="$(git diff --name-only | grep -Ev '^(go\.mod|go\.sum)$' || true)"
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
- name: Patch bundled Terraform
if: matrix.lane == 'terraform'
shell: bash
run: |
set -euo pipefail
current="$(
grep -oE 'NewVersion\("[0-9]+\.[0-9]+\.[0-9]+"\)' \
provisioner/terraform/install.go \
| head -1 \
| grep -oE '[0-9]+\.[0-9]+\.[0-9]+'
)"
series="$(echo "$current" | cut -d. -f1,2)"
latest="$(
curl -fsSL https://releases.hashicorp.com/terraform/index.json \
| jq -r --arg series "$series" '
.versions
| keys[]
| select(startswith($series + "."))
' \
| sort -V \
| tail -1
)"
test -n "$latest"
[ "$latest" != "$current" ] || exit 0
CURRENT_TERRAFORM_VERSION="$current" \
LATEST_TERRAFORM_VERSION="$latest" \
python3 - <<'PY'
from pathlib import Path
import os
current = os.environ["CURRENT_TERRAFORM_VERSION"]
latest = os.environ["LATEST_TERRAFORM_VERSION"]
updates = {
"scripts/Dockerfile.base": (
f"terraform/{current}/",
f"terraform/{latest}/",
),
"provisioner/terraform/install.go": (
f'NewVersion("{current}")',
f'NewVersion("{latest}")',
),
"install.sh": (
f'TERRAFORM_VERSION="{current}"',
f'TERRAFORM_VERSION="{latest}"',
),
}
for path_str, (before, after) in updates.items():
path = Path(path_str)
content = path.read_text()
if before not in content:
raise SystemExit(f"did not find expected text in {path_str}: {before}")
path.write_text(content.replace(before, after))
PY
# Guardrail: only the Terraform-version files may change.
extra="$(git diff --name-only | grep -Ev '^(scripts/Dockerfile.base|provisioner/terraform/install.go|install.sh)$' || true)"
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
- name: Validate Go dependency patch
if: matrix.lane == 'gomod'
shell: bash
run: |
set -euo pipefail
go test ./...
- name: Validate Terraform patch
if: matrix.lane == 'terraform'
shell: bash
run: |
set -euo pipefail
go test ./provisioner/terraform/...
docker build -f scripts/Dockerfile.base .
- name: Skip PR creation when there are no changes
id: changes
shell: bash
run: |
set -euo pipefail
if git diff --quiet; then
echo "has_changes=false" >> "${GITHUB_OUTPUT}"
else
echo "has_changes=true" >> "${GITHUB_OUTPUT}"
fi
- name: Commit changes
if: steps.changes.outputs.has_changes == 'true'
shell: bash
run: |
set -euo pipefail
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
git checkout -B "secpatch/${{ matrix.lane }}"
git add -A
git commit -m "security: patch ${{ matrix.lane }}"
- name: Push branch
if: steps.changes.outputs.has_changes == 'true'
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
git push --force-with-lease \
"https://x-access-token:${GH_TOKEN}@github.com/${{ github.repository }}.git" \
"HEAD:refs/heads/secpatch/${{ matrix.lane }}"
- name: Create or update PR
if: steps.changes.outputs.has_changes == 'true'
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
branch="secpatch/${{ matrix.lane }}"
title="security: patch ${{ matrix.lane }}"
body="$(cat <<'EOF'
Automated security patch PR for `${{ matrix.lane }}`.
Scope:
- gomod: patch-level Go dependency updates only
- terraform: bundled Terraform patch updates only
Guardrails:
- no application-code edits
- no auto-editing of go.mod replace directives
- CI must pass
EOF
)"
existing_pr="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
if [[ -n "${existing_pr}" ]]; then
gh pr edit "${existing_pr}" \
--title "${title}" \
--body "${body}"
pr_number="${existing_pr}"
else
gh pr create \
--base main \
--head "${branch}" \
--title "${title}" \
--body "${body}"
pr_number="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
fi
for label in security dependencies automated-pr; do
if gh label list --json name --jq '.[].name' | grep -Fxq "${label}"; then
gh pr edit "${pr_number}" --add-label "${label}"
fi
done
+3 -3
View File
@@ -27,7 +27,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -40,7 +40,7 @@ jobs:
uses: ./.github/actions/setup-go
- name: Initialize CodeQL
uses: github/codeql-action/init@c10b8064de6f491fea524254123dbe5e09572f13 # v3.29.5
uses: github/codeql-action/init@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
with:
languages: go, javascript
@@ -50,7 +50,7 @@ jobs:
rm Makefile
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@c10b8064de6f491fea524254123dbe5e09572f13 # v3.29.5
uses: github/codeql-action/analyze@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
- name: Send Slack notification on failure
if: ${{ failure() }}
+3 -3
View File
@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -96,7 +96,7 @@ jobs:
contents: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
@@ -120,7 +120,7 @@ jobs:
actions: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
+1 -1
View File
@@ -21,7 +21,7 @@ jobs:
pull-requests: write # required to post PR review comments by the action
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
-3
View File
@@ -110,9 +110,6 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
- For experimental or unstable API paths, skip public doc generation with
`// @x-apidocgen {"skip": true}` after the `@Router` annotation. This
keeps them out of the published API reference until they stabilize.
- Experimental chat endpoints in `coderd/exp_chats.go` omit swagger
annotations entirely. Do not add `@Summary`, `@Router`, or other
swagger comments to handlers in that file.
### Database Query Naming
+2 -2
View File
@@ -398,7 +398,7 @@ func (a *agent) init() {
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
desktop := agentdesktop.NewPortableDesktop(
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(), nil,
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(),
)
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
a.mcpManager = agentmcp.NewManager(a.logger.Named("mcp"))
@@ -1366,7 +1366,7 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context,
// lifecycle transition to avoid delaying Ready.
// This runs inside the tracked goroutine so it
// is properly awaited on shutdown.
if mcpErr := a.mcpManager.Connect(a.gracefulCtx, a.contextConfigAPI.MCPConfigFiles()); mcpErr != nil {
if mcpErr := a.mcpManager.Connect(a.gracefulCtx, a.contextConfigAPI.Config().MCPConfigFiles); mcpErr != nil {
a.logger.Warn(ctx, "failed to connect to workspace MCP servers", slog.Error(mcpErr))
}
})
+7 -7
View File
@@ -83,14 +83,14 @@ func TestContextConfigAPI_InitOnce(t *testing.T) {
return ""
})
mcpFiles1 := a.contextConfigAPI.MCPConfigFiles()
require.NotEmpty(t, mcpFiles1)
require.Contains(t, mcpFiles1[0], dir1)
cfg1 := a.contextConfigAPI.Config()
require.NotEmpty(t, cfg1.MCPConfigFiles)
require.Contains(t, cfg1.MCPConfigFiles[0], dir1)
// Simulate manifest update on reconnection -- no field
// Simulate manifest update on reconnection no field
// reassignment needed, the lazy closure picks it up.
a.manifest.Store(&agentsdk.Manifest{Directory: dir2})
mcpFiles2 := a.contextConfigAPI.MCPConfigFiles()
require.NotEmpty(t, mcpFiles2)
require.Contains(t, mcpFiles2[0], dir2)
cfg2 := a.contextConfigAPI.Config()
require.NotEmpty(t, cfg2.MCPConfigFiles)
require.Contains(t, cfg2.MCPConfigFiles[0], dir2)
}
+22 -251
View File
@@ -2,17 +2,13 @@ package agentcontextconfig
import (
"cmp"
"io"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/go-chi/chi/v5"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
@@ -26,47 +22,9 @@ const (
EnvMCPConfigFiles = "CODER_AGENT_EXP_MCP_CONFIG_FILES"
)
const (
maxInstructionFileBytes = 64 * 1024
maxSkillMetaBytes = 64 * 1024
)
// markdownCommentPattern strips HTML comments from instruction
// file content for security (prevents hidden prompt injection).
var markdownCommentPattern = regexp.MustCompile(`<!--[\s\S]*?-->`)
// invisibleRunePattern strips invisible Unicode characters that
// could be used for prompt injection.
//
//nolint:gocritic // Non-ASCII char ranges are intentional for invisible Unicode stripping.
var invisibleRunePattern = regexp.MustCompile(
"[\u00ad\u034f\u061c\u070f" +
"\u115f\u1160\u17b4\u17b5" +
"\u180b-\u180f" +
"\u200b\u200d\u200e\u200f" +
"\u202a-\u202e" +
"\u2060-\u206f" +
"\u3164" +
"\ufe00-\ufe0f" +
"\ufeff" +
"\uffa0" +
"\ufff0-\ufff8]",
)
// skillNamePattern validates kebab-case skill names.
var skillNamePattern = regexp.MustCompile(
`^[a-z0-9]+(-[a-z0-9]+)*$`,
)
// Default values for agent-internal configuration. These are
// used when the corresponding env vars are unset.
const (
DefaultInstructionsDir = "~/.coder"
DefaultInstructionsFile = "AGENTS.md"
DefaultSkillsDir = ".agents/skills"
DefaultSkillMetaFile = "SKILL.md"
DefaultMCPConfigFile = ".mcp.json"
)
// Defaults are defined in codersdk/workspacesdk so both
// the agent and server can reference them without a
// cross-layer import.
// API exposes the resolved context configuration through the
// agent's HTTP API.
@@ -84,61 +42,33 @@ func NewAPI(workingDir func() string) *API {
return &API{workingDir: workingDir}
}
// Config reads env vars, resolves paths, reads instruction files,
// and discovers skills. Returns the HTTP response and the resolved
// MCP config file paths (used only agent-internally). Exported
// for use by tests.
func Config(workingDir string) (workspacesdk.ContextConfigResponse, []string) {
// Config reads env vars and resolves paths. Exported for use
// by the MCP manager and tests.
func Config(workingDir string) workspacesdk.ContextConfigResponse {
// TrimSpace all env vars before cmp.Or so that a
// whitespace-only value falls through to the default
// consistently. ResolvePaths also trims each comma-
// separated entry, but without pre-trimming here a
// bare " " would bypass cmp.Or and produce nil.
instructionsDir := cmp.Or(strings.TrimSpace(os.Getenv(EnvInstructionsDirs)), DefaultInstructionsDir)
instructionsFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvInstructionsFile)), DefaultInstructionsFile)
skillsDir := cmp.Or(strings.TrimSpace(os.Getenv(EnvSkillsDirs)), DefaultSkillsDir)
skillMetaFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvSkillMetaFile)), DefaultSkillMetaFile)
mcpConfigFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvMCPConfigFiles)), DefaultMCPConfigFile)
resolvedInstructionsDirs := ResolvePaths(instructionsDir, workingDir)
resolvedSkillsDirs := ResolvePaths(skillsDir, workingDir)
// Read instruction files from each configured directory.
parts := readInstructionFiles(resolvedInstructionsDirs, instructionsFile)
// Also check the working directory for the instruction file,
// unless it was already covered by InstructionsDirs.
if workingDir != "" {
seenDirs := make(map[string]struct{}, len(resolvedInstructionsDirs))
for _, d := range resolvedInstructionsDirs {
seenDirs[d] = struct{}{}
}
if _, ok := seenDirs[workingDir]; !ok {
if entry, found := readInstructionFileFromDir(workingDir, instructionsFile); found {
parts = append(parts, entry)
}
}
}
// Discover skills from each configured skills directory.
skillParts := discoverSkills(resolvedSkillsDirs, skillMetaFile)
parts = append(parts, skillParts...)
// Guarantee non-nil slice to signal agent support.
if parts == nil {
parts = []codersdk.ChatMessagePart{}
}
instructionsDir := cmp.Or(strings.TrimSpace(os.Getenv(EnvInstructionsDirs)), workspacesdk.DefaultInstructionsDir)
instructionsFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvInstructionsFile)), workspacesdk.DefaultInstructionsFile)
skillsDir := cmp.Or(strings.TrimSpace(os.Getenv(EnvSkillsDirs)), workspacesdk.DefaultSkillsDir)
skillMetaFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvSkillMetaFile)), workspacesdk.DefaultSkillMetaFile)
mcpConfigFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvMCPConfigFiles)), workspacesdk.DefaultMCPConfigFile)
return workspacesdk.ContextConfigResponse{
Parts: parts,
}, ResolvePaths(mcpConfigFile, workingDir)
InstructionsDirs: ResolvePaths(instructionsDir, workingDir),
InstructionsFile: instructionsFile,
SkillsDirs: ResolvePaths(skillsDir, workingDir),
SkillMetaFile: skillMetaFile,
MCPConfigFiles: ResolvePaths(mcpConfigFile, workingDir),
}
}
// MCPConfigFiles returns the resolved MCP configuration file
// paths for the agent's MCP manager.
func (api *API) MCPConfigFiles() []string {
_, mcpFiles := Config(api.workingDir())
return mcpFiles
// Config returns the resolved config for use by other agent
// components (e.g. MCP manager).
func (api *API) Config() workspacesdk.ContextConfigResponse {
return Config(api.workingDir())
}
// Routes returns the HTTP handler for the context config
@@ -150,164 +80,5 @@ func (api *API) Routes() http.Handler {
}
func (api *API) handleGet(rw http.ResponseWriter, r *http.Request) {
response, _ := Config(api.workingDir())
httpapi.Write(r.Context(), rw, http.StatusOK, response)
}
// readInstructionFiles reads instruction files from each given
// directory. Missing directories are silently skipped. Duplicate
// directories are deduplicated.
func readInstructionFiles(dirs []string, fileName string) []codersdk.ChatMessagePart {
var parts []codersdk.ChatMessagePart
seen := make(map[string]struct{}, len(dirs))
for _, dir := range dirs {
if _, ok := seen[dir]; ok {
continue
}
seen[dir] = struct{}{}
if part, found := readInstructionFileFromDir(dir, fileName); found {
parts = append(parts, part)
}
}
return parts
}
// readInstructionFileFromDir scans a directory for a file matching
// fileName (case-insensitive) and reads its contents.
func readInstructionFileFromDir(dir, fileName string) (codersdk.ChatMessagePart, bool) {
dirEntries, err := os.ReadDir(dir)
if err != nil {
return codersdk.ChatMessagePart{}, false
}
for _, e := range dirEntries {
if e.IsDir() {
continue
}
if strings.EqualFold(strings.TrimSpace(e.Name()), fileName) {
filePath := filepath.Join(dir, e.Name())
content, truncated, ok := readAndSanitizeFile(filePath, maxInstructionFileBytes)
if !ok {
return codersdk.ChatMessagePart{}, false
}
if content == "" {
return codersdk.ChatMessagePart{}, false
}
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: filePath,
ContextFileContent: content,
ContextFileTruncated: truncated,
}, true
}
}
return codersdk.ChatMessagePart{}, false
}
// readAndSanitizeFile reads the file at path, capping the read
// at maxBytes to avoid unbounded memory allocation. It sanitizes
// the content (strips HTML comments and invisible Unicode) and
// returns the result. Returns false if the file cannot be read.
func readAndSanitizeFile(path string, maxBytes int64) (content string, truncated bool, ok bool) {
f, err := os.Open(path)
if err != nil {
return "", false, false
}
defer f.Close()
// Read at most maxBytes+1 to detect truncation without
// allocating the entire file into memory.
raw, err := io.ReadAll(io.LimitReader(f, maxBytes+1))
if err != nil {
return "", false, false
}
truncated = int64(len(raw)) > maxBytes
if truncated {
raw = raw[:maxBytes]
}
s := sanitizeInstructionMarkdown(string(raw))
if s == "" {
return "", truncated, true
}
return s, truncated, true
}
// sanitizeInstructionMarkdown strips HTML comments, invisible
// Unicode characters, and CRLF line endings from instruction
// file content.
func sanitizeInstructionMarkdown(content string) string {
content = strings.ReplaceAll(content, "\r\n", "\n")
content = strings.ReplaceAll(content, "\r", "\n")
content = markdownCommentPattern.ReplaceAllString(content, "")
content = invisibleRunePattern.ReplaceAllString(content, "")
return strings.TrimSpace(content)
}
// discoverSkills walks the given skills directories and returns
// metadata for every valid skill it finds. Body and supporting
// file lists are NOT included; chatd fetches those on demand
// via read_skill. Missing directories or individual errors are
// silently skipped.
func discoverSkills(skillsDirs []string, metaFile string) []codersdk.ChatMessagePart {
seen := make(map[string]struct{})
var parts []codersdk.ChatMessagePart
for _, skillsDir := range skillsDirs {
entries, err := os.ReadDir(skillsDir)
if err != nil {
continue
}
for _, entry := range entries {
if !entry.IsDir() {
continue
}
metaPath := filepath.Join(skillsDir, entry.Name(), metaFile)
f, err := os.Open(metaPath)
if err != nil {
continue
}
raw, err := io.ReadAll(io.LimitReader(f, maxSkillMetaBytes+1))
_ = f.Close()
if err != nil {
continue
}
if int64(len(raw)) > maxSkillMetaBytes {
raw = raw[:maxSkillMetaBytes]
}
name, description, _, err := workspacesdk.ParseSkillFrontmatter(string(raw))
if err != nil {
continue
}
// The directory name must match the declared name.
if name != entry.Name() {
continue
}
if !skillNamePattern.MatchString(name) {
continue
}
// First occurrence wins across directories.
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
skillDir := filepath.Join(skillsDir, entry.Name())
parts = append(parts, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: name,
SkillDescription: description,
SkillDir: skillDir,
ContextFileSkillMetaFile: metaFile,
})
}
}
return parts
httpapi.Write(r.Context(), rw, http.StatusOK, api.Config())
}
+36 -358
View File
@@ -1,28 +1,15 @@
package agentcontextconfig_test
import (
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentcontextconfig"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// filterParts returns only the parts matching the given type.
func filterParts(parts []codersdk.ChatMessagePart, t codersdk.ChatMessagePartType) []codersdk.ChatMessagePart {
var out []codersdk.ChatMessagePart
for _, p := range parts {
if p.Type == t {
out = append(out, p)
}
}
return out
}
func TestConfig(t *testing.T) {
t.Run("Defaults", func(t *testing.T) {
fakeHome := t.TempDir()
@@ -37,13 +24,19 @@ func TestConfig(t *testing.T) {
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := platformAbsPath("work")
cfg, mcpFiles := agentcontextconfig.Config(workDir)
cfg := agentcontextconfig.Config(workDir)
// Parts is always non-nil.
require.NotNil(t, cfg.Parts)
require.Equal(t, workspacesdk.DefaultInstructionsFile, cfg.InstructionsFile)
require.Equal(t, workspacesdk.DefaultSkillMetaFile, cfg.SkillMetaFile)
// Default instructions dir is "~/.coder" which resolves
// to the home directory.
require.Equal(t, []string{filepath.Join(fakeHome, ".coder")}, cfg.InstructionsDirs)
// Default skills dir is ".agents/skills" (relative),
// resolved against the working directory.
require.Equal(t, []string{filepath.Join(workDir, ".agents", "skills")}, cfg.SkillsDirs)
// Default MCP config file is ".mcp.json" (relative),
// resolved against the working directory.
require.Equal(t, []string{filepath.Join(workDir, ".mcp.json")}, mcpFiles)
require.Equal(t, []string{filepath.Join(workDir, ".mcp.json")}, cfg.MCPConfigFiles)
})
t.Run("CustomEnvVars", func(t *testing.T) {
@@ -51,8 +44,8 @@ func TestConfig(t *testing.T) {
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
optInstructions := t.TempDir()
optSkills := t.TempDir()
optInstructions := platformAbsPath("opt", "instructions")
optSkills := platformAbsPath("opt", "skills")
optMCP := platformAbsPath("opt", "mcp.json")
t.Setenv(agentcontextconfig.EnvInstructionsDirs, optInstructions)
@@ -61,58 +54,32 @@ func TestConfig(t *testing.T) {
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "META.yaml")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
// Create files matching the custom names so we can
// verify the env vars actually change lookup behavior.
require.NoError(t, os.WriteFile(filepath.Join(optInstructions, "CUSTOM.md"), []byte("custom instructions"), 0o600))
skillDir := filepath.Join(optSkills, "my-skill")
require.NoError(t, os.MkdirAll(skillDir, 0o755))
require.NoError(t, os.WriteFile(
filepath.Join(skillDir, "META.yaml"),
[]byte("---\nname: my-skill\ndescription: custom meta\n---\n"),
0o600,
))
workDir := platformAbsPath("work")
cfg, mcpFiles := agentcontextconfig.Config(workDir)
cfg := agentcontextconfig.Config(workDir)
require.Equal(t, []string{optMCP}, mcpFiles)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, "custom instructions", ctxFiles[0].ContextFileContent)
skillParts := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeSkill)
require.Len(t, skillParts, 1)
require.Equal(t, "my-skill", skillParts[0].SkillName)
require.Equal(t, "META.yaml", skillParts[0].ContextFileSkillMetaFile)
require.Equal(t, "CUSTOM.md", cfg.InstructionsFile)
require.Equal(t, "META.yaml", cfg.SkillMetaFile)
require.Equal(t, []string{optInstructions}, cfg.InstructionsDirs)
require.Equal(t, []string{optSkills}, cfg.SkillsDirs)
require.Equal(t, []string{optMCP}, cfg.MCPConfigFiles)
})
t.Run("WhitespaceInFileNames", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, " CLAUDE.md ")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
// Create a file matching the trimmed name.
require.NoError(t, os.WriteFile(filepath.Join(fakeHome, "CLAUDE.md"), []byte("hello"), 0o600))
workDir := platformAbsPath("work")
cfg := agentcontextconfig.Config(workDir)
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, "hello", ctxFiles[0].ContextFileContent)
require.Equal(t, "CLAUDE.md", cfg.InstructionsFile)
})
t.Run("CommaSeparatedDirs", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
a := t.TempDir()
b := t.TempDir()
a := platformAbsPath("opt", "a")
b := platformAbsPath("opt", "b")
t.Setenv(agentcontextconfig.EnvInstructionsDirs, a+","+b)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
@@ -120,300 +87,10 @@ func TestConfig(t *testing.T) {
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
// Put instruction files in both dirs.
require.NoError(t, os.WriteFile(filepath.Join(a, "AGENTS.md"), []byte("from a"), 0o600))
require.NoError(t, os.WriteFile(filepath.Join(b, "AGENTS.md"), []byte("from b"), 0o600))
workDir := platformAbsPath("work")
cfg := agentcontextconfig.Config(workDir)
workDir := t.TempDir()
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 2)
require.Equal(t, "from a", ctxFiles[0].ContextFileContent)
require.Equal(t, "from b", ctxFiles[1].ContextFileContent)
})
t.Run("ReadsInstructionFiles", func(t *testing.T) {
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
// Create ~/.coder/AGENTS.md
coderDir := filepath.Join(fakeHome, ".coder")
require.NoError(t, os.MkdirAll(coderDir, 0o755))
require.NoError(t, os.WriteFile(
filepath.Join(coderDir, "AGENTS.md"),
[]byte("home instructions"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.NotNil(t, cfg.Parts)
require.Len(t, ctxFiles, 1)
require.Equal(t, "home instructions", ctxFiles[0].ContextFileContent)
require.Equal(t, filepath.Join(coderDir, "AGENTS.md"), ctxFiles[0].ContextFilePath)
require.False(t, ctxFiles[0].ContextFileTruncated)
})
t.Run("ReadsWorkingDirInstructionFile", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
// Create AGENTS.md in the working directory.
require.NoError(t, os.WriteFile(
filepath.Join(workDir, "AGENTS.md"),
[]byte("project instructions"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
// Should find the working dir file (not in instruction dirs).
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.NotNil(t, cfg.Parts)
require.Len(t, ctxFiles, 1)
require.Equal(t, "project instructions", ctxFiles[0].ContextFileContent)
require.Equal(t, filepath.Join(workDir, "AGENTS.md"), ctxFiles[0].ContextFilePath)
})
t.Run("TruncatesLargeInstructionFile", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
largeContent := strings.Repeat("a", 64*1024+100)
require.NoError(t, os.WriteFile(filepath.Join(workDir, "AGENTS.md"), []byte(largeContent), 0o600))
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.True(t, ctxFiles[0].ContextFileTruncated)
require.Len(t, ctxFiles[0].ContextFileContent, 64*1024)
})
t.Run("SanitizesHTMLComments", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
require.NoError(t, os.WriteFile(
filepath.Join(workDir, "AGENTS.md"),
[]byte("visible\n<!-- hidden -->content"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, "visible\ncontent", ctxFiles[0].ContextFileContent)
})
t.Run("SanitizesInvisibleUnicode", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
// U+200B (zero-width space) should be stripped.
require.NoError(t, os.WriteFile(
filepath.Join(workDir, "AGENTS.md"),
[]byte("before\u200bafter"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, "beforeafter", ctxFiles[0].ContextFileContent)
})
t.Run("NormalizesCRLF", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
require.NoError(t, os.WriteFile(
filepath.Join(workDir, "AGENTS.md"),
[]byte("line1\r\nline2\rline3"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, "line1\nline2\nline3", ctxFiles[0].ContextFileContent)
})
t.Run("DiscoversSkills", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
skillsDir := filepath.Join(workDir, ".agents", "skills")
t.Setenv(agentcontextconfig.EnvSkillsDirs, skillsDir)
// Create a valid skill.
skillDir := filepath.Join(skillsDir, "my-skill")
require.NoError(t, os.MkdirAll(skillDir, 0o755))
require.NoError(t, os.WriteFile(
filepath.Join(skillDir, "SKILL.md"),
[]byte("---\nname: my-skill\ndescription: A test skill\n---\nSkill body"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
skillParts := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeSkill)
require.Len(t, skillParts, 1)
require.Equal(t, "my-skill", skillParts[0].SkillName)
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
require.Equal(t, skillDir, skillParts[0].SkillDir)
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
})
t.Run("SkipsMissingDirs", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
nonExistent := filepath.Join(t.TempDir(), "does-not-exist")
t.Setenv(agentcontextconfig.EnvInstructionsDirs, nonExistent)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, nonExistent)
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
cfg, _ := agentcontextconfig.Config(workDir)
// Non-nil empty slice (signals agent supports new format).
require.NotNil(t, cfg.Parts)
require.Empty(t, cfg.Parts)
})
t.Run("MCPConfigFilesResolvedSeparately", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
optMCP := platformAbsPath("opt", "custom.json")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
workDir := t.TempDir()
_, mcpFiles := agentcontextconfig.Config(workDir)
require.Equal(t, []string{optMCP}, mcpFiles)
})
t.Run("SkillNameMustMatchDir", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
skillsDir := filepath.Join(workDir, "skills")
t.Setenv(agentcontextconfig.EnvSkillsDirs, skillsDir)
// Skill name in frontmatter doesn't match directory name.
skillDir := filepath.Join(skillsDir, "wrong-dir-name")
require.NoError(t, os.MkdirAll(skillDir, 0o755))
require.NoError(t, os.WriteFile(
filepath.Join(skillDir, "SKILL.md"),
[]byte("---\nname: actual-name\ndescription: mismatch\n---\n"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
skillParts := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeSkill)
require.Empty(t, skillParts)
})
t.Run("DuplicateSkillsFirstWins", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
skillsDir1 := filepath.Join(workDir, "skills1")
skillsDir2 := filepath.Join(workDir, "skills2")
t.Setenv(agentcontextconfig.EnvSkillsDirs, skillsDir1+","+skillsDir2)
// Same skill name in both directories.
for _, dir := range []string{skillsDir1, skillsDir2} {
skillDir := filepath.Join(dir, "dup-skill")
require.NoError(t, os.MkdirAll(skillDir, 0o755))
require.NoError(t, os.WriteFile(
filepath.Join(skillDir, "SKILL.md"),
[]byte("---\nname: dup-skill\ndescription: from "+filepath.Base(dir)+"\n---\n"),
0o600,
))
}
cfg, _ := agentcontextconfig.Config(workDir)
skillParts := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeSkill)
require.Len(t, skillParts, 1)
require.Equal(t, "from skills1", skillParts[0].SkillDescription)
require.Equal(t, []string{a, b}, cfg.InstructionsDirs)
})
}
@@ -427,13 +104,14 @@ func TestNewAPI_LazyDirectory(t *testing.T) {
dir := ""
api := agentcontextconfig.NewAPI(func() string { return dir })
// Before directory is set, MCP paths resolve to nothing.
mcpFiles := api.MCPConfigFiles()
require.Empty(t, mcpFiles)
// Before directory is set, relative paths resolve to nothing.
cfg := api.Config()
require.Empty(t, cfg.SkillsDirs)
require.Empty(t, cfg.MCPConfigFiles)
// After setting the directory, MCPConfigFiles() picks it up.
// After setting the directory, Config() picks it up lazily.
dir = platformAbsPath("work")
mcpFiles = api.MCPConfigFiles()
require.NotEmpty(t, mcpFiles)
require.Equal(t, []string{filepath.Join(dir, ".mcp.json")}, mcpFiles)
cfg = api.Config()
require.NotEmpty(t, cfg.SkillsDirs)
require.Equal(t, []string{filepath.Join(dir, ".agents", "skills")}, cfg.SkillsDirs)
}
-156
View File
@@ -1,17 +1,12 @@
package agentdesktop
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"strconv"
"sync"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentssh"
@@ -52,9 +47,6 @@ type API struct {
logger slog.Logger
desktop Desktop
clock quartz.Clock
closeMu sync.Mutex
closed bool
}
// NewAPI creates a new desktop streaming API.
@@ -74,10 +66,6 @@ func (a *API) Routes() http.Handler {
r := chi.NewRouter()
r.Get("/vnc", a.handleDesktopVNC)
r.Post("/action", a.handleAction)
r.Route("/recording", func(r chi.Router) {
r.Post("/start", a.handleRecordingStart)
r.Post("/stop", a.handleRecordingStop)
})
return r
}
@@ -128,9 +116,6 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
handlerStart := a.clock.Now()
// Update last desktop action timestamp for idle recording monitor.
a.desktop.RecordActivity()
// Ensure the desktop is running and grab native dimensions.
cfg, err := a.desktop.Start(ctx)
if err != nil {
@@ -495,150 +480,9 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
// Close shuts down the desktop session if one is running.
func (a *API) Close() error {
a.closeMu.Lock()
if a.closed {
a.closeMu.Unlock()
return nil
}
a.closed = true
a.closeMu.Unlock()
return a.desktop.Close()
}
// decodeRecordingRequest decodes and validates a recording request
// from the HTTP body, returning the recording ID. Returns false if
// the request was invalid and an error response was already written.
func (*API) decodeRecordingRequest(rw http.ResponseWriter, r *http.Request) (string, bool) {
ctx := r.Context()
var req struct {
RecordingID string `json:"recording_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to decode request body.",
Detail: err.Error(),
})
return "", false
}
if req.RecordingID == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Missing recording_id.",
})
return "", false
}
if _, err := uuid.Parse(req.RecordingID); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid recording_id format.",
Detail: "recording_id must be a valid UUID.",
})
return "", false
}
return req.RecordingID, true
}
func (a *API) handleRecordingStart(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
recordingID, ok := a.decodeRecordingRequest(rw, r)
if !ok {
return
}
a.closeMu.Lock()
if a.closed {
a.closeMu.Unlock()
httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{
Message: "Desktop API is shutting down.",
})
return
}
a.closeMu.Unlock()
if err := a.desktop.StartRecording(ctx, recordingID); err != nil {
if errors.Is(err, ErrDesktopClosed) {
httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{
Message: "Desktop API is shutting down.",
})
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to start recording.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
Message: "Recording started.",
})
}
func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
recordingID, ok := a.decodeRecordingRequest(rw, r)
if !ok {
return
}
a.closeMu.Lock()
if a.closed {
a.closeMu.Unlock()
httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{
Message: "Desktop API is shutting down.",
})
return
}
a.closeMu.Unlock()
// Stop recording (idempotent).
// Use a context detached from the HTTP request so that if the
// connection drops, the recording process can still shut down
// gracefully. WithoutCancel preserves request-scoped values.
stopCtx, stopCancel := context.WithTimeout(context.WithoutCancel(r.Context()), 30*time.Second)
defer stopCancel()
artifact, err := a.desktop.StopRecording(stopCtx, recordingID)
if err != nil {
if errors.Is(err, ErrUnknownRecording) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: "Recording not found.",
Detail: err.Error(),
})
return
}
if errors.Is(err, ErrRecordingCorrupted) {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Recording is corrupted.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to stop recording.",
Detail: err.Error(),
})
return
}
defer artifact.Reader.Close()
if artifact.Size > workspacesdk.MaxRecordingSize {
a.logger.Warn(ctx, "recording file exceeds maximum size",
slog.F("recording_id", recordingID),
slog.F("size", artifact.Size),
slog.F("max_size", workspacesdk.MaxRecordingSize),
)
httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{
Message: "Recording file exceeds maximum allowed size.",
})
return
}
rw.Header().Set("Content-Type", "video/mp4")
rw.Header().Set("Content-Length", strconv.FormatInt(artifact.Size, 10))
rw.WriteHeader(http.StatusOK)
_, _ = io.Copy(rw, artifact.Reader)
}
// coordFromAction extracts the coordinate pair from a DesktopAction,
// returning an error if the coordinate field is missing.
func coordFromAction(action DesktopAction) (x, y int, err error) {
-661
View File
@@ -4,17 +4,12 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
"slices"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
@@ -26,16 +21,6 @@ import (
"github.com/coder/quartz"
)
// Test recording UUIDs used across tests.
const (
testRecIDDefault = "870e1f02-8118-4300-a37e-4adb0117baf3"
testRecIDStartIdempotent = "250a2ffb-a5e5-4c94-9754-4d6a4ab7ba20"
testRecIDStopIdempotent = "38f8a378-f98f-4758-a4ae-950b44cf989a"
testRecIDConcurrentA = "8dc173eb-23c6-4601-a485-b6dfb2a42c3a"
testRecIDConcurrentB = "fea490d4-70f0-4798-a181-29d65ce25ae1"
testRecIDRestart = "75173a0d-b018-4e2e-a771-defa3fc6af69"
)
// Ensure fakeDesktop satisfies the Desktop interface at compile time.
var _ agentdesktop.Desktop = (*fakeDesktop)(nil)
@@ -58,14 +43,6 @@ type fakeDesktop struct {
lastTyped string
lastKeyDown string
lastKeyUp string
// Recording tracking (guarded by recMu).
recMu sync.Mutex
recordings map[string]string // ID → file path
stopCalls []string // recording IDs passed to StopRecording
recStopCh chan string // optional: signaled when StopRecording is called
startCount int // incremented on each new recording start
activityCount int // incremented by RecordActivity
}
func (f *fakeDesktop) Start(context.Context) (agentdesktop.DisplayConfig, error) {
@@ -130,140 +107,11 @@ func (f *fakeDesktop) CursorPosition(context.Context) (x int, y int, err error)
return f.cursorPos[0], f.cursorPos[1], nil
}
func (f *fakeDesktop) StartRecording(_ context.Context, recordingID string) error {
f.recMu.Lock()
defer f.recMu.Unlock()
if f.recordings == nil {
f.recordings = make(map[string]string)
}
if path, ok := f.recordings[recordingID]; ok {
// Check if already stopped (file still exists but stop was
// called). For the fake, a stopped recording means its ID
// appears in stopCalls. In that case, remove the old file
// and start fresh.
stopped := slices.Contains(f.stopCalls, recordingID)
if !stopped {
// Active recording - no-op.
return nil
}
// Completed recording - discard old file, start fresh.
_ = os.Remove(path)
delete(f.recordings, recordingID)
}
f.startCount++
tmpFile, err := os.CreateTemp("", "fake-recording-*.mp4")
if err != nil {
return err
}
_, _ = tmpFile.Write([]byte(fmt.Sprintf("fake-mp4-data-%s-%d", recordingID, f.startCount)))
_ = tmpFile.Close()
f.recordings[recordingID] = tmpFile.Name()
return nil
}
func (f *fakeDesktop) StopRecording(_ context.Context, recordingID string) (*agentdesktop.RecordingArtifact, error) {
f.recMu.Lock()
defer f.recMu.Unlock()
if f.recordings == nil {
return nil, agentdesktop.ErrUnknownRecording
}
path, ok := f.recordings[recordingID]
if !ok {
return nil, agentdesktop.ErrUnknownRecording
}
f.stopCalls = append(f.stopCalls, recordingID)
if f.recStopCh != nil {
select {
case f.recStopCh <- recordingID:
default:
}
}
file, err := os.Open(path)
if err != nil {
return nil, err
}
info, err := file.Stat()
if err != nil {
_ = file.Close()
return nil, err
}
return &agentdesktop.RecordingArtifact{
Reader: file,
Size: info.Size(),
}, nil
}
func (f *fakeDesktop) RecordActivity() {
f.recMu.Lock()
f.activityCount++
f.recMu.Unlock()
}
func (f *fakeDesktop) Close() error {
f.closed = true
f.recMu.Lock()
defer f.recMu.Unlock()
for _, path := range f.recordings {
_ = os.Remove(path)
}
return nil
}
// failStartRecordingDesktop wraps fakeDesktop and overrides
// StartRecording to always return an error.
type failStartRecordingDesktop struct {
fakeDesktop
startRecordingErr error
}
func (f *failStartRecordingDesktop) StartRecording(_ context.Context, _ string) error {
return f.startRecordingErr
}
// corruptedStopDesktop wraps fakeDesktop and overrides
// StopRecording to always return ErrRecordingCorrupted.
type corruptedStopDesktop struct {
fakeDesktop
}
func (*corruptedStopDesktop) StopRecording(_ context.Context, _ string) (*agentdesktop.RecordingArtifact, error) {
return nil, agentdesktop.ErrRecordingCorrupted
}
// oversizedFakeDesktop wraps fakeDesktop and expands recording files
// beyond MaxRecordingSize when StopRecording is called.
type oversizedFakeDesktop struct {
fakeDesktop
}
func (f *oversizedFakeDesktop) StopRecording(ctx context.Context, recordingID string) (*agentdesktop.RecordingArtifact, error) {
artifact, err := f.fakeDesktop.StopRecording(ctx, recordingID)
if err != nil {
return nil, err
}
// Close the original reader since we're going to re-open after truncation.
artifact.Reader.Close()
// Look up the path from the fakeDesktop recordings.
f.fakeDesktop.recMu.Lock()
path := f.fakeDesktop.recordings[recordingID]
f.fakeDesktop.recMu.Unlock()
// Expand the file to exceed the maximum recording size.
if err := os.Truncate(path, workspacesdk.MaxRecordingSize+1); err != nil {
return nil, err
}
// Re-open the truncated file.
file, err := os.Open(path)
if err != nil {
return nil, err
}
return &agentdesktop.RecordingArtifact{
Reader: file,
Size: workspacesdk.MaxRecordingSize + 1,
}, nil
}
func TestHandleDesktopVNC_StartError(t *testing.T) {
t.Parallel()
@@ -286,37 +134,6 @@ func TestHandleDesktopVNC_StartError(t *testing.T) {
assert.Equal(t, "Failed to start desktop session.", resp.Message)
}
func TestHandleAction_CallsRecordActivity(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
body := agentdesktop.DesktopAction{
Action: "left_click",
Coordinate: &[2]int{100, 200},
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
fake.recMu.Lock()
count := fake.activityCount
fake.recMu.Unlock()
assert.Equal(t, 1, count, "handleAction should call RecordActivity exactly once")
}
func TestHandleAction_Screenshot(t *testing.T) {
t.Parallel()
@@ -757,481 +574,3 @@ func TestHandleAction_CursorPositionReturnsDeclaredCoordinates(t *testing.T) {
// Native (960,540) in 1920x1080 should map to declared space in 1280x720.
assert.Equal(t, "x=640,y=360", resp.Output)
}
func TestRecordingStartStop(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
// Start recording.
startBody, err := json.Marshal(map[string]string{"recording_id": testRecIDDefault})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Stop recording.
stopBody, err := json.Marshal(map[string]string{"recording_id": testRecIDDefault})
require.NoError(t, err)
rr = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), rr.Body.Bytes())
}
func TestRecordingStartFails(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &failStartRecordingDesktop{
fakeDesktop: fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
},
startRecordingErr: xerrors.New("start recording error"),
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
body, err := json.Marshal(map[string]string{"recording_id": uuid.New().String()})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusInternalServerError, rr.Code)
var resp codersdk.Response
err = json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "Failed to start recording.", resp.Message)
}
func TestRecordingStartIdempotent(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
// Start same recording twice - both should succeed.
for range 2 {
body, err := json.Marshal(map[string]string{"recording_id": testRecIDStartIdempotent})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
}
// Stop once, verify normal response.
stopBody, err := json.Marshal(map[string]string{"recording_id": testRecIDStartIdempotent})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), rr.Body.Bytes())
}
func TestRecordingStopIdempotent(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
// Start recording.
startBody, err := json.Marshal(map[string]string{"recording_id": testRecIDStopIdempotent})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Stop twice - both should succeed with identical data.
var bodies [2][]byte
for i := range 2 {
body, err := json.Marshal(map[string]string{"recording_id": testRecIDStopIdempotent})
require.NoError(t, err)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
handler.ServeHTTP(recorder, request)
require.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "video/mp4", recorder.Header().Get("Content-Type"))
bodies[i] = recorder.Body.Bytes()
}
assert.Equal(t, bodies[0], bodies[1])
}
func TestRecordingStopInvalidIDFormat(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
body, err := json.Marshal(map[string]string{"recording_id": "not-a-uuid"})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusBadRequest, rr.Code)
}
func TestRecordingStopUnknownRecording(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
// Send a valid UUID that was never started - should reach
// StopRecording, get ErrUnknownRecording, and return 404.
body, err := json.Marshal(map[string]string{"recording_id": uuid.New().String()})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusNotFound, rr.Code)
var resp codersdk.Response
err = json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "Recording not found.", resp.Message)
}
func TestRecordingStopOversizedFile(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &oversizedFakeDesktop{
fakeDesktop: fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
// Start recording.
recID := uuid.New().String()
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Stop recording - file exceeds max size, expect 413.
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
require.NoError(t, err)
rr = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusRequestEntityTooLarge, rr.Code)
var resp codersdk.Response
err = json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "Recording file exceeds maximum allowed size.", resp.Message)
}
func TestRecordingMultipleSimultaneous(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
// Start two recordings with different IDs.
for _, id := range []string{testRecIDConcurrentA, testRecIDConcurrentB} {
body, err := json.Marshal(map[string]string{"recording_id": id})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
}
// Stop both and verify each returns its own data.
expected := map[string][]byte{
testRecIDConcurrentA: []byte("fake-mp4-data-" + testRecIDConcurrentA + "-1"),
testRecIDConcurrentB: []byte("fake-mp4-data-" + testRecIDConcurrentB + "-2"),
}
for _, id := range []string{testRecIDConcurrentA, testRecIDConcurrentB} {
body, err := json.Marshal(map[string]string{"recording_id": id})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
assert.Equal(t, expected[id], rr.Body.Bytes())
}
}
func TestRecordingStartMalformedBody(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader([]byte("not json")))
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusBadRequest, rr.Code)
}
func TestRecordingStartEmptyID(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
body, err := json.Marshal(map[string]string{"recording_id": ""})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusBadRequest, rr.Code)
}
func TestRecordingStopEmptyID(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
body, err := json.Marshal(map[string]string{"recording_id": ""})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusBadRequest, rr.Code)
}
func TestRecordingStopMalformedBody(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader([]byte("not json")))
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusBadRequest, rr.Code)
}
func TestRecordingStartAfterCompleted(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
// Step 1: Start recording.
startBody, err := json.Marshal(map[string]string{"recording_id": testRecIDRestart})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Step 2: Stop recording (gets first MP4 data).
stopBody, err := json.Marshal(map[string]string{"recording_id": testRecIDRestart})
require.NoError(t, err)
rr = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
firstData := rr.Body.Bytes()
require.NotEmpty(t, firstData)
// Step 3: Start again with the same ID - should succeed
// (old file discarded, new recording started).
rr = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Step 4: Stop again - should return NEW MP4 data.
rr = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
secondData := rr.Body.Bytes()
require.NotEmpty(t, secondData)
// The two recordings should have different data because the
// fake increments a counter on each fresh start.
assert.NotEqual(t, firstData, secondData,
"restarted recording should produce different data")
}
func TestRecordingStartAfterClose(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
handler := api.Routes()
// Close the API before sending the request.
api.Close()
body, err := json.Marshal(map[string]string{"recording_id": uuid.New().String()})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusServiceUnavailable, rr.Code)
var resp codersdk.Response
err = json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "Desktop API is shutting down.", resp.Message)
}
func TestRecordingStartDesktopClosed(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
// StartRecording returns ErrDesktopClosed to simulate a race
// where the desktop is closed between the API-level check and
// the desktop-level StartRecording call.
fake := &failStartRecordingDesktop{
fakeDesktop: fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
},
startRecordingErr: agentdesktop.ErrDesktopClosed,
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
body, err := json.Marshal(map[string]string{"recording_id": uuid.New().String()})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusServiceUnavailable, rr.Code)
var resp codersdk.Response
err = json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "Desktop API is shutting down.", resp.Message)
}
func TestRecordingStopCorrupted(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &corruptedStopDesktop{
fakeDesktop: fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
// Start a recording so the stop has something to find.
recID := uuid.New().String()
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Stop returns ErrRecordingCorrupted.
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
require.NoError(t, err)
rr = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusInternalServerError, rr.Code)
var respStop codersdk.Response
err = json.NewDecoder(rr.Body).Decode(&respStop)
require.NoError(t, err)
assert.Equal(t, "Recording is corrupted.", respStop.Message)
}
-45
View File
@@ -2,10 +2,7 @@ package agentdesktop
import (
"context"
"io"
"net"
"golang.org/x/xerrors"
)
// Desktop abstracts a virtual desktop session running inside a workspace.
@@ -61,52 +58,10 @@ type Desktop interface {
// CursorPosition returns the current cursor coordinates.
CursorPosition(ctx context.Context) (x, y int, err error)
// RecordActivity marks the desktop as having received user
// interaction, resetting the idle-recording timer.
RecordActivity()
// StartRecording begins recording the desktop to an MP4 file
// using the caller-provided recording ID. Safe to call
// repeatedly - active recordings continue unchanged, stopped
// recordings are discarded and restarted. Concurrent recordings
// are supported.
StartRecording(ctx context.Context, recordingID string) error
// StopRecording finalizes the recording identified by the given
// ID. Idempotent - safe to call on an already-stopped recording.
// Returns a RecordingArtifact that the caller can stream. The
// caller must close the artifact when done. Returns an error if
// the recording ID is unknown.
StopRecording(ctx context.Context, recordingID string) (*RecordingArtifact, error)
// Close shuts down the desktop session and cleans up resources.
Close() error
}
// ErrUnknownRecording is returned by StopRecording when the
// recording ID is not recognized.
var ErrUnknownRecording = xerrors.New("unknown recording ID")
// ErrDesktopClosed is returned when an operation is attempted on a
// closed desktop session.
var ErrDesktopClosed = xerrors.New("desktop closed")
// ErrRecordingCorrupted is returned by StopRecording when the
// recording process was force-killed and the artifact is likely
// incomplete or corrupt.
var ErrRecordingCorrupted = xerrors.New("recording corrupted: process was force-killed")
// RecordingArtifact is a finalized recording returned by StopRecording.
// The caller streams the artifact and must call Close when done. The
// artifact remains valid even if the same recording ID is restarted
// or the desktop is closed while the caller is reading.
type RecordingArtifact struct {
// Reader is the MP4 content. Callers must close it when done.
Reader io.ReadCloser
// Size is the byte length of the MP4 content.
Size int64
}
// DisplayConfig describes a running desktop session.
type DisplayConfig struct {
Width int // native width in pixels
+16 -385
View File
@@ -3,7 +3,6 @@ package agentdesktop
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"os"
@@ -12,7 +11,6 @@ import (
"runtime"
"strconv"
"sync"
"sync/atomic"
"time"
"golang.org/x/xerrors"
@@ -20,7 +18,6 @@ import (
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/quartz"
)
// portableDesktopOutput is the JSON output from
@@ -52,65 +49,32 @@ type screenshotOutput struct {
Data string `json:"data"`
}
// recordingProcess tracks a single desktop recording subprocess.
type recordingProcess struct {
cmd *exec.Cmd
filePath string
stopped bool
killed bool // true when the process was SIGKILLed
done chan struct{} // closed when cmd.Wait() returns
waitErr error // set before done is closed
stopOnce sync.Once
idleCancel context.CancelFunc // cancels the per-recording idle goroutine
idleDone chan struct{} // closed when idle goroutine exits
}
// maxConcurrentRecordings is the maximum number of active (non-stopped)
// recordings allowed at once. This prevents resource exhaustion.
const maxConcurrentRecordings = 5
// idleTimeout is the duration of desktop inactivity after which all
// active recordings are automatically stopped.
const idleTimeout = 10 * time.Minute
// portableDesktop implements Desktop by shelling out to the
// portabledesktop CLI via agentexec.Execer.
type portableDesktop struct {
logger slog.Logger
execer agentexec.Execer
scriptBinDir string // coder script bin directory
clock quartz.Clock
mu sync.Mutex
session *desktopSession // nil until started
binPath string // resolved path to binary, cached
closed bool
recordings map[string]*recordingProcess // guarded by mu
lastDesktopActionAt atomic.Int64
mu sync.Mutex
session *desktopSession // nil until started
binPath string // resolved path to binary, cached
closed bool
}
// NewPortableDesktop creates a Desktop backed by the portabledesktop
// CLI binary, using execer to spawn child processes. scriptBinDir is
// the coder script bin directory checked for the binary. If clk is
// nil, a real clock is used.
// the coder script bin directory checked for the binary.
func NewPortableDesktop(
logger slog.Logger,
execer agentexec.Execer,
scriptBinDir string,
clk quartz.Clock,
) Desktop {
if clk == nil {
clk = quartz.NewReal()
}
pd := &portableDesktop{
return &portableDesktop{
logger: logger,
execer: execer,
scriptBinDir: scriptBinDir,
clock: clk,
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
return pd
}
// Start launches the desktop session (idempotent).
@@ -119,7 +83,7 @@ func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
defer p.mu.Unlock()
if p.closed {
return DisplayConfig{}, ErrDesktopClosed
return DisplayConfig{}, xerrors.New("desktop is closed")
}
if err := p.ensureBinary(ctx); err != nil {
@@ -349,328 +313,23 @@ func (p *portableDesktop) CursorPosition(ctx context.Context) (x int, y int, err
return result.X, result.Y, nil
}
// StartRecording begins recording the desktop to an MP4 file.
// Three-state idempotency: active recordings are no-ops,
// completed recordings are discarded and restarted.
func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string) error {
// Ensure the desktop session is running before acquiring the
// recording lock. Start is independently locked and idempotent.
if _, err := p.Start(ctx); err != nil {
return xerrors.Errorf("ensure desktop session: %w", err)
}
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return ErrDesktopClosed
}
// Three-state idempotency:
// - Active recording → no-op, continue recording.
// - Completed recording → discard old file, start fresh.
// - Unknown ID → fall through to start a new recording.
if rec, ok := p.recordings[recordingID]; ok {
if !rec.stopped {
select {
case <-rec.done:
// Process exited unexpectedly; treat as completed
// so we fall through to discard the old file and
// restart.
default:
// Active recording - no-op, continue recording.
return nil
}
}
// Completed recording - discard old file, start fresh.
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
p.logger.Warn(ctx, "failed to remove old recording file",
slog.F("recording_id", recordingID),
slog.F("file_path", rec.filePath),
slog.Error(err),
)
}
delete(p.recordings, recordingID)
}
// Check concurrent recording limit.
if p.lockedActiveRecordingCount() >= maxConcurrentRecordings {
return xerrors.Errorf("too many concurrent recordings (max %d)", maxConcurrentRecordings)
}
// GC sweep: remove stopped recordings with stale files.
p.lockedCleanStaleRecordings(ctx)
if err := p.ensureBinary(ctx); err != nil {
return xerrors.Errorf("ensure portabledesktop binary: %w", err)
}
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".mp4")
// Use a background context so the process outlives the HTTP
// request that triggered it.
procCtx, procCancel := context.WithCancel(context.Background())
//nolint:gosec // portabledesktop is a trusted binary resolved via ensureBinary.
cmd := p.execer.CommandContext(procCtx, p.binPath, "record",
// The following options are used to speed up the recording when the desktop is idle.
// They were taken out of an example in the portabledesktop repo.
// There's likely room for improvement to optimize the values.
"--idle-speedup", "20",
"--idle-min-duration", "0.35",
"--idle-noise-tolerance", "-38dB",
filePath)
if err := cmd.Start(); err != nil {
procCancel()
return xerrors.Errorf("start recording process: %w", err)
}
rec := &recordingProcess{
cmd: cmd,
filePath: filePath,
done: make(chan struct{}),
}
go func() {
rec.waitErr = cmd.Wait()
close(rec.done)
// avoid a context resource leak by canceling the context
procCancel()
}()
p.recordings[recordingID] = rec
p.logger.Info(ctx, "started desktop recording",
slog.F("recording_id", recordingID),
slog.F("file_path", filePath),
slog.F("pid", cmd.Process.Pid),
)
// Record activity so a recording started on an already-idle
// desktop does not stop immediately.
p.lastDesktopActionAt.Store(p.clock.Now().UnixNano())
// Spawn a per-recording idle goroutine.
idleCtx, idleCancel := context.WithCancel(context.Background())
rec.idleCancel = idleCancel
rec.idleDone = make(chan struct{})
go func() {
defer close(rec.idleDone)
p.monitorRecordingIdle(idleCtx, rec)
}()
return nil
}
// StopRecording finalizes the recording. Idempotent - safe to call
// on an already-stopped recording. Returns a RecordingArtifact
// that the caller can stream. The caller must close the Reader
// on the returned artifact to avoid leaking file descriptors.
func (p *portableDesktop) StopRecording(ctx context.Context, recordingID string) (*RecordingArtifact, error) {
p.mu.Lock()
rec, ok := p.recordings[recordingID]
if !ok {
p.mu.Unlock()
return nil, ErrUnknownRecording
}
p.lockedStopRecordingProcess(ctx, rec, false)
killed := rec.killed
p.mu.Unlock()
p.logger.Info(ctx, "stopped desktop recording",
slog.F("recording_id", recordingID),
slog.F("file_path", rec.filePath),
)
if killed {
return nil, ErrRecordingCorrupted
}
// Open the file and return an artifact. Each call opens a fresh
// file descriptor so the caller is insulated from restarts and
// desktop close.
f, err := os.Open(rec.filePath)
if err != nil {
return nil, xerrors.Errorf("open recording artifact: %w", err)
}
info, err := f.Stat()
if err != nil {
_ = f.Close()
return nil, xerrors.Errorf("stat recording artifact: %w", err)
}
return &RecordingArtifact{
Reader: f,
Size: info.Size(),
}, nil
}
// lockedStopRecordingProcess stops a single recording via stopOnce.
// It sends SIGINT, waits up to 15 seconds for graceful exit, then
// SIGKILLs. When force is true the process is SIGKILLed immediately
// without attempting a graceful shutdown. Must be called while p.mu
// is held; the lock is held for the full duration so that no
// concurrent StopRecording caller can read rec.stopped = true
// before the process has finished writing the MP4 file.
//
//nolint:revive // force flag keeps shared stopOnce/cleanup logic in one place.
func (p *portableDesktop) lockedStopRecordingProcess(ctx context.Context, rec *recordingProcess, force bool) {
rec.stopOnce.Do(func() {
if force {
_ = rec.cmd.Process.Kill()
rec.killed = true
} else {
_ = interruptRecordingProcess(rec.cmd.Process)
timer := p.clock.NewTimer(15*time.Second, "agentdesktop", "stop_timeout")
defer timer.Stop()
select {
case <-rec.done:
case <-ctx.Done():
_ = rec.cmd.Process.Kill()
rec.killed = true
case <-timer.C:
_ = rec.cmd.Process.Kill()
rec.killed = true
}
}
rec.stopped = true
if rec.idleCancel != nil {
rec.idleCancel()
}
})
// NOTE: We intentionally do not wait on rec.done here.
// If goleak is added to this package's tests, this may
// need revisiting to avoid flakes.
}
// lockedActiveRecordingCount returns the number of recordings that
// are still actively running. Must be called while p.mu is held.
// The max concurrency is low (maxConcurrentRecordings = 5), so a
// full scan is cheap and avoids maintaining a separate counter.
func (p *portableDesktop) lockedActiveRecordingCount() int {
active := 0
for _, rec := range p.recordings {
if rec.stopped {
continue
}
select {
case <-rec.done:
default:
active++
}
}
return active
}
// lockedCleanStaleRecordings removes stopped recordings whose temp
// files are older than one hour. Must be called while p.mu is held.
func (p *portableDesktop) lockedCleanStaleRecordings(ctx context.Context) {
for id, rec := range p.recordings {
if !rec.stopped {
continue
}
info, err := os.Stat(rec.filePath)
if err != nil {
// File already removed or inaccessible; drop entry.
delete(p.recordings, id)
continue
}
if p.clock.Since(info.ModTime()) > time.Hour {
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
p.logger.Warn(ctx, "failed to remove stale recording file",
slog.F("recording_id", id),
slog.F("file_path", rec.filePath),
slog.Error(err),
)
}
delete(p.recordings, id)
}
}
}
// Close shuts down the desktop session and cleans up resources.
func (p *portableDesktop) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
p.closed = true
// Force-kill all active recordings. The stopOnce inside
// lockedStopRecordingProcess makes this safe for
// already-stopped recordings.
for _, rec := range p.recordings {
p.lockedStopRecordingProcess(context.Background(), rec, true)
}
// Snapshot recording file paths and idle goroutine channels
// for cleanup, then clear the map.
type recEntry struct {
id string
filePath string
idleDone chan struct{}
}
var allRecs []recEntry
for id, rec := range p.recordings {
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, idleDone: rec.idleDone})
delete(p.recordings, id)
}
session := p.session
p.session = nil
p.mu.Unlock()
// Wait for all per-recording idle goroutines to exit.
for _, entry := range allRecs {
if entry.idleDone != nil {
<-entry.idleDone
}
}
// Remove all recording files and wait for the session to
// exit with a timeout so a slow filesystem or hung process
// cannot block agent shutdown indefinitely.
cleanupDone := make(chan struct{})
go func() {
defer close(cleanupDone)
for _, entry := range allRecs {
if err := os.Remove(entry.filePath); err != nil && !os.IsNotExist(err) {
p.logger.Warn(context.Background(), "failed to remove recording file on close",
slog.F("recording_id", entry.id),
slog.F("file_path", entry.filePath),
slog.Error(err),
)
}
}
if session != nil {
session.cancel()
if err := session.cmd.Process.Kill(); err != nil {
p.logger.Warn(context.Background(), "failed to kill portabledesktop process",
slog.Error(err),
)
}
if err := session.cmd.Wait(); err != nil {
var exitErr *exec.ExitError
if !errors.As(err, &exitErr) {
p.logger.Warn(context.Background(), "portabledesktop process exited with error",
slog.Error(err),
)
}
}
}
}()
timer := p.clock.NewTimer(15*time.Second, "agentdesktop", "close_cleanup_timeout")
defer timer.Stop()
select {
case <-cleanupDone:
case <-timer.C:
p.logger.Warn(context.Background(), "timed out waiting for close cleanup")
if p.session != nil {
p.session.cancel()
// Xvnc is a child process — killing it cleans up the X
// session.
_ = p.session.cmd.Process.Kill()
_ = p.session.cmd.Wait()
p.session = nil
}
return nil
}
// RecordActivity marks the desktop as having received user
// interaction, resetting the idle-recording timer.
func (p *portableDesktop) RecordActivity() {
p.lastDesktopActionAt.Store(p.clock.Now().UnixNano())
}
// runCmd executes a portabledesktop subcommand and returns combined
// output. The caller must have previously called ensureBinary.
func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, error) {
@@ -738,31 +397,3 @@ func (p *portableDesktop) ensureBinary(ctx context.Context) error {
return xerrors.New("portabledesktop binary not found in PATH or script bin directory")
}
// monitorRecordingIdle watches for desktop inactivity and stops the
// given recording when the idle timeout is reached.
func (p *portableDesktop) monitorRecordingIdle(ctx context.Context, rec *recordingProcess) {
timer := p.clock.NewTimer(idleTimeout, "agentdesktop", "recording_idle")
defer timer.Stop()
for {
select {
case <-timer.C:
lastNano := p.lastDesktopActionAt.Load()
lastAction := time.Unix(0, lastNano)
elapsed := p.clock.Since(lastAction)
if elapsed >= idleTimeout {
p.mu.Lock()
p.lockedStopRecordingProcess(context.Background(), rec, false)
p.mu.Unlock()
return
}
// Activity happened; reset with remaining budget.
timer.Reset(idleTimeout-elapsed, "agentdesktop", "recording_idle")
case <-rec.done:
return
case <-ctx.Done():
return
}
}
}
@@ -9,17 +9,13 @@ import (
"strings"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/pty"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
// recordedExecer implements agentexec.Execer by recording every
@@ -90,7 +86,6 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
clock: quartz.NewReal(),
}
ctx := t.Context()
@@ -122,7 +117,6 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
clock: quartz.NewReal(),
}
ctx := t.Context()
@@ -165,7 +159,6 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
clock: quartz.NewReal(),
}
ctx := t.Context()
@@ -191,7 +184,6 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
clock: quartz.NewReal(),
}
ctx := t.Context()
@@ -290,7 +282,6 @@ func TestPortableDesktop_MouseMethods(t *testing.T) {
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
clock: quartz.NewReal(),
}
err := tt.invoke(t.Context(), pd)
@@ -298,6 +289,7 @@ func TestPortableDesktop_MouseMethods(t *testing.T) {
cmds := rec.allCommands()
require.NotEmpty(t, cmds, "expected at least one command")
// Find at least one recorded command that contains
// all expected argument substrings.
found := false
@@ -375,7 +367,6 @@ func TestPortableDesktop_KeyboardMethods(t *testing.T) {
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
clock: quartz.NewReal(),
}
err := tt.invoke(t.Context(), pd)
@@ -432,7 +423,6 @@ func TestPortableDesktop_Close(t *testing.T) {
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
clock: quartz.NewReal(),
}
ctx := t.Context()
@@ -455,7 +445,7 @@ func TestPortableDesktop_Close(t *testing.T) {
// Subsequent Start must fail.
_, err = pd.Start(ctx)
require.Error(t, err)
assert.Contains(t, err.Error(), "desktop closed")
assert.Contains(t, err.Error(), "desktop is closed")
}
// --- ensureBinary tests ---
@@ -549,410 +539,7 @@ func TestEnsureBinary_NotFound(t *testing.T) {
assert.Contains(t, err.Error(), "not found")
}
func TestPortableDesktop_StartRecording(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"record": `trap 'exit 0' INT; sleep 120 & wait`,
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
clk := quartz.NewReal()
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
ctx := t.Context()
recID := uuid.New().String()
err := pd.StartRecording(ctx, recID)
require.NoError(t, err)
cmds := rec.allCommands()
require.NotEmpty(t, cmds)
// Find the record command (not the up command).
found := false
for _, cmd := range cmds {
joined := strings.Join(cmd, " ")
if strings.Contains(joined, "record") && strings.Contains(joined, "coder-recording-"+recID) {
found = true
break
}
}
assert.True(t, found, "expected a record command with the recording ID")
require.NoError(t, pd.Close())
}
func TestPortableDesktop_StartRecording_ConcurrentLimit(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"record": `trap 'exit 0' INT; sleep 120 & wait`,
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
clk := quartz.NewReal()
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
ctx := t.Context()
for i := range maxConcurrentRecordings {
err := pd.StartRecording(ctx, uuid.New().String())
require.NoError(t, err, "recording %d should succeed", i)
}
err := pd.StartRecording(ctx, uuid.New().String())
require.Error(t, err)
assert.Contains(t, err.Error(), "too many concurrent recordings")
require.NoError(t, pd.Close())
}
func TestPortableDesktop_StopRecording_ReturnsArtifact(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"record": `trap 'exit 0' INT; sleep 120 & wait`,
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
clk := quartz.NewReal()
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
ctx := t.Context()
recID := uuid.New().String()
err := pd.StartRecording(ctx, recID)
require.NoError(t, err)
// Write a dummy MP4 file at the expected path so StopRecording
// can open it as an artifact.
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".mp4")
require.NoError(t, os.WriteFile(filePath, []byte("fake-mp4-data"), 0o600))
t.Cleanup(func() { _ = os.Remove(filePath) })
artifact, err := pd.StopRecording(ctx, recID)
require.NoError(t, err)
defer artifact.Reader.Close()
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
require.NoError(t, pd.Close())
}
func TestPortableDesktop_StopRecording_UnknownID(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"record": `trap 'exit 0' INT; sleep 120 & wait`,
},
}
clk := quartz.NewReal()
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
ctx := t.Context()
_, err := pd.StopRecording(ctx, uuid.New().String())
require.ErrorIs(t, err, ErrUnknownRecording)
require.NoError(t, pd.Close())
}
// Ensure that portableDesktop satisfies the Desktop interface at
// compile time. This uses the unexported type so it lives in the
// internal test package.
var _ Desktop = (*portableDesktop)(nil)
func TestPortableDesktop_IdleTimeout_StopsRecordings(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"record": `trap 'exit 0' INT; sleep 120 & wait`,
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
clk := quartz.NewMock(t)
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
ctx := t.Context()
recID := uuid.New().String()
// Install the trap before StartRecording so it is guaranteed
// to catch the idle monitor's NewTimer call regardless of
// goroutine scheduling.
trap := clk.Trap().NewTimer("agentdesktop", "recording_idle")
err := pd.StartRecording(ctx, recID)
require.NoError(t, err)
// Verify recording is active.
pd.mu.Lock()
require.False(t, pd.recordings[recID].stopped)
pd.mu.Unlock()
// Wait for the idle monitor timer to be created and release
// it so the monitor enters its select loop.
trap.MustWait(ctx).MustRelease(ctx)
trap.Close()
// The stop-all path calls lockedStopRecordingProcess which
// creates a per-recording 15s stop_timeout timer.
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
// Advance past idle timeout to trigger the stop-all.
clk.Advance(idleTimeout)
// Wait for the stop timer to be created, then release it.
stopTrap.MustWait(ctx).MustRelease(ctx)
stopTrap.Close()
// The recording process should now be stopped.
require.Eventually(t, func() bool {
pd.mu.Lock()
defer pd.mu.Unlock()
rec, ok := pd.recordings[recID]
return ok && rec.stopped
}, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, pd.Close())
}
func TestPortableDesktop_IdleTimeout_ActivityResetsTimer(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"record": `trap 'exit 0' INT; sleep 120 & wait`,
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
clk := quartz.NewMock(t)
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
ctx := t.Context()
recID := uuid.New().String()
// Install the trap before StartRecording so it is guaranteed
// to catch the idle monitor's NewTimer call regardless of
// goroutine scheduling.
trap := clk.Trap().NewTimer("agentdesktop", "recording_idle")
err := pd.StartRecording(ctx, recID)
require.NoError(t, err)
// Wait for the idle monitor timer to be created.
trap.MustWait(ctx).MustRelease(ctx)
trap.Close()
// Advance most of the way but not past the timeout.
clk.Advance(idleTimeout - time.Minute)
// Record activity to reset the timer.
pd.RecordActivity()
// Trap the Reset call that the idle monitor makes when it
// sees recent activity.
resetTrap := clk.Trap().TimerReset("agentdesktop", "recording_idle")
// Advance past the original idle timeout deadline. The
// monitor should see the recent activity and reset instead
// of stopping.
clk.Advance(time.Minute)
resetTrap.MustWait(ctx).MustRelease(ctx)
resetTrap.Close()
// Recording should still be active because activity was
// recorded.
pd.mu.Lock()
require.False(t, pd.recordings[recID].stopped)
pd.mu.Unlock()
require.NoError(t, pd.Close())
}
func TestPortableDesktop_IdleTimeout_MultipleRecordings(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"record": `trap 'exit 0' INT; sleep 120 & wait`,
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
clk := quartz.NewMock(t)
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
ctx := t.Context()
recID1 := uuid.New().String()
recID2 := uuid.New().String()
// Trap idle timer creation for both recordings.
trap := clk.Trap().NewTimer("agentdesktop", "recording_idle")
err := pd.StartRecording(ctx, recID1)
require.NoError(t, err)
// Wait for first recording's idle timer.
trap.MustWait(ctx).MustRelease(ctx)
err = pd.StartRecording(ctx, recID2)
require.NoError(t, err)
// Wait for second recording's idle timer.
trap.MustWait(ctx).MustRelease(ctx)
trap.Close()
// Trap the stop timers that will be created when idle fires.
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
// Advance past idle timeout.
clk.Advance(idleTimeout)
// Wait for both stop timers.
stopTrap.MustWait(ctx).MustRelease(ctx)
stopTrap.MustWait(ctx).MustRelease(ctx)
stopTrap.Close()
// Both recordings should be stopped.
require.Eventually(t, func() bool {
pd.mu.Lock()
defer pd.mu.Unlock()
r1, ok1 := pd.recordings[recID1]
r2, ok2 := pd.recordings[recID2]
return ok1 && r1.stopped && ok2 && r2.stopped
}, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, pd.Close())
}
func TestPortableDesktop_StartRecording_ReturnsErrDesktopClosed(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
clk := quartz.NewReal()
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
// Start and close the desktop so it's in the closed state.
ctx := t.Context()
_, err := pd.Start(ctx)
require.NoError(t, err)
require.NoError(t, pd.Close())
// StartRecording should now return ErrDesktopClosed.
err = pd.StartRecording(ctx, uuid.New().String())
require.ErrorIs(t, err, ErrDesktopClosed)
}
func TestPortableDesktop_Start_ReturnsErrDesktopClosed(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: quartz.NewReal(),
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(pd.clock.Now().UnixNano())
ctx := t.Context()
_, err := pd.Start(ctx)
require.NoError(t, err)
require.NoError(t, pd.Close())
_, err = pd.Start(ctx)
require.ErrorIs(t, err, ErrDesktopClosed)
}
@@ -1,12 +0,0 @@
//go:build !windows
package agentdesktop
import "os"
// interruptRecordingProcess sends a SIGINT to the recording process
// for graceful shutdown. On Unix, os.Interrupt is delivered as
// SIGINT which lets the recorder finalize the MP4 container.
func interruptRecordingProcess(p *os.Process) error {
return p.Signal(os.Interrupt)
}
@@ -1,10 +0,0 @@
package agentdesktop
import "os"
// interruptRecordingProcess kills the recording process directly
// because os.Process.Signal(os.Interrupt) is not supported on
// Windows and returns an error without delivering a signal.
func interruptRecordingProcess(p *os.Process) error {
return p.Kill()
}
+1 -5
View File
@@ -187,11 +187,7 @@ func (*Manager) connectServer(ctx context.Context, cfg ServerConfig) (*client.Cl
connectCtx, cancel := context.WithTimeout(ctx, connectTimeout)
defer cancel()
// Use the parent ctx (not connectCtx) so the subprocess outlives
// the connect/initialize handshake. connectCtx bounds only the
// Initialize call below. The subprocess is cleaned up when the
// Manager is closed or ctx is canceled.
if err := c.Start(ctx); err != nil {
if err := c.Start(connectCtx); err != nil {
_ = c.Close()
return nil, xerrors.Errorf("start %q: %w", cfg.Name, err)
}
-121
View File
@@ -1,11 +1,6 @@
package agentmcp
import (
"bufio"
"context"
"encoding/json"
"fmt"
"os"
"testing"
"github.com/mark3labs/mcp-go/mcp"
@@ -13,7 +8,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/testutil"
)
func TestSplitToolName(t *testing.T) {
@@ -199,118 +193,3 @@ func TestConvertResult(t *testing.T) {
})
}
}
// TestConnectServer_StdioProcessSurvivesConnect verifies that a stdio MCP
// server subprocess remains alive after connectServer returns. This is a
// regression test for a bug where the subprocess was tied to a short-lived
// connectCtx and killed as soon as the context was canceled.
func TestConnectServer_StdioProcessSurvivesConnect(t *testing.T) {
t.Parallel()
if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" {
// Child process: act as a minimal MCP server over stdio.
runFakeMCPServer()
return
}
// Get the path to the test binary so we can re-exec ourselves
// as a fake MCP server subprocess.
testBin, err := os.Executable()
require.NoError(t, err)
cfg := ServerConfig{
Name: "fake",
Transport: "stdio",
Command: testBin,
Args: []string{"-test.run=^TestConnectServer_StdioProcessSurvivesConnect$"},
Env: map[string]string{"TEST_MCP_FAKE_SERVER": "1"},
}
ctx := testutil.Context(t, testutil.WaitLong)
m := &Manager{}
client, err := m.connectServer(ctx, cfg)
require.NoError(t, err, "connectServer should succeed")
t.Cleanup(func() { _ = client.Close() })
// At this point connectServer has returned and its internal
// connectCtx has been canceled. The subprocess must still be
// alive. Verify by listing tools (requires a live server).
listCtx, listCancel := context.WithTimeout(ctx, testutil.WaitShort)
defer listCancel()
result, err := client.ListTools(listCtx, mcp.ListToolsRequest{})
require.NoError(t, err, "ListTools should succeed — server must be alive after connect")
require.Len(t, result.Tools, 1)
assert.Equal(t, "echo", result.Tools[0].Name)
}
// runFakeMCPServer implements a minimal JSON-RPC / MCP server over
// stdin/stdout, just enough for initialize + tools/list.
func runFakeMCPServer() {
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
line := scanner.Bytes()
var req struct {
JSONRPC string `json:"jsonrpc"`
ID json.RawMessage `json:"id"`
Method string `json:"method"`
}
if err := json.Unmarshal(line, &req); err != nil {
continue
}
var resp any
switch req.Method {
case "initialize":
resp = map[string]any{
"jsonrpc": "2.0",
"id": req.ID,
"result": map[string]any{
"protocolVersion": "2025-03-26",
"capabilities": map[string]any{
"tools": map[string]any{},
},
"serverInfo": map[string]any{
"name": "fake-server",
"version": "0.0.1",
},
},
}
case "notifications/initialized":
// No response needed for notifications.
continue
case "tools/list":
resp = map[string]any{
"jsonrpc": "2.0",
"id": req.ID,
"result": map[string]any{
"tools": []map[string]any{
{
"name": "echo",
"description": "echoes input",
"inputSchema": map[string]any{
"type": "object",
"properties": map[string]any{},
},
},
},
},
}
default:
resp = map[string]any{
"jsonrpc": "2.0",
"id": req.ID,
"error": map[string]any{
"code": -32601,
"message": "method not found",
},
}
}
out, err := json.Marshal(resp)
if err != nil {
continue
}
_, _ = fmt.Fprintf(os.Stdout, "%s\n", out)
}
}
+19 -28
View File
@@ -3,13 +3,11 @@
"enabled": true,
"clientKind": "git",
"useIgnoreFile": true,
"defaultBranch": "main"
"defaultBranch": "main",
},
"files": {
// static/*.html are Go templates with {{ }} directives that
// Biome's HTML parser does not support.
"includes": ["**", "!**/pnpm-lock.yaml", "!**/static/*.html"],
"ignoreUnknown": true
"includes": ["**", "!**/pnpm-lock.yaml"],
"ignoreUnknown": true,
},
"linter": {
"rules": {
@@ -17,7 +15,7 @@
"noSvgWithoutTitle": "off",
"useButtonType": "off",
"useSemanticElements": "off",
"noStaticElementInteractions": "off"
"noStaticElementInteractions": "off",
},
"correctness": {
"noUnusedImports": "warn",
@@ -26,9 +24,9 @@
"noUnusedVariables": {
"level": "warn",
"options": {
"ignoreRestSiblings": true
}
}
"ignoreRestSiblings": true,
},
},
},
"style": {
"noNonNullAssertion": "off",
@@ -49,7 +47,7 @@
"paths": {
"react": {
"message": "React 19 no longer requires forwardRef. Use ref as a prop instead.",
"importNames": ["forwardRef"]
"importNames": ["forwardRef"],
},
// "@mui/material/Alert": "Use components/Alert/Alert instead.",
// "@mui/material/AlertTitle": "Use components/Alert/Alert instead.",
@@ -117,10 +115,10 @@
"@emotion/styled": "Use Tailwind CSS instead.",
// "@emotion/cache": "Use Tailwind CSS instead.",
// "components/Stack/Stack": "Use Tailwind flex utilities instead (e.g., <div className='flex flex-col gap-4'>).",
"lodash": "Use lodash/<name> instead."
}
}
}
"lodash": "Use lodash/<name> instead.",
},
},
},
},
"suspicious": {
"noArrayIndexKey": "off",
@@ -131,21 +129,14 @@
"noConsole": {
"level": "error",
"options": {
"allow": ["error", "info", "warn"]
}
}
"allow": ["error", "info", "warn"],
},
},
},
"complexity": {
"noImportantStyles": "off" // TODO: check and fix !important styles
}
}
"noImportantStyles": "off", // TODO: check and fix !important styles
},
},
},
"css": {
"parser": {
// Biome 2.3+ requires opt-in for @apply and other
// Tailwind directives.
"tailwindDirectives": true
}
},
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json"
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json",
}
+6 -33
View File
@@ -17,7 +17,6 @@ import (
"strings"
"time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/xerrors"
"gopkg.in/natefinch/lumberjack.v2"
@@ -273,14 +272,11 @@ func workspaceAgent() *serpent.Command {
logger.Info(ctx, "agent devcontainer detection not enabled")
}
reinitCtx, reinitCancel := context.WithCancel(ctx)
defer reinitCancel()
reinitEvents := agentsdk.WaitForReinitLoop(reinitCtx, logger, client)
reinitEvents := agentsdk.WaitForReinitLoop(ctx, logger, client)
var (
lastOwnerID uuid.UUID
lastErr error
mustExit bool
lastErr error
mustExit bool
)
for {
prometheusRegistry := prometheus.NewRegistry()
@@ -347,32 +343,9 @@ func workspaceAgent() *serpent.Command {
case <-ctx.Done():
logger.Info(ctx, "agent shutting down", slog.Error(context.Cause(ctx)))
mustExit = true
case event, ok := <-reinitEvents:
switch {
case !ok:
// Channel closed — the reinit loop exited
// (terminal 409 or context expired). Keep
// running the current agent until the parent
// context is canceled.
logger.Info(ctx, "reinit channel closed, running without reinit capability")
reinitEvents = nil
<-ctx.Done()
mustExit = true
case event.OwnerID != uuid.Nil && event.OwnerID == lastOwnerID:
// Duplicate reinit for same owner — already
// reinitialized. Cancel the reinit loop
// goroutine and keep the current agent.
logger.Info(ctx, "skipping redundant reinit, owner unchanged",
slog.F("owner_id", event.OwnerID))
reinitCancel()
reinitEvents = nil
<-ctx.Done()
mustExit = true
default:
lastOwnerID = event.OwnerID
logger.Info(ctx, "agent received instruction to reinitialize",
slog.F("workspace_id", event.WorkspaceID), slog.F("reason", event.Reason))
}
case event := <-reinitEvents:
logger.Info(ctx, "agent received instruction to reinitialize",
slog.F("workspace_id", event.WorkspaceID), slog.F("reason", event.Reason))
}
lastErr = agnt.Close()
+1 -1
View File
@@ -104,7 +104,7 @@ func (b *Builder) Build(inv *serpent.Invocation) (log slog.Logger, closeLog func
addSinkIfProvided := func(sinkFn func(io.Writer) slog.Sink, loc string) error {
switch loc {
case "", "/dev/null":
case "":
case "/dev/stdout":
sinks = append(sinks, sinkFn(inv.Stdout))
-3
View File
@@ -1401,9 +1401,6 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {
// Setup our workspace agent connection.
config := workspacetraffic.Config{
AgentID: agent.ID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agent.Name,
BytesPerTick: bytesPerTick,
Duration: strategy.timeout,
TickInterval: tickInterval,
+12
View File
@@ -16,9 +16,21 @@ OPTIONS:
Always prompt all parameters. Does not pull parameter values from
active template version.
--description string
Set the description of the template. Overrides any value from
README.md frontmatter.
-d, --directory string (default: .)
Specify the directory to create from, use '-' to read tar from stdin.
--display-name string
Set the display name of the template. Overrides any value from
README.md frontmatter.
--icon string
Set the icon of the template. Overrides any value from README.md
frontmatter.
--ignore-lockfile bool (default: false)
Ignore warnings about not having a .terraform.lock.hcl file present in
the template.
+1 -1
View File
@@ -85,7 +85,7 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor
AgentName: a.AgentName,
Type: connectionType,
Code: code,
IP: logIP,
Ip: logIP,
ConnectionID: uuid.NullUUID{
UUID: connectionID,
Valid: true,
+1 -1
View File
@@ -152,7 +152,7 @@ func TestConnectionLog(t *testing.T) {
Int32: tt.status,
Valid: *tt.action == agentproto.Connection_DISCONNECT,
},
IP: expectedIP,
Ip: expectedIP,
Type: agentProtoConnectionTypeToConnectionLog(t, *tt.typ),
DisconnectReason: sql.NullString{
String: tt.reason,
+2 -59
View File
@@ -10205,26 +10205,12 @@ const docTemplate = `{
],
"summary": "Get workspace agent reinitialization",
"operationId": "get-workspace-agent-reinitialization",
"parameters": [
{
"type": "boolean",
"description": "Opt in to durable reinit checks",
"name": "wait",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/agentsdk.ReinitializationEvent"
}
},
"409": {
"description": "Conflict",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
}
},
"security": [
@@ -12661,16 +12647,11 @@ const docTemplate = `{
"agentsdk.ReinitializationEvent": {
"type": "object",
"properties": {
"owner_id": {
"type": "string",
"format": "uuid"
},
"reason": {
"$ref": "#/definitions/agentsdk.ReinitializationReason"
},
"workspace_id": {
"type": "string",
"format": "uuid"
"workspaceID": {
"type": "string"
}
}
},
@@ -13133,12 +13114,6 @@ const docTemplate = `{
"codersdk.AIBridgeSessionThreadsTokenUsage": {
"type": "object",
"properties": {
"cache_read_input_tokens": {
"type": "integer"
},
"cache_write_input_tokens": {
"type": "integer"
},
"input_tokens": {
"type": "integer"
},
@@ -13154,12 +13129,6 @@ const docTemplate = `{
"codersdk.AIBridgeSessionTokenUsageSummary": {
"type": "object",
"properties": {
"cache_read_input_tokens": {
"type": "integer"
},
"cache_write_input_tokens": {
"type": "integer"
},
"input_tokens": {
"type": "integer"
},
@@ -13206,12 +13175,6 @@ const docTemplate = `{
"codersdk.AIBridgeTokenUsage": {
"type": "object",
"properties": {
"cache_read_input_tokens": {
"type": "integer"
},
"cache_write_input_tokens": {
"type": "integer"
},
"created_at": {
"type": "string",
"format": "date-time"
@@ -14175,9 +14138,6 @@ const docTemplate = `{
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -14499,9 +14459,6 @@ const docTemplate = `{
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -14589,17 +14546,6 @@ const docTemplate = `{
}
}
},
"codersdk.CreateFirstUserOnboardingInfo": {
"type": "object",
"properties": {
"newsletter_marketing": {
"type": "boolean"
},
"newsletter_releases": {
"type": "boolean"
}
}
},
"codersdk.CreateFirstUserRequest": {
"type": "object",
"required": [
@@ -14614,9 +14560,6 @@ const docTemplate = `{
"name": {
"type": "string"
},
"onboarding_info": {
"$ref": "#/definitions/codersdk.CreateFirstUserOnboardingInfo"
},
"password": {
"type": "string"
},
+2 -59
View File
@@ -9038,26 +9038,12 @@
"tags": ["Agents"],
"summary": "Get workspace agent reinitialization",
"operationId": "get-workspace-agent-reinitialization",
"parameters": [
{
"type": "boolean",
"description": "Opt in to durable reinit checks",
"name": "wait",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/agentsdk.ReinitializationEvent"
}
},
"409": {
"description": "Conflict",
"schema": {
"$ref": "#/definitions/codersdk.Response"
}
}
},
"security": [
@@ -11243,16 +11229,11 @@
"agentsdk.ReinitializationEvent": {
"type": "object",
"properties": {
"owner_id": {
"type": "string",
"format": "uuid"
},
"reason": {
"$ref": "#/definitions/agentsdk.ReinitializationReason"
},
"workspace_id": {
"type": "string",
"format": "uuid"
"workspaceID": {
"type": "string"
}
}
},
@@ -11711,12 +11692,6 @@
"codersdk.AIBridgeSessionThreadsTokenUsage": {
"type": "object",
"properties": {
"cache_read_input_tokens": {
"type": "integer"
},
"cache_write_input_tokens": {
"type": "integer"
},
"input_tokens": {
"type": "integer"
},
@@ -11732,12 +11707,6 @@
"codersdk.AIBridgeSessionTokenUsageSummary": {
"type": "object",
"properties": {
"cache_read_input_tokens": {
"type": "integer"
},
"cache_write_input_tokens": {
"type": "integer"
},
"input_tokens": {
"type": "integer"
},
@@ -11784,12 +11753,6 @@
"codersdk.AIBridgeTokenUsage": {
"type": "object",
"properties": {
"cache_read_input_tokens": {
"type": "integer"
},
"cache_write_input_tokens": {
"type": "integer"
},
"created_at": {
"type": "string",
"format": "date-time"
@@ -12739,9 +12702,6 @@
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -13042,9 +13002,6 @@
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -13129,17 +13086,6 @@
}
}
},
"codersdk.CreateFirstUserOnboardingInfo": {
"type": "object",
"properties": {
"newsletter_marketing": {
"type": "boolean"
},
"newsletter_releases": {
"type": "boolean"
}
}
},
"codersdk.CreateFirstUserRequest": {
"type": "object",
"required": ["email", "password", "username"],
@@ -13150,9 +13096,6 @@
"name": {
"type": "string"
},
"onboarding_info": {
"$ref": "#/definitions/codersdk.CreateFirstUserOnboardingInfo"
},
"password": {
"type": "string"
},
+1 -8
View File
@@ -26,11 +26,6 @@ import (
"github.com/coder/coder/v2/codersdk"
)
// Limit the count query to avoid a slow sequential scan due to joins
// on a large table. Set to 0 to disable capping (but also see the note
// in the SQL query).
const auditLogCountCap = 2000
// @Summary Get audit logs
// @ID get-audit-logs
// @Security CoderSessionToken
@@ -71,7 +66,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
countFilter.Username = ""
}
countFilter.CountCap = auditLogCountCap
// Use the same filters to count the number of audit logs
count, err := api.Database.CountAuditLogs(ctx, countFilter)
if dbauthz.IsNotAuthorizedError(err) {
httpapi.Forbidden(rw)
@@ -86,7 +81,6 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
AuditLogs: []codersdk.AuditLog{},
Count: 0,
CountCap: auditLogCountCap,
})
return
}
@@ -104,7 +98,6 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
AuditLogs: api.convertAuditLogs(ctx, dblogs),
Count: count,
CountCap: auditLogCountCap,
})
}
+1 -10
View File
@@ -168,7 +168,6 @@ type Options struct {
ConnectionLogger connectionlog.ConnectionLogger
AgentConnectionUpdateFrequency time.Duration
AgentInactiveDisconnectTimeout time.Duration
ChatdInstructionLookupTimeout time.Duration
AWSCertificates awsidentity.Certificates
Authorizer rbac.Authorizer
AzureCertificates x509.VerifyOptions
@@ -783,10 +782,9 @@ func New(options *Options) *API {
ReplicaID: api.ID,
SubscribeFn: options.ChatSubscribeFn,
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
ProviderAPIKeys: ChatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
AgentConn: api.agentProvider.AgentConn,
AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout,
InstructionLookupTimeout: options.ChatdInstructionLookupTimeout,
CreateWorkspace: api.chatCreateWorkspace,
StartWorkspace: api.chatStartWorkspace,
Pubsub: options.Pubsub,
@@ -1223,13 +1221,6 @@ func New(options *Options) *API {
r.Delete("/", api.deleteChatUsageLimitGroupOverride)
})
})
r.Route("/user-provider-configs", func(r chi.Router) {
r.Get("/", api.listUserChatProviderConfigs)
r.Route("/{providerConfig}", func(r chi.Router) {
r.Put("/", api.upsertUserChatProviderKey)
r.Delete("/", api.deleteUserChatProviderKey)
})
})
r.Route("/{chat}", func(r chi.Router) {
r.Use(httpmw.ExtractChatParam(options.Database))
r.Get("/", api.getChat)
+6 -8
View File
@@ -149,13 +149,12 @@ type Options struct {
OneTimePasscodeValidityPeriod time.Duration
// IncludeProvisionerDaemon when true means to start an in-memory provisionerD
IncludeProvisionerDaemon bool
ChatdInstructionLookupTimeout time.Duration
ProvisionerDaemonVersion string
ProvisionerDaemonTags map[string]string
MetricsCacheRefreshInterval time.Duration
AgentStatsRefreshInterval time.Duration
DeploymentValues *codersdk.DeploymentValues
IncludeProvisionerDaemon bool
ProvisionerDaemonVersion string
ProvisionerDaemonTags map[string]string
MetricsCacheRefreshInterval time.Duration
AgentStatsRefreshInterval time.Duration
DeploymentValues *codersdk.DeploymentValues
// Set update check options to enable update check.
UpdateCheckOptions *updatecheck.Options
@@ -576,7 +575,6 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
// Force a long disconnection timeout to ensure
// agents are not marked as disconnected during slow tests.
AgentInactiveDisconnectTimeout: testutil.WaitShort,
ChatdInstructionLookupTimeout: options.ChatdInstructionLookupTimeout,
AccessURL: accessURL,
AppHostname: options.AppHostname,
AppHostnameRegex: appHostnameRegex,
+2 -2
View File
@@ -90,8 +90,8 @@ func (m *FakeConnectionLogger) Contains(t testing.TB, expected database.UpsertCo
t.Logf("connection log %d: expected Code %d, got %d", idx+1, expected.Code.Int32, cl.Code.Int32)
continue
}
if expected.IP.Valid && cl.IP.IPNet.String() != expected.IP.IPNet.String() {
t.Logf("connection log %d: expected IP %s, got %s", idx+1, expected.IP.IPNet, cl.IP.IPNet)
if expected.Ip.Valid && cl.Ip.IPNet.String() != expected.Ip.IPNet.String() {
t.Logf("connection log %d: expected IP %s, got %s", idx+1, expected.Ip.IPNet, cl.Ip.IPNet)
continue
}
if expected.UserAgent.Valid && cl.UserAgent.String != expected.UserAgent.String {
-2
View File
@@ -10,7 +10,6 @@ const (
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
CheckValidCredentialPolicy CheckConstraint = "valid_credential_policy" // chat_providers
CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config
CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config
CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config
@@ -33,5 +32,4 @@ const (
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
CheckUserChatProviderKeysAPIKeyCheck CheckConstraint = "user_chat_provider_keys_api_key_check" // user_chat_provider_keys
)
+21 -45
View File
@@ -1037,10 +1037,8 @@ func AIBridgeSession(row database.ListAIBridgeSessionsRow) codersdk.AIBridgeSess
StartedAt: row.StartedAt,
Threads: row.Threads,
TokenUsageSummary: codersdk.AIBridgeSessionTokenUsageSummary{
InputTokens: row.InputTokens,
OutputTokens: row.OutputTokens,
CacheReadInputTokens: row.CacheReadInputTokens,
CacheWriteInputTokens: row.CacheWriteInputTokens,
InputTokens: row.InputTokens,
OutputTokens: row.OutputTokens,
},
}
// Ensure non-nil slices for JSON serialization.
@@ -1064,15 +1062,13 @@ func AIBridgeSession(row database.ListAIBridgeSessionsRow) codersdk.AIBridgeSess
func AIBridgeTokenUsage(usage database.AIBridgeTokenUsage) codersdk.AIBridgeTokenUsage {
return codersdk.AIBridgeTokenUsage{
ID: usage.ID,
InterceptionID: usage.InterceptionID,
ProviderResponseID: usage.ProviderResponseID,
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
CacheWriteInputTokens: usage.CacheWriteInputTokens,
Metadata: jsonOrEmptyMap(usage.Metadata),
CreatedAt: usage.CreatedAt,
ID: usage.ID,
InterceptionID: usage.InterceptionID,
ProviderResponseID: usage.ProviderResponseID,
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
Metadata: jsonOrEmptyMap(usage.Metadata),
CreatedAt: usage.CreatedAt,
}
}
@@ -1183,11 +1179,9 @@ func AIBridgeSessionThreads(
PageStartedAt: pageStartedAt,
PageEndedAt: pageEndedAt,
TokenUsageSummary: codersdk.AIBridgeSessionThreadsTokenUsage{
InputTokens: session.InputTokens,
OutputTokens: session.OutputTokens,
CacheReadInputTokens: session.CacheReadInputTokens,
CacheWriteInputTokens: session.CacheWriteInputTokens,
Metadata: sessionTokenMeta,
InputTokens: session.InputTokens,
OutputTokens: session.OutputTokens,
Metadata: sessionTokenMeta,
},
Threads: threads,
}
@@ -1320,19 +1314,17 @@ func buildAIBridgeThread(
// aggregateTokenUsage sums token usage rows and aggregates metadata.
func aggregateTokenUsage(tokens []database.AIBridgeTokenUsage) codersdk.AIBridgeSessionThreadsTokenUsage {
var inputTokens, outputTokens, cacheRead, cacheWrite int64
var inputTokens, outputTokens int64
for _, tu := range tokens {
inputTokens += tu.InputTokens
outputTokens += tu.OutputTokens
cacheRead += tu.CacheReadInputTokens
cacheWrite += tu.CacheWriteInputTokens
// TODO: once https://github.com/coder/aibridge/issues/150 lands we
// should aggregate the other token types.
}
return codersdk.AIBridgeSessionThreadsTokenUsage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
CacheReadInputTokens: cacheRead,
CacheWriteInputTokens: cacheWrite,
Metadata: aggregateTokenMetadata(tokens),
InputTokens: inputTokens,
OutputTokens: outputTokens,
Metadata: aggregateTokenMetadata(tokens),
}
}
@@ -1528,10 +1520,7 @@ func nullInt64Ptr(v sql.NullInt64) *int64 {
// Chat converts a database.Chat to a codersdk.Chat. It coalesces
// nil slices and maps to empty values for JSON serialization and
// derives RootChatID from the parent chain when not explicitly set.
// When diffStatus is non-nil the response includes diff metadata.
// When files is non-empty the response includes file metadata;
// pass nil to omit the files field (e.g. list endpoints).
func Chat(c database.Chat, diffStatus *database.ChatDiffStatus, files []database.GetChatFileMetadataByChatIDRow) codersdk.Chat {
func Chat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat {
mcpServerIDs := c.MCPServerIDs
if mcpServerIDs == nil {
mcpServerIDs = []uuid.UUID{}
@@ -1584,19 +1573,6 @@ func Chat(c database.Chat, diffStatus *database.ChatDiffStatus, files []database
convertedDiffStatus := ChatDiffStatus(c.ID, diffStatus)
chat.DiffStatus = &convertedDiffStatus
}
if len(files) > 0 {
chat.Files = make([]codersdk.ChatFileMetadata, 0, len(files))
for _, row := range files {
chat.Files = append(chat.Files, codersdk.ChatFileMetadata{
ID: row.ID,
OwnerID: row.OwnerID,
OrganizationID: row.OrganizationID,
Name: row.Name,
MimeType: row.Mimetype,
CreatedAt: row.CreatedAt,
})
}
}
if c.LastInjectedContext.Valid {
var parts []codersdk.ChatMessagePart
// Internal fields are stripped at write time in
@@ -1620,9 +1596,9 @@ func ChatRows(rows []database.GetChatsRow, diffStatusesByChatID map[uuid.UUID]da
for i, row := range rows {
diffStatus, ok := diffStatusesByChatID[row.Chat.ID]
if ok {
result[i] = Chat(row.Chat, &diffStatus, nil)
result[i] = Chat(row.Chat, &diffStatus)
} else {
result[i] = Chat(row.Chat, nil, nil)
result[i] = Chat(row.Chat, nil)
if diffStatusesByChatID != nil {
emptyDiffStatus := ChatDiffStatus(row.Chat.ID, nil)
result[i].DiffStatus = &emptyDiffStatus
+9 -131
View File
@@ -259,13 +259,11 @@ func TestAIBridgeInterception(t *testing.T) {
},
tokenUsages: []database.AIBridgeTokenUsage{
{
ID: uuid.New(),
InterceptionID: interceptionID,
ProviderResponseID: "resp-123",
InputTokens: 100,
OutputTokens: 200,
CacheReadInputTokens: 50,
CacheWriteInputTokens: 10,
ID: uuid.New(),
InterceptionID: interceptionID,
ProviderResponseID: "resp-123",
InputTokens: 100,
OutputTokens: 200,
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"cache":"hit"}`),
Valid: true,
@@ -415,8 +413,6 @@ func TestAIBridgeInterception(t *testing.T) {
require.Equal(t, tu.ProviderResponseID, result.TokenUsages[i].ProviderResponseID)
require.Equal(t, tu.InputTokens, result.TokenUsages[i].InputTokens)
require.Equal(t, tu.OutputTokens, result.TokenUsages[i].OutputTokens)
require.Equal(t, tu.CacheReadInputTokens, result.TokenUsages[i].CacheReadInputTokens)
require.Equal(t, tu.CacheWriteInputTokens, result.TokenUsages[i].CacheWriteInputTokens)
}
// Verify user prompts are converted correctly.
@@ -561,26 +557,14 @@ func TestChat_AllFieldsPopulated(t *testing.T) {
ChatID: input.ID,
}
fileRows := []database.GetChatFileMetadataByChatIDRow{
{
ID: uuid.New(),
OwnerID: input.OwnerID,
OrganizationID: uuid.New(),
Name: "test.png",
Mimetype: "image/png",
CreatedAt: now,
},
}
got := db2sdk.Chat(input, diffStatus, fileRows)
got := db2sdk.Chat(input, diffStatus)
v := reflect.ValueOf(got)
typ := v.Type()
// HasUnread is populated by ChatRows (which joins the
// read-cursor query), not by Chat. Warnings is a transient
// field populated by handlers, not the converter. Both are
// expected to remain zero here.
skip := map[string]bool{"HasUnread": true, "Warnings": true}
// read-cursor query), not by Chat, so it is expected
// to remain zero here.
skip := map[string]bool{"HasUnread": true}
for i := range typ.NumField() {
field := typ.Field(i)
if skip[field.Name] {
@@ -593,112 +577,6 @@ func TestChat_AllFieldsPopulated(t *testing.T) {
}
}
func TestChat_FileMetadataConversion(t *testing.T) {
t.Parallel()
ownerID := uuid.New()
orgID := uuid.New()
fileID := uuid.New()
now := dbtime.Now()
chat := database.Chat{
ID: uuid.New(),
OwnerID: ownerID,
LastModelConfigID: uuid.New(),
Title: "file metadata test",
Status: database.ChatStatusWaiting,
CreatedAt: now,
UpdatedAt: now,
}
rows := []database.GetChatFileMetadataByChatIDRow{
{
ID: fileID,
OwnerID: ownerID,
OrganizationID: orgID,
Name: "screenshot.png",
Mimetype: "image/png",
CreatedAt: now,
},
}
result := db2sdk.Chat(chat, nil, rows)
require.Len(t, result.Files, 1)
f := result.Files[0]
require.Equal(t, fileID, f.ID)
require.Equal(t, ownerID, f.OwnerID, "OwnerID must be mapped from DB row")
require.Equal(t, orgID, f.OrganizationID, "OrganizationID must be mapped from DB row")
require.Equal(t, "screenshot.png", f.Name)
require.Equal(t, "image/png", f.MimeType)
require.Equal(t, now, f.CreatedAt)
// Verify JSON serialization uses snake_case for mime_type.
data, err := json.Marshal(f)
require.NoError(t, err)
require.Contains(t, string(data), `"mime_type"`)
require.NotContains(t, string(data), `"mimetype"`)
}
func TestChat_NilFilesOmitted(t *testing.T) {
t.Parallel()
chat := database.Chat{
ID: uuid.New(),
OwnerID: uuid.New(),
LastModelConfigID: uuid.New(),
Title: "no files",
Status: database.ChatStatusWaiting,
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
}
result := db2sdk.Chat(chat, nil, nil)
require.Empty(t, result.Files)
}
func TestChat_MultipleFiles(t *testing.T) {
t.Parallel()
now := dbtime.Now()
file1 := uuid.New()
file2 := uuid.New()
chat := database.Chat{
ID: uuid.New(),
OwnerID: uuid.New(),
LastModelConfigID: uuid.New(),
Title: "multi file test",
Status: database.ChatStatusWaiting,
CreatedAt: now,
UpdatedAt: now,
}
rows := []database.GetChatFileMetadataByChatIDRow{
{
ID: file1,
OwnerID: chat.OwnerID,
OrganizationID: uuid.New(),
Name: "a.png",
Mimetype: "image/png",
CreatedAt: now,
},
{
ID: file2,
OwnerID: chat.OwnerID,
OrganizationID: uuid.New(),
Name: "b.txt",
Mimetype: "text/plain",
CreatedAt: now,
},
}
result := db2sdk.Chat(chat, nil, rows)
require.Len(t, result.Files, 2)
require.Equal(t, "a.png", result.Files[0].Name)
require.Equal(t, "b.txt", result.Files[1].Name)
}
func TestChatQueuedMessage_MalformedContent(t *testing.T) {
t.Parallel()
+26 -78
View File
@@ -1627,13 +1627,6 @@ func (q *querier) BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg datab
return q.db.BatchUpdateWorkspaceNextStartAt(ctx, arg)
}
func (q *querier) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return err
}
return q.db.BatchUpsertConnectionLogs(ctx, arg)
}
func (q *querier) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceNotificationMessage); err != nil {
return 0, err
@@ -2088,6 +2081,13 @@ func (q *querier) DeleteOrganizationMember(ctx context.Context, arg database.Del
}, q.db.DeleteOrganizationMember)(ctx, arg)
}
func (q *querier) DeleteOrphanedChatFiles(ctx context.Context, before time.Time) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
return 0, err
}
return q.db.DeleteOrphanedChatFiles(ctx, before)
}
func (q *querier) DeleteProvisionerKey(ctx context.Context, id uuid.UUID) error {
return deleteQ(q.log, q.auth, q.db.GetProvisionerKeyByID, q.db.DeleteProvisionerKey)(ctx, id)
}
@@ -2144,17 +2144,6 @@ func (q *querier) DeleteUserChatCompactionThreshold(ctx context.Context, arg dat
return q.db.DeleteUserChatCompactionThreshold(ctx, arg)
}
func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
return err
}
return q.db.DeleteUserChatProviderKey(ctx, arg)
}
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
// First get the secret to check ownership
secret, err := q.GetUserSecret(ctx, id)
@@ -2583,10 +2572,6 @@ func (q *querier) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.C
return file, nil
}
func (q *querier) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]database.GetChatFileMetadataByChatIDRow, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatFileMetadataByChatID)(ctx, chatID)
}
func (q *querier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
files, err := q.db.GetChatFilesByIDs(ctx, ids)
if err != nil {
@@ -3657,18 +3642,18 @@ func (q *querier) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]database
return q.db.GetTailnetPeers(ctx, id)
}
func (q *querier) GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsBatchRow, error) {
func (q *querier) GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return nil, err
}
return q.db.GetTailnetTunnelPeerBindingsBatch(ctx, ids)
return q.db.GetTailnetTunnelPeerBindings(ctx, srcID)
}
func (q *querier) GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerIDsBatchRow, error) {
func (q *querier) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return nil, err
}
return q.db.GetTailnetTunnelPeerIDsBatch(ctx, ids)
return q.db.GetTailnetTunnelPeerIDs(ctx, srcID)
}
func (q *querier) GetTaskByID(ctx context.Context, id uuid.UUID) (database.Task, error) {
@@ -4046,17 +4031,6 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID)
return q.db.GetUserChatCustomPrompt(ctx, userID)
}
func (q *querier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
u, err := q.db.GetUserByID(ctx, userID)
if err != nil {
return nil, err
}
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
return nil, err
}
return q.db.GetUserChatProviderKeys(ctx, userID)
}
func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil {
return 0, err
@@ -5397,17 +5371,6 @@ func (q *querier) InsertWorkspaceResourceMetadata(ctx context.Context, arg datab
return q.db.InsertWorkspaceResourceMetadata(ctx, arg)
}
func (q *querier) LinkChatFiles(ctx context.Context, arg database.LinkChatFilesParams) (int32, error) {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return 0, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return 0, err
}
return q.db.LinkChatFiles(ctx, arg)
}
func (q *querier) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
if err != nil {
@@ -5782,15 +5745,15 @@ func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByI
return q.db.UpdateChatByID(ctx, arg)
}
func (q *querier) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
// The batch heartbeat is a system-level operation filtered by
// worker_id. Authorization is enforced by the AsChatd context
// at the call site rather than per-row, because checking each
// row individually would defeat the purpose of batching.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return nil, err
func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return 0, err
}
return q.db.UpdateChatHeartbeats(ctx, arg)
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return 0, err
}
return q.db.UpdateChatHeartbeat(ctx, arg)
}
func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateChatLabelsByIDParams) (database.Chat, error) {
@@ -6498,17 +6461,6 @@ func (q *querier) UpdateUserChatCustomPrompt(ctx context.Context, arg database.U
return q.db.UpdateUserChatCustomPrompt(ctx, arg)
}
func (q *querier) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
return database.UserChatProviderKey{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
return database.UserChatProviderKey{}, err
}
return q.db.UpdateUserChatProviderKey(ctx, arg)
}
func (q *querier) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error {
return deleteQ(q.log, q.auth, q.db.GetUserByID, q.db.UpdateUserDeletedByID)(ctx, id)
}
@@ -7087,6 +7039,13 @@ func (q *querier) UpsertChatWorkspaceTTL(ctx context.Context, workspaceTtl strin
return q.db.UpsertChatWorkspaceTTL(ctx, workspaceTtl)
}
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return database.ConnectionLog{}, err
}
return q.db.UpsertConnectionLog(ctx, arg)
}
func (q *querier) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
@@ -7229,17 +7188,6 @@ func (q *querier) UpsertTemplateUsageStats(ctx context.Context) error {
return q.db.UpsertTemplateUsageStats(ctx)
}
func (q *querier) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
return database.UserChatProviderKey{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
return database.UserChatProviderKey{}, err
}
return q.db.UpsertUserChatProviderKey(ctx, arg)
}
func (q *querier) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
+17 -68
View File
@@ -338,9 +338,10 @@ func (s *MethodTestSuite) TestAuditLogs() {
}
func (s *MethodTestSuite) TestConnectionLogs() {
s.Run("BatchUpsertConnectionLogs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.BatchUpsertConnectionLogsParams{}
dbm.EXPECT().BatchUpsertConnectionLogs(gomock.Any(), arg).Return(nil).AnyTimes()
s.Run("UpsertConnectionLog", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ws := testutil.Fake(s.T(), faker, database.WorkspaceTable{})
arg := database.UpsertConnectionLogParams{Ip: defaultIPAddress(), Type: database.ConnectionTypeSsh, WorkspaceID: ws.ID, OrganizationID: ws.OrganizationID, ConnectionStatus: database.ConnectionStatusConnected, WorkspaceOwnerID: ws.OwnerID}
dbm.EXPECT().UpsertConnectionLog(gomock.Any(), arg).Return(database.ConnectionLog{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceConnectionLog, policy.ActionUpdate)
}))
s.Run("GetConnectionLogsOffset", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
@@ -400,17 +401,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return([]database.Chat{chat}, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns([]database.Chat{chat})
}))
s.Run("LinkChatFiles", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.LinkChatFilesParams{
ChatID: chat.ID,
MaxFileLinks: int32(codersdk.MaxChatFileIDs),
FileIds: []uuid.UUID{uuid.New()},
}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().LinkChatFiles(gomock.Any(), arg).Return(int32(0), nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int32(0))
}))
s.Run("PinChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
@@ -587,19 +577,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatFilesByIDs(gomock.Any(), []uuid.UUID{file.ID}).Return([]database.ChatFile{file}, nil).AnyTimes()
check.Args([]uuid.UUID{file.ID}).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns([]database.ChatFile{file})
}))
s.Run("GetChatFileMetadataByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
file := testutil.Fake(s.T(), faker, database.ChatFile{})
rows := []database.GetChatFileMetadataByChatIDRow{{
ID: file.ID,
Name: file.Name,
Mimetype: file.Mimetype,
CreatedAt: file.CreatedAt,
OwnerID: file.OwnerID,
OrganizationID: file.OrganizationID,
}}
dbm.EXPECT().GetChatFileMetadataByChatID(gomock.Any(), file.ID).Return(rows, nil).AnyTimes()
check.Args(file.ID).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns(rows)
}))
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
@@ -842,15 +819,15 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
}))
s.Run("UpdateChatHeartbeats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
resultID := uuid.New()
arg := database.UpdateChatHeartbeatsParams{
IDs: []uuid.UUID{resultID},
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatHeartbeatParams{
ID: chat.ID,
WorkerID: uuid.New(),
Now: time.Now(),
}
dbm.EXPECT().UpdateChatHeartbeats(gomock.Any(), arg).Return([]uuid.UUID{resultID}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]uuid.UUID{resultID})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatHeartbeat(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
}))
s.Run("UpdateChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
@@ -2430,36 +2407,6 @@ func (s *MethodTestSuite) TestUser() {
dbm.EXPECT().GetUserChatCustomPrompt(gomock.Any(), u.ID).Return("my custom prompt", nil).AnyTimes()
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns("my custom prompt")
}))
s.Run("GetUserChatProviderKeys", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID})
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().GetUserChatProviderKeys(gomock.Any(), u.ID).Return([]database.UserChatProviderKey{key}, nil).AnyTimes()
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns([]database.UserChatProviderKey{key})
}))
s.Run("DeleteUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
arg := database.DeleteUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New()}
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().DeleteUserChatProviderKey(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns()
}))
s.Run("UpdateUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
arg := database.UpdateUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "updated-api-key"}
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey})
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().UpdateUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
}))
s.Run("UpsertUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
arg := database.UpsertUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "upserted-api-key"}
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey})
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().UpsertUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
}))
s.Run("UpdateUserChatCustomPrompt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
uc := database.UserConfig{UserID: u.ID, Key: "chat_custom_prompt", Value: "my custom prompt"}
@@ -3773,11 +3720,13 @@ func (s *MethodTestSuite) TestTailnetFunctions() {
check.Args(uuid.New()).
Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead)
}))
s.Run("GetTailnetTunnelPeerBindingsBatch", s.Subtest(func(_ database.Store, check *expects) {
check.Args([]uuid.UUID{uuid.New()}).Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead)
s.Run("GetTailnetTunnelPeerBindings", s.Subtest(func(_ database.Store, check *expects) {
check.Args(uuid.New()).
Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead)
}))
s.Run("GetTailnetTunnelPeerIDsBatch", s.Subtest(func(_ database.Store, check *expects) {
check.Args([]uuid.UUID{uuid.New()}).Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead)
s.Run("GetTailnetTunnelPeerIDs", s.Subtest(func(_ database.Store, check *expects) {
check.Args(uuid.New()).
Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead)
}))
s.Run("GetAllTailnetCoordinators", s.Subtest(func(_ database.Store, check *expects) {
check.Args().
+10 -56
View File
@@ -76,7 +76,7 @@ func AuditLog(t testing.TB, db database.Store, seed database.AuditLog) database.
}
func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnectionLogParams) database.ConnectionLog {
arg := database.UpsertConnectionLogParams{
log, err := db.UpsertConnectionLog(genCtx, database.UpsertConnectionLogParams{
ID: takeFirst(seed.ID, uuid.New()),
Time: takeFirst(seed.Time, dbtime.Now()),
OrganizationID: takeFirst(seed.OrganizationID, uuid.New()),
@@ -89,7 +89,7 @@ func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnecti
Int32: takeFirst(seed.Code.Int32, 0),
Valid: takeFirst(seed.Code.Valid, false),
},
IP: pqtype.Inet{
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
@@ -117,53 +117,9 @@ func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnecti
Valid: takeFirst(seed.DisconnectReason.Valid, false),
},
ConnectionStatus: takeFirst(seed.ConnectionStatus, database.ConnectionStatusConnected),
}
var disconnectTime sql.NullTime
if arg.ConnectionStatus == database.ConnectionStatusDisconnected {
disconnectTime = sql.NullTime{Time: arg.Time, Valid: true}
}
err := db.BatchUpsertConnectionLogs(genCtx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{arg.ID},
ConnectTime: []time.Time{arg.Time},
OrganizationID: []uuid.UUID{arg.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{arg.WorkspaceOwnerID},
WorkspaceID: []uuid.UUID{arg.WorkspaceID},
WorkspaceName: []string{arg.WorkspaceName},
AgentName: []string{arg.AgentName},
Type: []database.ConnectionType{arg.Type},
Code: []int32{arg.Code.Int32},
CodeValid: []bool{arg.Code.Valid},
Ip: []pqtype.Inet{arg.IP},
UserAgent: []string{arg.UserAgent.String},
UserID: []uuid.UUID{arg.UserID.UUID},
SlugOrPort: []string{arg.SlugOrPort.String},
ConnectionID: []uuid.UUID{arg.ConnectionID.UUID},
DisconnectReason: []string{arg.DisconnectReason.String},
DisconnectTime: []time.Time{disconnectTime.Time},
})
require.NoError(t, err, "insert connection log")
// Query back the actual row from the database. On upsert
// conflict the DB keeps the original row's ID, so we can't
// rely on arg.ID. Match on the conflict key for rows with a
// connection_id, or by primary key for NULL connection_id.
rows, err := db.GetConnectionLogsOffset(genCtx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err, "query connection logs")
for _, row := range rows {
if arg.ConnectionID.Valid {
if row.ConnectionLog.ConnectionID == arg.ConnectionID &&
row.ConnectionLog.WorkspaceID == arg.WorkspaceID &&
row.ConnectionLog.AgentName == arg.AgentName {
return row.ConnectionLog
}
} else if row.ConnectionLog.ID == arg.ID {
return row.ConnectionLog
}
}
require.Failf(t, "connection log not found", "id=%s", arg.ID)
return database.ConnectionLog{} // unreachable
return log
}
func Template(t testing.TB, db database.Store, seed database.Template) database.Template {
@@ -1657,15 +1613,13 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
func AIBridgeTokenUsage(t testing.TB, db database.Store, seed database.InsertAIBridgeTokenUsageParams) database.AIBridgeTokenUsage {
usage, err := db.InsertAIBridgeTokenUsage(genCtx, database.InsertAIBridgeTokenUsageParams{
ID: takeFirst(seed.ID, uuid.New()),
InterceptionID: takeFirst(seed.InterceptionID, uuid.New()),
ProviderResponseID: takeFirst(seed.ProviderResponseID, "provider_response_id"),
InputTokens: takeFirst(seed.InputTokens, 100),
OutputTokens: takeFirst(seed.OutputTokens, 100),
CacheReadInputTokens: seed.CacheReadInputTokens,
CacheWriteInputTokens: seed.CacheWriteInputTokens,
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()),
ID: takeFirst(seed.ID, uuid.New()),
InterceptionID: takeFirst(seed.InterceptionID, uuid.New()),
ProviderResponseID: takeFirst(seed.ProviderResponseID, "provider_response_id"),
InputTokens: takeFirst(seed.InputTokens, 100),
OutputTokens: takeFirst(seed.OutputTokens, 100),
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()),
})
require.NoError(t, err, "insert aibridge token usage")
return usage
+28 -68
View File
@@ -208,14 +208,6 @@ func (m queryMetricsStore) BatchUpdateWorkspaceNextStartAt(ctx context.Context,
return r0
}
func (m queryMetricsStore) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
start := time.Now()
r0 := m.s.BatchUpsertConnectionLogs(ctx, arg)
m.queryLatencies.WithLabelValues("BatchUpsertConnectionLogs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BatchUpsertConnectionLogs").Inc()
return r0
}
func (m queryMetricsStore) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.BulkMarkNotificationMessagesFailed(ctx, arg)
@@ -648,6 +640,14 @@ func (m queryMetricsStore) DeleteOrganizationMember(ctx context.Context, arg dat
return r0
}
func (m queryMetricsStore) DeleteOrphanedChatFiles(ctx context.Context, before time.Time) (int64, error) {
start := time.Now()
r0, r1 := m.s.DeleteOrphanedChatFiles(ctx, before)
m.queryLatencies.WithLabelValues("DeleteOrphanedChatFiles").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteOrphanedChatFiles").Inc()
return r0, r1
}
func (m queryMetricsStore) DeleteProvisionerKey(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteProvisionerKey(ctx, id)
@@ -704,14 +704,6 @@ func (m queryMetricsStore) DeleteUserChatCompactionThreshold(ctx context.Context
return r0
}
func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
start := time.Now()
r0 := m.s.DeleteUserChatProviderKey(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteUserChatProviderKey").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserChatProviderKey").Inc()
return r0
}
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteUserSecret(ctx, id)
@@ -1128,14 +1120,6 @@ func (m queryMetricsStore) GetChatFileByID(ctx context.Context, id uuid.UUID) (d
return r0, r1
}
func (m queryMetricsStore) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]database.GetChatFileMetadataByChatIDRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatFileMetadataByChatID(ctx, chatID)
m.queryLatencies.WithLabelValues("GetChatFileMetadataByChatID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatFileMetadataByChatID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
start := time.Now()
r0, r1 := m.s.GetChatFilesByIDs(ctx, ids)
@@ -2232,19 +2216,19 @@ func (m queryMetricsStore) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([
return r0, r1
}
func (m queryMetricsStore) GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsBatchRow, error) {
func (m queryMetricsStore) GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsRow, error) {
start := time.Now()
r0, r1 := m.s.GetTailnetTunnelPeerBindingsBatch(ctx, ids)
m.queryLatencies.WithLabelValues("GetTailnetTunnelPeerBindingsBatch").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTailnetTunnelPeerBindingsBatch").Inc()
r0, r1 := m.s.GetTailnetTunnelPeerBindings(ctx, srcID)
m.queryLatencies.WithLabelValues("GetTailnetTunnelPeerBindings").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTailnetTunnelPeerBindings").Inc()
return r0, r1
}
func (m queryMetricsStore) GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerIDsBatchRow, error) {
func (m queryMetricsStore) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) {
start := time.Now()
r0, r1 := m.s.GetTailnetTunnelPeerIDsBatch(ctx, ids)
m.queryLatencies.WithLabelValues("GetTailnetTunnelPeerIDsBatch").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTailnetTunnelPeerIDsBatch").Inc()
r0, r1 := m.s.GetTailnetTunnelPeerIDs(ctx, srcID)
m.queryLatencies.WithLabelValues("GetTailnetTunnelPeerIDs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTailnetTunnelPeerIDs").Inc()
return r0, r1
}
@@ -2552,14 +2536,6 @@ func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID u
return r0, r1
}
func (m queryMetricsStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
start := time.Now()
r0, r1 := m.s.GetUserChatProviderKeys(ctx, userID)
m.queryLatencies.WithLabelValues("GetUserChatProviderKeys").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatProviderKeys").Inc()
return r0, r1
}
func (m queryMetricsStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.GetUserChatSpendInPeriod(ctx, arg)
@@ -3784,14 +3760,6 @@ func (m queryMetricsStore) InsertWorkspaceResourceMetadata(ctx context.Context,
return r0, r1
}
func (m queryMetricsStore) LinkChatFiles(ctx context.Context, arg database.LinkChatFilesParams) (int32, error) {
start := time.Now()
r0, r1 := m.s.LinkChatFiles(ctx, arg)
m.queryLatencies.WithLabelValues("LinkChatFiles").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "LinkChatFiles").Inc()
return r0, r1
}
func (m queryMetricsStore) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) {
start := time.Now()
r0, r1 := m.s.ListAIBridgeClients(ctx, arg)
@@ -4136,11 +4104,11 @@ func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.Upda
return r0, r1
}
func (m queryMetricsStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatHeartbeats(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatHeartbeats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeats").Inc()
r0, r1 := m.s.UpdateChatHeartbeat(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatHeartbeat").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeat").Inc()
return r0, r1
}
@@ -4600,14 +4568,6 @@ func (m queryMetricsStore) UpdateUserChatCustomPrompt(ctx context.Context, arg d
return r0, r1
}
func (m queryMetricsStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
start := time.Now()
r0, r1 := m.s.UpdateUserChatProviderKey(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateUserChatProviderKey").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserChatProviderKey").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.UpdateUserDeletedByID(ctx, id)
@@ -5048,6 +5008,14 @@ func (m queryMetricsStore) UpsertChatWorkspaceTTL(ctx context.Context, workspace
return r0
}
func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
start := time.Now()
r0, r1 := m.s.UpsertConnectionLog(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertConnectionLog").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertConnectionLog").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
start := time.Now()
r0 := m.s.UpsertDefaultProxy(ctx, arg)
@@ -5192,14 +5160,6 @@ func (m queryMetricsStore) UpsertTemplateUsageStats(ctx context.Context) error {
return r0
}
func (m queryMetricsStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
start := time.Now()
r0, r1 := m.s.UpsertUserChatProviderKey(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertUserChatProviderKey").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertUserChatProviderKey").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
start := time.Now()
r0 := m.s.UpsertWebpushVAPIDKeys(ctx, arg)
+51 -124
View File
@@ -233,20 +233,6 @@ func (mr *MockStoreMockRecorder) BatchUpdateWorkspaceNextStartAt(ctx, arg any) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdateWorkspaceNextStartAt", reflect.TypeOf((*MockStore)(nil).BatchUpdateWorkspaceNextStartAt), ctx, arg)
}
// BatchUpsertConnectionLogs mocks base method.
func (m *MockStore) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BatchUpsertConnectionLogs", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// BatchUpsertConnectionLogs indicates an expected call of BatchUpsertConnectionLogs.
func (mr *MockStoreMockRecorder) BatchUpsertConnectionLogs(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpsertConnectionLogs", reflect.TypeOf((*MockStore)(nil).BatchUpsertConnectionLogs), ctx, arg)
}
// BulkMarkNotificationMessagesFailed mocks base method.
func (m *MockStore) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
m.ctrl.T.Helper()
@@ -1084,6 +1070,21 @@ func (mr *MockStoreMockRecorder) DeleteOrganizationMember(ctx, arg any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOrganizationMember", reflect.TypeOf((*MockStore)(nil).DeleteOrganizationMember), ctx, arg)
}
// DeleteOrphanedChatFiles mocks base method.
func (m *MockStore) DeleteOrphanedChatFiles(ctx context.Context, before time.Time) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteOrphanedChatFiles", ctx, before)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteOrphanedChatFiles indicates an expected call of DeleteOrphanedChatFiles.
func (mr *MockStoreMockRecorder) DeleteOrphanedChatFiles(ctx, before any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOrphanedChatFiles", reflect.TypeOf((*MockStore)(nil).DeleteOrphanedChatFiles), ctx, before)
}
// DeleteProvisionerKey mocks base method.
func (m *MockStore) DeleteProvisionerKey(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
@@ -1185,20 +1186,6 @@ func (mr *MockStoreMockRecorder) DeleteUserChatCompactionThreshold(ctx, arg any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).DeleteUserChatCompactionThreshold), ctx, arg)
}
// DeleteUserChatProviderKey mocks base method.
func (m *MockStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteUserChatProviderKey", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteUserChatProviderKey indicates an expected call of DeleteUserChatProviderKey.
func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg)
}
// DeleteUserSecret mocks base method.
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
@@ -2072,21 +2059,6 @@ func (mr *MockStoreMockRecorder) GetChatFileByID(ctx, id any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFileByID", reflect.TypeOf((*MockStore)(nil).GetChatFileByID), ctx, id)
}
// GetChatFileMetadataByChatID mocks base method.
func (m *MockStore) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]database.GetChatFileMetadataByChatIDRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatFileMetadataByChatID", ctx, chatID)
ret0, _ := ret[0].([]database.GetChatFileMetadataByChatIDRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatFileMetadataByChatID indicates an expected call of GetChatFileMetadataByChatID.
func (mr *MockStoreMockRecorder) GetChatFileMetadataByChatID(ctx, chatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFileMetadataByChatID", reflect.TypeOf((*MockStore)(nil).GetChatFileMetadataByChatID), ctx, chatID)
}
// GetChatFilesByIDs mocks base method.
func (m *MockStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
m.ctrl.T.Helper()
@@ -4142,34 +4114,34 @@ func (mr *MockStoreMockRecorder) GetTailnetPeers(ctx, id any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetPeers", reflect.TypeOf((*MockStore)(nil).GetTailnetPeers), ctx, id)
}
// GetTailnetTunnelPeerBindingsBatch mocks base method.
func (m *MockStore) GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsBatchRow, error) {
// GetTailnetTunnelPeerBindings mocks base method.
func (m *MockStore) GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTailnetTunnelPeerBindingsBatch", ctx, ids)
ret0, _ := ret[0].([]database.GetTailnetTunnelPeerBindingsBatchRow)
ret := m.ctrl.Call(m, "GetTailnetTunnelPeerBindings", ctx, srcID)
ret0, _ := ret[0].([]database.GetTailnetTunnelPeerBindingsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTailnetTunnelPeerBindingsBatch indicates an expected call of GetTailnetTunnelPeerBindingsBatch.
func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerBindingsBatch(ctx, ids any) *gomock.Call {
// GetTailnetTunnelPeerBindings indicates an expected call of GetTailnetTunnelPeerBindings.
func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerBindings(ctx, srcID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerBindingsBatch", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerBindingsBatch), ctx, ids)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerBindings", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerBindings), ctx, srcID)
}
// GetTailnetTunnelPeerIDsBatch mocks base method.
func (m *MockStore) GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerIDsBatchRow, error) {
// GetTailnetTunnelPeerIDs mocks base method.
func (m *MockStore) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTailnetTunnelPeerIDsBatch", ctx, ids)
ret0, _ := ret[0].([]database.GetTailnetTunnelPeerIDsBatchRow)
ret := m.ctrl.Call(m, "GetTailnetTunnelPeerIDs", ctx, srcID)
ret0, _ := ret[0].([]database.GetTailnetTunnelPeerIDsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTailnetTunnelPeerIDsBatch indicates an expected call of GetTailnetTunnelPeerIDsBatch.
func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerIDsBatch(ctx, ids any) *gomock.Call {
// GetTailnetTunnelPeerIDs indicates an expected call of GetTailnetTunnelPeerIDs.
func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerIDs(ctx, srcID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerIDsBatch", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerIDsBatch), ctx, ids)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerIDs", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerIDs), ctx, srcID)
}
// GetTaskByID mocks base method.
@@ -4772,21 +4744,6 @@ func (mr *MockStoreMockRecorder) GetUserChatCustomPrompt(ctx, userID any) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).GetUserChatCustomPrompt), ctx, userID)
}
// GetUserChatProviderKeys mocks base method.
func (m *MockStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserChatProviderKeys", ctx, userID)
ret0, _ := ret[0].([]database.UserChatProviderKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserChatProviderKeys indicates an expected call of GetUserChatProviderKeys.
func (mr *MockStoreMockRecorder) GetUserChatProviderKeys(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatProviderKeys", reflect.TypeOf((*MockStore)(nil).GetUserChatProviderKeys), ctx, userID)
}
// GetUserChatSpendInPeriod mocks base method.
func (m *MockStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
m.ctrl.T.Helper()
@@ -7081,21 +7038,6 @@ func (mr *MockStoreMockRecorder) InsertWorkspaceResourceMetadata(ctx, arg any) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertWorkspaceResourceMetadata", reflect.TypeOf((*MockStore)(nil).InsertWorkspaceResourceMetadata), ctx, arg)
}
// LinkChatFiles mocks base method.
func (m *MockStore) LinkChatFiles(ctx context.Context, arg database.LinkChatFilesParams) (int32, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkChatFiles", ctx, arg)
ret0, _ := ret[0].(int32)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LinkChatFiles indicates an expected call of LinkChatFiles.
func (mr *MockStoreMockRecorder) LinkChatFiles(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkChatFiles", reflect.TypeOf((*MockStore)(nil).LinkChatFiles), ctx, arg)
}
// ListAIBridgeClients mocks base method.
func (m *MockStore) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) {
m.ctrl.T.Helper()
@@ -7835,19 +7777,19 @@ func (mr *MockStoreMockRecorder) UpdateChatByID(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatByID", reflect.TypeOf((*MockStore)(nil).UpdateChatByID), ctx, arg)
}
// UpdateChatHeartbeats mocks base method.
func (m *MockStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
// UpdateChatHeartbeat mocks base method.
func (m *MockStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatHeartbeats", ctx, arg)
ret0, _ := ret[0].([]uuid.UUID)
ret := m.ctrl.Call(m, "UpdateChatHeartbeat", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatHeartbeats indicates an expected call of UpdateChatHeartbeats.
func (mr *MockStoreMockRecorder) UpdateChatHeartbeats(ctx, arg any) *gomock.Call {
// UpdateChatHeartbeat indicates an expected call of UpdateChatHeartbeat.
func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeats", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeats), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg)
}
// UpdateChatLabelsByID mocks base method.
@@ -8678,21 +8620,6 @@ func (mr *MockStoreMockRecorder) UpdateUserChatCustomPrompt(ctx, arg any) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).UpdateUserChatCustomPrompt), ctx, arg)
}
// UpdateUserChatProviderKey mocks base method.
func (m *MockStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateUserChatProviderKey", ctx, arg)
ret0, _ := ret[0].(database.UserChatProviderKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateUserChatProviderKey indicates an expected call of UpdateUserChatProviderKey.
func (mr *MockStoreMockRecorder) UpdateUserChatProviderKey(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpdateUserChatProviderKey), ctx, arg)
}
// UpdateUserDeletedByID mocks base method.
func (m *MockStore) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
@@ -9486,6 +9413,21 @@ func (mr *MockStoreMockRecorder) UpsertChatWorkspaceTTL(ctx, workspaceTtl any) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatWorkspaceTTL", reflect.TypeOf((*MockStore)(nil).UpsertChatWorkspaceTTL), ctx, workspaceTtl)
}
// UpsertConnectionLog mocks base method.
func (m *MockStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertConnectionLog", ctx, arg)
ret0, _ := ret[0].(database.ConnectionLog)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertConnectionLog indicates an expected call of UpsertConnectionLog.
func (mr *MockStoreMockRecorder) UpsertConnectionLog(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertConnectionLog", reflect.TypeOf((*MockStore)(nil).UpsertConnectionLog), ctx, arg)
}
// UpsertDefaultProxy mocks base method.
func (m *MockStore) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
m.ctrl.T.Helper()
@@ -9744,21 +9686,6 @@ func (mr *MockStoreMockRecorder) UpsertTemplateUsageStats(ctx any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTemplateUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertTemplateUsageStats), ctx)
}
// UpsertUserChatProviderKey mocks base method.
func (m *MockStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertUserChatProviderKey", ctx, arg)
ret0, _ := ret[0].(database.UserChatProviderKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertUserChatProviderKey indicates an expected call of UpsertUserChatProviderKey.
func (mr *MockStoreMockRecorder) UpsertUserChatProviderKey(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpsertUserChatProviderKey), ctx, arg)
}
// UpsertWebpushVAPIDKeys mocks base method.
func (m *MockStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
m.ctrl.T.Helper()
+11
View File
@@ -34,6 +34,9 @@ const (
// long enough to cover the maximum interval of a heartbeat event (currently
// 1 hour) plus some buffer.
maxTelemetryHeartbeatAge = 24 * time.Hour
// Chat files not referenced by any chat message that are older
// than this threshold are considered orphaned and deleted.
maxOrphanedChatFileAge = 24 * time.Hour
)
// New creates a new periodically purging database instance.
@@ -213,12 +216,19 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.
}
}
deleteOrphanedChatFilesBefore := start.Add(-maxOrphanedChatFileAge)
purgedChatFiles, err := tx.DeleteOrphanedChatFiles(ctx, deleteOrphanedChatFilesBefore)
if err != nil {
return xerrors.Errorf("failed to delete orphaned chat files: %w", err)
}
i.logger.Debug(ctx, "purged old database entries",
slog.F("workspace_agent_logs", purgedWorkspaceAgentLogs),
slog.F("expired_api_keys", expiredAPIKeys),
slog.F("aibridge_records", purgedAIBridgeRecords),
slog.F("connection_logs", purgedConnectionLogs),
slog.F("audit_logs", purgedAuditLogs),
slog.F("orphaned_chat_files", purgedChatFiles),
slog.F("duration", i.clk.Since(start)),
)
@@ -232,6 +242,7 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.
i.recordsPurged.WithLabelValues("aibridge_records").Add(float64(purgedAIBridgeRecords))
i.recordsPurged.WithLabelValues("connection_logs").Add(float64(purgedConnectionLogs))
i.recordsPurged.WithLabelValues("audit_logs").Add(float64(purgedAuditLogs))
i.recordsPurged.WithLabelValues("orphaned_chat_files").Add(float64(purgedChatFiles))
}
return nil
+68
View File
@@ -1630,6 +1630,74 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
}
}
//nolint:paralleltest // It uses LockIDDBPurge.
func TestDeleteOrphanedChatFiles(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort)
clk := quartz.NewMock(t)
now := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC)
clk.Set(now).MustWait(ctx)
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "test", Model: "test", DisplayName: "Test", Enabled: true, Options: json.RawMessage("{}"),
})
require.NoError(t, err)
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: user.ID, LastModelConfigID: modelCfg.ID, Title: "test chat",
})
require.NoError(t, err)
referencedFile, err := db.InsertChatFile(ctx, database.InsertChatFileParams{
OwnerID: user.ID, OrganizationID: org.ID, Name: "ref.png", Mimetype: "image/png", Data: []byte("ref"),
})
require.NoError(t, err)
fileContent := fmt.Sprintf("[{\"type\":\"file\",\"media_type\":\"image/png\",\"file_id\":\"%s\"}]", referencedFile.ID)
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: chat.ID, CreatedBy: []uuid.UUID{user.ID}, ModelConfigID: []uuid.UUID{modelCfg.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleUser}, Content: []string{fileContent},
ContentVersion: []int16{1}, Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{0}, OutputTokens: []int64{0}, TotalTokens: []int64{0}, ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0}, CacheReadTokens: []int64{0}, ContextLimit: []int64{0},
Compressed: []bool{false}, TotalCostMicros: []int64{0}, RuntimeMs: []int64{0}, ProviderResponseID: []string{""},
})
require.NoError(t, err)
orphanedFile, err := db.InsertChatFile(ctx, database.InsertChatFileParams{
OwnerID: user.ID, OrganizationID: org.ID, Name: "orphan.png", Mimetype: "image/png", Data: []byte("orphan"),
})
require.NoError(t, err)
recentOrphanedFile, err := db.InsertChatFile(ctx, database.InsertChatFileParams{
OwnerID: user.ID, OrganizationID: org.ID, Name: "recent.png", Mimetype: "image/png", Data: []byte("recent"),
})
require.NoError(t, err)
// Backdate old files past the 24h threshold.
_, err = sqlDB.ExecContext(ctx, "UPDATE chat_files SET created_at = $1 WHERE id = $2", now.Add(-48*time.Hour), orphanedFile.ID)
require.NoError(t, err)
_, err = sqlDB.ExecContext(ctx, "UPDATE chat_files SET created_at = $1 WHERE id = $2", now.Add(-48*time.Hour), referencedFile.ID)
require.NoError(t, err)
done := awaitDoTick(ctx, t, clk)
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
defer closer.Close()
<-done
_, err = db.GetChatFileByID(ctx, orphanedFile.ID)
require.Error(t, err, "orphaned file should be deleted after purge")
_, err = db.GetChatFileByID(ctx, referencedFile.ID)
require.NoError(t, err, "referenced file should still exist after purge")
_, err = db.GetChatFileByID(ctx, recentOrphanedFile.ID)
require.NoError(t, err, "recent orphaned file should still exist after purge")
}
// ptr is a helper to create a pointer to a value.
func ptr[T any](v T) *T {
return &v
+3 -55
View File
@@ -1134,9 +1134,7 @@ CREATE TABLE aibridge_token_usages (
input_tokens bigint NOT NULL,
output_tokens bigint NOT NULL,
metadata jsonb,
created_at timestamp with time zone NOT NULL,
cache_read_input_tokens bigint DEFAULT 0 NOT NULL,
cache_write_input_tokens bigint DEFAULT 0 NOT NULL
created_at timestamp with time zone NOT NULL
);
COMMENT ON TABLE aibridge_token_usages IS 'Audit log of tokens used by intercepted requests in AI Bridge';
@@ -1269,11 +1267,6 @@ CREATE TABLE chat_diff_statuses (
head_branch text
);
CREATE TABLE chat_file_links (
chat_id uuid NOT NULL,
file_id uuid NOT NULL
);
CREATE TABLE chat_files (
id uuid DEFAULT gen_random_uuid() NOT NULL,
owner_id uuid NOT NULL,
@@ -1348,11 +1341,7 @@ CREATE TABLE chat_providers (
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
base_url text DEFAULT ''::text NOT NULL,
central_api_key_enabled boolean DEFAULT true NOT NULL,
allow_user_api_key boolean DEFAULT false NOT NULL,
allow_central_api_key_fallback boolean DEFAULT false NOT NULL,
CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text]))),
CONSTRAINT valid_credential_policy CHECK (((central_api_key_enabled OR allow_user_api_key) AND ((NOT allow_central_api_key_fallback) OR (central_api_key_enabled AND allow_user_api_key))))
CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text])))
);
COMMENT ON COLUMN chat_providers.api_key_key_id IS 'The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted';
@@ -2763,17 +2752,6 @@ COMMENT ON TABLE usage_events_daily IS 'usage_events_daily is a daily rollup of
COMMENT ON COLUMN usage_events_daily.day IS 'The date of the summed usage events, always in UTC.';
CREATE TABLE user_chat_provider_keys (
id uuid DEFAULT gen_random_uuid() NOT NULL,
user_id uuid NOT NULL,
chat_provider_id uuid NOT NULL,
api_key text NOT NULL,
api_key_key_id text,
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
CONSTRAINT user_chat_provider_keys_api_key_check CHECK ((api_key <> ''::text))
);
CREATE TABLE user_configs (
user_id uuid NOT NULL,
key character varying(256) NOT NULL,
@@ -2815,8 +2793,7 @@ CREATE TABLE user_secrets (
env_name text DEFAULT ''::text NOT NULL,
file_path text DEFAULT ''::text NOT NULL,
created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NOT NULL,
value_key_id text
updated_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NOT NULL
);
CREATE TABLE user_status_changes (
@@ -3349,9 +3326,6 @@ ALTER TABLE ONLY boundary_usage_stats
ALTER TABLE ONLY chat_diff_statuses
ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
ALTER TABLE ONLY chat_file_links
ADD CONSTRAINT chat_file_links_chat_id_file_id_key UNIQUE (chat_id, file_id);
ALTER TABLE ONLY chat_files
ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
@@ -3574,12 +3548,6 @@ ALTER TABLE ONLY usage_events_daily
ALTER TABLE ONLY usage_events
ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id);
ALTER TABLE ONLY user_chat_provider_keys
ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id);
ALTER TABLE ONLY user_chat_provider_keys
ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id);
ALTER TABLE ONLY user_configs
ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key);
@@ -3742,8 +3710,6 @@ CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC);
CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses USING btree (stale_at);
CREATE INDEX idx_chat_file_links_chat_id ON chat_file_links USING btree (chat_id);
CREATE INDEX idx_chat_files_org ON chat_files USING btree (organization_id);
CREATE INDEX idx_chat_files_owner ON chat_files USING btree (owner_id);
@@ -4046,12 +4012,6 @@ ALTER TABLE ONLY api_keys
ALTER TABLE ONLY chat_diff_statuses
ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ALTER TABLE ONLY chat_file_links
ADD CONSTRAINT chat_file_links_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ALTER TABLE ONLY chat_file_links
ADD CONSTRAINT chat_file_links_file_id_fkey FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
ALTER TABLE ONLY chat_files
ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
@@ -4298,15 +4258,6 @@ ALTER TABLE ONLY templates
ALTER TABLE ONLY templates
ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ALTER TABLE ONLY user_chat_provider_keys
ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY user_chat_provider_keys
ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE;
ALTER TABLE ONLY user_chat_provider_keys
ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY user_configs
ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
@@ -4325,9 +4276,6 @@ ALTER TABLE ONLY user_links
ALTER TABLE ONLY user_secrets
ADD CONSTRAINT user_secrets_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY user_secrets
ADD CONSTRAINT user_secrets_value_key_id_fkey FOREIGN KEY (value_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY user_status_changes
ADD CONSTRAINT user_status_changes_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id);
@@ -10,8 +10,6 @@ const (
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatFileLinksChatID ForeignKeyConstraint = "chat_file_links_chat_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatFileLinksFileID ForeignKeyConstraint = "chat_file_links_file_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_file_id_fkey FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
ForeignKeyChatFilesOrganizationID ForeignKeyConstraint = "chat_files_organization_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ForeignKeyChatFilesOwnerID ForeignKeyConstraint = "chat_files_owner_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyChatMessagesChatID ForeignKeyConstraint = "chat_messages_chat_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
@@ -94,16 +92,12 @@ const (
ForeignKeyTemplateVersionsTemplateID ForeignKeyConstraint = "template_versions_template_id_fkey" // ALTER TABLE ONLY template_versions ADD CONSTRAINT template_versions_template_id_fkey FOREIGN KEY (template_id) REFERENCES templates(id) ON DELETE CASCADE;
ForeignKeyTemplatesCreatedBy ForeignKeyConstraint = "templates_created_by_fkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE RESTRICT;
ForeignKeyTemplatesOrganizationID ForeignKeyConstraint = "templates_organization_id_fkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ForeignKeyUserChatProviderKeysAPIKeyKeyID ForeignKeyConstraint = "user_chat_provider_keys_api_key_key_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyUserChatProviderKeysChatProviderID ForeignKeyConstraint = "user_chat_provider_keys_chat_provider_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE;
ForeignKeyUserChatProviderKeysUserID ForeignKeyConstraint = "user_chat_provider_keys_user_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyUserConfigsUserID ForeignKeyConstraint = "user_configs_user_id_fkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyUserDeletedUserID ForeignKeyConstraint = "user_deleted_user_id_fkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id);
ForeignKeyUserLinksOauthAccessTokenKeyID ForeignKeyConstraint = "user_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyUserLinksOauthRefreshTokenKeyID ForeignKeyConstraint = "user_links_oauth_refresh_token_key_id_fkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_oauth_refresh_token_key_id_fkey FOREIGN KEY (oauth_refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyUserLinksUserID ForeignKeyConstraint = "user_links_user_id_fkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyUserSecretsUserID ForeignKeyConstraint = "user_secrets_user_id_fkey" // ALTER TABLE ONLY user_secrets ADD CONSTRAINT user_secrets_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyUserSecretsValueKeyID ForeignKeyConstraint = "user_secrets_value_key_id_fkey" // ALTER TABLE ONLY user_secrets ADD CONSTRAINT user_secrets_value_key_id_fkey FOREIGN KEY (value_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyUserStatusChangesUserID ForeignKeyConstraint = "user_status_changes_user_id_fkey" // ALTER TABLE ONLY user_status_changes ADD CONSTRAINT user_status_changes_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id);
ForeignKeyWebpushSubscriptionsUserID ForeignKeyConstraint = "webpush_subscriptions_user_id_fkey" // ALTER TABLE ONLY webpush_subscriptions ADD CONSTRAINT webpush_subscriptions_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyWorkspaceAgentDevcontainersSubagentID ForeignKeyConstraint = "workspace_agent_devcontainers_subagent_id_fkey" // ALTER TABLE ONLY workspace_agent_devcontainers ADD CONSTRAINT workspace_agent_devcontainers_subagent_id_fkey FOREIGN KEY (subagent_id) REFERENCES workspace_agents(id) ON DELETE CASCADE;
@@ -1,8 +0,0 @@
DROP TABLE IF EXISTS user_chat_provider_keys;
ALTER TABLE chat_providers DROP CONSTRAINT IF EXISTS valid_credential_policy;
ALTER TABLE chat_providers
DROP COLUMN IF EXISTS central_api_key_enabled,
DROP COLUMN IF EXISTS allow_user_api_key,
DROP COLUMN IF EXISTS allow_central_api_key_fallback;
@@ -1,24 +0,0 @@
ALTER TABLE chat_providers
ADD COLUMN central_api_key_enabled BOOLEAN NOT NULL DEFAULT TRUE,
ADD COLUMN allow_user_api_key BOOLEAN NOT NULL DEFAULT FALSE,
ADD COLUMN allow_central_api_key_fallback BOOLEAN NOT NULL DEFAULT FALSE;
ALTER TABLE chat_providers
ADD CONSTRAINT valid_credential_policy CHECK (
(central_api_key_enabled OR allow_user_api_key) AND
(
NOT allow_central_api_key_fallback OR
(central_api_key_enabled AND allow_user_api_key)
)
);
CREATE TABLE user_chat_provider_keys (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
chat_provider_id UUID NOT NULL REFERENCES chat_providers(id) ON DELETE CASCADE,
api_key TEXT NOT NULL CHECK (api_key != ''),
api_key_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE (user_id, chat_provider_id)
);
@@ -1,3 +0,0 @@
ALTER TABLE user_secrets
DROP CONSTRAINT user_secrets_value_key_id_fkey,
DROP COLUMN value_key_id;
@@ -1,5 +0,0 @@
ALTER TABLE user_secrets
ADD COLUMN value_key_id TEXT;
ALTER TABLE ONLY user_secrets
ADD CONSTRAINT user_secrets_value_key_id_fkey FOREIGN KEY (value_key_id) REFERENCES dbcrypt_keys(active_key_digest);
@@ -1,3 +0,0 @@
ALTER TABLE aibridge_token_usages
DROP COLUMN cache_read_input_tokens,
DROP COLUMN cache_write_input_tokens;
@@ -1,26 +0,0 @@
ALTER TABLE aibridge_token_usages
ADD COLUMN cache_read_input_tokens BIGINT NOT NULL DEFAULT 0,
ADD COLUMN cache_write_input_tokens BIGINT NOT NULL DEFAULT 0;
-- Backfill from metadata JSONB. Old rows stored cache tokens under
-- provider-specific keys; new rows use the dedicated columns above.
UPDATE aibridge_token_usages
SET
-- Cache-read metadata keys by provider:
-- Anthropic (/v1/messages): "cache_read_input"
-- OpenAI (/v1/responses): "input_cached"
-- OpenAI (/v1/chat/completions): "prompt_cached"
cache_read_input_tokens = GREATEST(
COALESCE((metadata->>'cache_read_input')::bigint, 0),
COALESCE((metadata->>'input_cached')::bigint, 0),
COALESCE((metadata->>'prompt_cached')::bigint, 0)
),
-- Cache-write metadata keys by provider:
-- Anthropic (/v1/messages): "cache_creation_input"
-- OpenAI does not report cache-write tokens.
cache_write_input_tokens = COALESCE((metadata->>'cache_creation_input')::bigint, 0)
WHERE metadata IS NOT NULL
AND cache_read_input_tokens = 0
AND cache_write_input_tokens = 0;
@@ -1,9 +0,0 @@
ALTER TABLE chats ADD COLUMN file_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL;
UPDATE chats SET file_ids = (
SELECT COALESCE(array_agg(cfl.file_id), '{}')
FROM chat_file_links cfl
WHERE cfl.chat_id = chats.id
);
DROP TABLE chat_file_links;
@@ -1,17 +0,0 @@
CREATE TABLE chat_file_links (
chat_id uuid NOT NULL,
file_id uuid NOT NULL,
UNIQUE (chat_id, file_id)
);
CREATE INDEX idx_chat_file_links_chat_id ON chat_file_links (chat_id);
ALTER TABLE chat_file_links
ADD CONSTRAINT chat_file_links_chat_id_fkey
FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ALTER TABLE chat_file_links
ADD CONSTRAINT chat_file_links_file_id_fkey
FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
ALTER TABLE chats DROP COLUMN IF EXISTS file_ids;
@@ -1,16 +0,0 @@
INSERT INTO user_chat_provider_keys (
user_id,
chat_provider_id,
api_key,
created_at,
updated_at
)
SELECT
id,
'0a8b2f84-b5a8-4c44-8c9f-e58c44a534a7',
'fixture-test-key',
'2025-01-01 00:00:00+00',
'2025-01-01 00:00:00+00'
FROM users
ORDER BY created_at, id
LIMIT 1;
@@ -1,5 +0,0 @@
INSERT INTO chat_file_links (chat_id, file_id)
VALUES (
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
'00000000-0000-0000-0000-000000000099'
);
-30
View File
@@ -10,7 +10,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"golang.org/x/exp/maps"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
@@ -187,10 +186,6 @@ func (c ChatFile) RBACObject() rbac.Object {
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()).InOrg(c.OrganizationID)
}
func (c GetChatFileMetadataByChatIDRow) RBACObject() rbac.Object {
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()).InOrg(c.OrganizationID)
}
func (s APIKeyScope) ToRBAC() rbac.ScopeName {
switch s {
case ApiKeyScopeCoderAll:
@@ -928,28 +923,3 @@ func WorkspaceIdentityFromWorkspace(w Workspace) WorkspaceIdentity {
func (r GetWorkspaceAgentAndWorkspaceByIDRow) RBACObject() rbac.Object {
return r.WorkspaceTable.RBACObject()
}
// UpsertConnectionLogParams contains the parameters for upserting a
// connection log entry. This struct is hand-maintained (not generated
// by sqlc) because the single-row UpsertConnectionLog query was
// removed in favor of BatchUpsertConnectionLogs, but the struct is
// still used as the canonical connection log event type throughout
// the codebase.
type UpsertConnectionLogParams struct {
ID uuid.UUID `db:"id" json:"id"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
WorkspaceOwnerID uuid.UUID `db:"workspace_owner_id" json:"workspace_owner_id"`
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
WorkspaceName string `db:"workspace_name" json:"workspace_name"`
AgentName string `db:"agent_name" json:"agent_name"`
Type ConnectionType `db:"type" json:"type"`
Code sql.NullInt32 `db:"code" json:"code"`
IP pqtype.Inet `db:"ip" json:"ip"`
UserAgent sql.NullString `db:"user_agent" json:"user_agent"`
UserID uuid.NullUUID `db:"user_id" json:"user_id"`
SlugOrPort sql.NullString `db:"slug_or_port" json:"slug_or_port"`
ConnectionID uuid.NullUUID `db:"connection_id" json:"connection_id"`
DisconnectReason sql.NullString `db:"disconnect_reason" json:"disconnect_reason"`
Time time.Time `db:"time" json:"time"`
ConnectionStatus ConnectionStatus `db:"connection_status" json:"connection_status"`
}
-4
View File
@@ -584,7 +584,6 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi
arg.DateTo,
arg.BuildReason,
arg.RequestID,
arg.CountCap,
)
if err != nil {
return 0, err
@@ -721,7 +720,6 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
arg.WorkspaceID,
arg.ConnectionID,
arg.Status,
arg.CountCap,
)
if err != nil {
return 0, err
@@ -1031,8 +1029,6 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg Lis
&i.Threads,
&i.InputTokens,
&i.OutputTokens,
&i.CacheReadInputTokens,
&i.CacheWriteInputTokens,
&i.LastPrompt,
); err != nil {
return nil, err
@@ -145,13 +145,5 @@ func extractWhereClause(query string) string {
// Remove SQL comments
whereClause = regexp.MustCompile(`(?m)--.*$`).ReplaceAllString(whereClause, "")
// Normalize indentation so subquery wrapping doesn't cause
// mismatches.
lines := strings.Split(whereClause, "\n")
for i, line := range lines {
lines[i] = strings.TrimLeft(line, " \t")
}
whereClause = strings.Join(lines, "\n")
return strings.TrimSpace(whereClause)
}
+20 -41
View File
@@ -4055,13 +4055,11 @@ type AIBridgeTokenUsage struct {
ID uuid.UUID `db:"id" json:"id"`
InterceptionID uuid.UUID `db:"interception_id" json:"interception_id"`
// The ID for the response in which the tokens were used, produced by the provider.
ProviderResponseID string `db:"provider_response_id" json:"provider_response_id"`
InputTokens int64 `db:"input_tokens" json:"input_tokens"`
OutputTokens int64 `db:"output_tokens" json:"output_tokens"`
Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
CacheReadInputTokens int64 `db:"cache_read_input_tokens" json:"cache_read_input_tokens"`
CacheWriteInputTokens int64 `db:"cache_write_input_tokens" json:"cache_write_input_tokens"`
ProviderResponseID string `db:"provider_response_id" json:"provider_response_id"`
InputTokens int64 `db:"input_tokens" json:"input_tokens"`
OutputTokens int64 `db:"output_tokens" json:"output_tokens"`
Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
}
// Audit log of tool calls in intercepted requests in AI Bridge
@@ -4218,11 +4216,6 @@ type ChatFile struct {
Data []byte `db:"data" json:"data"`
}
type ChatFileLink struct {
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
FileID uuid.UUID `db:"file_id" json:"file_id"`
}
type ChatMessage struct {
ID int64 `db:"id" json:"id"`
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
@@ -4271,15 +4264,12 @@ type ChatProvider struct {
DisplayName string `db:"display_name" json:"display_name"`
APIKey string `db:"api_key" json:"api_key"`
// The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
Enabled bool `db:"enabled" json:"enabled"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
BaseUrl string `db:"base_url" json:"base_url"`
CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"`
AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"`
AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"`
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
Enabled bool `db:"enabled" json:"enabled"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
BaseUrl string `db:"base_url" json:"base_url"`
}
type ChatQueuedMessage struct {
@@ -5232,16 +5222,6 @@ type User struct {
ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"`
}
type UserChatProviderKey struct {
ID uuid.UUID `db:"id" json:"id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
APIKey string `db:"api_key" json:"api_key"`
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
type UserConfig struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
Key string `db:"key" json:"key"`
@@ -5271,16 +5251,15 @@ type UserLink struct {
}
type UserSecret struct {
ID uuid.UUID `db:"id" json:"id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
Value string `db:"value" json:"value"`
EnvName string `db:"env_name" json:"env_name"`
FilePath string `db:"file_path" json:"file_path"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
ValueKeyID sql.NullString `db:"value_key_id" json:"value_key_id"`
ID uuid.UUID `db:"id" json:"id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
Value string `db:"value" json:"value"`
EnvName string `db:"env_name" json:"env_name"`
FilePath string `db:"file_path" json:"file_path"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
// Tracks the history of user status changes
+11 -25
View File
@@ -65,7 +65,6 @@ type sqlcQuerier interface {
BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg BatchUpdateWorkspaceAgentMetadataParams) error
BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg BatchUpdateWorkspaceLastUsedAtParams) error
BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error
BatchUpsertConnectionLogs(ctx context.Context, arg BatchUpsertConnectionLogsParams) error
BulkMarkNotificationMessagesFailed(ctx context.Context, arg BulkMarkNotificationMessagesFailedParams) (int64, error)
BulkMarkNotificationMessagesSent(ctx context.Context, arg BulkMarkNotificationMessagesSentParams) (int64, error)
// Calculates the telemetry summary for a given provider, model, and client
@@ -144,6 +143,11 @@ type sqlcQuerier interface {
DeleteOldWorkspaceAgentLogs(ctx context.Context, threshold time.Time) (int64, error)
DeleteOldWorkspaceAgentStats(ctx context.Context) error
DeleteOrganizationMember(ctx context.Context, arg DeleteOrganizationMemberParams) error
// Deletes chat_files rows older than the given threshold that are
// not referenced by any non-deleted chat message. File references
// live inside the JSONB content array of chat_messages as
// {"file_id": "<uuid>"} entries in file-type parts.
DeleteOrphanedChatFiles(ctx context.Context, before time.Time) (int64, error)
DeleteProvisionerKey(ctx context.Context, id uuid.UUID) error
DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error
DeleteRuntimeConfig(ctx context.Context, key string) error
@@ -151,7 +155,6 @@ type sqlcQuerier interface {
DeleteTailnetTunnel(ctx context.Context, arg DeleteTailnetTunnelParams) (DeleteTailnetTunnelRow, error)
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
@@ -244,10 +247,6 @@ type sqlcQuerier interface {
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error)
// GetChatFileMetadataByChatID returns lightweight file metadata for
// all files linked to a chat. The data column is excluded to avoid
// loading file content.
GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]GetChatFileMetadataByChatIDRow, error)
GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error)
// GetChatIncludeDefaultSystemPrompt preserves the legacy default
// for deployments created before the explicit include-default toggle.
@@ -482,8 +481,8 @@ type sqlcQuerier interface {
// Used for recovery after coderd crashes or long hangs.
GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error)
GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]TailnetPeer, error)
GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]GetTailnetTunnelPeerBindingsBatchRow, error)
GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uuid.UUID) ([]GetTailnetTunnelPeerIDsBatchRow, error)
GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerBindingsRow, error)
GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerIDsRow, error)
GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error)
GetTaskByOwnerIDAndName(ctx context.Context, arg GetTaskByOwnerIDAndNameParams) (Task, error)
GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (Task, error)
@@ -583,7 +582,6 @@ type sqlcQuerier interface {
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
GetUserChatCompactionThreshold(ctx context.Context, arg GetUserChatCompactionThresholdParams) (string, error)
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error)
GetUserChatSpendInPeriod(ctx context.Context, arg GetUserChatSpendInPeriodParams) (int64, error)
GetUserCount(ctx context.Context, includeSystem bool) (int64, error)
// Returns the minimum (most restrictive) group limit for a user.
@@ -782,15 +780,6 @@ type sqlcQuerier interface {
InsertWorkspaceProxy(ctx context.Context, arg InsertWorkspaceProxyParams) (WorkspaceProxy, error)
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error)
// LinkChatFiles inserts file associations into the chat_file_links
// join table with deduplication (ON CONFLICT DO NOTHING). The INSERT
// is conditional: it only proceeds when the total number of links
// (existing + genuinely new) does not exceed max_file_links. Returns
// the number of genuinely new file IDs that were NOT inserted due to
// the cap. A return value of 0 means all files were linked (or were
// already linked). A positive value means the cap blocked that many
// new links.
LinkChatFiles(ctx context.Context, arg LinkChatFilesParams) (int32, error)
ListAIBridgeClients(ctx context.Context, arg ListAIBridgeClientsParams) ([]string, error)
ListAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams) ([]ListAIBridgeInterceptionsRow, error)
// Finds all unique AI Bridge interception telemetry summaries combinations
@@ -870,11 +859,9 @@ type sqlcQuerier interface {
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error)
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
// Bumps the heartbeat timestamp for the given set of chat IDs,
// provided they are still running and owned by the specified
// worker. Returns the IDs that were actually updated so the
// caller can detect stolen or completed chats via set-difference.
UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error)
// Bumps the heartbeat timestamp for a running chat so that other
// replicas know the worker is still alive.
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error)
// Updates the cached injected context parts (AGENTS.md +
// skills) on the chat row. Called only when context changes
@@ -945,7 +932,6 @@ type sqlcQuerier interface {
UpdateUsageEventsPostPublish(ctx context.Context, arg UpdateUsageEventsPostPublishParams) error
UpdateUserChatCompactionThreshold(ctx context.Context, arg UpdateUserChatCompactionThresholdParams) (UserConfig, error)
UpdateUserChatCustomPrompt(ctx context.Context, arg UpdateUserChatCustomPromptParams) (UserConfig, error)
UpdateUserChatProviderKey(ctx context.Context, arg UpdateUserChatProviderKeyParams) (UserChatProviderKey, error)
UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error
UpdateUserGithubComUserID(ctx context.Context, arg UpdateUserGithubComUserIDParams) error
UpdateUserHashedOneTimePasscode(ctx context.Context, arg UpdateUserHashedOneTimePasscodeParams) error
@@ -1007,6 +993,7 @@ type sqlcQuerier interface {
UpsertChatUsageLimitGroupOverride(ctx context.Context, arg UpsertChatUsageLimitGroupOverrideParams) (UpsertChatUsageLimitGroupOverrideRow, error)
UpsertChatUsageLimitUserOverride(ctx context.Context, arg UpsertChatUsageLimitUserOverrideParams) (UpsertChatUsageLimitUserOverrideRow, error)
UpsertChatWorkspaceTTL(ctx context.Context, workspaceTtl string) error
UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error)
// The default proxy is implied and not actually stored in the database.
// So we need to store it's configuration here for display purposes.
// The functional values are immutable and controlled implicitly.
@@ -1033,7 +1020,6 @@ type sqlcQuerier interface {
// used to store the data, and the minutes are summed for each user and template
// combination. The result is stored in the template_usage_stats table.
UpsertTemplateUsageStats(ctx context.Context) error
UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error)
UpsertWebpushVAPIDKeys(ctx context.Context, arg UpsertWebpushVAPIDKeysParams) error
UpsertWorkspaceAgentPortShare(ctx context.Context, arg UpsertWorkspaceAgentPortShareParams) (WorkspaceAgentPortShare, error)
UpsertWorkspaceApp(ctx context.Context, arg UpsertWorkspaceAppParams) (WorkspaceApp, error)
+243 -535
View File
@@ -1261,11 +1261,10 @@ func TestGetAuthorizedChats(t *testing.T) {
// Create FK dependencies: a chat provider and model config.
ctx := testutil.Context(t, testutil.WaitMedium)
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
})
require.NoError(t, err)
@@ -3566,11 +3565,9 @@ func connectionOnlyIDs[T database.ConnectionLog | database.GetConnectionLogsOffs
return ids
}
func TestBatchUpsertConnectionLogs(t *testing.T) {
func TestUpsertConnectionLog(t *testing.T) {
t.Parallel()
createWorkspace := func(t *testing.T, db database.Store) database.WorkspaceTable {
t.Helper()
u := dbgen.User(t, db, database.User{})
o := dbgen.Organization(t, db, database.Organization{})
tpl := dbgen.Template(t, db, database.Template{
@@ -3586,536 +3583,253 @@ func TestBatchUpsertConnectionLogs(t *testing.T) {
})
}
// zeroTime is the sentinel value that the SQL treats as "no
// connect/disconnect time provided".
zeroTime := time.Time{}
defaultIP := pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
}
t.Run("SingleConnect", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connID := uuid.New()
connectTime := dbtime.Now()
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{connectTime},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{0},
CodeValid: []bool{false},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{""},
DisconnectTime: []time.Time{zeroTime},
})
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows, 1)
require.True(t, connectTime.Equal(rows[0].ConnectionLog.ConnectTime))
require.False(t, rows[0].ConnectionLog.DisconnectTime.Valid,
"disconnect_time should be NULL for a connect-only event")
})
t.Run("ConnectThenDisconnect", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connID := uuid.New()
connectionID := uuid.New()
agentName := "test-agent"
// 1. Insert a 'connect' event.
connectTime := dbtime.Now()
// Insert connect.
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{connectTime},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{0},
CodeValid: []bool{false},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{""},
DisconnectTime: []time.Time{zeroTime},
})
require.NoError(t, err)
// Insert disconnect for same connection.
disconnectTime := connectTime.Add(time.Second)
err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{zeroTime},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{1},
CodeValid: []bool{true},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{"test disconnect"},
DisconnectTime: []time.Time{disconnectTime},
})
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows, 1)
row := rows[0].ConnectionLog
require.True(t, connectTime.Equal(row.ConnectTime))
require.True(t, row.DisconnectTime.Valid)
require.True(t, disconnectTime.Equal(row.DisconnectTime.Time))
require.Equal(t, "test disconnect", row.DisconnectReason.String)
require.Equal(t, int32(1), row.Code.Int32)
})
t.Run("DuplicateConnectIsNoOp", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connID := uuid.New()
connectTime := dbtime.Now()
mkParams := func(ct time.Time, ip pqtype.Inet) database.BatchUpsertConnectionLogsParams {
return database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{ct},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{0},
CodeValid: []bool{false},
Ip: []pqtype.Inet{ip},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{""},
DisconnectTime: []time.Time{zeroTime},
}
}
err := db.BatchUpsertConnectionLogs(ctx, mkParams(connectTime, defaultIP))
require.NoError(t, err)
rows1, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows1, 1)
// Second connect with later time and different IP.
otherIP := pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(10, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
connectParams := database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: connectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agentName,
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
},
Valid: true,
}
err = db.BatchUpsertConnectionLogs(ctx, mkParams(connectTime.Add(time.Second), otherIP))
log1, err := db.UpsertConnectionLog(ctx, connectParams)
require.NoError(t, err)
require.Equal(t, connectParams.ID, log1.ID)
require.False(t, log1.DisconnectTime.Valid, "DisconnectTime should not be set on connect")
// Check that one row exists.
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows, 1)
// 2. Insert a 'disconnected' event for the same connection.
disconnectTime := connectTime.Add(time.Second)
disconnectParams := database.UpsertConnectionLogParams{
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
WorkspaceID: ws.ID,
AgentName: agentName,
ConnectionStatus: database.ConnectionStatusDisconnected,
// Updated to:
Time: disconnectTime,
DisconnectReason: sql.NullString{String: "test disconnect", Valid: true},
Code: sql.NullInt32{Int32: 1, Valid: true},
// Ignored
ID: uuid.New(),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceName: ws.Name,
Type: database.ConnectionTypeSsh,
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 254),
},
Valid: true,
},
}
log2, err := db.UpsertConnectionLog(ctx, disconnectParams)
require.NoError(t, err)
rows2, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows2, 1)
// Updated
require.Equal(t, log1.ID, log2.ID)
require.True(t, log2.DisconnectTime.Valid)
require.True(t, disconnectTime.Equal(log2.DisconnectTime.Time))
require.Equal(t, disconnectParams.DisconnectReason.String, log2.DisconnectReason.String)
// The LEAST logic should pick the earlier connect_time; IP and
// other fields are not updated on conflict.
require.True(t, connectTime.Equal(rows2[0].ConnectionLog.ConnectTime),
"connect_time should remain the original (earlier) value")
rows, err = db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
require.Len(t, rows, 1)
})
t.Run("OrderIndependentConnectTime", func(t *testing.T) {
t.Run("ConnectDoesNotUpdate", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connID := uuid.New()
connectionID := uuid.New()
agentName := "test-agent"
// 1. Insert a 'connect' event.
connectTime := dbtime.Now()
connectParams := database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: connectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agentName,
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
},
}
log, err := db.UpsertConnectionLog(ctx, connectParams)
require.NoError(t, err)
// 2. Insert another 'connect' event for the same connection.
connectTime2 := connectTime.Add(time.Second)
connectParams2 := database.UpsertConnectionLogParams{
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
WorkspaceID: ws.ID,
AgentName: agentName,
ConnectionStatus: database.ConnectionStatusConnected,
// Ignored
ID: uuid.New(),
Time: connectTime2,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceName: ws.Name,
Type: database.ConnectionTypeSsh,
Code: sql.NullInt32{Int32: 0, Valid: false},
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 254),
},
Valid: true,
},
}
origLog, err := db.UpsertConnectionLog(ctx, connectParams2)
require.NoError(t, err)
require.Equal(t, log, origLog, "connect update should be a no-op")
// Check that still only one row exists.
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
require.Len(t, rows, 1)
require.Equal(t, log, rows[0].ConnectionLog)
})
t.Run("DisconnectThenConnect", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connectionID := uuid.New()
agentName := "test-agent"
// Insert just a 'disconect' event
disconnectTime := dbtime.Now()
connectTime := disconnectTime.Add(-5 * time.Second)
// Disconnect arrives first.
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{disconnectTime},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{0},
CodeValid: []bool{true},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{"bye"},
DisconnectTime: []time.Time{disconnectTime},
})
require.NoError(t, err)
// Connect arrives second with the real (earlier) connect_time.
err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{connectTime},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{0},
CodeValid: []bool{false},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{""},
DisconnectTime: []time.Time{zeroTime},
})
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows, 1)
require.True(t, connectTime.Equal(rows[0].ConnectionLog.ConnectTime),
"LEAST should pick the earlier connect_time")
})
t.Run("DisconnectFieldsAreWriteOnce", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connID := uuid.New()
disconnectTime := dbtime.Now()
mkDisconnect := func(reason string, code int32) database.BatchUpsertConnectionLogsParams {
return database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{disconnectTime},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{code},
CodeValid: []bool{true},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{reason},
DisconnectTime: []time.Time{disconnectTime},
}
disconnectParams := database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: disconnectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agentName,
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
ConnectionStatus: database.ConnectionStatusDisconnected,
DisconnectReason: sql.NullString{String: "server shutting down", Valid: true},
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
},
}
err := db.BatchUpsertConnectionLogs(ctx, mkDisconnect("first reason", 1))
_, err := db.UpsertConnectionLog(ctx, disconnectParams)
require.NoError(t, err)
// Second disconnect with different reason and code.
err = db.BatchUpsertConnectionLogs(ctx, mkDisconnect("second reason", 2))
firstRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
require.Len(t, firstRows, 1)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows, 1)
row := rows[0].ConnectionLog
require.Equal(t, "first reason", row.DisconnectReason.String,
"disconnect_reason should not be overwritten")
require.Equal(t, int32(1), row.Code.Int32,
"code should not be overwritten")
})
// We expect the connection event to be marked as closed with the start
// and close time being the same.
require.True(t, firstRows[0].ConnectionLog.DisconnectTime.Valid)
require.Equal(t, disconnectTime, firstRows[0].ConnectionLog.DisconnectTime.Time.UTC())
require.Equal(t, firstRows[0].ConnectionLog.ConnectTime.UTC(), firstRows[0].ConnectionLog.DisconnectTime.Time.UTC())
t.Run("ConnectAfterDisconnectIsNoOp", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connID := uuid.New()
disconnectTime := dbtime.Now()
// Insert disconnect first.
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{disconnectTime},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{42},
CodeValid: []bool{true},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{"server shutdown"},
DisconnectTime: []time.Time{disconnectTime},
})
require.NoError(t, err)
rows1, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows1, 1)
require.True(t, rows1[0].ConnectionLog.DisconnectTime.Valid)
require.Equal(t, "server shutdown", rows1[0].ConnectionLog.DisconnectReason.String)
require.Equal(t, int32(42), rows1[0].ConnectionLog.Code.Int32)
// Insert connect for same connection_id.
err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{disconnectTime.Add(time.Second)},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{0},
CodeValid: []bool{false},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{""},
DisconnectTime: []time.Time{zeroTime},
})
require.NoError(t, err)
rows2, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows2, 1)
row := rows2[0].ConnectionLog
require.True(t, row.DisconnectTime.Valid,
"disconnect_time should not be cleared by a later connect")
require.Equal(t, "server shutdown", row.DisconnectReason.String,
"disconnect_reason should not be cleared")
require.Equal(t, int32(42), row.Code.Int32,
"code should not be cleared")
})
t.Run("CodeZeroPreserved", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connID := uuid.New()
now := dbtime.Now()
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{now},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{0},
CodeValid: []bool{true},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{"normal"},
DisconnectTime: []time.Time{now},
})
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows, 1)
require.True(t, rows[0].ConnectionLog.Code.Valid, "code should be non-NULL")
require.Equal(t, int32(0), rows[0].ConnectionLog.Code.Int32,
"code=0 should be preserved, not treated as NULL")
})
t.Run("CodeNullWhenInvalid", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connID := uuid.New()
now := dbtime.Now()
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{now},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{99},
CodeValid: []bool{false},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{""},
DisconnectTime: []time.Time{zeroTime},
})
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows, 1)
require.False(t, rows[0].ConnectionLog.Code.Valid,
"code should be NULL when code_valid is false")
})
t.Run("NullConnectionIDEvents", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
now := dbtime.Now()
// Insert two web events with NULL connection_id (uuid.Nil →
// NULL via NULLIF) for the same workspace/agent.
for i := range 2 {
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{now.Add(time.Duration(i) * time.Second)},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{200},
CodeValid: []bool{true},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{"Mozilla/5.0"},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{"web-terminal"},
ConnectionID: []uuid.UUID{uuid.Nil},
DisconnectReason: []string{""},
DisconnectTime: []time.Time{zeroTime},
})
require.NoError(t, err)
// Now insert a 'connect' event for the same connection.
// This should be a no op
connectTime := disconnectTime.Add(time.Second)
connectParams := database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: connectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agentName,
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
DisconnectReason: sql.NullString{String: "reconnected", Valid: true},
Code: sql.NullInt32{Int32: 0, Valid: false},
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
},
}
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
_, err = db.UpsertConnectionLog(ctx, connectParams)
require.NoError(t, err)
require.Len(t, rows, 2,
"NULL connection_id rows should not conflict with each other")
})
t.Run("MultipleIndependentConnections", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
now := dbtime.Now()
secondRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
require.Len(t, secondRows, 1)
require.Equal(t, firstRows, secondRows)
n := 5
ids := make([]uuid.UUID, n)
connectTimes := make([]time.Time, n)
orgIDs := make([]uuid.UUID, n)
ownerIDs := make([]uuid.UUID, n)
wsIDs := make([]uuid.UUID, n)
wsNames := make([]string, n)
agentNames := make([]string, n)
types := make([]database.ConnectionType, n)
codes := make([]int32, n)
codeValids := make([]bool, n)
ips := make([]pqtype.Inet, n)
userAgents := make([]string, n)
userIDs := make([]uuid.UUID, n)
slugOrPorts := make([]string, n)
connIDs := make([]uuid.UUID, n)
disconnectReasons := make([]string, n)
disconnectTimes := make([]time.Time, n)
for i := range n {
ids[i] = uuid.New()
connectTimes[i] = now.Add(time.Duration(i) * time.Second)
orgIDs[i] = ws.OrganizationID
ownerIDs[i] = ws.OwnerID
wsIDs[i] = ws.ID
wsNames[i] = ws.Name
agentNames[i] = "agent"
types[i] = database.ConnectionTypeSsh
codes[i] = 0
codeValids[i] = false
ips[i] = defaultIP
userAgents[i] = ""
userIDs[i] = uuid.Nil
slugOrPorts[i] = ""
connIDs[i] = uuid.New()
disconnectReasons[i] = ""
disconnectTimes[i] = zeroTime
// Upsert a disconnection, which should also be a no op
disconnectParams.DisconnectReason = sql.NullString{
String: "updated close reason",
Valid: true,
}
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: ids,
ConnectTime: connectTimes,
OrganizationID: orgIDs,
WorkspaceOwnerID: ownerIDs,
WorkspaceID: wsIDs,
WorkspaceName: wsNames,
AgentName: agentNames,
Type: types,
Code: codes,
CodeValid: codeValids,
Ip: ips,
UserAgent: userAgents,
UserID: userIDs,
SlugOrPort: slugOrPorts,
ConnectionID: connIDs,
DisconnectReason: disconnectReasons,
DisconnectTime: disconnectTimes,
})
_, err = db.UpsertConnectionLog(ctx, disconnectParams)
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
thirdRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
require.Len(t, rows, n, "each unique connection_id should produce its own row")
require.Len(t, secondRows, 1)
// The close reason shouldn't be updated
require.Equal(t, secondRows, thirdRows)
})
}
@@ -9742,11 +9456,10 @@ func TestInsertChatMessages(t *testing.T) {
provider := "openai"
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: provider,
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
Provider: provider,
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
})
require.NoError(t, err)
@@ -9908,11 +9621,10 @@ func TestGetChatMessagesForPromptByChatID(t *testing.T) {
// A chat_providers row is required as a FK for model configs.
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
})
require.NoError(t, err)
@@ -10280,11 +9992,10 @@ func TestGetPRInsights(t *testing.T) {
user := dbgen.User(t, store, database.User{})
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "anthropic",
DisplayName: "Anthropic",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
Provider: "anthropic",
DisplayName: "Anthropic",
APIKey: "test-key",
Enabled: true,
})
require.NoError(t, err)
@@ -10805,11 +10516,10 @@ func TestChatPinOrderQueries(t *testing.T) {
// timed test context doesn't tick during DB init.
bg := context.Background()
_, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
})
require.NoError(t, err)
@@ -10986,11 +10696,10 @@ func TestChatLabels(t *testing.T) {
owner := dbgen.User(t, db, database.User{})
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
})
require.NoError(t, err)
@@ -11198,11 +10907,10 @@ func TestChatHasUnread(t *testing.T) {
user := dbgen.User(t, store, database.User{})
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
})
require.NoError(t, err)
File diff suppressed because it is too large Load Diff
+18 -9
View File
@@ -31,9 +31,9 @@ WHERE aibridge_interceptions.id = (
-- name: InsertAIBridgeTokenUsage :one
INSERT INTO aibridge_token_usages (
id, interception_id, provider_response_id, input_tokens, output_tokens, cache_read_input_tokens, cache_write_input_tokens, metadata, created_at
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
) VALUES (
@id, @interception_id, @provider_response_id, @input_tokens, @output_tokens, @cache_read_input_tokens, @cache_write_input_tokens, COALESCE(@metadata::jsonb, '{}'::jsonb), @created_at
@id, @interception_id, @provider_response_id, @input_tokens, @output_tokens, COALESCE(@metadata::jsonb, '{}'::jsonb), @created_at
)
RETURNING *;
@@ -299,8 +299,21 @@ token_aggregates AS (
SELECT
COALESCE(SUM(tu.input_tokens), 0) AS token_count_input,
COALESCE(SUM(tu.output_tokens), 0) AS token_count_output,
COALESCE(SUM(tu.cache_read_input_tokens), 0) AS token_count_cached_read,
COALESCE(SUM(tu.cache_write_input_tokens), 0) AS token_count_cached_written,
-- Cached tokens are stored in metadata JSON, extract if available.
-- Read tokens may be stored in:
-- - cache_read_input (Anthropic)
-- - prompt_cached (OpenAI)
COALESCE(SUM(
COALESCE((tu.metadata->>'cache_read_input')::bigint, 0) +
COALESCE((tu.metadata->>'prompt_cached')::bigint, 0)
), 0) AS token_count_cached_read,
-- Written tokens may be stored in:
-- - cache_creation_input (Anthropic)
-- Note that cache_ephemeral_5m_input and cache_ephemeral_1h_input on
-- Anthropic are included in the cache_creation_input field.
COALESCE(SUM(
COALESCE((tu.metadata->>'cache_creation_input')::bigint, 0)
), 0) AS token_count_cached_written,
COUNT(tu.id) AS token_usages_count
FROM
interceptions_in_range i
@@ -539,8 +552,6 @@ SELECT
sp.threads,
COALESCE(st.input_tokens, 0)::bigint AS input_tokens,
COALESCE(st.output_tokens, 0)::bigint AS output_tokens,
COALESCE(st.cache_read_input_tokens, 0)::bigint AS cache_read_input_tokens,
COALESCE(st.cache_write_input_tokens, 0)::bigint AS cache_write_input_tokens,
COALESCE(slp.prompt, '') AS last_prompt
FROM
session_page sp
@@ -562,9 +573,7 @@ LEFT JOIN LATERAL (
-- Aggregate tokens only for this session's interceptions.
SELECT
COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens,
COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens,
COALESCE(SUM(tu.cache_read_input_tokens), 0)::bigint AS cache_read_input_tokens,
COALESCE(SUM(tu.cache_write_input_tokens), 0)::bigint AS cache_write_input_tokens
COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens
FROM aibridge_token_usages tu
WHERE tu.interception_id = ANY(sr.interception_ids)
) st ON true
+88 -99
View File
@@ -149,105 +149,94 @@ VALUES (
RETURNING *;
-- name: CountAuditLogs :one
SELECT COUNT(*) FROM (
SELECT 1
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
ELSE true
END
-- Filter organization_id
AND CASE
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN @resource_target::text != '' THEN resource_target = @resource_target
ELSE true
END
-- Filter action
AND CASE
WHEN @action::text != '' THEN action = @action::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower(@username)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @email::text != '' THEN users.email = @email
ELSE true
END
-- Filter by date_from
AND CASE
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
ELSE true
END
-- Filter by date_to
AND CASE
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
ELSE true
END
-- Filter request_id
AND CASE
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
-- Avoid a slow scan on a large table with joins. The caller
-- passes the count cap and we add 1 so the frontend can detect
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
-- -> NULL + 1 = NULL).
-- NOTE: Parameterizing this so that we can easily change from,
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
-- here if disabling the capping on a large table permanently.
-- This way the PG planner can plan parallel execution for
-- potential large wins.
LIMIT NULLIF(@count_cap::int, 0) + 1
) AS limited_count;
SELECT COUNT(*)
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
ELSE true
END
-- Filter organization_id
AND CASE
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN @resource_target::text != '' THEN resource_target = @resource_target
ELSE true
END
-- Filter action
AND CASE
WHEN @action::text != '' THEN action = @action::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower(@username)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @email::text != '' THEN users.email = @email
ELSE true
END
-- Filter by date_from
AND CASE
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
ELSE true
END
-- Filter by date_to
AND CASE
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
ELSE true
END
-- Filter request_id
AND CASE
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
;
-- name: DeleteOldAuditLogConnectionEvents :exec
DELETE FROM audit_logs
+14 -9
View File
@@ -9,12 +9,17 @@ SELECT * FROM chat_files WHERE id = @id::uuid;
-- name: GetChatFilesByIDs :many
SELECT * FROM chat_files WHERE id = ANY(@ids::uuid[]);
-- name: GetChatFileMetadataByChatID :many
-- GetChatFileMetadataByChatID returns lightweight file metadata for
-- all files linked to a chat. The data column is excluded to avoid
-- loading file content.
SELECT cf.id, cf.owner_id, cf.organization_id, cf.name, cf.mimetype, cf.created_at
FROM chat_files cf
JOIN chat_file_links cfl ON cfl.file_id = cf.id
WHERE cfl.chat_id = @chat_id::uuid
ORDER BY cf.created_at ASC;
-- name: DeleteOrphanedChatFiles :execrows
-- Deletes chat_files rows older than the given threshold that are
-- not referenced by any non-deleted chat message. File references
-- live inside the JSONB content array of chat_messages as
-- {"file_id": "<uuid>"} entries in file-type parts.
DELETE FROM chat_files
WHERE created_at < @before::timestamptz
AND NOT EXISTS (
SELECT 1
FROM chat_messages cm,
jsonb_array_elements(cm.content) AS elem
WHERE (elem ->> 'file_id')::uuid = chat_files.id
AND cm.deleted = false
);
+2 -11
View File
@@ -40,10 +40,7 @@ INSERT INTO chat_providers (
base_url,
api_key_key_id,
created_by,
enabled,
central_api_key_enabled,
allow_user_api_key,
allow_central_api_key_fallback
enabled
) VALUES (
@provider::text,
@display_name::text,
@@ -51,10 +48,7 @@ INSERT INTO chat_providers (
@base_url::text,
sqlc.narg('api_key_key_id')::text,
sqlc.narg('created_by')::uuid,
@enabled::boolean,
@central_api_key_enabled::boolean,
@allow_user_api_key::boolean,
@allow_central_api_key_fallback::boolean
@enabled::boolean
)
RETURNING
*;
@@ -68,9 +62,6 @@ SET
base_url = @base_url::text,
api_key_key_id = sqlc.narg('api_key_key_id')::text,
enabled = @enabled::boolean,
central_api_key_enabled = @central_api_key_enabled::boolean,
allow_user_api_key = @allow_user_api_key::boolean,
allow_central_api_key_fallback = @allow_central_api_key_fallback::boolean,
updated_at = NOW()
WHERE
id = @id::uuid
+11 -57
View File
@@ -567,43 +567,6 @@ WHERE
RETURNING
*;
-- name: LinkChatFiles :one
-- LinkChatFiles inserts file associations into the chat_file_links
-- join table with deduplication (ON CONFLICT DO NOTHING). The INSERT
-- is conditional: it only proceeds when the total number of links
-- (existing + genuinely new) does not exceed max_file_links. Returns
-- the number of genuinely new file IDs that were NOT inserted due to
-- the cap. A return value of 0 means all files were linked (or were
-- already linked). A positive value means the cap blocked that many
-- new links.
WITH current AS (
SELECT COUNT(*) AS cnt
FROM chat_file_links
WHERE chat_id = @chat_id::uuid
),
new_links AS (
SELECT @chat_id::uuid AS chat_id, unnest(@file_ids::uuid[]) AS file_id
),
genuinely_new AS (
SELECT nl.chat_id, nl.file_id
FROM new_links nl
WHERE NOT EXISTS (
SELECT 1 FROM chat_file_links cfl
WHERE cfl.chat_id = nl.chat_id AND cfl.file_id = nl.file_id
)
),
inserted AS (
INSERT INTO chat_file_links (chat_id, file_id)
SELECT gn.chat_id, gn.file_id
FROM genuinely_new gn, current c
WHERE c.cnt + (SELECT COUNT(*) FROM genuinely_new) <= @max_file_links::int
ON CONFLICT (chat_id, file_id) DO NOTHING
RETURNING file_id
)
SELECT
(SELECT COUNT(*)::int FROM genuinely_new) -
(SELECT COUNT(*)::int FROM inserted) AS rejected_new_files;
-- name: AcquireChats :many
-- Acquires up to @num_chats pending chats for processing. Uses SKIP LOCKED
-- to prevent multiple replicas from acquiring the same chat.
@@ -674,20 +637,17 @@ WHERE
status = 'running'::chat_status
AND heartbeat_at < @stale_threshold::timestamptz;
-- name: UpdateChatHeartbeats :many
-- Bumps the heartbeat timestamp for the given set of chat IDs,
-- provided they are still running and owned by the specified
-- worker. Returns the IDs that were actually updated so the
-- caller can detect stolen or completed chats via set-difference.
-- name: UpdateChatHeartbeat :execrows
-- Bumps the heartbeat timestamp for a running chat so that other
-- replicas know the worker is still alive.
UPDATE
chats
SET
heartbeat_at = @now::timestamptz
heartbeat_at = NOW()
WHERE
id = ANY(@ids::uuid[])
id = @id::uuid
AND worker_id = @worker_id::uuid
AND status = 'running'::chat_status
RETURNING id;
AND status = 'running'::chat_status;
-- name: GetChatDiffStatusByChatID :one
SELECT
@@ -923,8 +883,7 @@ SELECT
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
FROM
chat_messages cm
JOIN
@@ -954,8 +913,7 @@ SELECT
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
FROM
chat_messages cm
JOIN
@@ -990,8 +948,7 @@ WITH chat_costs AS (
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
FROM chat_messages cm
JOIN chats c ON c.id = cm.chat_id
WHERE c.owner_id = @owner_id::uuid
@@ -1008,8 +965,7 @@ SELECT
cc.total_input_tokens,
cc.total_output_tokens,
cc.total_cache_read_tokens,
cc.total_cache_creation_tokens,
cc.total_runtime_ms
cc.total_cache_creation_tokens
FROM chat_costs cc
LEFT JOIN chats rc ON rc.id = cc.root_chat_id
ORDER BY cc.total_cost_micros DESC;
@@ -1035,8 +991,7 @@ WITH chat_cost_users AS (
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
FROM
chat_messages cm
JOIN
@@ -1070,7 +1025,6 @@ SELECT
total_output_tokens,
total_cache_read_tokens,
total_cache_creation_tokens,
total_runtime_ms,
COUNT(*) OVER()::bigint AS total_count
FROM
chat_cost_users
+154 -176
View File
@@ -133,113 +133,111 @@ OFFSET
@offset_opt;
-- name: CountConnectionLogs :one
SELECT COUNT(*) AS count FROM (
SELECT 1
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = @organization_id
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN @workspace_owner :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = @workspace_owner_id
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN @workspace_owner_email :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = @workspace_owner_email AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN @type :: text != '' THEN
type = @type :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower(@username) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @user_email :: text != '' THEN
users.email = @user_email
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= @connected_after
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= @connected_before
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = @workspace_id
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = @connection_id
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN @status :: text != '' THEN
((@status = 'ongoing' AND disconnect_time IS NULL) OR
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
-- NOTE: See the CountAuditLogs LIMIT note.
LIMIT NULLIF(@count_cap::int, 0) + 1
) AS limited_count;
SELECT
COUNT(*) AS count
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = @organization_id
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN @workspace_owner :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = @workspace_owner_id
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN @workspace_owner_email :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = @workspace_owner_email AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN @type :: text != '' THEN
type = @type :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower(@username) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @user_email :: text != '' THEN
users.email = @user_email
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= @connected_after
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= @connected_before
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = @workspace_id
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = @connection_id
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN @status :: text != '' THEN
((@status = 'ongoing' AND disconnect_time IS NULL) OR
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
;
-- name: DeleteOldConnectionLogs :execrows
WITH old_logs AS (
@@ -253,75 +251,55 @@ DELETE FROM connection_logs
USING old_logs
WHERE connection_logs.id = old_logs.id;
-- name: BatchUpsertConnectionLogs :exec
-- name: UpsertConnectionLog :one
INSERT INTO connection_logs (
id, connect_time, organization_id, workspace_owner_id, workspace_id,
workspace_name, agent_name, type, code, ip, user_agent, user_id,
slug_or_port, connection_id, disconnect_reason, disconnect_time
)
SELECT
u.id,
u.connect_time,
u.organization_id,
u.workspace_owner_id,
u.workspace_id,
u.workspace_name,
u.agent_name,
u.type,
-- Use the validity flag to distinguish "no code" (NULL) from a
-- legitimate zero exit code.
CASE WHEN u.code_valid THEN u.code ELSE NULL END,
u.ip,
NULLIF(u.user_agent, ''),
NULLIF(u.user_id, '00000000-0000-0000-0000-000000000000'::uuid),
NULLIF(u.slug_or_port, ''),
NULLIF(u.connection_id, '00000000-0000-0000-0000-000000000000'::uuid),
NULLIF(u.disconnect_reason, ''),
NULLIF(u.disconnect_time, '0001-01-01 00:00:00Z'::timestamptz)
FROM (
SELECT
unnest(sqlc.arg('id')::uuid[]) AS id,
unnest(sqlc.arg('connect_time')::timestamptz[]) AS connect_time,
unnest(sqlc.arg('organization_id')::uuid[]) AS organization_id,
unnest(sqlc.arg('workspace_owner_id')::uuid[]) AS workspace_owner_id,
unnest(sqlc.arg('workspace_id')::uuid[]) AS workspace_id,
unnest(sqlc.arg('workspace_name')::text[]) AS workspace_name,
unnest(sqlc.arg('agent_name')::text[]) AS agent_name,
unnest(sqlc.arg('type')::connection_type[]) AS type,
unnest(sqlc.arg('code')::int4[]) AS code,
unnest(sqlc.arg('code_valid')::bool[]) AS code_valid,
unnest(sqlc.arg('ip')::inet[]) AS ip,
unnest(sqlc.arg('user_agent')::text[]) AS user_agent,
unnest(sqlc.arg('user_id')::uuid[]) AS user_id,
unnest(sqlc.arg('slug_or_port')::text[]) AS slug_or_port,
unnest(sqlc.arg('connection_id')::uuid[]) AS connection_id,
unnest(sqlc.arg('disconnect_reason')::text[]) AS disconnect_reason,
unnest(sqlc.arg('disconnect_time')::timestamptz[]) AS disconnect_time
) AS u
id,
connect_time,
organization_id,
workspace_owner_id,
workspace_id,
workspace_name,
agent_name,
type,
code,
ip,
user_agent,
user_id,
slug_or_port,
connection_id,
disconnect_reason,
disconnect_time
) VALUES
($1, @time, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14,
-- If we've only received a disconnect event, mark the event as immediately
-- closed.
CASE
WHEN @connection_status::connection_status = 'disconnected'
THEN @time :: timestamp with time zone
ELSE NULL
END)
ON CONFLICT (connection_id, workspace_id, agent_name)
DO UPDATE SET
-- Pick the earliest real connect_time. The zero sentinel
-- ('0001-01-01') means the batch didn't know the connect_time
-- (e.g. a pure disconnect event), so we keep the existing value.
connect_time = CASE
WHEN EXCLUDED.connect_time = '0001-01-01 00:00:00Z'::timestamptz
THEN connection_logs.connect_time
WHEN connection_logs.connect_time = '0001-01-01 00:00:00Z'::timestamptz
THEN EXCLUDED.connect_time
ELSE LEAST(connection_logs.connect_time, EXCLUDED.connect_time)
END,
disconnect_time = CASE
WHEN connection_logs.disconnect_time IS NULL
THEN EXCLUDED.disconnect_time
ELSE connection_logs.disconnect_time
END,
disconnect_reason = CASE
WHEN connection_logs.disconnect_reason IS NULL
THEN EXCLUDED.disconnect_reason
ELSE connection_logs.disconnect_reason
END,
code = CASE
WHEN connection_logs.code IS NULL
THEN EXCLUDED.code
ELSE connection_logs.code
END;
-- No-op if the connection is still open.
disconnect_time = CASE
WHEN @connection_status::connection_status = 'disconnected'
-- Can only be set once
AND connection_logs.disconnect_time IS NULL
THEN EXCLUDED.connect_time
ELSE connection_logs.disconnect_time
END,
disconnect_reason = CASE
WHEN @connection_status::connection_status = 'disconnected'
-- Can only be set once
AND connection_logs.disconnect_reason IS NULL
THEN EXCLUDED.disconnect_reason
ELSE connection_logs.disconnect_reason
END,
code = CASE
WHEN @connection_status::connection_status = 'disconnected'
-- Can only be set once
AND connection_logs.code IS NULL
THEN EXCLUDED.code
ELSE connection_logs.code
END
RETURNING *;
+22 -19
View File
@@ -96,6 +96,28 @@ DELETE
FROM tailnet_tunnels
WHERE coordinator_id = $1 and src_id = $2;
-- name: GetTailnetTunnelPeerIDs :many
SELECT dst_id as peer_id, coordinator_id, updated_at
FROM tailnet_tunnels
WHERE tailnet_tunnels.src_id = $1
UNION
SELECT src_id as peer_id, coordinator_id, updated_at
FROM tailnet_tunnels
WHERE tailnet_tunnels.dst_id = $1;
-- name: GetTailnetTunnelPeerBindings :many
SELECT id AS peer_id, coordinator_id, updated_at, node, status
FROM tailnet_peers
WHERE id IN (
SELECT dst_id as peer_id
FROM tailnet_tunnels
WHERE tailnet_tunnels.src_id = $1
UNION
SELECT src_id as peer_id
FROM tailnet_tunnels
WHERE tailnet_tunnels.dst_id = $1
);
-- For PG Coordinator HTMLDebug
-- name: GetAllTailnetCoordinators :many
@@ -106,22 +128,3 @@ SELECT * FROM tailnet_peers;
-- name: GetAllTailnetTunnels :many
SELECT * FROM tailnet_tunnels;
-- name: GetTailnetTunnelPeerIDsBatch :many
SELECT src_id AS lookup_id, dst_id AS peer_id, coordinator_id, updated_at
FROM tailnet_tunnels WHERE src_id = ANY(@ids :: uuid[])
UNION ALL
SELECT dst_id AS lookup_id, src_id AS peer_id, coordinator_id, updated_at
FROM tailnet_tunnels WHERE dst_id = ANY(@ids :: uuid[]);
-- name: GetTailnetTunnelPeerBindingsBatch :many
SELECT tp.id AS peer_id, tp.coordinator_id, tp.updated_at, tp.node, tp.status,
tunnels.lookup_id
FROM (
SELECT dst_id AS peer_id, src_id AS lookup_id
FROM tailnet_tunnels WHERE src_id = ANY(@ids :: uuid[])
UNION
SELECT src_id AS peer_id, dst_id AS lookup_id
FROM tailnet_tunnels WHERE dst_id = ANY(@ids :: uuid[])
) tunnels
INNER JOIN tailnet_peers tp ON tp.id = tunnels.peer_id;
@@ -1,20 +0,0 @@
-- name: GetUserChatProviderKeys :many
SELECT * FROM user_chat_provider_keys WHERE user_id = @user_id ORDER BY created_at ASC, id ASC;
-- name: UpsertUserChatProviderKey :one
INSERT INTO user_chat_provider_keys (user_id, chat_provider_id, api_key, api_key_key_id)
VALUES (@user_id, @chat_provider_id, @api_key, sqlc.narg('api_key_key_id')::text)
ON CONFLICT (user_id, chat_provider_id) DO UPDATE SET
api_key = @api_key,
api_key_key_id = sqlc.narg('api_key_key_id')::text,
updated_at = NOW()
RETURNING *;
-- name: UpdateUserChatProviderKey :one
UPDATE user_chat_provider_keys
SET api_key = @api_key, api_key_key_id = sqlc.narg('api_key_key_id')::text, updated_at = NOW()
WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id
RETURNING *;
-- name: DeleteUserChatProviderKey :exec
DELETE FROM user_chat_provider_keys WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id;
-1
View File
@@ -247,7 +247,6 @@ sql:
mcp_server_tool_snapshots: MCPServerToolSnapshots
mcp_server_config_id: MCPServerConfigID
mcp_server_ids: MCPServerIDs
max_file_links: MaxFileLinks
icon_url: IconURL
oauth2_client_id: OAuth2ClientID
oauth2_client_secret: OAuth2ClientSecret
-3
View File
@@ -16,7 +16,6 @@ const (
UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
UniqueBoundaryUsageStatsPkey UniqueConstraint = "boundary_usage_stats_pkey" // ALTER TABLE ONLY boundary_usage_stats ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
UniqueChatDiffStatusesPkey UniqueConstraint = "chat_diff_statuses_pkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
UniqueChatFileLinksChatIDFileIDKey UniqueConstraint = "chat_file_links_chat_id_file_id_key" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_file_id_key UNIQUE (chat_id, file_id);
UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
UniqueChatMessagesPkey UniqueConstraint = "chat_messages_pkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
UniqueChatModelConfigsPkey UniqueConstraint = "chat_model_configs_pkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id);
@@ -91,8 +90,6 @@ const (
UniqueTemplatesPkey UniqueConstraint = "templates_pkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_pkey PRIMARY KEY (id);
UniqueUsageEventsDailyPkey UniqueConstraint = "usage_events_daily_pkey" // ALTER TABLE ONLY usage_events_daily ADD CONSTRAINT usage_events_daily_pkey PRIMARY KEY (day, event_type);
UniqueUsageEventsPkey UniqueConstraint = "usage_events_pkey" // ALTER TABLE ONLY usage_events ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id);
UniqueUserChatProviderKeysPkey UniqueConstraint = "user_chat_provider_keys_pkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id);
UniqueUserChatProviderKeysUserIDChatProviderIDKey UniqueConstraint = "user_chat_provider_keys_user_id_chat_provider_id_key" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id);
UniqueUserConfigsPkey UniqueConstraint = "user_configs_pkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key);
UniqueUserDeletedPkey UniqueConstraint = "user_deleted_pkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_pkey PRIMARY KEY (id);
UniqueUserLinksPkey UniqueConstraint = "user_links_pkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_pkey PRIMARY KEY (user_id, login_type);
+107 -616
View File
File diff suppressed because it is too large Load Diff
+10 -1552
View File
File diff suppressed because it is too large Load Diff
+7 -8
View File
@@ -39,14 +39,13 @@ func TestChatParam(t *testing.T) {
t.Helper()
_, err := db.InsertChatProvider(context.Background(), database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-api-key",
BaseUrl: "https://api.openai.com/v1",
ApiKeyKeyID: sql.NullString{},
CreatedBy: uuid.NullUUID{UUID: ownerID, Valid: true},
Enabled: true,
CentralApiKeyEnabled: true,
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-api-key",
BaseUrl: "https://api.openai.com/v1",
ApiKeyKeyID: sql.NullString{},
CreatedBy: uuid.NullUUID{UUID: ownerID, Valid: true},
Enabled: true,
})
require.NoError(t, err)
+16 -29
View File
@@ -2,7 +2,6 @@ package prebuilds
import (
"context"
"encoding/json"
"sync"
"github.com/google/uuid"
@@ -23,11 +22,7 @@ type PubsubWorkspaceClaimPublisher struct {
func (p PubsubWorkspaceClaimPublisher) PublishWorkspaceClaim(claim agentsdk.ReinitializationEvent) error {
channel := agentsdk.PrebuildClaimedChannel(claim.WorkspaceID)
payload, err := json.Marshal(claim)
if err != nil {
return xerrors.Errorf("marshal claim event: %w", err)
}
if err := p.ps.Publish(channel, payload); err != nil {
if err := p.ps.Publish(channel, []byte(claim.Reason)); err != nil {
return xerrors.Errorf("failed to trigger prebuilt workspace agent reinitialization: %w", err)
}
return nil
@@ -42,41 +37,33 @@ type PubsubWorkspaceClaimListener struct {
ps pubsub.Pubsub
}
// ListenForWorkspaceClaims subscribes to a pubsub channel and returns a
// receive-only channel that emits claim events for the given workspace.
// The returned channel is owned by this function and is never closed,
// because pubsub.Pubsub does not guarantee that all in-flight callbacks
// have returned after unsubscribe. Call the returned cancel function to
// unsubscribe when events are no longer needed; cancel is also called
// automatically if ctx expires or is canceled.
func (p PubsubWorkspaceClaimListener) ListenForWorkspaceClaims(ctx context.Context, workspaceID uuid.UUID) (<-chan agentsdk.ReinitializationEvent, func(), error) {
// ListenForWorkspaceClaims subscribes to a pubsub channel and sends any received events on the chan that it returns.
// pubsub.Pubsub does not communicate when its last callback has been called after it has been closed. As such the chan
// returned by this method is never closed. Call the returned cancel() function to close the subscription when it is no longer needed.
// cancel() will be called if ctx expires or is canceled.
func (p PubsubWorkspaceClaimListener) ListenForWorkspaceClaims(ctx context.Context, workspaceID uuid.UUID, reinitEvents chan<- agentsdk.ReinitializationEvent) (func(), error) {
select {
case <-ctx.Done():
return nil, func() {}, ctx.Err()
return func() {}, ctx.Err()
default:
}
reinitEvents := make(chan agentsdk.ReinitializationEvent, 1)
cancelSub, err := p.ps.Subscribe(agentsdk.PrebuildClaimedChannel(workspaceID), func(inner context.Context, payload []byte) {
var event agentsdk.ReinitializationEvent
if err := json.Unmarshal(payload, &event); err != nil {
// Rolling upgrade: old publishers send the raw reason
// string instead of JSON.
event = agentsdk.ReinitializationEvent{
WorkspaceID: workspaceID,
Reason: agentsdk.ReinitializationReason(payload),
}
cancelSub, err := p.ps.Subscribe(agentsdk.PrebuildClaimedChannel(workspaceID), func(inner context.Context, reason []byte) {
claim := agentsdk.ReinitializationEvent{
WorkspaceID: workspaceID,
Reason: agentsdk.ReinitializationReason(reason),
}
select {
case <-ctx.Done():
return
case <-inner.Done():
case reinitEvents <- event:
return
case reinitEvents <- claim:
}
})
if err != nil {
return nil, func() {}, xerrors.Errorf("failed to subscribe to prebuild claimed channel: %w", err)
return func() {}, xerrors.Errorf("failed to subscribe to prebuild claimed channel: %w", err)
}
var once sync.Once
@@ -91,5 +78,5 @@ func (p PubsubWorkspaceClaimListener) ListenForWorkspaceClaims(ctx context.Conte
cancel()
}()
return reinitEvents, cancel, nil
return cancel, nil
}
+12 -11
View File
@@ -25,26 +25,24 @@ func TestPubsubWorkspaceClaimPublisher(t *testing.T) {
logger := testutil.Logger(t)
ps := pubsub.NewInMemory()
workspaceID := uuid.New()
reinitEvents := make(chan agentsdk.ReinitializationEvent, 1)
publisher := prebuilds.NewPubsubWorkspaceClaimPublisher(ps)
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, logger)
events, cancel, err := listener.ListenForWorkspaceClaims(ctx, workspaceID)
cancel, err := listener.ListenForWorkspaceClaims(ctx, workspaceID, reinitEvents)
require.NoError(t, err)
defer cancel()
userID := uuid.New()
claim := agentsdk.ReinitializationEvent{
WorkspaceID: workspaceID,
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
OwnerID: userID,
}
err = publisher.PublishWorkspaceClaim(claim)
require.NoError(t, err)
gotEvent := testutil.RequireReceive(ctx, t, events)
gotEvent := testutil.RequireReceive(ctx, t, reinitEvents)
require.Equal(t, workspaceID, gotEvent.WorkspaceID)
require.Equal(t, claim.Reason, gotEvent.Reason)
require.Equal(t, userID, gotEvent.OwnerID)
})
t.Run("fail to publish claim", func(t *testing.T) {
@@ -71,8 +69,10 @@ func TestPubsubWorkspaceClaimListener(t *testing.T) {
ps := pubsub.NewInMemory()
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
claims := make(chan agentsdk.ReinitializationEvent, 1) // Buffer to avoid messing with goroutines in the rest of the test
workspaceID := uuid.New()
events, cancelFunc, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID)
cancelFunc, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID, claims)
require.NoError(t, err)
defer cancelFunc()
@@ -84,10 +84,9 @@ func TestPubsubWorkspaceClaimListener(t *testing.T) {
// Verify we receive the claim
ctx := testutil.Context(t, testutil.WaitShort)
claim := testutil.RequireReceive(ctx, t, events)
claim := testutil.RequireReceive(ctx, t, claims)
require.Equal(t, workspaceID, claim.WorkspaceID)
require.Equal(t, reason, claim.Reason)
require.Equal(t, uuid.Nil, claim.OwnerID)
})
t.Run("ignores claim events for other workspaces", func(t *testing.T) {
@@ -96,9 +95,10 @@ func TestPubsubWorkspaceClaimListener(t *testing.T) {
ps := pubsub.NewInMemory()
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
claims := make(chan agentsdk.ReinitializationEvent)
workspaceID := uuid.New()
otherWorkspaceID := uuid.New()
events, cancelFunc, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID)
cancelFunc, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID, claims)
require.NoError(t, err)
defer cancelFunc()
@@ -109,7 +109,7 @@ func TestPubsubWorkspaceClaimListener(t *testing.T) {
// Verify we don't receive the claim
select {
case <-events:
case <-claims:
t.Fatal("received claim for wrong workspace")
case <-time.After(100 * time.Millisecond):
// Expected - no claim received
@@ -119,10 +119,11 @@ func TestPubsubWorkspaceClaimListener(t *testing.T) {
t.Run("communicates the error if it can't subscribe", func(t *testing.T) {
t.Parallel()
claims := make(chan agentsdk.ReinitializationEvent)
ps := &brokenPubsub{}
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
_, _, err := listener.ListenForWorkspaceClaims(context.Background(), uuid.New())
_, err := listener.ListenForWorkspaceClaims(context.Background(), uuid.New(), claims)
require.ErrorContains(t, err, "failed to subscribe to prebuild claimed channel")
})
}
@@ -2539,7 +2539,6 @@ func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.Pro
err = prebuilds.NewPubsubWorkspaceClaimPublisher(s.Pubsub).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{
WorkspaceID: workspace.ID,
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
OwnerID: workspace.OwnerID,
})
if err != nil {
s.Logger.Error(ctx, "failed to publish workspace claim event", slog.Error(err))
@@ -51,6 +51,7 @@ import (
"github.com/coder/coder/v2/coderd/usage/usagetypes"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/provisionerd/proto"
"github.com/coder/coder/v2/provisionersdk"
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
@@ -2786,7 +2787,8 @@ func TestCompleteJob(t *testing.T) {
require.NoError(t, err)
// GIVEN something is listening to process workspace reinitialization:
reinitChan, cancel, err := agplprebuilds.NewPubsubWorkspaceClaimListener(ps, testutil.Logger(t)).ListenForWorkspaceClaims(ctx, workspace.ID)
reinitChan := make(chan agentsdk.ReinitializationEvent, 1) // Buffered to simplify test structure
cancel, err := agplprebuilds.NewPubsubWorkspaceClaimListener(ps, testutil.Logger(t)).ListenForWorkspaceClaims(ctx, workspace.ID, reinitChan)
require.NoError(t, err)
defer cancel()
-34
View File
@@ -298,40 +298,6 @@ neq(input.object.owner, "");
ExpectedSQL: p("'' = 'org-id'"),
VariableConverter: regosql.ChatConverter(),
},
{
Name: "AuditLogUUID",
Queries: []string{
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
`input.object.org_owner != ""`,
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708", "05f58202-4bfc-43ce-9ba4-5ff6e0174a71"}`,
`"read" in input.object.acl_group_list[input.object.org_owner]`,
},
ExpectedSQL: p(
p("audit_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("audit_logs.organization_id IS NOT NULL") + " OR " +
p("audit_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("audit_logs.organization_id = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'::uuid,'8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
"(false)"),
VariableConverter: regosql.AuditLogConverter(),
},
{
Name: "ConnectionLogUUID",
Queries: []string{
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
`input.object.org_owner != ""`,
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708"}`,
`"read" in input.object.acl_group_list[input.object.org_owner]`,
},
ExpectedSQL: p(
p("connection_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("connection_logs.organization_id IS NOT NULL") + " OR " +
p("connection_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("connection_logs.organization_id = ANY(ARRAY ['8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
"(false)"),
VariableConverter: regosql.ConnectionLogConverter(),
},
}
for _, tc := range testCases {
+2 -2
View File
@@ -53,7 +53,7 @@ func WorkspaceConverter() *sqltypes.VariableConverter {
func AuditLogConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
sqltypes.UUIDVarMatcher("audit_logs.organization_id", []string{"input", "object", "org_owner"}),
sqltypes.StringVarMatcher("COALESCE(audit_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
// Audit logs have no user owner, only owner by an organization.
sqltypes.AlwaysFalse(userOwnerMatcher()),
)
@@ -67,7 +67,7 @@ func AuditLogConverter() *sqltypes.VariableConverter {
func ConnectionLogConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
sqltypes.UUIDVarMatcher("connection_logs.organization_id", []string{"input", "object", "org_owner"}),
sqltypes.StringVarMatcher("COALESCE(connection_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
// Connection logs have no user owner, only owner by an organization.
sqltypes.AlwaysFalse(userOwnerMatcher()),
)
-114
View File
@@ -1,114 +0,0 @@
package sqltypes
import (
"fmt"
"strings"
"github.com/open-policy-agent/opa/ast"
"golang.org/x/xerrors"
)
var (
_ VariableMatcher = astUUIDVar{}
_ Node = astUUIDVar{}
_ SupportsEquality = astUUIDVar{}
)
// astUUIDVar is a variable that represents a UUID column. Unlike
// astStringVar it emits native UUID comparisons (column = 'val'::uuid)
// instead of text-based ones (COALESCE(column::text, ”) = 'val').
// This allows PostgreSQL to use indexes on UUID columns.
type astUUIDVar struct {
Source RegoSource
FieldPath []string
ColumnString string
}
func UUIDVarMatcher(sqlColumn string, regoPath []string) VariableMatcher {
return astUUIDVar{FieldPath: regoPath, ColumnString: sqlColumn}
}
func (astUUIDVar) UseAs() Node { return astUUIDVar{} }
func (u astUUIDVar) ConvertVariable(rego ast.Ref) (Node, bool) {
left, err := RegoVarPath(u.FieldPath, rego)
if err == nil && len(left) == 0 {
return astUUIDVar{
Source: RegoSource(rego.String()),
FieldPath: u.FieldPath,
ColumnString: u.ColumnString,
}, true
}
return nil, false
}
func (u astUUIDVar) SQLString(_ *SQLGenerator) string {
return u.ColumnString
}
// EqualsSQLString handles equality comparisons for UUID columns.
// Rego always produces string literals, so we accept AstString and
// cast the literal to ::uuid in the output SQL. This lets PG use
// native UUID indexes instead of falling back to text comparisons.
// nolint:revive
func (u astUUIDVar) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
switch other.UseAs().(type) {
case AstString:
// The other side is a rego string literal like
// "8c0b9bdc-a013-4b14-a49b-5747bc335708". Emit a comparison
// that casts the literal to uuid so PG can use indexes:
// column = 'val'::uuid
// instead of the text-based:
// 'val' = COALESCE(column::text, '')
s, ok := other.(AstString)
if !ok {
return "", xerrors.Errorf("expected AstString, got %T", other)
}
if s.Value == "" {
// Empty string in rego means "no value". Compare the
// column against NULL since UUID columns represent
// absent values as NULL, not empty strings.
op := "IS NULL"
if not {
op = "IS NOT NULL"
}
return fmt.Sprintf("%s %s", u.ColumnString, op), nil
}
return fmt.Sprintf("%s %s '%s'::uuid",
u.ColumnString, equalsOp(not), s.Value), nil
case astUUIDVar:
return basicSQLEquality(cfg, not, u, other), nil
default:
return "", xerrors.Errorf("unsupported equality: %T %s %T",
u, equalsOp(not), other)
}
}
// ContainedInSQL implements SupportsContainedIn so that a UUID column
// can appear in membership checks like `col = ANY(ARRAY[...])`. The
// array elements are rego strings, so we cast each to ::uuid.
func (u astUUIDVar) ContainedInSQL(_ *SQLGenerator, haystack Node) (string, error) {
arr, ok := haystack.(ASTArray)
if !ok {
return "", xerrors.Errorf("unsupported containedIn: %T in %T", u, haystack)
}
if len(arr.Value) == 0 {
return "false", nil
}
// Build ARRAY['uuid1'::uuid, 'uuid2'::uuid, ...]
values := make([]string, 0, len(arr.Value))
for _, v := range arr.Value {
s, ok := v.(AstString)
if !ok {
return "", xerrors.Errorf("expected AstString array element, got %T", v)
}
values = append(values, fmt.Sprintf("'%s'::uuid", s.Value))
}
return fmt.Sprintf("%s = ANY(ARRAY [%s])",
u.ColumnString,
strings.Join(values, ",")), nil
}
+1 -2
View File
@@ -66,7 +66,7 @@ func AuditLogs(ctx context.Context, db database.Store, query string) (database.G
}
// Prepare the count filter, which uses the same parameters as the GetAuditLogsOffsetParams.
// nolint:exhaustruct // UserID and CountCap are not obtained from the query parameters.
// nolint:exhaustruct // UserID is not obtained from the query parameters.
countFilter := database.CountAuditLogsParams{
RequestID: filter.RequestID,
ResourceID: filter.ResourceID,
@@ -123,7 +123,6 @@ func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey
}
// This MUST be kept in sync with the above
// nolint:exhaustruct // CountCap is not obtained from the query parameters.
countFilter := database.CountConnectionLogsParams{
OrganizationID: filter.OrganizationID,
WorkspaceOwner: filter.WorkspaceOwner,
-9
View File
@@ -1502,7 +1502,6 @@ type Snapshot struct {
PrebuiltWorkspaces []PrebuiltWorkspace `json:"prebuilt_workspaces"`
AIBridgeInterceptionsSummaries []AIBridgeInterceptionsSummary `json:"aibridge_interceptions_summaries"`
BoundaryUsageSummary *BoundaryUsageSummary `json:"boundary_usage_summary"`
FirstUserOnboarding *FirstUserOnboarding `json:"first_user_onboarding"`
}
// Deployment contains information about the host running Coder.
@@ -1552,14 +1551,6 @@ type User struct {
LoginType string `json:"login_type,omitempty"`
}
// FirstUserOnboarding contains optional newsletter preference data
// collected during first user setup. This is sent once when the first
// user is created.
type FirstUserOnboarding struct {
NewsletterMarketing bool `json:"newsletter_marketing"`
NewsletterReleases bool `json:"newsletter_releases"`
}
type Group struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
+8 -12
View File
@@ -223,12 +223,10 @@ func TestTelemetry(t *testing.T) {
StartedAt: previousAIBridgeInterceptionPeriod.Add(-30 * time.Minute),
}, nil)
_ = dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{
InterceptionID: aiBridgeInterception1.ID,
InputTokens: 100,
OutputTokens: 200,
CacheReadInputTokens: 300,
CacheWriteInputTokens: 400,
Metadata: json.RawMessage(`{"cache_read_input":300,"cache_creation_input":400}`),
InterceptionID: aiBridgeInterception1.ID,
InputTokens: 100,
OutputTokens: 200,
Metadata: json.RawMessage(`{"cache_read_input":300,"cache_creation_input":400}`),
})
_ = dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{
InterceptionID: aiBridgeInterception1.ID,
@@ -250,12 +248,10 @@ func TestTelemetry(t *testing.T) {
StartedAt: aiBridgeInterception1.StartedAt,
}, nil)
_ = dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{
InterceptionID: aiBridgeInterception2.ID,
InputTokens: 100,
OutputTokens: 200,
CacheReadInputTokens: 300,
CacheWriteInputTokens: 400,
Metadata: json.RawMessage(`{"cache_read_input":300,"cache_creation_input":400}`),
InterceptionID: aiBridgeInterception2.ID,
InputTokens: 100,
OutputTokens: 200,
Metadata: json.RawMessage(`{"cache_read_input":300,"cache_creation_input":400}`),
})
_ = dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{
InterceptionID: aiBridgeInterception2.ID,
+1 -12
View File
@@ -281,19 +281,8 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) {
telemetryUser := telemetry.ConvertUser(user)
// Send the initial users email address!
telemetryUser.Email = &user.Email
// Only populate onboarding data when the client actually sent it. A nil
// OnboardingInfo means the request came from an older client, the CLI, or
// the OIDC flow — not from a user who answered "no" to every question.
var onboarding *telemetry.FirstUserOnboarding
if createUser.OnboardingInfo != nil {
onboarding = &telemetry.FirstUserOnboarding{
NewsletterMarketing: createUser.OnboardingInfo.NewsletterMarketing,
NewsletterReleases: createUser.OnboardingInfo.NewsletterReleases,
}
}
api.Telemetry.Report(&telemetry.Snapshot{
Users: []telemetry.User{telemetryUser},
FirstUserOnboarding: onboarding,
Users: []telemetry.User{telemetryUser},
})
httpapi.Write(ctx, rw, http.StatusCreated, codersdk.CreateFirstUserResponse{
-71
View File
@@ -116,77 +116,6 @@ func TestFirstUser(t *testing.T) {
})
}
func TestFirstUser_OnboardingTelemetry(t *testing.T) {
t.Parallel()
t.Run("OnboardingInfoFlowsToSnapshot", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
fTelemetry := newFakeTelemetryReporter(ctx, t, 10)
client := coderdtest.New(t, &coderdtest.Options{
TelemetryReporter: fTelemetry,
})
_, err := client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{
Email: "admin@coder.com",
Username: "admin",
Password: "SomeSecurePassword!",
OnboardingInfo: &codersdk.CreateFirstUserOnboardingInfo{
NewsletterMarketing: false,
NewsletterReleases: true,
},
})
require.NoError(t, err)
snapshot := testutil.TryReceive(ctx, t, fTelemetry.snapshots)
require.NotNil(t, snapshot.FirstUserOnboarding)
require.False(t, snapshot.FirstUserOnboarding.NewsletterMarketing)
require.True(t, snapshot.FirstUserOnboarding.NewsletterReleases)
})
t.Run("NilWhenOnboardingInfoOmitted", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
fTelemetry := newFakeTelemetryReporter(ctx, t, 10)
client := coderdtest.New(t, &coderdtest.Options{
TelemetryReporter: fTelemetry,
})
_, err := client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{
Email: "admin@coder.com",
Username: "admin",
Password: "SomeSecurePassword!",
// No OnboardingInfo — simulates old CLI or OIDC flow.
})
require.NoError(t, err)
snapshot := testutil.TryReceive(ctx, t, fTelemetry.snapshots)
require.Nil(t, snapshot.FirstUserOnboarding)
})
t.Run("EmptyOnboardingInfoIsNonNilWithZeroFields", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
fTelemetry := newFakeTelemetryReporter(ctx, t, 10)
client := coderdtest.New(t, &coderdtest.Options{
TelemetryReporter: fTelemetry,
})
_, err := client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{
Email: "admin@coder.com", Username: "admin",
Password: "SomeSecurePassword!",
OnboardingInfo: &codersdk.CreateFirstUserOnboardingInfo{},
})
require.NoError(t, err)
snapshot := testutil.TryReceive(ctx, t, fTelemetry.snapshots)
require.NotNil(t, snapshot.FirstUserOnboarding,
"non-nil OnboardingInfo must produce non-nil telemetry")
require.False(t, snapshot.FirstUserOnboarding.NewsletterMarketing)
require.False(t, snapshot.FirstUserOnboarding.NewsletterReleases)
})
}
func TestPostLogin(t *testing.T) {
t.Parallel()
t.Run("InvalidUser", func(t *testing.T) {
+3 -100
View File
@@ -1465,9 +1465,7 @@ func (api *API) workspaceAgentPostLogSource(rw http.ResponseWriter, r *http.Requ
// @Security CoderSessionToken
// @Produce json
// @Tags Agents
// @Param wait query bool false "Opt in to durable reinit checks"
// @Success 200 {object} agentsdk.ReinitializationEvent
// @Failure 409 {object} codersdk.Response
// @Router /workspaceagents/me/reinit [get]
func (api *API) workspaceAgentReinit(rw http.ResponseWriter, r *http.Request) {
// Allow us to interrupt watch via cancel.
@@ -1484,113 +1482,18 @@ func (api *API) workspaceAgentReinit(rw http.ResponseWriter, r *http.Request) {
if err != nil {
log.Error(ctx, "failed to retrieve workspace from agent token", slog.Error(err))
httpapi.InternalServerError(rw, xerrors.New("failed to determine workspace from agent token"))
return
}
log = log.With(slog.F("workspace_id", workspace.ID))
log.Info(ctx, "agent waiting for reinit instruction")
// Subscribe to claim events BEFORE any durable checks to avoid a
// TOCTOU race: without this, a claim could fire between the
// IsPrebuild() check and the subscribe call, and we'd miss the
// pubsub event entirely. By subscribing first, any event that
// fires during the checks below is buffered in the channel.
pubsubCh, cancelSub, err := prebuilds.NewPubsubWorkspaceClaimListener(api.Pubsub, log).ListenForWorkspaceClaims(ctx, workspace.ID)
reinitEvents := make(chan agentsdk.ReinitializationEvent)
cancel, err = prebuilds.NewPubsubWorkspaceClaimListener(api.Pubsub, log).ListenForWorkspaceClaims(ctx, workspace.ID, reinitEvents)
if err != nil {
log.Error(ctx, "subscribe to prebuild claimed channel", slog.Error(err))
httpapi.InternalServerError(rw, xerrors.New("failed to subscribe to prebuild claimed channel"))
return
}
defer cancelSub()
reinitEvents := pubsubCh
// Only perform the durable claim check when the agent opts in via
// the "wait" query parameter. Older agents don't send the
// "wait" query parameter and lack the duplicate-reinit guard, so
// they would enter an infinite reinit loop if we pre-seeded the
// channel on every connection.
waitParam, _ := strconv.ParseBool(r.URL.Query().Get("wait"))
if waitParam && !workspace.IsPrebuild() {
firstBuild, err := api.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx,
database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
WorkspaceID: workspace.ID,
BuildNumber: 1,
})
if err != nil {
log.Error(ctx, "failed to get first workspace build", slog.Error(err))
httpapi.InternalServerError(rw, xerrors.New("failed to get first workspace build"))
return
}
if firstBuild.InitiatorID != database.PrebuildsSystemUserID {
// Not a claimed prebuild — this is a regular workspace.
// Return 409 so the agent stops reconnecting to this
// endpoint.
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: "Workspace is not a prebuilt workspace waiting to be claimed.",
Detail: "This endpoint is only for agents running in prebuilt workspaces.",
})
return
}
// This workspace was a prebuild that got claimed. Check if
// the claim build completed successfully before sending
// reinit. We assume the latest build is the claim build
// (build 2). If a third build (e.g. a restart) starts
// between the claim and the agent's reconnection, this
// would check that build instead. The window is extremely
// small in practice, and a restart would trigger its own
// reinit path.
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspace.ID)
if err != nil {
log.Error(ctx, "failed to get latest workspace build", slog.Error(err))
httpapi.InternalServerError(rw, xerrors.New("failed to get latest workspace build"))
return
}
job, err := api.Database.GetProvisionerJobByID(ctx, latestBuild.JobID)
if err != nil {
log.Error(ctx, "failed to get provisioner job", slog.Error(err))
httpapi.InternalServerError(rw, xerrors.New("failed to get provisioner job"))
return
}
if job.CompletedAt.Valid && !job.Error.Valid {
// Claim build succeeded — cancel the pubsub
// subscription (no longer needed) and swap in a
// pre-seeded channel so the transmitter delivers
// exactly one reinit event.
cancelSub()
seeded := make(chan agentsdk.ReinitializationEvent, 1)
seeded <- agentsdk.ReinitializationEvent{
WorkspaceID: workspace.ID,
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
OwnerID: workspace.OwnerID,
}
reinitEvents = seeded
} else if job.CompletedAt.Valid && job.Error.Valid {
// Claim build failed permanently. Return 409 so the
// agent treats this as terminal and stops retrying
// (WaitForReinitLoop exits on any 409).
cancelSub()
log.Warn(ctx, "claim build failed",
slog.F("job_id", job.ID),
slog.F("error", job.Error.String))
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: "Claim build failed permanently.",
Detail: job.Error.String,
})
return
}
// Claim build still in progress — fall through to the
// transmitter. The pubsub subscription (set up above)
// will deliver the event when the build completes
// successfully. Note: FailJob does not publish a claim
// event, so a failed in-progress build will leave the
// agent blocking here until it disconnects and
// reconnects (at which point the durable check above
// handles it).
}
defer cancel()
transmitter := agentsdk.NewSSEAgentReinitTransmitter(log, rw, r)

Some files were not shown because too many files have changed in this diff Show More