Compare commits
42 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c174b3037b | |||
| f5165d304f | |||
| 684f21740d | |||
| 86ca61d6ca | |||
| f0521cfa3c | |||
| 0c5d189aff | |||
| d7c8213eee | |||
| 63924ac687 | |||
| 6c47e9ea23 | |||
| aede045549 | |||
| 2ea08aa168 | |||
| d4b9248202 | |||
| fd6c623560 | |||
| 99da498679 | |||
| a20b817c28 | |||
| d5a1792f07 | |||
| beb99c17de | |||
| 8913f9f5c1 | |||
| acd5f01b4b | |||
| 6c62d8f5e6 | |||
| 5000f15021 | |||
| 44be5a0d1e | |||
| 3ca2aae9ca | |||
| 01080302a5 | |||
| 61d6c728b9 | |||
| 648787e739 | |||
| d2950e7615 | |||
| df8f695e84 | |||
| 8bb48ffdda | |||
| 4cfbf544a0 | |||
| a2ce74f398 | |||
| 0060dee222 | |||
| 5ff1058f30 | |||
| 500fc5e2a4 | |||
| baba9e6ede | |||
| b36619b905 | |||
| 937f50f0ae | |||
| a16755dd66 | |||
| 8bdc35f91f | |||
| 5b32c4d79d | |||
| 8625543413 | |||
| e18094825a |
@@ -31,7 +31,8 @@ updates:
|
||||
patterns:
|
||||
- "golang.org/x/*"
|
||||
ignore:
|
||||
# Ignore patch updates for all dependencies
|
||||
# Patch updates are handled by the security-patch-prs workflow so this
|
||||
# lane stays focused on broader dependency updates.
|
||||
- dependency-name: "*"
|
||||
update-types:
|
||||
- version-update:semver-patch
|
||||
@@ -56,7 +57,7 @@ updates:
|
||||
labels: []
|
||||
ignore:
|
||||
# We need to coordinate terraform updates with the version hardcoded in
|
||||
# our Go code.
|
||||
# our Go code. These are handled by the security-patch-prs workflow.
|
||||
- dependency-name: "terraform"
|
||||
|
||||
- package-ecosystem: "npm"
|
||||
@@ -117,11 +118,11 @@ updates:
|
||||
interval: "weekly"
|
||||
commit-message:
|
||||
prefix: "chore"
|
||||
labels: []
|
||||
groups:
|
||||
coder-modules:
|
||||
patterns:
|
||||
- "coder/*/coder"
|
||||
labels: []
|
||||
ignore:
|
||||
- dependency-name: "*"
|
||||
update-types:
|
||||
|
||||
+17
-17
@@ -35,7 +35,7 @@ jobs:
|
||||
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -157,7 +157,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -247,7 +247,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -272,7 +272,7 @@ jobs:
|
||||
if: ${{ !cancelled() }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -327,7 +327,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -379,7 +379,7 @@ jobs:
|
||||
- windows-2022
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -575,7 +575,7 @@ jobs:
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -637,7 +637,7 @@ jobs:
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -709,7 +709,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -736,7 +736,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -769,7 +769,7 @@ jobs:
|
||||
name: ${{ matrix.variant.name }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -849,7 +849,7 @@ jobs:
|
||||
if: needs.changes.outputs.site == 'true' || needs.changes.outputs.ci == 'true'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -930,7 +930,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1005,7 +1005,7 @@ jobs:
|
||||
if: always()
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1043,7 +1043,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1097,7 +1097,7 @@ jobs:
|
||||
IMAGE: ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1479,7 +1479,7 @@ jobs:
|
||||
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ jobs:
|
||||
steps:
|
||||
- name: Dependabot metadata
|
||||
id: metadata
|
||||
uses: dependabot/fetch-metadata@21025c705c08248db411dc16f3619e6b5f9ea21a # v2.5.0
|
||||
uses: dependabot/fetch-metadata@ffa630c65fa7e0ecfa0625b5ceda64399aea1b36 # v3.0.0
|
||||
with:
|
||||
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
verdict: ${{ steps.check.outputs.verdict }} # DEPLOY or NOOP
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -65,7 +65,7 @@ jobs:
|
||||
packages: write # to retag image as dogfood
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -142,7 +142,7 @@ jobs:
|
||||
needs: deploy
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
if: github.repository_owner == 'coder'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -125,7 +125,7 @@ jobs:
|
||||
id-token: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
- windows-2022
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ jobs:
|
||||
packages: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ jobs:
|
||||
PR_OPEN: ${{ steps.check_pr.outputs.pr_open }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -76,7 +76,7 @@ jobs:
|
||||
runs-on: "ubuntu-latest"
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -184,7 +184,7 @@ jobs:
|
||||
pull-requests: write # needed for commenting on PRs
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -228,7 +228,7 @@ jobs:
|
||||
CODER_IMAGE_TAG: ${{ needs.get_info.outputs.CODER_IMAGE_TAG }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -288,7 +288,7 @@ jobs:
|
||||
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ jobs:
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -673,7 +673,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -749,7 +749,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -47,6 +47,6 @@ jobs:
|
||||
|
||||
# Upload the results to GitHub's code scanning dashboard.
|
||||
- name: "Upload to code-scanning"
|
||||
uses: github/codeql-action/upload-sarif@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
|
||||
uses: github/codeql-action/upload-sarif@c10b8064de6f491fea524254123dbe5e09572f13 # v3.29.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
@@ -0,0 +1,354 @@
|
||||
name: security-backport
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types:
|
||||
- labeled
|
||||
- unlabeled
|
||||
- closed
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
pull_request:
|
||||
description: Pull request number to backport.
|
||||
required: true
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || inputs.pull_request }}
|
||||
cancel-in-progress: false
|
||||
|
||||
env:
|
||||
LATEST_BRANCH: release/2.31
|
||||
STABLE_BRANCH: release/2.30
|
||||
STABLE_1_BRANCH: release/2.29
|
||||
|
||||
jobs:
|
||||
label-policy:
|
||||
if: github.event_name == 'pull_request_target'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Apply security backport label policy
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
|
||||
pr_json="$(gh pr view "${pr_number}" --json number,title,url,baseRefName,labels)"
|
||||
|
||||
PR_JSON="${pr_json}" \
|
||||
python3 - <<'PY'
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
pr = json.loads(os.environ["PR_JSON"])
|
||||
|
||||
pr_number = pr["number"]
|
||||
labels = [label["name"] for label in pr.get("labels", [])]
|
||||
|
||||
def has(label: str) -> bool:
|
||||
return label in labels
|
||||
|
||||
def ensure_label(label: str) -> None:
|
||||
if not has(label):
|
||||
subprocess.run(
|
||||
["gh", "pr", "edit", str(pr_number), "--add-label", label],
|
||||
check=False,
|
||||
)
|
||||
|
||||
def remove_label(label: str) -> None:
|
||||
if has(label):
|
||||
subprocess.run(
|
||||
["gh", "pr", "edit", str(pr_number), "--remove-label", label],
|
||||
check=False,
|
||||
)
|
||||
|
||||
def comment(body: str) -> None:
|
||||
subprocess.run(
|
||||
["gh", "pr", "comment", str(pr_number), "--body", body],
|
||||
check=True,
|
||||
)
|
||||
|
||||
if not has("security:patch"):
|
||||
remove_label("status:needs-severity")
|
||||
sys.exit(0)
|
||||
|
||||
severity_labels = [
|
||||
label
|
||||
for label in ("severity:medium", "severity:high", "severity:critical")
|
||||
if has(label)
|
||||
]
|
||||
if len(severity_labels) == 0:
|
||||
ensure_label("status:needs-severity")
|
||||
comment(
|
||||
"This PR is labeled `security:patch` but is missing a severity "
|
||||
"label. Add one of `severity:medium`, `severity:high`, or "
|
||||
"`severity:critical` before backport automation can proceed."
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
if len(severity_labels) > 1:
|
||||
comment(
|
||||
"This PR has multiple severity labels. Keep exactly one of "
|
||||
"`severity:medium`, `severity:high`, or `severity:critical`."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
remove_label("status:needs-severity")
|
||||
|
||||
target_labels = [
|
||||
label
|
||||
for label in ("backport:stable", "backport:stable-1")
|
||||
if has(label)
|
||||
]
|
||||
has_none = has("backport:none")
|
||||
if has_none and target_labels:
|
||||
comment(
|
||||
"`backport:none` cannot be combined with other backport labels. "
|
||||
"Remove `backport:none` or remove the explicit backport targets."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if not has_none and not target_labels:
|
||||
ensure_label("backport:stable")
|
||||
ensure_label("backport:stable-1")
|
||||
comment(
|
||||
"Applied default backport labels `backport:stable` and "
|
||||
"`backport:stable-1` for a qualifying security patch."
|
||||
)
|
||||
PY
|
||||
|
||||
backport:
|
||||
if: >
|
||||
github.event_name == 'workflow_dispatch' ||
|
||||
(
|
||||
github.event_name == 'pull_request_target' &&
|
||||
github.event.pull_request.merged == true
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Resolve PR metadata
|
||||
id: metadata
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
INPUT_PR_NUMBER: ${{ inputs.pull_request }}
|
||||
LATEST_BRANCH: ${{ env.LATEST_BRANCH }}
|
||||
STABLE_BRANCH: ${{ env.STABLE_BRANCH }}
|
||||
STABLE_1_BRANCH: ${{ env.STABLE_1_BRANCH }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [ "${GITHUB_EVENT_NAME}" = "workflow_dispatch" ]; then
|
||||
pr_number="${INPUT_PR_NUMBER}"
|
||||
else
|
||||
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
|
||||
fi
|
||||
|
||||
case "${pr_number}" in
|
||||
''|*[!0-9]*)
|
||||
echo "A valid pull request number is required."
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
pr_json="$(gh pr view "${pr_number}" --json number,title,url,mergeCommit,baseRefName,labels,mergedAt,author)"
|
||||
|
||||
PR_JSON="${pr_json}" \
|
||||
python3 - <<'PY'
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
pr = json.loads(os.environ["PR_JSON"])
|
||||
github_output = os.environ["GITHUB_OUTPUT"]
|
||||
|
||||
labels = [label["name"] for label in pr.get("labels", [])]
|
||||
if "security:patch" not in labels:
|
||||
print("Not a security patch PR; skipping.")
|
||||
sys.exit(0)
|
||||
|
||||
severity_labels = [
|
||||
label
|
||||
for label in ("severity:medium", "severity:high", "severity:critical")
|
||||
if label in labels
|
||||
]
|
||||
if len(severity_labels) != 1:
|
||||
raise SystemExit(
|
||||
"Merged security patch PR must have exactly one severity label."
|
||||
)
|
||||
|
||||
if not pr.get("mergedAt"):
|
||||
raise SystemExit(f"PR #{pr['number']} is not merged.")
|
||||
|
||||
if "backport:none" in labels:
|
||||
target_pairs = []
|
||||
else:
|
||||
mapping = {
|
||||
"backport:stable": os.environ["STABLE_BRANCH"],
|
||||
"backport:stable-1": os.environ["STABLE_1_BRANCH"],
|
||||
}
|
||||
target_pairs = []
|
||||
for label_name, branch in mapping.items():
|
||||
if label_name in labels and branch and branch != pr["baseRefName"]:
|
||||
target_pairs.append({"label": label_name, "branch": branch})
|
||||
|
||||
with open(github_output, "a", encoding="utf-8") as f:
|
||||
f.write(f"pr_number={pr['number']}\n")
|
||||
f.write(f"merge_sha={pr['mergeCommit']['oid']}\n")
|
||||
f.write(f"title={pr['title']}\n")
|
||||
f.write(f"url={pr['url']}\n")
|
||||
f.write(f"author={pr['author']['login']}\n")
|
||||
f.write(f"severity_label={severity_labels[0]}\n")
|
||||
f.write(f"target_pairs={json.dumps(target_pairs)}\n")
|
||||
PY
|
||||
|
||||
- name: Backport to release branches
|
||||
if: ${{ steps.metadata.outputs.target_pairs != '[]' }}
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_NUMBER: ${{ steps.metadata.outputs.pr_number }}
|
||||
MERGE_SHA: ${{ steps.metadata.outputs.merge_sha }}
|
||||
PR_TITLE: ${{ steps.metadata.outputs.title }}
|
||||
PR_URL: ${{ steps.metadata.outputs.url }}
|
||||
PR_AUTHOR: ${{ steps.metadata.outputs.author }}
|
||||
SEVERITY_LABEL: ${{ steps.metadata.outputs.severity_label }}
|
||||
TARGET_PAIRS: ${{ steps.metadata.outputs.target_pairs }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
git remote set-url origin "https://x-access-token:${GH_TOKEN}@github.com/${GITHUB_REPOSITORY}.git"
|
||||
git fetch origin --prune
|
||||
|
||||
merge_parent_count="$(git rev-list --parents -n 1 "${MERGE_SHA}" | awk '{print NF-1}')"
|
||||
|
||||
failures=()
|
||||
successes=()
|
||||
|
||||
while IFS=$'\t' read -r backport_label target_branch; do
|
||||
[ -n "${target_branch}" ] || continue
|
||||
|
||||
safe_branch_name="${target_branch//\//-}"
|
||||
head_branch="backport/${safe_branch_name}/pr-${PR_NUMBER}"
|
||||
|
||||
existing_pr="$(gh pr list \
|
||||
--base "${target_branch}" \
|
||||
--head "${head_branch}" \
|
||||
--state all \
|
||||
--json number,url \
|
||||
--jq '.[0]')"
|
||||
if [ -n "${existing_pr}" ] && [ "${existing_pr}" != "null" ]; then
|
||||
pr_url="$(printf '%s' "${existing_pr}" | jq -r '.url')"
|
||||
successes+=("${target_branch}:existing:${pr_url}")
|
||||
continue
|
||||
fi
|
||||
|
||||
git checkout -B "${head_branch}" "origin/${target_branch}"
|
||||
|
||||
if [ "${merge_parent_count}" -gt 1 ]; then
|
||||
cherry_pick_args=(-m 1 "${MERGE_SHA}")
|
||||
else
|
||||
cherry_pick_args=("${MERGE_SHA}")
|
||||
fi
|
||||
|
||||
if ! git cherry-pick -x "${cherry_pick_args[@]}"; then
|
||||
git cherry-pick --abort || true
|
||||
gh pr edit "${PR_NUMBER}" --add-label "backport:conflict" || true
|
||||
gh pr comment "${PR_NUMBER}" --body \
|
||||
"Automatic backport to \`${target_branch}\` conflicted. The original author or release manager should resolve it manually."
|
||||
failures+=("${target_branch}:cherry-pick failed")
|
||||
continue
|
||||
fi
|
||||
|
||||
git push --force-with-lease origin "${head_branch}"
|
||||
|
||||
body_file="$(mktemp)"
|
||||
printf '%s\n' \
|
||||
"Automated backport of [#${PR_NUMBER}](${PR_URL})." \
|
||||
"" \
|
||||
"- Source PR: #${PR_NUMBER}" \
|
||||
"- Source commit: ${MERGE_SHA}" \
|
||||
"- Target branch: ${target_branch}" \
|
||||
"- Severity: ${SEVERITY_LABEL}" \
|
||||
> "${body_file}"
|
||||
|
||||
pr_url="$(gh pr create \
|
||||
--base "${target_branch}" \
|
||||
--head "${head_branch}" \
|
||||
--title "${PR_TITLE} (backport to ${target_branch})" \
|
||||
--body-file "${body_file}")"
|
||||
|
||||
backport_pr_number="$(gh pr list \
|
||||
--base "${target_branch}" \
|
||||
--head "${head_branch}" \
|
||||
--state open \
|
||||
--json number \
|
||||
--jq '.[0].number')"
|
||||
|
||||
gh pr edit "${backport_pr_number}" \
|
||||
--add-label "security:patch" \
|
||||
--add-label "${SEVERITY_LABEL}" \
|
||||
--add-label "${backport_label}" || true
|
||||
|
||||
successes+=("${target_branch}:created:${pr_url}")
|
||||
done < <(
|
||||
python3 - <<'PY'
|
||||
import json
|
||||
import os
|
||||
for pair in json.loads(os.environ["TARGET_PAIRS"]):
|
||||
print(f"{pair['label']}\t{pair['branch']}")
|
||||
PY
|
||||
)
|
||||
|
||||
summary_file="$(mktemp)"
|
||||
{
|
||||
echo "## Security backport summary"
|
||||
echo
|
||||
if [ "${#successes[@]}" -gt 0 ]; then
|
||||
echo "### Created or existing"
|
||||
for entry in "${successes[@]}"; do
|
||||
echo "- ${entry}"
|
||||
done
|
||||
echo
|
||||
fi
|
||||
if [ "${#failures[@]}" -gt 0 ]; then
|
||||
echo "### Failures"
|
||||
for entry in "${failures[@]}"; do
|
||||
echo "- ${entry}"
|
||||
done
|
||||
fi
|
||||
} | tee -a "${GITHUB_STEP_SUMMARY}" > "${summary_file}"
|
||||
|
||||
gh pr comment "${PR_NUMBER}" --body-file "${summary_file}"
|
||||
|
||||
if [ "${#failures[@]}" -gt 0 ]; then
|
||||
printf 'Backport failures:\n%s\n' "${failures[@]}" >&2
|
||||
exit 1
|
||||
fi
|
||||
@@ -0,0 +1,214 @@
|
||||
name: security-patch-prs
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 3 * * 1-5"
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
patch:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
lane:
|
||||
- gomod
|
||||
- terraform
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Go
|
||||
uses: ./.github/actions/setup-go
|
||||
|
||||
- name: Patch Go dependencies
|
||||
if: matrix.lane == 'gomod'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
go get -u=patch ./...
|
||||
go mod tidy
|
||||
|
||||
# Guardrail: do not auto-edit replace directives.
|
||||
if git diff --unified=0 -- go.mod | grep -E '^[+-]replace '; then
|
||||
echo "Refusing to auto-edit go.mod replace directives"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Guardrail: only go.mod / go.sum may change.
|
||||
extra="$(git diff --name-only | grep -Ev '^(go\.mod|go\.sum)$' || true)"
|
||||
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
|
||||
|
||||
- name: Patch bundled Terraform
|
||||
if: matrix.lane == 'terraform'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
current="$(
|
||||
grep -oE 'NewVersion\("[0-9]+\.[0-9]+\.[0-9]+"\)' \
|
||||
provisioner/terraform/install.go \
|
||||
| head -1 \
|
||||
| grep -oE '[0-9]+\.[0-9]+\.[0-9]+'
|
||||
)"
|
||||
|
||||
series="$(echo "$current" | cut -d. -f1,2)"
|
||||
|
||||
latest="$(
|
||||
curl -fsSL https://releases.hashicorp.com/terraform/index.json \
|
||||
| jq -r --arg series "$series" '
|
||||
.versions
|
||||
| keys[]
|
||||
| select(startswith($series + "."))
|
||||
' \
|
||||
| sort -V \
|
||||
| tail -1
|
||||
)"
|
||||
|
||||
test -n "$latest"
|
||||
[ "$latest" != "$current" ] || exit 0
|
||||
|
||||
CURRENT_TERRAFORM_VERSION="$current" \
|
||||
LATEST_TERRAFORM_VERSION="$latest" \
|
||||
python3 - <<'PY'
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
current = os.environ["CURRENT_TERRAFORM_VERSION"]
|
||||
latest = os.environ["LATEST_TERRAFORM_VERSION"]
|
||||
|
||||
updates = {
|
||||
"scripts/Dockerfile.base": (
|
||||
f"terraform/{current}/",
|
||||
f"terraform/{latest}/",
|
||||
),
|
||||
"provisioner/terraform/install.go": (
|
||||
f'NewVersion("{current}")',
|
||||
f'NewVersion("{latest}")',
|
||||
),
|
||||
"install.sh": (
|
||||
f'TERRAFORM_VERSION="{current}"',
|
||||
f'TERRAFORM_VERSION="{latest}"',
|
||||
),
|
||||
}
|
||||
|
||||
for path_str, (before, after) in updates.items():
|
||||
path = Path(path_str)
|
||||
content = path.read_text()
|
||||
if before not in content:
|
||||
raise SystemExit(f"did not find expected text in {path_str}: {before}")
|
||||
path.write_text(content.replace(before, after))
|
||||
PY
|
||||
|
||||
# Guardrail: only the Terraform-version files may change.
|
||||
extra="$(git diff --name-only | grep -Ev '^(scripts/Dockerfile.base|provisioner/terraform/install.go|install.sh)$' || true)"
|
||||
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
|
||||
|
||||
- name: Validate Go dependency patch
|
||||
if: matrix.lane == 'gomod'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
go test ./...
|
||||
|
||||
- name: Validate Terraform patch
|
||||
if: matrix.lane == 'terraform'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
go test ./provisioner/terraform/...
|
||||
docker build -f scripts/Dockerfile.base .
|
||||
|
||||
- name: Skip PR creation when there are no changes
|
||||
id: changes
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if git diff --quiet; then
|
||||
echo "has_changes=false" >> "${GITHUB_OUTPUT}"
|
||||
else
|
||||
echo "has_changes=true" >> "${GITHUB_OUTPUT}"
|
||||
fi
|
||||
|
||||
- name: Commit changes
|
||||
if: steps.changes.outputs.has_changes == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
git checkout -B "secpatch/${{ matrix.lane }}"
|
||||
git add -A
|
||||
git commit -m "security: patch ${{ matrix.lane }}"
|
||||
|
||||
- name: Push branch
|
||||
if: steps.changes.outputs.has_changes == 'true'
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
git push --force-with-lease \
|
||||
"https://x-access-token:${GH_TOKEN}@github.com/${{ github.repository }}.git" \
|
||||
"HEAD:refs/heads/secpatch/${{ matrix.lane }}"
|
||||
|
||||
- name: Create or update PR
|
||||
if: steps.changes.outputs.has_changes == 'true'
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
branch="secpatch/${{ matrix.lane }}"
|
||||
title="security: patch ${{ matrix.lane }}"
|
||||
body="$(cat <<'EOF'
|
||||
Automated security patch PR for `${{ matrix.lane }}`.
|
||||
|
||||
Scope:
|
||||
- gomod: patch-level Go dependency updates only
|
||||
- terraform: bundled Terraform patch updates only
|
||||
|
||||
Guardrails:
|
||||
- no application-code edits
|
||||
- no auto-editing of go.mod replace directives
|
||||
- CI must pass
|
||||
EOF
|
||||
)"
|
||||
|
||||
existing_pr="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
|
||||
if [[ -n "${existing_pr}" ]]; then
|
||||
gh pr edit "${existing_pr}" \
|
||||
--title "${title}" \
|
||||
--body "${body}"
|
||||
pr_number="${existing_pr}"
|
||||
else
|
||||
gh pr create \
|
||||
--base main \
|
||||
--head "${branch}" \
|
||||
--title "${title}" \
|
||||
--body "${body}"
|
||||
pr_number="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
|
||||
fi
|
||||
|
||||
for label in security dependencies automated-pr; do
|
||||
if gh label list --json name --jq '.[].name' | grep -Fxq "${label}"; then
|
||||
gh pr edit "${pr_number}" --add-label "${label}"
|
||||
fi
|
||||
done
|
||||
@@ -27,7 +27,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-go
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
|
||||
uses: github/codeql-action/init@c10b8064de6f491fea524254123dbe5e09572f13 # v3.29.5
|
||||
with:
|
||||
languages: go, javascript
|
||||
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
rm Makefile
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
|
||||
uses: github/codeql-action/analyze@c10b8064de6f491fea524254123dbe5e09572f13 # v3.29.5
|
||||
|
||||
- name: Send Slack notification on failure
|
||||
if: ${{ failure() }}
|
||||
|
||||
@@ -18,7 +18,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -96,7 +96,7 @@ jobs:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -120,7 +120,7 @@ jobs:
|
||||
actions: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ jobs:
|
||||
pull-requests: write # required to post PR review comments by the action
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -187,7 +187,11 @@ func (*Manager) connectServer(ctx context.Context, cfg ServerConfig) (*client.Cl
|
||||
connectCtx, cancel := context.WithTimeout(ctx, connectTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := c.Start(connectCtx); err != nil {
|
||||
// Use the parent ctx (not connectCtx) so the subprocess outlives
|
||||
// the connect/initialize handshake. connectCtx bounds only the
|
||||
// Initialize call below. The subprocess is cleaned up when the
|
||||
// Manager is closed or ctx is canceled.
|
||||
if err := c.Start(ctx); err != nil {
|
||||
_ = c.Close()
|
||||
return nil, xerrors.Errorf("start %q: %w", cfg.Name, err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package agentmcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
@@ -8,6 +13,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestSplitToolName(t *testing.T) {
|
||||
@@ -193,3 +199,118 @@ func TestConvertResult(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectServer_StdioProcessSurvivesConnect verifies that a stdio MCP
|
||||
// server subprocess remains alive after connectServer returns. This is a
|
||||
// regression test for a bug where the subprocess was tied to a short-lived
|
||||
// connectCtx and killed as soon as the context was canceled.
|
||||
func TestConnectServer_StdioProcessSurvivesConnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" {
|
||||
// Child process: act as a minimal MCP server over stdio.
|
||||
runFakeMCPServer()
|
||||
return
|
||||
}
|
||||
|
||||
// Get the path to the test binary so we can re-exec ourselves
|
||||
// as a fake MCP server subprocess.
|
||||
testBin, err := os.Executable()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := ServerConfig{
|
||||
Name: "fake",
|
||||
Transport: "stdio",
|
||||
Command: testBin,
|
||||
Args: []string{"-test.run=^TestConnectServer_StdioProcessSurvivesConnect$"},
|
||||
Env: map[string]string{"TEST_MCP_FAKE_SERVER": "1"},
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
m := &Manager{}
|
||||
client, err := m.connectServer(ctx, cfg)
|
||||
require.NoError(t, err, "connectServer should succeed")
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
// At this point connectServer has returned and its internal
|
||||
// connectCtx has been canceled. The subprocess must still be
|
||||
// alive. Verify by listing tools (requires a live server).
|
||||
listCtx, listCancel := context.WithTimeout(ctx, testutil.WaitShort)
|
||||
defer listCancel()
|
||||
result, err := client.ListTools(listCtx, mcp.ListToolsRequest{})
|
||||
require.NoError(t, err, "ListTools should succeed — server must be alive after connect")
|
||||
require.Len(t, result.Tools, 1)
|
||||
assert.Equal(t, "echo", result.Tools[0].Name)
|
||||
}
|
||||
|
||||
// runFakeMCPServer implements a minimal JSON-RPC / MCP server over
|
||||
// stdin/stdout, just enough for initialize + tools/list.
|
||||
func runFakeMCPServer() {
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
|
||||
var req struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.RawMessage `json:"id"`
|
||||
Method string `json:"method"`
|
||||
}
|
||||
if err := json.Unmarshal(line, &req); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var resp any
|
||||
switch req.Method {
|
||||
case "initialize":
|
||||
resp = map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": req.ID,
|
||||
"result": map[string]any{
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": map[string]any{
|
||||
"tools": map[string]any{},
|
||||
},
|
||||
"serverInfo": map[string]any{
|
||||
"name": "fake-server",
|
||||
"version": "0.0.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
case "notifications/initialized":
|
||||
// No response needed for notifications.
|
||||
continue
|
||||
case "tools/list":
|
||||
resp = map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": req.ID,
|
||||
"result": map[string]any{
|
||||
"tools": []map[string]any{
|
||||
{
|
||||
"name": "echo",
|
||||
"description": "echoes input",
|
||||
"inputSchema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
default:
|
||||
resp = map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": req.ID,
|
||||
"error": map[string]any{
|
||||
"code": -32601,
|
||||
"message": "method not found",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
out, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
_, _ = fmt.Fprintf(os.Stdout, "%s\n", out)
|
||||
}
|
||||
}
|
||||
|
||||
+28
-19
@@ -3,11 +3,13 @@
|
||||
"enabled": true,
|
||||
"clientKind": "git",
|
||||
"useIgnoreFile": true,
|
||||
"defaultBranch": "main",
|
||||
"defaultBranch": "main"
|
||||
},
|
||||
"files": {
|
||||
"includes": ["**", "!**/pnpm-lock.yaml"],
|
||||
"ignoreUnknown": true,
|
||||
// static/*.html are Go templates with {{ }} directives that
|
||||
// Biome's HTML parser does not support.
|
||||
"includes": ["**", "!**/pnpm-lock.yaml", "!**/static/*.html"],
|
||||
"ignoreUnknown": true
|
||||
},
|
||||
"linter": {
|
||||
"rules": {
|
||||
@@ -15,7 +17,7 @@
|
||||
"noSvgWithoutTitle": "off",
|
||||
"useButtonType": "off",
|
||||
"useSemanticElements": "off",
|
||||
"noStaticElementInteractions": "off",
|
||||
"noStaticElementInteractions": "off"
|
||||
},
|
||||
"correctness": {
|
||||
"noUnusedImports": "warn",
|
||||
@@ -24,9 +26,9 @@
|
||||
"noUnusedVariables": {
|
||||
"level": "warn",
|
||||
"options": {
|
||||
"ignoreRestSiblings": true,
|
||||
},
|
||||
},
|
||||
"ignoreRestSiblings": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"style": {
|
||||
"noNonNullAssertion": "off",
|
||||
@@ -47,7 +49,7 @@
|
||||
"paths": {
|
||||
"react": {
|
||||
"message": "React 19 no longer requires forwardRef. Use ref as a prop instead.",
|
||||
"importNames": ["forwardRef"],
|
||||
"importNames": ["forwardRef"]
|
||||
},
|
||||
// "@mui/material/Alert": "Use components/Alert/Alert instead.",
|
||||
// "@mui/material/AlertTitle": "Use components/Alert/Alert instead.",
|
||||
@@ -115,10 +117,10 @@
|
||||
"@emotion/styled": "Use Tailwind CSS instead.",
|
||||
// "@emotion/cache": "Use Tailwind CSS instead.",
|
||||
// "components/Stack/Stack": "Use Tailwind flex utilities instead (e.g., <div className='flex flex-col gap-4'>).",
|
||||
"lodash": "Use lodash/<name> instead.",
|
||||
},
|
||||
},
|
||||
},
|
||||
"lodash": "Use lodash/<name> instead."
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"suspicious": {
|
||||
"noArrayIndexKey": "off",
|
||||
@@ -129,14 +131,21 @@
|
||||
"noConsole": {
|
||||
"level": "error",
|
||||
"options": {
|
||||
"allow": ["error", "info", "warn"],
|
||||
},
|
||||
},
|
||||
"allow": ["error", "info", "warn"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"complexity": {
|
||||
"noImportantStyles": "off", // TODO: check and fix !important styles
|
||||
},
|
||||
},
|
||||
"noImportantStyles": "off" // TODO: check and fix !important styles
|
||||
}
|
||||
}
|
||||
},
|
||||
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json",
|
||||
"css": {
|
||||
"parser": {
|
||||
// Biome 2.3+ requires opt-in for @apply and other
|
||||
// Tailwind directives.
|
||||
"tailwindDirectives": true
|
||||
}
|
||||
},
|
||||
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json"
|
||||
}
|
||||
|
||||
Generated
+6
@@ -14175,6 +14175,9 @@ const docTemplate = `{
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -14496,6 +14499,9 @@ const docTemplate = `{
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
Generated
+6
@@ -12739,6 +12739,9 @@
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -13039,6 +13042,9 @@
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
+8
-1
@@ -26,6 +26,11 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// Limit the count query to avoid a slow sequential scan due to joins
|
||||
// on a large table. Set to 0 to disable capping (but also see the note
|
||||
// in the SQL query).
|
||||
const auditLogCountCap = 2000
|
||||
|
||||
// @Summary Get audit logs
|
||||
// @ID get-audit-logs
|
||||
// @Security CoderSessionToken
|
||||
@@ -66,7 +71,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
countFilter.Username = ""
|
||||
}
|
||||
|
||||
// Use the same filters to count the number of audit logs
|
||||
countFilter.CountCap = auditLogCountCap
|
||||
count, err := api.Database.CountAuditLogs(ctx, countFilter)
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
@@ -81,6 +86,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
|
||||
AuditLogs: []codersdk.AuditLog{},
|
||||
Count: 0,
|
||||
CountCap: auditLogCountCap,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -98,6 +104,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
|
||||
AuditLogs: api.convertAuditLogs(ctx, dblogs),
|
||||
Count: count,
|
||||
CountCap: auditLogCountCap,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1528,7 +1528,10 @@ func nullInt64Ptr(v sql.NullInt64) *int64 {
|
||||
// Chat converts a database.Chat to a codersdk.Chat. It coalesces
|
||||
// nil slices and maps to empty values for JSON serialization and
|
||||
// derives RootChatID from the parent chain when not explicitly set.
|
||||
func Chat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat {
|
||||
// When diffStatus is non-nil the response includes diff metadata.
|
||||
// When files is non-empty the response includes file metadata;
|
||||
// pass nil to omit the files field (e.g. list endpoints).
|
||||
func Chat(c database.Chat, diffStatus *database.ChatDiffStatus, files []database.GetChatFileMetadataByChatIDRow) codersdk.Chat {
|
||||
mcpServerIDs := c.MCPServerIDs
|
||||
if mcpServerIDs == nil {
|
||||
mcpServerIDs = []uuid.UUID{}
|
||||
@@ -1581,6 +1584,19 @@ func Chat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat {
|
||||
convertedDiffStatus := ChatDiffStatus(c.ID, diffStatus)
|
||||
chat.DiffStatus = &convertedDiffStatus
|
||||
}
|
||||
if len(files) > 0 {
|
||||
chat.Files = make([]codersdk.ChatFileMetadata, 0, len(files))
|
||||
for _, row := range files {
|
||||
chat.Files = append(chat.Files, codersdk.ChatFileMetadata{
|
||||
ID: row.ID,
|
||||
OwnerID: row.OwnerID,
|
||||
OrganizationID: row.OrganizationID,
|
||||
Name: row.Name,
|
||||
MimeType: row.Mimetype,
|
||||
CreatedAt: row.CreatedAt,
|
||||
})
|
||||
}
|
||||
}
|
||||
if c.LastInjectedContext.Valid {
|
||||
var parts []codersdk.ChatMessagePart
|
||||
// Internal fields are stripped at write time in
|
||||
@@ -1604,9 +1620,9 @@ func ChatRows(rows []database.GetChatsRow, diffStatusesByChatID map[uuid.UUID]da
|
||||
for i, row := range rows {
|
||||
diffStatus, ok := diffStatusesByChatID[row.Chat.ID]
|
||||
if ok {
|
||||
result[i] = Chat(row.Chat, &diffStatus)
|
||||
result[i] = Chat(row.Chat, &diffStatus, nil)
|
||||
} else {
|
||||
result[i] = Chat(row.Chat, nil)
|
||||
result[i] = Chat(row.Chat, nil, nil)
|
||||
if diffStatusesByChatID != nil {
|
||||
emptyDiffStatus := ChatDiffStatus(row.Chat.ID, nil)
|
||||
result[i].DiffStatus = &emptyDiffStatus
|
||||
|
||||
@@ -561,14 +561,26 @@ func TestChat_AllFieldsPopulated(t *testing.T) {
|
||||
ChatID: input.ID,
|
||||
}
|
||||
|
||||
got := db2sdk.Chat(input, diffStatus)
|
||||
fileRows := []database.GetChatFileMetadataByChatIDRow{
|
||||
{
|
||||
ID: uuid.New(),
|
||||
OwnerID: input.OwnerID,
|
||||
OrganizationID: uuid.New(),
|
||||
Name: "test.png",
|
||||
Mimetype: "image/png",
|
||||
CreatedAt: now,
|
||||
},
|
||||
}
|
||||
|
||||
got := db2sdk.Chat(input, diffStatus, fileRows)
|
||||
|
||||
v := reflect.ValueOf(got)
|
||||
typ := v.Type()
|
||||
// HasUnread is populated by ChatRows (which joins the
|
||||
// read-cursor query), not by Chat, so it is expected
|
||||
// to remain zero here.
|
||||
skip := map[string]bool{"HasUnread": true}
|
||||
// read-cursor query), not by Chat. Warnings is a transient
|
||||
// field populated by handlers, not the converter. Both are
|
||||
// expected to remain zero here.
|
||||
skip := map[string]bool{"HasUnread": true, "Warnings": true}
|
||||
for i := range typ.NumField() {
|
||||
field := typ.Field(i)
|
||||
if skip[field.Name] {
|
||||
@@ -581,6 +593,112 @@ func TestChat_AllFieldsPopulated(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_FileMetadataConversion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ownerID := uuid.New()
|
||||
orgID := uuid.New()
|
||||
fileID := uuid.New()
|
||||
now := dbtime.Now()
|
||||
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
OwnerID: ownerID,
|
||||
LastModelConfigID: uuid.New(),
|
||||
Title: "file metadata test",
|
||||
Status: database.ChatStatusWaiting,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
rows := []database.GetChatFileMetadataByChatIDRow{
|
||||
{
|
||||
ID: fileID,
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
Name: "screenshot.png",
|
||||
Mimetype: "image/png",
|
||||
CreatedAt: now,
|
||||
},
|
||||
}
|
||||
|
||||
result := db2sdk.Chat(chat, nil, rows)
|
||||
|
||||
require.Len(t, result.Files, 1)
|
||||
f := result.Files[0]
|
||||
require.Equal(t, fileID, f.ID)
|
||||
require.Equal(t, ownerID, f.OwnerID, "OwnerID must be mapped from DB row")
|
||||
require.Equal(t, orgID, f.OrganizationID, "OrganizationID must be mapped from DB row")
|
||||
require.Equal(t, "screenshot.png", f.Name)
|
||||
require.Equal(t, "image/png", f.MimeType)
|
||||
require.Equal(t, now, f.CreatedAt)
|
||||
|
||||
// Verify JSON serialization uses snake_case for mime_type.
|
||||
data, err := json.Marshal(f)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(data), `"mime_type"`)
|
||||
require.NotContains(t, string(data), `"mimetype"`)
|
||||
}
|
||||
|
||||
func TestChat_NilFilesOmitted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
OwnerID: uuid.New(),
|
||||
LastModelConfigID: uuid.New(),
|
||||
Title: "no files",
|
||||
Status: database.ChatStatusWaiting,
|
||||
CreatedAt: dbtime.Now(),
|
||||
UpdatedAt: dbtime.Now(),
|
||||
}
|
||||
|
||||
result := db2sdk.Chat(chat, nil, nil)
|
||||
require.Empty(t, result.Files)
|
||||
}
|
||||
|
||||
func TestChat_MultipleFiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := dbtime.Now()
|
||||
file1 := uuid.New()
|
||||
file2 := uuid.New()
|
||||
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
OwnerID: uuid.New(),
|
||||
LastModelConfigID: uuid.New(),
|
||||
Title: "multi file test",
|
||||
Status: database.ChatStatusWaiting,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
rows := []database.GetChatFileMetadataByChatIDRow{
|
||||
{
|
||||
ID: file1,
|
||||
OwnerID: chat.OwnerID,
|
||||
OrganizationID: uuid.New(),
|
||||
Name: "a.png",
|
||||
Mimetype: "image/png",
|
||||
CreatedAt: now,
|
||||
},
|
||||
{
|
||||
ID: file2,
|
||||
OwnerID: chat.OwnerID,
|
||||
OrganizationID: uuid.New(),
|
||||
Name: "b.txt",
|
||||
Mimetype: "text/plain",
|
||||
CreatedAt: now,
|
||||
},
|
||||
}
|
||||
|
||||
result := db2sdk.Chat(chat, nil, rows)
|
||||
require.Len(t, result.Files, 2)
|
||||
require.Equal(t, "a.png", result.Files[0].Name)
|
||||
require.Equal(t, "b.txt", result.Files[1].Name)
|
||||
}
|
||||
|
||||
func TestChatQueuedMessage_MalformedContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -2583,6 +2583,10 @@ func (q *querier) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.C
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]database.GetChatFileMetadataByChatIDRow, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatFileMetadataByChatID)(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
|
||||
files, err := q.db.GetChatFilesByIDs(ctx, ids)
|
||||
if err != nil {
|
||||
@@ -5393,6 +5397,17 @@ func (q *querier) InsertWorkspaceResourceMetadata(ctx context.Context, arg datab
|
||||
return q.db.InsertWorkspaceResourceMetadata(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) LinkChatFiles(ctx context.Context, arg database.LinkChatFilesParams) (int32, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.LinkChatFiles(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
@@ -5767,15 +5782,15 @@ func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByI
|
||||
return q.db.UpdateChatByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
func (q *querier) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
// The batch heartbeat is a system-level operation filtered by
|
||||
// worker_id. Authorization is enforced by the AsChatd context
|
||||
// at the call site rather than per-row, because checking each
|
||||
// row individually would defeat the purpose of batching.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.UpdateChatHeartbeat(ctx, arg)
|
||||
return q.db.UpdateChatHeartbeats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateChatLabelsByIDParams) (database.Chat, error) {
|
||||
|
||||
@@ -400,6 +400,17 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return([]database.Chat{chat}, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns([]database.Chat{chat})
|
||||
}))
|
||||
s.Run("LinkChatFiles", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.LinkChatFilesParams{
|
||||
ChatID: chat.ID,
|
||||
MaxFileLinks: int32(codersdk.MaxChatFileIDs),
|
||||
FileIds: []uuid.UUID{uuid.New()},
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().LinkChatFiles(gomock.Any(), arg).Return(int32(0), nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int32(0))
|
||||
}))
|
||||
s.Run("PinChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
@@ -576,6 +587,19 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatFilesByIDs(gomock.Any(), []uuid.UUID{file.ID}).Return([]database.ChatFile{file}, nil).AnyTimes()
|
||||
check.Args([]uuid.UUID{file.ID}).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns([]database.ChatFile{file})
|
||||
}))
|
||||
s.Run("GetChatFileMetadataByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
file := testutil.Fake(s.T(), faker, database.ChatFile{})
|
||||
rows := []database.GetChatFileMetadataByChatIDRow{{
|
||||
ID: file.ID,
|
||||
Name: file.Name,
|
||||
Mimetype: file.Mimetype,
|
||||
CreatedAt: file.CreatedAt,
|
||||
OwnerID: file.OwnerID,
|
||||
OrganizationID: file.OrganizationID,
|
||||
}}
|
||||
dbm.EXPECT().GetChatFileMetadataByChatID(gomock.Any(), file.ID).Return(rows, nil).AnyTimes()
|
||||
check.Args(file.ID).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns(rows)
|
||||
}))
|
||||
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
@@ -818,15 +842,15 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
s.Run("UpdateChatHeartbeats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
resultID := uuid.New()
|
||||
arg := database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{resultID},
|
||||
WorkerID: uuid.New(),
|
||||
Now: time.Now(),
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatHeartbeat(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
|
||||
dbm.EXPECT().UpdateChatHeartbeats(gomock.Any(), arg).Return([]uuid.UUID{resultID}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]uuid.UUID{resultID})
|
||||
}))
|
||||
s.Run("UpdateChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
|
||||
@@ -1128,6 +1128,14 @@ func (m queryMetricsStore) GetChatFileByID(ctx context.Context, id uuid.UUID) (d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]database.GetChatFileMetadataByChatIDRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatFileMetadataByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatFileMetadataByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatFileMetadataByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatFilesByIDs(ctx, ids)
|
||||
@@ -3776,6 +3784,14 @@ func (m queryMetricsStore) InsertWorkspaceResourceMetadata(ctx context.Context,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) LinkChatFiles(ctx context.Context, arg database.LinkChatFilesParams) (int32, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.LinkChatFiles(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("LinkChatFiles").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "LinkChatFiles").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeClients(ctx, arg)
|
||||
@@ -4120,11 +4136,11 @@ func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.Upda
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
func (m queryMetricsStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatHeartbeat(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatHeartbeat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeat").Inc()
|
||||
r0, r1 := m.s.UpdateChatHeartbeats(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatHeartbeats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeats").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
|
||||
@@ -2072,6 +2072,21 @@ func (mr *MockStoreMockRecorder) GetChatFileByID(ctx, id any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFileByID", reflect.TypeOf((*MockStore)(nil).GetChatFileByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatFileMetadataByChatID mocks base method.
|
||||
func (m *MockStore) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]database.GetChatFileMetadataByChatIDRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatFileMetadataByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].([]database.GetChatFileMetadataByChatIDRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatFileMetadataByChatID indicates an expected call of GetChatFileMetadataByChatID.
|
||||
func (mr *MockStoreMockRecorder) GetChatFileMetadataByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFileMetadataByChatID", reflect.TypeOf((*MockStore)(nil).GetChatFileMetadataByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// GetChatFilesByIDs mocks base method.
|
||||
func (m *MockStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7066,6 +7081,21 @@ func (mr *MockStoreMockRecorder) InsertWorkspaceResourceMetadata(ctx, arg any) *
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertWorkspaceResourceMetadata", reflect.TypeOf((*MockStore)(nil).InsertWorkspaceResourceMetadata), ctx, arg)
|
||||
}
|
||||
|
||||
// LinkChatFiles mocks base method.
|
||||
func (m *MockStore) LinkChatFiles(ctx context.Context, arg database.LinkChatFilesParams) (int32, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkChatFiles", ctx, arg)
|
||||
ret0, _ := ret[0].(int32)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LinkChatFiles indicates an expected call of LinkChatFiles.
|
||||
func (mr *MockStoreMockRecorder) LinkChatFiles(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkChatFiles", reflect.TypeOf((*MockStore)(nil).LinkChatFiles), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeClients mocks base method.
|
||||
func (m *MockStore) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7805,19 +7835,19 @@ func (mr *MockStoreMockRecorder) UpdateChatByID(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatByID", reflect.TypeOf((*MockStore)(nil).UpdateChatByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatHeartbeat mocks base method.
|
||||
func (m *MockStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
// UpdateChatHeartbeats mocks base method.
|
||||
func (m *MockStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatHeartbeat", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret := m.ctrl.Call(m, "UpdateChatHeartbeats", ctx, arg)
|
||||
ret0, _ := ret[0].([]uuid.UUID)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatHeartbeat indicates an expected call of UpdateChatHeartbeat.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call {
|
||||
// UpdateChatHeartbeats indicates an expected call of UpdateChatHeartbeats.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatHeartbeats(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeats", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeats), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatLabelsByID mocks base method.
|
||||
|
||||
Generated
+16
@@ -1269,6 +1269,11 @@ CREATE TABLE chat_diff_statuses (
|
||||
head_branch text
|
||||
);
|
||||
|
||||
CREATE TABLE chat_file_links (
|
||||
chat_id uuid NOT NULL,
|
||||
file_id uuid NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE chat_files (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
owner_id uuid NOT NULL,
|
||||
@@ -3344,6 +3349,9 @@ ALTER TABLE ONLY boundary_usage_stats
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
|
||||
ALTER TABLE ONLY chat_file_links
|
||||
ADD CONSTRAINT chat_file_links_chat_id_file_id_key UNIQUE (chat_id, file_id);
|
||||
|
||||
ALTER TABLE ONLY chat_files
|
||||
ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
|
||||
|
||||
@@ -3734,6 +3742,8 @@ CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC);
|
||||
|
||||
CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses USING btree (stale_at);
|
||||
|
||||
CREATE INDEX idx_chat_file_links_chat_id ON chat_file_links USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chat_files_org ON chat_files USING btree (organization_id);
|
||||
|
||||
CREATE INDEX idx_chat_files_owner ON chat_files USING btree (owner_id);
|
||||
@@ -4036,6 +4046,12 @@ ALTER TABLE ONLY api_keys
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_file_links
|
||||
ADD CONSTRAINT chat_file_links_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_file_links
|
||||
ADD CONSTRAINT chat_file_links_file_id_fkey FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_files
|
||||
ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ const (
|
||||
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
|
||||
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFileLinksChatID ForeignKeyConstraint = "chat_file_links_chat_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFileLinksFileID ForeignKeyConstraint = "chat_file_links_file_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_file_id_fkey FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFilesOrganizationID ForeignKeyConstraint = "chat_files_organization_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFilesOwnerID ForeignKeyConstraint = "chat_files_owner_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatMessagesChatID ForeignKeyConstraint = "chat_messages_chat_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
ALTER TABLE chats ADD COLUMN file_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL;
|
||||
|
||||
UPDATE chats SET file_ids = (
|
||||
SELECT COALESCE(array_agg(cfl.file_id), '{}')
|
||||
FROM chat_file_links cfl
|
||||
WHERE cfl.chat_id = chats.id
|
||||
);
|
||||
|
||||
DROP TABLE chat_file_links;
|
||||
@@ -0,0 +1,17 @@
|
||||
CREATE TABLE chat_file_links (
|
||||
chat_id uuid NOT NULL,
|
||||
file_id uuid NOT NULL,
|
||||
UNIQUE (chat_id, file_id)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chat_file_links_chat_id ON chat_file_links (chat_id);
|
||||
|
||||
ALTER TABLE chat_file_links
|
||||
ADD CONSTRAINT chat_file_links_chat_id_fkey
|
||||
FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE chat_file_links
|
||||
ADD CONSTRAINT chat_file_links_file_id_fkey
|
||||
FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE chats DROP COLUMN IF EXISTS file_ids;
|
||||
@@ -0,0 +1,5 @@
|
||||
INSERT INTO chat_file_links (chat_id, file_id)
|
||||
VALUES (
|
||||
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
|
||||
'00000000-0000-0000-0000-000000000099'
|
||||
);
|
||||
@@ -187,6 +187,10 @@ func (c ChatFile) RBACObject() rbac.Object {
|
||||
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()).InOrg(c.OrganizationID)
|
||||
}
|
||||
|
||||
func (c GetChatFileMetadataByChatIDRow) RBACObject() rbac.Object {
|
||||
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()).InOrg(c.OrganizationID)
|
||||
}
|
||||
|
||||
func (s APIKeyScope) ToRBAC() rbac.ScopeName {
|
||||
switch s {
|
||||
case ApiKeyScopeCoderAll:
|
||||
|
||||
@@ -584,6 +584,7 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi
|
||||
arg.DateTo,
|
||||
arg.BuildReason,
|
||||
arg.RequestID,
|
||||
arg.CountCap,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -720,6 +721,7 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
|
||||
arg.WorkspaceID,
|
||||
arg.ConnectionID,
|
||||
arg.Status,
|
||||
arg.CountCap,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
||||
@@ -145,5 +145,13 @@ func extractWhereClause(query string) string {
|
||||
// Remove SQL comments
|
||||
whereClause = regexp.MustCompile(`(?m)--.*$`).ReplaceAllString(whereClause, "")
|
||||
|
||||
// Normalize indentation so subquery wrapping doesn't cause
|
||||
// mismatches.
|
||||
lines := strings.Split(whereClause, "\n")
|
||||
for i, line := range lines {
|
||||
lines[i] = strings.TrimLeft(line, " \t")
|
||||
}
|
||||
whereClause = strings.Join(lines, "\n")
|
||||
|
||||
return strings.TrimSpace(whereClause)
|
||||
}
|
||||
|
||||
@@ -4218,6 +4218,11 @@ type ChatFile struct {
|
||||
Data []byte `db:"data" json:"data"`
|
||||
}
|
||||
|
||||
type ChatFileLink struct {
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
FileID uuid.UUID `db:"file_id" json:"file_id"`
|
||||
}
|
||||
|
||||
type ChatMessage struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
|
||||
@@ -244,6 +244,10 @@ type sqlcQuerier interface {
|
||||
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
|
||||
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
|
||||
GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error)
|
||||
// GetChatFileMetadataByChatID returns lightweight file metadata for
|
||||
// all files linked to a chat. The data column is excluded to avoid
|
||||
// loading file content.
|
||||
GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]GetChatFileMetadataByChatIDRow, error)
|
||||
GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error)
|
||||
// GetChatIncludeDefaultSystemPrompt preserves the legacy default
|
||||
// for deployments created before the explicit include-default toggle.
|
||||
@@ -778,6 +782,15 @@ type sqlcQuerier interface {
|
||||
InsertWorkspaceProxy(ctx context.Context, arg InsertWorkspaceProxyParams) (WorkspaceProxy, error)
|
||||
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
|
||||
InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error)
|
||||
// LinkChatFiles inserts file associations into the chat_file_links
|
||||
// join table with deduplication (ON CONFLICT DO NOTHING). The INSERT
|
||||
// is conditional: it only proceeds when the total number of links
|
||||
// (existing + genuinely new) does not exceed max_file_links. Returns
|
||||
// the number of genuinely new file IDs that were NOT inserted due to
|
||||
// the cap. A return value of 0 means all files were linked (or were
|
||||
// already linked). A positive value means the cap blocked that many
|
||||
// new links.
|
||||
LinkChatFiles(ctx context.Context, arg LinkChatFilesParams) (int32, error)
|
||||
ListAIBridgeClients(ctx context.Context, arg ListAIBridgeClientsParams) ([]string, error)
|
||||
ListAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams) ([]ListAIBridgeInterceptionsRow, error)
|
||||
// Finds all unique AI Bridge interception telemetry summaries combinations
|
||||
@@ -857,9 +870,11 @@ type sqlcQuerier interface {
|
||||
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
|
||||
UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error)
|
||||
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
|
||||
// Bumps the heartbeat timestamp for a running chat so that other
|
||||
// replicas know the worker is still alive.
|
||||
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
|
||||
// Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
// provided they are still running and owned by the specified
|
||||
// worker. Returns the IDs that were actually updated so the
|
||||
// caller can detect stolen or completed chats via set-difference.
|
||||
UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error)
|
||||
UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error)
|
||||
// Updates the cached injected context parts (AGENTS.md +
|
||||
// skills) on the chat row. Called only when context changes
|
||||
|
||||
+362
-209
@@ -2275,93 +2275,105 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP
|
||||
}
|
||||
|
||||
const countAuditLogs = `-- name: CountAuditLogs :one
|
||||
SELECT COUNT(*)
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN $1::text != '' THEN resource_type = $1::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN $4::text != '' THEN resource_target = $4
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN action = $5::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower($7)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8::text != '' THEN users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT 1
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN $1::text != '' THEN resource_type = $1::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN $4::text != '' THEN resource_target = $4
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN action = $5::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower($7)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8::text != '' THEN users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
-- Avoid a slow scan on a large table with joins. The caller
|
||||
-- passes the count cap and we add 1 so the frontend can detect
|
||||
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
|
||||
-- -> NULL + 1 = NULL).
|
||||
-- NOTE: Parameterizing this so that we can easily change from,
|
||||
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
|
||||
-- here if disabling the capping on a large table permanently.
|
||||
-- This way the PG planner can plan parallel execution for
|
||||
-- potential large wins.
|
||||
LIMIT NULLIF($13::int, 0) + 1
|
||||
) AS limited_count
|
||||
`
|
||||
|
||||
type CountAuditLogsParams struct {
|
||||
@@ -2377,6 +2389,7 @@ type CountAuditLogsParams struct {
|
||||
DateTo time.Time `db:"date_to" json:"date_to"`
|
||||
BuildReason string `db:"build_reason" json:"build_reason"`
|
||||
RequestID uuid.UUID `db:"request_id" json:"request_id"`
|
||||
CountCap int32 `db:"count_cap" json:"count_cap"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error) {
|
||||
@@ -2393,6 +2406,7 @@ func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParam
|
||||
arg.DateTo,
|
||||
arg.BuildReason,
|
||||
arg.RequestID,
|
||||
arg.CountCap,
|
||||
)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
@@ -2889,6 +2903,56 @@ func (q *sqlQuerier) GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFil
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatFileMetadataByChatID = `-- name: GetChatFileMetadataByChatID :many
|
||||
SELECT cf.id, cf.owner_id, cf.organization_id, cf.name, cf.mimetype, cf.created_at
|
||||
FROM chat_files cf
|
||||
JOIN chat_file_links cfl ON cfl.file_id = cf.id
|
||||
WHERE cfl.chat_id = $1::uuid
|
||||
ORDER BY cf.created_at ASC
|
||||
`
|
||||
|
||||
type GetChatFileMetadataByChatIDRow struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Mimetype string `db:"mimetype" json:"mimetype"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
// GetChatFileMetadataByChatID returns lightweight file metadata for
|
||||
// all files linked to a chat. The data column is excluded to avoid
|
||||
// loading file content.
|
||||
func (q *sqlQuerier) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]GetChatFileMetadataByChatIDRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getChatFileMetadataByChatID, chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetChatFileMetadataByChatIDRow
|
||||
for rows.Next() {
|
||||
var i GetChatFileMetadataByChatIDRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.OrganizationID,
|
||||
&i.Name,
|
||||
&i.Mimetype,
|
||||
&i.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getChatFilesByIDs = `-- name: GetChatFilesByIDs :many
|
||||
SELECT id, owner_id, organization_id, created_at, name, mimetype, data FROM chat_files WHERE id = ANY($1::uuid[])
|
||||
`
|
||||
@@ -4530,7 +4594,8 @@ WITH chat_costs AS (
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
|
||||
FROM chat_messages cm
|
||||
JOIN chats c ON c.id = cm.chat_id
|
||||
WHERE c.owner_id = $1::uuid
|
||||
@@ -4547,7 +4612,8 @@ SELECT
|
||||
cc.total_input_tokens,
|
||||
cc.total_output_tokens,
|
||||
cc.total_cache_read_tokens,
|
||||
cc.total_cache_creation_tokens
|
||||
cc.total_cache_creation_tokens,
|
||||
cc.total_runtime_ms
|
||||
FROM chat_costs cc
|
||||
LEFT JOIN chats rc ON rc.id = cc.root_chat_id
|
||||
ORDER BY cc.total_cost_micros DESC
|
||||
@@ -4568,6 +4634,7 @@ type GetChatCostPerChatRow struct {
|
||||
TotalOutputTokens int64 `db:"total_output_tokens" json:"total_output_tokens"`
|
||||
TotalCacheReadTokens int64 `db:"total_cache_read_tokens" json:"total_cache_read_tokens"`
|
||||
TotalCacheCreationTokens int64 `db:"total_cache_creation_tokens" json:"total_cache_creation_tokens"`
|
||||
TotalRuntimeMs int64 `db:"total_runtime_ms" json:"total_runtime_ms"`
|
||||
}
|
||||
|
||||
// Per-root-chat cost breakdown for a single user within a date range.
|
||||
@@ -4591,6 +4658,7 @@ func (q *sqlQuerier) GetChatCostPerChat(ctx context.Context, arg GetChatCostPerC
|
||||
&i.TotalOutputTokens,
|
||||
&i.TotalCacheReadTokens,
|
||||
&i.TotalCacheCreationTokens,
|
||||
&i.TotalRuntimeMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -4622,7 +4690,8 @@ SELECT
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
@@ -4657,6 +4726,7 @@ type GetChatCostPerModelRow struct {
|
||||
TotalOutputTokens int64 `db:"total_output_tokens" json:"total_output_tokens"`
|
||||
TotalCacheReadTokens int64 `db:"total_cache_read_tokens" json:"total_cache_read_tokens"`
|
||||
TotalCacheCreationTokens int64 `db:"total_cache_creation_tokens" json:"total_cache_creation_tokens"`
|
||||
TotalRuntimeMs int64 `db:"total_runtime_ms" json:"total_runtime_ms"`
|
||||
}
|
||||
|
||||
// Per-model cost breakdown for a single user within a date range.
|
||||
@@ -4681,6 +4751,7 @@ func (q *sqlQuerier) GetChatCostPerModel(ctx context.Context, arg GetChatCostPer
|
||||
&i.TotalOutputTokens,
|
||||
&i.TotalCacheReadTokens,
|
||||
&i.TotalCacheCreationTokens,
|
||||
&i.TotalRuntimeMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -4714,7 +4785,8 @@ WITH chat_cost_users AS (
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
@@ -4748,6 +4820,7 @@ SELECT
|
||||
total_output_tokens,
|
||||
total_cache_read_tokens,
|
||||
total_cache_creation_tokens,
|
||||
total_runtime_ms,
|
||||
COUNT(*) OVER()::bigint AS total_count
|
||||
FROM
|
||||
chat_cost_users
|
||||
@@ -4780,6 +4853,7 @@ type GetChatCostPerUserRow struct {
|
||||
TotalOutputTokens int64 `db:"total_output_tokens" json:"total_output_tokens"`
|
||||
TotalCacheReadTokens int64 `db:"total_cache_read_tokens" json:"total_cache_read_tokens"`
|
||||
TotalCacheCreationTokens int64 `db:"total_cache_creation_tokens" json:"total_cache_creation_tokens"`
|
||||
TotalRuntimeMs int64 `db:"total_runtime_ms" json:"total_runtime_ms"`
|
||||
TotalCount int64 `db:"total_count" json:"total_count"`
|
||||
}
|
||||
|
||||
@@ -4812,6 +4886,7 @@ func (q *sqlQuerier) GetChatCostPerUser(ctx context.Context, arg GetChatCostPerU
|
||||
&i.TotalOutputTokens,
|
||||
&i.TotalCacheReadTokens,
|
||||
&i.TotalCacheCreationTokens,
|
||||
&i.TotalRuntimeMs,
|
||||
&i.TotalCount,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
@@ -4846,7 +4921,8 @@ SELECT
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
@@ -4872,6 +4948,7 @@ type GetChatCostSummaryRow struct {
|
||||
TotalOutputTokens int64 `db:"total_output_tokens" json:"total_output_tokens"`
|
||||
TotalCacheReadTokens int64 `db:"total_cache_read_tokens" json:"total_cache_read_tokens"`
|
||||
TotalCacheCreationTokens int64 `db:"total_cache_creation_tokens" json:"total_cache_creation_tokens"`
|
||||
TotalRuntimeMs int64 `db:"total_runtime_ms" json:"total_runtime_ms"`
|
||||
}
|
||||
|
||||
// Aggregate cost summary for a single user within a date range.
|
||||
@@ -4887,6 +4964,7 @@ func (q *sqlQuerier) GetChatCostSummary(ctx context.Context, arg GetChatCostSumm
|
||||
&i.TotalOutputTokens,
|
||||
&i.TotalCacheReadTokens,
|
||||
&i.TotalCacheCreationTokens,
|
||||
&i.TotalRuntimeMs,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -6019,6 +6097,57 @@ func (q *sqlQuerier) InsertChatQueuedMessage(ctx context.Context, arg InsertChat
|
||||
return i, err
|
||||
}
|
||||
|
||||
const linkChatFiles = `-- name: LinkChatFiles :one
|
||||
WITH current AS (
|
||||
SELECT COUNT(*) AS cnt
|
||||
FROM chat_file_links
|
||||
WHERE chat_id = $1::uuid
|
||||
),
|
||||
new_links AS (
|
||||
SELECT $1::uuid AS chat_id, unnest($2::uuid[]) AS file_id
|
||||
),
|
||||
genuinely_new AS (
|
||||
SELECT nl.chat_id, nl.file_id
|
||||
FROM new_links nl
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM chat_file_links cfl
|
||||
WHERE cfl.chat_id = nl.chat_id AND cfl.file_id = nl.file_id
|
||||
)
|
||||
),
|
||||
inserted AS (
|
||||
INSERT INTO chat_file_links (chat_id, file_id)
|
||||
SELECT gn.chat_id, gn.file_id
|
||||
FROM genuinely_new gn, current c
|
||||
WHERE c.cnt + (SELECT COUNT(*) FROM genuinely_new) <= $3::int
|
||||
ON CONFLICT (chat_id, file_id) DO NOTHING
|
||||
RETURNING file_id
|
||||
)
|
||||
SELECT
|
||||
(SELECT COUNT(*)::int FROM genuinely_new) -
|
||||
(SELECT COUNT(*)::int FROM inserted) AS rejected_new_files
|
||||
`
|
||||
|
||||
type LinkChatFilesParams struct {
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
FileIds []uuid.UUID `db:"file_ids" json:"file_ids"`
|
||||
MaxFileLinks int32 `db:"max_file_links" json:"max_file_links"`
|
||||
}
|
||||
|
||||
// LinkChatFiles inserts file associations into the chat_file_links
|
||||
// join table with deduplication (ON CONFLICT DO NOTHING). The INSERT
|
||||
// is conditional: it only proceeds when the total number of links
|
||||
// (existing + genuinely new) does not exceed max_file_links. Returns
|
||||
// the number of genuinely new file IDs that were NOT inserted due to
|
||||
// the cap. A return value of 0 means all files were linked (or were
|
||||
// already linked). A positive value means the cap blocked that many
|
||||
// new links.
|
||||
func (q *sqlQuerier) LinkChatFiles(ctx context.Context, arg LinkChatFilesParams) (int32, error) {
|
||||
row := q.db.QueryRowContext(ctx, linkChatFiles, arg.ChatID, pq.Array(arg.FileIds), arg.MaxFileLinks)
|
||||
var rejected_new_files int32
|
||||
err := row.Scan(&rejected_new_files)
|
||||
return rejected_new_files, err
|
||||
}
|
||||
|
||||
const listChatUsageLimitGroupOverrides = `-- name: ListChatUsageLimitGroupOverrides :many
|
||||
SELECT
|
||||
g.id AS group_id,
|
||||
@@ -6486,30 +6615,49 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateChatHeartbeat = `-- name: UpdateChatHeartbeat :execrows
|
||||
const updateChatHeartbeats = `-- name: UpdateChatHeartbeats :many
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
heartbeat_at = NOW()
|
||||
heartbeat_at = $1::timestamptz
|
||||
WHERE
|
||||
id = $1::uuid
|
||||
AND worker_id = $2::uuid
|
||||
id = ANY($2::uuid[])
|
||||
AND worker_id = $3::uuid
|
||||
AND status = 'running'::chat_status
|
||||
RETURNING id
|
||||
`
|
||||
|
||||
type UpdateChatHeartbeatParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
|
||||
type UpdateChatHeartbeatsParams struct {
|
||||
Now time.Time `db:"now" json:"now"`
|
||||
IDs []uuid.UUID `db:"ids" json:"ids"`
|
||||
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
|
||||
}
|
||||
|
||||
// Bumps the heartbeat timestamp for a running chat so that other
|
||||
// replicas know the worker is still alive.
|
||||
func (q *sqlQuerier) UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error) {
|
||||
result, err := q.db.ExecContext(ctx, updateChatHeartbeat, arg.ID, arg.WorkerID)
|
||||
// Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
// provided they are still running and owned by the specified
|
||||
// worker. Returns the IDs that were actually updated so the
|
||||
// caller can detect stolen or completed chats via set-difference.
|
||||
func (q *sqlQuerier) UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
rows, err := q.db.QueryContext(ctx, updateChatHeartbeats, arg.Now, pq.Array(arg.IDs), arg.WorkerID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return nil, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
defer rows.Close()
|
||||
var items []uuid.UUID
|
||||
for rows.Next() {
|
||||
var id uuid.UUID
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, id)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateChatLabelsByID = `-- name: UpdateChatLabelsByID :one
|
||||
@@ -7456,110 +7604,113 @@ func (q *sqlQuerier) BatchUpsertConnectionLogs(ctx context.Context, arg BatchUps
|
||||
}
|
||||
|
||||
const countConnectionLogs = `-- name: CountConnectionLogs :one
|
||||
SELECT
|
||||
COUNT(*) AS count
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = $1
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN $2 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($2) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN $4 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = $4 AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN $5 :: text != '' THEN
|
||||
type = $5 :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7 :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($7) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8 :: text != '' THEN
|
||||
users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN $13 :: text != '' THEN
|
||||
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
SELECT COUNT(*) AS count FROM (
|
||||
SELECT 1
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = $1
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN $2 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($2) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN $4 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = $4 AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN $5 :: text != '' THEN
|
||||
type = $5 :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7 :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($7) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8 :: text != '' THEN
|
||||
users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN $13 :: text != '' THEN
|
||||
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
-- NOTE: See the CountAuditLogs LIMIT note.
|
||||
LIMIT NULLIF($14::int, 0) + 1
|
||||
) AS limited_count
|
||||
`
|
||||
|
||||
type CountConnectionLogsParams struct {
|
||||
@@ -7576,6 +7727,7 @@ type CountConnectionLogsParams struct {
|
||||
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
|
||||
ConnectionID uuid.UUID `db:"connection_id" json:"connection_id"`
|
||||
Status string `db:"status" json:"status"`
|
||||
CountCap int32 `db:"count_cap" json:"count_cap"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) {
|
||||
@@ -7593,6 +7745,7 @@ func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectio
|
||||
arg.WorkspaceID,
|
||||
arg.ConnectionID,
|
||||
arg.Status,
|
||||
arg.CountCap,
|
||||
)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
|
||||
@@ -149,94 +149,105 @@ VALUES (
|
||||
RETURNING *;
|
||||
|
||||
-- name: CountAuditLogs :one
|
||||
SELECT COUNT(*)
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN @resource_target::text != '' THEN resource_target = @resource_target
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN @action::text != '' THEN action = @action::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower(@username)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @email::text != '' THEN users.email = @email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
;
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT 1
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN @resource_target::text != '' THEN resource_target = @resource_target
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN @action::text != '' THEN action = @action::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower(@username)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @email::text != '' THEN users.email = @email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
-- Avoid a slow scan on a large table with joins. The caller
|
||||
-- passes the count cap and we add 1 so the frontend can detect
|
||||
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
|
||||
-- -> NULL + 1 = NULL).
|
||||
-- NOTE: Parameterizing this so that we can easily change from,
|
||||
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
|
||||
-- here if disabling the capping on a large table permanently.
|
||||
-- This way the PG planner can plan parallel execution for
|
||||
-- potential large wins.
|
||||
LIMIT NULLIF(@count_cap::int, 0) + 1
|
||||
) AS limited_count;
|
||||
|
||||
-- name: DeleteOldAuditLogConnectionEvents :exec
|
||||
DELETE FROM audit_logs
|
||||
|
||||
@@ -8,3 +8,13 @@ SELECT * FROM chat_files WHERE id = @id::uuid;
|
||||
|
||||
-- name: GetChatFilesByIDs :many
|
||||
SELECT * FROM chat_files WHERE id = ANY(@ids::uuid[]);
|
||||
|
||||
-- name: GetChatFileMetadataByChatID :many
|
||||
-- GetChatFileMetadataByChatID returns lightweight file metadata for
|
||||
-- all files linked to a chat. The data column is excluded to avoid
|
||||
-- loading file content.
|
||||
SELECT cf.id, cf.owner_id, cf.organization_id, cf.name, cf.mimetype, cf.created_at
|
||||
FROM chat_files cf
|
||||
JOIN chat_file_links cfl ON cfl.file_id = cf.id
|
||||
WHERE cfl.chat_id = @chat_id::uuid
|
||||
ORDER BY cf.created_at ASC;
|
||||
|
||||
@@ -567,6 +567,43 @@ WHERE
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: LinkChatFiles :one
|
||||
-- LinkChatFiles inserts file associations into the chat_file_links
|
||||
-- join table with deduplication (ON CONFLICT DO NOTHING). The INSERT
|
||||
-- is conditional: it only proceeds when the total number of links
|
||||
-- (existing + genuinely new) does not exceed max_file_links. Returns
|
||||
-- the number of genuinely new file IDs that were NOT inserted due to
|
||||
-- the cap. A return value of 0 means all files were linked (or were
|
||||
-- already linked). A positive value means the cap blocked that many
|
||||
-- new links.
|
||||
WITH current AS (
|
||||
SELECT COUNT(*) AS cnt
|
||||
FROM chat_file_links
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
),
|
||||
new_links AS (
|
||||
SELECT @chat_id::uuid AS chat_id, unnest(@file_ids::uuid[]) AS file_id
|
||||
),
|
||||
genuinely_new AS (
|
||||
SELECT nl.chat_id, nl.file_id
|
||||
FROM new_links nl
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM chat_file_links cfl
|
||||
WHERE cfl.chat_id = nl.chat_id AND cfl.file_id = nl.file_id
|
||||
)
|
||||
),
|
||||
inserted AS (
|
||||
INSERT INTO chat_file_links (chat_id, file_id)
|
||||
SELECT gn.chat_id, gn.file_id
|
||||
FROM genuinely_new gn, current c
|
||||
WHERE c.cnt + (SELECT COUNT(*) FROM genuinely_new) <= @max_file_links::int
|
||||
ON CONFLICT (chat_id, file_id) DO NOTHING
|
||||
RETURNING file_id
|
||||
)
|
||||
SELECT
|
||||
(SELECT COUNT(*)::int FROM genuinely_new) -
|
||||
(SELECT COUNT(*)::int FROM inserted) AS rejected_new_files;
|
||||
|
||||
-- name: AcquireChats :many
|
||||
-- Acquires up to @num_chats pending chats for processing. Uses SKIP LOCKED
|
||||
-- to prevent multiple replicas from acquiring the same chat.
|
||||
@@ -637,17 +674,20 @@ WHERE
|
||||
status = 'running'::chat_status
|
||||
AND heartbeat_at < @stale_threshold::timestamptz;
|
||||
|
||||
-- name: UpdateChatHeartbeat :execrows
|
||||
-- Bumps the heartbeat timestamp for a running chat so that other
|
||||
-- replicas know the worker is still alive.
|
||||
-- name: UpdateChatHeartbeats :many
|
||||
-- Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
-- provided they are still running and owned by the specified
|
||||
-- worker. Returns the IDs that were actually updated so the
|
||||
-- caller can detect stolen or completed chats via set-difference.
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
heartbeat_at = NOW()
|
||||
heartbeat_at = @now::timestamptz
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
id = ANY(@ids::uuid[])
|
||||
AND worker_id = @worker_id::uuid
|
||||
AND status = 'running'::chat_status;
|
||||
AND status = 'running'::chat_status
|
||||
RETURNING id;
|
||||
|
||||
-- name: GetChatDiffStatusByChatID :one
|
||||
SELECT
|
||||
@@ -883,7 +923,8 @@ SELECT
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
@@ -913,7 +954,8 @@ SELECT
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
@@ -948,7 +990,8 @@ WITH chat_costs AS (
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
|
||||
FROM chat_messages cm
|
||||
JOIN chats c ON c.id = cm.chat_id
|
||||
WHERE c.owner_id = @owner_id::uuid
|
||||
@@ -965,7 +1008,8 @@ SELECT
|
||||
cc.total_input_tokens,
|
||||
cc.total_output_tokens,
|
||||
cc.total_cache_read_tokens,
|
||||
cc.total_cache_creation_tokens
|
||||
cc.total_cache_creation_tokens,
|
||||
cc.total_runtime_ms
|
||||
FROM chat_costs cc
|
||||
LEFT JOIN chats rc ON rc.id = cc.root_chat_id
|
||||
ORDER BY cc.total_cost_micros DESC;
|
||||
@@ -991,7 +1035,8 @@ WITH chat_cost_users AS (
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
@@ -1025,6 +1070,7 @@ SELECT
|
||||
total_output_tokens,
|
||||
total_cache_read_tokens,
|
||||
total_cache_creation_tokens,
|
||||
total_runtime_ms,
|
||||
COUNT(*) OVER()::bigint AS total_count
|
||||
FROM
|
||||
chat_cost_users
|
||||
|
||||
@@ -133,111 +133,113 @@ OFFSET
|
||||
@offset_opt;
|
||||
|
||||
-- name: CountConnectionLogs :one
|
||||
SELECT
|
||||
COUNT(*) AS count
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN @workspace_owner :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = @workspace_owner_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN @workspace_owner_email :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = @workspace_owner_email AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN @type :: text != '' THEN
|
||||
type = @type :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@username) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @user_email :: text != '' THEN
|
||||
users.email = @user_email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= @connected_after
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= @connected_before
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = @workspace_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = @connection_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN @status :: text != '' THEN
|
||||
((@status = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
;
|
||||
SELECT COUNT(*) AS count FROM (
|
||||
SELECT 1
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN @workspace_owner :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = @workspace_owner_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN @workspace_owner_email :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = @workspace_owner_email AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN @type :: text != '' THEN
|
||||
type = @type :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@username) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @user_email :: text != '' THEN
|
||||
users.email = @user_email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= @connected_after
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= @connected_before
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = @workspace_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = @connection_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN @status :: text != '' THEN
|
||||
((@status = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
-- NOTE: See the CountAuditLogs LIMIT note.
|
||||
LIMIT NULLIF(@count_cap::int, 0) + 1
|
||||
) AS limited_count;
|
||||
|
||||
-- name: DeleteOldConnectionLogs :execrows
|
||||
WITH old_logs AS (
|
||||
|
||||
@@ -247,6 +247,7 @@ sql:
|
||||
mcp_server_tool_snapshots: MCPServerToolSnapshots
|
||||
mcp_server_config_id: MCPServerConfigID
|
||||
mcp_server_ids: MCPServerIDs
|
||||
max_file_links: MaxFileLinks
|
||||
icon_url: IconURL
|
||||
oauth2_client_id: OAuth2ClientID
|
||||
oauth2_client_secret: OAuth2ClientSecret
|
||||
|
||||
@@ -16,6 +16,7 @@ const (
|
||||
UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
|
||||
UniqueBoundaryUsageStatsPkey UniqueConstraint = "boundary_usage_stats_pkey" // ALTER TABLE ONLY boundary_usage_stats ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
|
||||
UniqueChatDiffStatusesPkey UniqueConstraint = "chat_diff_statuses_pkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
UniqueChatFileLinksChatIDFileIDKey UniqueConstraint = "chat_file_links_chat_id_file_id_key" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_file_id_key UNIQUE (chat_id, file_id);
|
||||
UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
|
||||
UniqueChatMessagesPkey UniqueConstraint = "chat_messages_pkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
|
||||
UniqueChatModelConfigsPkey UniqueConstraint = "chat_model_configs_pkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id);
|
||||
|
||||
+169
-63
@@ -403,7 +403,17 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
contentBlocks, titleSource, inputError := createChatInputFromRequest(ctx, api.Database, req)
|
||||
// Validate per-chat system prompt length.
|
||||
const maxSystemPromptLen = 10000
|
||||
if len(req.SystemPrompt) > maxSystemPromptLen {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "System prompt exceeds maximum length.",
|
||||
Detail: fmt.Sprintf("System prompt must be at most %d characters, got %d.", maxSystemPromptLen, len(req.SystemPrompt)),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
contentBlocks, titleSource, fileIDs, inputError := createChatInputFromRequest(ctx, api.Database, req)
|
||||
if inputError != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, *inputError)
|
||||
return
|
||||
@@ -483,7 +493,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
WorkspaceID: workspaceSelection.WorkspaceID,
|
||||
Title: title,
|
||||
ModelConfigID: modelConfigID,
|
||||
SystemPrompt: api.resolvedChatSystemPrompt(ctx),
|
||||
SystemPrompt: req.SystemPrompt,
|
||||
InitialUserContent: contentBlocks,
|
||||
MCPServerIDs: mcpServerIDs,
|
||||
Labels: labels,
|
||||
@@ -514,7 +524,32 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, db2sdk.Chat(chat, nil))
|
||||
// Link any user-uploaded files referenced in the initial
|
||||
// message to this newly created chat (best-effort; cap
|
||||
// enforced in SQL).
|
||||
unlinked, capExceeded := api.linkFilesToChat(ctx, chat.ID, fileIDs)
|
||||
|
||||
// Re-read the chat so the response reflects the authoritative
|
||||
// database state (file links are deduped in the join table).
|
||||
chat, err = api.Database.GetChatByID(ctx, chat.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to read back chat after creation.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chatFiles := api.fetchChatFileMetadata(ctx, chat.ID)
|
||||
response := db2sdk.Chat(chat, nil, chatFiles)
|
||||
if len(unlinked) > 0 {
|
||||
if capExceeded {
|
||||
response.Warnings = append(response.Warnings, fileLinkCapWarning(len(unlinked)))
|
||||
} else {
|
||||
response.Warnings = append(response.Warnings, fileLinkErrorWarning(len(unlinked)))
|
||||
}
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, response)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
@@ -717,6 +752,7 @@ func (api *API) chatCostSummary(rw http.ResponseWriter, r *http.Request) {
|
||||
TotalOutputTokens: summary.TotalOutputTokens,
|
||||
TotalCacheReadTokens: summary.TotalCacheReadTokens,
|
||||
TotalCacheCreationTokens: summary.TotalCacheCreationTokens,
|
||||
TotalRuntimeMs: summary.TotalRuntimeMs,
|
||||
ByModel: modelBreakdowns,
|
||||
ByChat: chatBreakdowns,
|
||||
}
|
||||
@@ -1290,7 +1326,11 @@ func (api *API) getChat(rw http.ResponseWriter, r *http.Request) {
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, diffStatus))
|
||||
|
||||
// Hydrate file metadata for all files linked to this chat.
|
||||
chatFiles := api.fetchChatFileMetadata(ctx, chat.ID)
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, diffStatus, chatFiles))
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
@@ -1780,7 +1820,7 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
contentBlocks, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
|
||||
contentBlocks, _, fileIDs, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
|
||||
if inputError != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: inputError.Message,
|
||||
@@ -1819,6 +1859,20 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
busyBehavior := chatd.SendMessageBusyBehaviorQueue
|
||||
switch req.BusyBehavior {
|
||||
case codersdk.ChatBusyBehaviorInterrupt:
|
||||
busyBehavior = chatd.SendMessageBusyBehaviorInterrupt
|
||||
case codersdk.ChatBusyBehaviorQueue, "":
|
||||
// Default to queue.
|
||||
default:
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid busy_behavior value.",
|
||||
Detail: `Must be "queue" or "interrupt".`,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
sendResult, sendErr := api.chatDaemon.SendMessage(
|
||||
ctx,
|
||||
chatd.SendMessageOptions{
|
||||
@@ -1826,7 +1880,7 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
CreatedBy: apiKey.UserID,
|
||||
Content: contentBlocks,
|
||||
ModelConfigID: req.ModelConfigID,
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
||||
BusyBehavior: busyBehavior,
|
||||
MCPServerIDs: req.MCPServerIDs,
|
||||
},
|
||||
)
|
||||
@@ -1848,6 +1902,9 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Link any user-uploaded files referenced in this message
|
||||
// to the chat (best-effort; cap enforced in SQL).
|
||||
unlinked, capExceeded := api.linkFilesToChat(ctx, chatID, fileIDs)
|
||||
response := codersdk.CreateChatMessageResponse{Queued: sendResult.Queued}
|
||||
if sendResult.Queued {
|
||||
if sendResult.QueuedMessage != nil {
|
||||
@@ -1857,6 +1914,13 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
message := convertChatMessage(sendResult.Message)
|
||||
response.Message = &message
|
||||
}
|
||||
if len(unlinked) > 0 {
|
||||
if capExceeded {
|
||||
response.Warnings = append(response.Warnings, fileLinkCapWarning(len(unlinked)))
|
||||
} else {
|
||||
response.Warnings = append(response.Warnings, fileLinkErrorWarning(len(unlinked)))
|
||||
}
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, response)
|
||||
}
|
||||
@@ -1890,7 +1954,7 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
contentBlocks, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
|
||||
contentBlocks, _, fileIDs, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content")
|
||||
if inputError != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: inputError.Message,
|
||||
@@ -1929,8 +1993,20 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
message := convertChatMessage(editResult.Message)
|
||||
httpapi.Write(ctx, rw, http.StatusOK, message)
|
||||
// Link any user-uploaded files referenced in the edited
|
||||
// message to the chat (best-effort; cap enforced in SQL).
|
||||
unlinked, capExceeded := api.linkFilesToChat(ctx, chat.ID, fileIDs)
|
||||
response := codersdk.EditChatMessageResponse{
|
||||
Message: convertChatMessage(editResult.Message),
|
||||
}
|
||||
if len(unlinked) > 0 {
|
||||
if capExceeded {
|
||||
response.Warnings = append(response.Warnings, fileLinkCapWarning(len(unlinked)))
|
||||
} else {
|
||||
response.Warnings = append(response.Warnings, fileLinkErrorWarning(len(unlinked)))
|
||||
}
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
@@ -2207,7 +2283,7 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
chat = updatedChat
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, nil))
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, nil, nil))
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
@@ -2251,7 +2327,7 @@ func (api *API) regenerateChatTitle(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(updatedChat, nil))
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(updatedChat, nil, nil))
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
@@ -3476,35 +3552,6 @@ func (api *API) deleteUserChatCompactionThreshold(rw http.ResponseWriter, r *htt
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (api *API) resolvedChatSystemPrompt(ctx context.Context) string {
|
||||
config, err := api.Database.GetChatSystemPromptConfig(ctx)
|
||||
if err != nil {
|
||||
// We intentionally fail open here. When the prompt configuration
|
||||
// cannot be read, returning the built-in default keeps the chat
|
||||
// grounded instead of sending no system guidance at all.
|
||||
api.Logger.Error(ctx, "failed to fetch chat system prompt configuration, using default", slog.Error(err))
|
||||
return chatd.DefaultSystemPrompt
|
||||
}
|
||||
|
||||
sanitizedCustom := chatd.SanitizePromptText(config.ChatSystemPrompt)
|
||||
if sanitizedCustom == "" && strings.TrimSpace(config.ChatSystemPrompt) != "" {
|
||||
api.Logger.Warn(ctx, "custom system prompt became empty after sanitization, omitting custom portion")
|
||||
}
|
||||
|
||||
var parts []string
|
||||
if config.IncludeDefaultSystemPrompt {
|
||||
parts = append(parts, chatd.DefaultSystemPrompt)
|
||||
}
|
||||
if sanitizedCustom != "" {
|
||||
parts = append(parts, sanitizedCustom)
|
||||
}
|
||||
result := strings.Join(parts, "\n\n")
|
||||
if result == "" {
|
||||
api.Logger.Warn(ctx, "resolved system prompt is empty, no system prompt will be injected into chats")
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (api *API) postChatFile(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
@@ -3692,6 +3739,7 @@ func (api *API) chatFileByID(rw http.ResponseWriter, r *http.Request) {
|
||||
func createChatInputFromRequest(ctx context.Context, db database.Store, req codersdk.CreateChatRequest) (
|
||||
[]codersdk.ChatMessagePart,
|
||||
string,
|
||||
[]uuid.UUID,
|
||||
*codersdk.Response,
|
||||
) {
|
||||
return createChatInputFromParts(ctx, db, req.Content, "content")
|
||||
@@ -3702,14 +3750,15 @@ func createChatInputFromParts(
|
||||
db database.Store,
|
||||
parts []codersdk.ChatInputPart,
|
||||
fieldName string,
|
||||
) ([]codersdk.ChatMessagePart, string, *codersdk.Response) {
|
||||
) ([]codersdk.ChatMessagePart, string, []uuid.UUID, *codersdk.Response) {
|
||||
if len(parts) == 0 {
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, "", nil, &codersdk.Response{
|
||||
Message: "Content is required.",
|
||||
Detail: "Content cannot be empty.",
|
||||
}
|
||||
}
|
||||
|
||||
var fileIDs []uuid.UUID
|
||||
content := make([]codersdk.ChatMessagePart, 0, len(parts))
|
||||
textParts := make([]string, 0, len(parts))
|
||||
for i, part := range parts {
|
||||
@@ -3717,7 +3766,7 @@ func createChatInputFromParts(
|
||||
case string(codersdk.ChatInputPartTypeText):
|
||||
text := strings.TrimSpace(part.Text)
|
||||
if text == "" {
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, "", nil, &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf("%s[%d].text cannot be empty.", fieldName, i),
|
||||
}
|
||||
@@ -3726,7 +3775,7 @@ func createChatInputFromParts(
|
||||
textParts = append(textParts, text)
|
||||
case string(codersdk.ChatInputPartTypeFile):
|
||||
if part.FileID == uuid.Nil {
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, "", nil, &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf("%s[%d].file_id is required for file parts.", fieldName, i),
|
||||
}
|
||||
@@ -3737,20 +3786,23 @@ func createChatInputFromParts(
|
||||
chatFile, err := db.GetChatFileByID(ctx, part.FileID)
|
||||
if err != nil {
|
||||
if httpapi.Is404Error(err) {
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, "", nil, &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf("%s[%d].file_id references a file that does not exist.", fieldName, i),
|
||||
}
|
||||
}
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, "", nil, &codersdk.Response{
|
||||
Message: "Internal error.",
|
||||
Detail: fmt.Sprintf("Failed to retrieve file for %s[%d].", fieldName, i),
|
||||
}
|
||||
}
|
||||
content = append(content, codersdk.ChatMessageFile(part.FileID, chatFile.Mimetype))
|
||||
fileIDs = append(fileIDs, part.FileID)
|
||||
// file-reference parts carry inline code snippets, not uploaded
|
||||
// files. They have no FileID and are excluded from file tracking.
|
||||
case string(codersdk.ChatInputPartTypeFileReference):
|
||||
if part.FileName == "" {
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, "", nil, &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf("%s[%d].file_name cannot be empty for file-reference.", fieldName, i),
|
||||
}
|
||||
@@ -3767,21 +3819,8 @@ func createChatInputFromParts(
|
||||
_, _ = fmt.Fprintf(&sb, "\n```%s\n%s\n```", part.FileName, strings.TrimSpace(part.Content))
|
||||
}
|
||||
textParts = append(textParts, sb.String())
|
||||
case string(codersdk.ChatInputPartTypeSkill):
|
||||
if part.SkillName == "" {
|
||||
return nil, "", &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf("%s[%d].skill_name cannot be empty for skill parts.", fieldName, i),
|
||||
}
|
||||
}
|
||||
content = append(content, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: part.SkillName,
|
||||
SkillDescription: part.SkillDescription,
|
||||
})
|
||||
textParts = append(textParts, fmt.Sprintf("Use the %q skill", part.SkillName))
|
||||
default:
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, "", nil, &codersdk.Response{
|
||||
Message: "Invalid input part.",
|
||||
Detail: fmt.Sprintf(
|
||||
"%s[%d].type %q is not supported.",
|
||||
@@ -3796,13 +3835,13 @@ func createChatInputFromParts(
|
||||
// Allow file-only messages. The titleSource may be empty
|
||||
// when only file parts are provided, callers handle this.
|
||||
if len(content) == 0 {
|
||||
return nil, "", &codersdk.Response{
|
||||
return nil, "", nil, &codersdk.Response{
|
||||
Message: "Content is required.",
|
||||
Detail: fmt.Sprintf("%s must include at least one text or file part.", fieldName),
|
||||
}
|
||||
}
|
||||
titleSource := strings.TrimSpace(strings.Join(textParts, " "))
|
||||
return content, titleSource, nil
|
||||
return content, titleSource, fileIDs, nil
|
||||
}
|
||||
|
||||
func chatTitleFromMessage(message string) string {
|
||||
@@ -3837,6 +3876,70 @@ func truncateRunes(value string, maxLen int) string {
|
||||
return string(runes[:maxLen])
|
||||
}
|
||||
|
||||
// linkFilesToChat inserts file-link rows into the chat_file_links
|
||||
// join table. Cap enforcement and dedup are handled atomically in
|
||||
// SQL. On success returns (nil, false). On failure returns the full
|
||||
// input fileIDs slice — linking is all-or-nothing because the
|
||||
// SQL operates on the batch atomically. capExceeded indicates
|
||||
// whether the failure was due to the cap being exceeded (true)
|
||||
// or a database error (false).
|
||||
// Failures are logged but never block the caller.
|
||||
func (api *API) linkFilesToChat(ctx context.Context, chatID uuid.UUID, fileIDs []uuid.UUID) (unlinked []uuid.UUID, capExceeded bool) {
|
||||
if len(fileIDs) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
rejected, err := api.Database.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: chatID,
|
||||
MaxFileLinks: int32(codersdk.MaxChatFileIDs),
|
||||
FileIds: fileIDs,
|
||||
})
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "failed to link files to chat",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.F("file_ids", fileIDs),
|
||||
slog.Error(err),
|
||||
)
|
||||
return fileIDs, false
|
||||
}
|
||||
if rejected > 0 {
|
||||
api.Logger.Warn(ctx, "file cap reached, files not linked",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.F("file_ids", fileIDs),
|
||||
slog.F("max_file_links", codersdk.MaxChatFileIDs),
|
||||
)
|
||||
return fileIDs, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// fileLinkCapWarning builds a user-facing warning when a batch
|
||||
// of file IDs was atomically rejected because the resulting
|
||||
// array would exceed the per-chat file cap.
|
||||
func fileLinkCapWarning(count int) string {
|
||||
return fmt.Sprintf("file linking skipped: batch of %d file(s) would exceed limit of %d", count, codersdk.MaxChatFileIDs)
|
||||
}
|
||||
|
||||
// fileLinkErrorWarning builds a user-facing warning when a
|
||||
// database error prevented linking files to a chat.
|
||||
func fileLinkErrorWarning(count int) string {
|
||||
return fmt.Sprintf("%d file(s) could not be linked due to a server error", count)
|
||||
}
|
||||
|
||||
// fetchChatFileMetadata returns metadata for all files linked to
|
||||
// the given chat. Errors are logged and result in a nil return
|
||||
// (callers treat file metadata as best-effort).
|
||||
func (api *API) fetchChatFileMetadata(ctx context.Context, chatID uuid.UUID) []database.GetChatFileMetadataByChatIDRow {
|
||||
rows, err := api.Database.GetChatFileMetadataByChatID(ctx, chatID)
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "failed to fetch chat file metadata",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
return rows
|
||||
}
|
||||
|
||||
func convertChatCostModelBreakdown(model database.GetChatCostPerModelRow) codersdk.ChatCostModelBreakdown {
|
||||
displayName := strings.TrimSpace(model.DisplayName)
|
||||
if displayName == "" {
|
||||
@@ -3853,6 +3956,7 @@ func convertChatCostModelBreakdown(model database.GetChatCostPerModelRow) coders
|
||||
TotalOutputTokens: model.TotalOutputTokens,
|
||||
TotalCacheReadTokens: model.TotalCacheReadTokens,
|
||||
TotalCacheCreationTokens: model.TotalCacheCreationTokens,
|
||||
TotalRuntimeMs: model.TotalRuntimeMs,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3866,6 +3970,7 @@ func convertChatCostChatBreakdown(chat database.GetChatCostPerChatRow) codersdk.
|
||||
TotalOutputTokens: chat.TotalOutputTokens,
|
||||
TotalCacheReadTokens: chat.TotalCacheReadTokens,
|
||||
TotalCacheCreationTokens: chat.TotalCacheCreationTokens,
|
||||
TotalRuntimeMs: chat.TotalRuntimeMs,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3882,6 +3987,7 @@ func convertChatCostUserRollup(user database.GetChatCostPerUserRow) codersdk.Cha
|
||||
TotalOutputTokens: user.TotalOutputTokens,
|
||||
TotalCacheReadTokens: user.TotalCacheReadTokens,
|
||||
TotalCacheCreationTokens: user.TotalCacheCreationTokens,
|
||||
TotalRuntimeMs: user.TotalRuntimeMs,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+558
-104
@@ -313,6 +313,111 @@ func TestPostChats(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WithPerChatSystemPrompt", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello with system prompt",
|
||||
},
|
||||
},
|
||||
SystemPrompt: "You are a Go expert.",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, uuid.Nil, chat.ID)
|
||||
|
||||
// Use the DB directly to see system messages, which are
|
||||
// hidden from the public API.
|
||||
dbMessages, err := db.GetChatMessagesForPromptByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expect: deployment system prompt, per-chat system prompt,
|
||||
// workspace awareness, user message.
|
||||
var systemMessages []database.ChatMessage
|
||||
for _, msg := range dbMessages {
|
||||
if msg.Role == database.ChatMessageRoleSystem {
|
||||
systemMessages = append(systemMessages, msg)
|
||||
}
|
||||
}
|
||||
require.GreaterOrEqual(t, len(systemMessages), 2,
|
||||
"expected at least deployment + per-chat system messages")
|
||||
|
||||
// The per-chat system prompt should be the second system
|
||||
// message and contain the user-specified text.
|
||||
foundPerChat := false
|
||||
for _, msg := range systemMessages {
|
||||
if msg.Content.Valid {
|
||||
raw := string(msg.Content.RawMessage)
|
||||
if strings.Contains(raw, "You are a Go expert.") {
|
||||
foundPerChat = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
require.True(t, foundPerChat,
|
||||
"per-chat system prompt not found in system messages")
|
||||
})
|
||||
|
||||
t.Run("PerChatSystemPromptEmpty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello without system prompt",
|
||||
},
|
||||
},
|
||||
SystemPrompt: "",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
dbMessages, err := db.GetChatMessagesForPromptByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// No per-chat system prompt should be present.
|
||||
for _, msg := range dbMessages {
|
||||
if msg.Role == database.ChatMessageRoleSystem && msg.Content.Valid {
|
||||
raw := string(msg.Content.RawMessage)
|
||||
require.NotContains(t, raw, "You are a Go expert.",
|
||||
"unexpected per-chat system prompt in messages")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PerChatSystemPromptTooLong", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
longPrompt := strings.Repeat("a", 10001)
|
||||
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello",
|
||||
},
|
||||
},
|
||||
SystemPrompt: longPrompt,
|
||||
})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("WorkspaceNotAccessible", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -3122,6 +3227,153 @@ func TestGetChat(t *testing.T) {
|
||||
_, err = otherClient.GetChat(ctx, createdChat.ID)
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
|
||||
t.Run("FilesHydrated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Upload a file.
|
||||
pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "hydrated.png", bytes.NewReader(pngData))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a chat with a text + file part.
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{Type: codersdk.ChatInputPartTypeText, Text: "check file hydration"},
|
||||
{Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// GET the chat — files must be hydrated with all metadata fields.
|
||||
chatResult, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, chatResult.Files, 1)
|
||||
f := chatResult.Files[0]
|
||||
require.Equal(t, uploadResp.ID, f.ID)
|
||||
require.Equal(t, firstUser.UserID, f.OwnerID)
|
||||
require.NotEqual(t, uuid.Nil, f.OrganizationID)
|
||||
require.Equal(t, "image/png", f.MimeType)
|
||||
require.Equal(t, "hydrated.png", f.Name)
|
||||
require.NotZero(t, f.CreatedAt)
|
||||
})
|
||||
|
||||
// ToolCreatedFilesLinked exercises the DB path that chatd uses
|
||||
// when a tool (e.g. propose_plan) creates a file: InsertChatFile
|
||||
// then LinkChatFiles. This is a DB-level test because driving
|
||||
// the full chatd tool-call pipeline requires an LLM mock.
|
||||
t.Run("ToolCreatedFilesLinked", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, store := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Create a chat via the API so all metadata is set up.
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{Type: codersdk.ChatInputPartTypeText, Text: "tool file test"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mimic what chatd's StoreFile closure does:
|
||||
// 1. InsertChatFile
|
||||
// 2. LinkChatFiles
|
||||
//nolint:gocritic // Using AsChatd to mimic the chatd background worker.
|
||||
chatdCtx := dbauthz.AsChatd(ctx)
|
||||
fileRow, err := store.InsertChatFile(chatdCtx, database.InsertChatFileParams{
|
||||
OwnerID: firstUser.UserID,
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
Name: "plan.md",
|
||||
Mimetype: "text/markdown",
|
||||
Data: []byte("# Plan"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rejected, err := store.LinkChatFiles(chatdCtx, database.LinkChatFilesParams{
|
||||
ChatID: chat.ID,
|
||||
MaxFileLinks: int32(codersdk.MaxChatFileIDs),
|
||||
FileIds: []uuid.UUID{fileRow.ID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(0), rejected, "0 rejected = all files linked")
|
||||
|
||||
// Verify via the API that the file appears in the chat.
|
||||
chatResult, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, chatResult.Files, 1)
|
||||
f := chatResult.Files[0]
|
||||
require.Equal(t, fileRow.ID, f.ID)
|
||||
require.Equal(t, firstUser.UserID, f.OwnerID)
|
||||
require.Equal(t, firstUser.OrganizationID, f.OrganizationID)
|
||||
require.Equal(t, "plan.md", f.Name)
|
||||
require.Equal(t, "text/markdown", f.MimeType)
|
||||
|
||||
// Fill up to the cap by inserting more files via the
|
||||
// chatd DB path, then verify the cap is enforced.
|
||||
for i := 1; i < codersdk.MaxChatFileIDs; i++ {
|
||||
extra, err := store.InsertChatFile(chatdCtx, database.InsertChatFileParams{
|
||||
OwnerID: firstUser.UserID,
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
Name: fmt.Sprintf("file%d.md", i),
|
||||
Mimetype: "text/markdown",
|
||||
Data: []byte("data"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = store.LinkChatFiles(chatdCtx, database.LinkChatFilesParams{
|
||||
ChatID: chat.ID,
|
||||
MaxFileLinks: int32(codersdk.MaxChatFileIDs),
|
||||
FileIds: []uuid.UUID{extra.ID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Chat should now have exactly MaxChatFileIDs files.
|
||||
chatResult, err = client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, chatResult.Files, codersdk.MaxChatFileIDs)
|
||||
|
||||
// Attempt to add one more file — should be rejected (0 rows).
|
||||
overflow, err := store.InsertChatFile(chatdCtx, database.InsertChatFileParams{
|
||||
OwnerID: firstUser.UserID,
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
Name: "overflow.md",
|
||||
Mimetype: "text/markdown",
|
||||
Data: []byte("too many"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
rejected, err = store.LinkChatFiles(chatdCtx, database.LinkChatFilesParams{
|
||||
ChatID: chat.ID,
|
||||
MaxFileLinks: int32(codersdk.MaxChatFileIDs),
|
||||
FileIds: []uuid.UUID{overflow.ID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(1), rejected, "cap should reject the 21st file")
|
||||
|
||||
// Re-appending an already-linked ID at cap should succeed
|
||||
// (dedup means no array growth).
|
||||
rejected, err = store.LinkChatFiles(chatdCtx, database.LinkChatFilesParams{
|
||||
ChatID: chat.ID,
|
||||
MaxFileLinks: int32(codersdk.MaxChatFileIDs),
|
||||
FileIds: []uuid.UUID{fileRow.ID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// ON CONFLICT DO NOTHING returns 0 rows when the link
|
||||
// already exists, which is fine — the file is still linked.
|
||||
require.Equal(t, int32(0), rejected, "dedup of existing ID should be a no-op")
|
||||
|
||||
// Count should still be exactly MaxChatFileIDs.
|
||||
chatResult, err = client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, chatResult.Files, codersdk.MaxChatFileIDs)
|
||||
})
|
||||
}
|
||||
|
||||
func TestArchiveChat(t *testing.T) {
|
||||
@@ -4135,104 +4387,6 @@ func TestChatMessageWithFileReferences(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatMessageWithSkillParts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
createChatForTest := func(t *testing.T, client *codersdk.ExperimentalClient) codersdk.Chat {
|
||||
t.Helper()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "initial message",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return chat
|
||||
}
|
||||
|
||||
t.Run("SkillPartRoundTrip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
chat := createChatForTest(t, client)
|
||||
|
||||
created, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "please run this skill",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeSkill,
|
||||
SkillName: "deep-review",
|
||||
SkillDescription: "Multi-reviewer code review",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
checkSkill := func(part codersdk.ChatMessagePart) bool {
|
||||
return part.Type == codersdk.ChatMessagePartTypeSkill &&
|
||||
part.SkillName == "deep-review" &&
|
||||
part.SkillDescription == "Multi-reviewer code review"
|
||||
}
|
||||
|
||||
var found bool
|
||||
require.Eventually(t, func() bool {
|
||||
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, message := range messagesResult.Messages {
|
||||
if message.Role != codersdk.ChatMessageRoleUser {
|
||||
continue
|
||||
}
|
||||
for _, part := range message.Content {
|
||||
if checkSkill(part) {
|
||||
found = true
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
if created.Queued && created.QueuedMessage != nil {
|
||||
for _, queued := range messagesResult.QueuedMessages {
|
||||
for _, part := range queued.Content {
|
||||
if checkSkill(part) {
|
||||
found = true
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
require.True(t, found, "expected to find skill part in stored message")
|
||||
})
|
||||
|
||||
t.Run("SkillPartEmptyNameRejected", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
chat := createChatForTest(t, client)
|
||||
|
||||
_, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeSkill,
|
||||
SkillName: "",
|
||||
}},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "skill_name cannot be empty")
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatMessageWithFiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -4366,6 +4520,14 @@ func TestChatMessageWithFiles(t *testing.T) {
|
||||
|
||||
// With no text, chatTitleFromMessage("") returns "New Chat".
|
||||
require.Equal(t, "New Chat", chat.Title)
|
||||
require.Len(t, chat.Files, 1)
|
||||
f := chat.Files[0]
|
||||
require.Equal(t, uploadResp.ID, f.ID)
|
||||
require.Equal(t, firstUser.UserID, f.OwnerID)
|
||||
require.NotEqual(t, uuid.Nil, f.OrganizationID)
|
||||
require.Equal(t, "image/png", f.MimeType)
|
||||
require.Equal(t, "test.png", f.Name)
|
||||
require.NotZero(t, f.CreatedAt)
|
||||
})
|
||||
|
||||
t.Run("InvalidFileID", func(t *testing.T) {
|
||||
@@ -4400,6 +4562,189 @@ func TestChatMessageWithFiles(t *testing.T) {
|
||||
require.Equal(t, "Invalid input part.", sdkErr.Message)
|
||||
require.Contains(t, sdkErr.Detail, "does not exist")
|
||||
})
|
||||
|
||||
t.Run("FilesLinkedOnSend", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Create a text-only chat (no files initially).
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{Type: codersdk.ChatInputPartTypeText, Text: "no files yet"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Upload a file.
|
||||
pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "linked.png", bytes.NewReader(pngData))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send a message with the file.
|
||||
_, err = client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{Type: codersdk.ChatInputPartTypeText, Text: "here is a file"},
|
||||
{Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// GET the chat — file should be linked.
|
||||
chatResult, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, chatResult.Files, 1)
|
||||
require.Equal(t, uploadResp.ID, chatResult.Files[0].ID)
|
||||
require.Equal(t, "linked.png", chatResult.Files[0].Name)
|
||||
})
|
||||
|
||||
t.Run("DedupFileIDs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Upload a file.
|
||||
pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "dedup.png", bytes.NewReader(pngData))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a chat with a file.
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{Type: codersdk.ChatInputPartTypeText, Text: "first mention"},
|
||||
{Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send another message with the SAME file.
|
||||
msgResp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{Type: codersdk.ChatInputPartTypeText, Text: "same file again"},
|
||||
{Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, msgResp.Warnings, "dedup below cap should not produce warnings")
|
||||
|
||||
// GET — should have exactly 1 file (deduped by SQL DISTINCT).
|
||||
chatResult, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, chatResult.Files, 1, "duplicate file IDs should be deduped")
|
||||
require.Equal(t, uploadResp.ID, chatResult.Files[0].ID)
|
||||
})
|
||||
|
||||
t.Run("FileCapExceeded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
|
||||
// Upload MaxChatFileIDs files.
|
||||
fileIDs := make([]uuid.UUID, 0, codersdk.MaxChatFileIDs)
|
||||
for i := range codersdk.MaxChatFileIDs {
|
||||
resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", fmt.Sprintf("file%d.png", i), bytes.NewReader(pngData))
|
||||
require.NoError(t, err)
|
||||
fileIDs = append(fileIDs, resp.ID)
|
||||
}
|
||||
|
||||
// Create a chat using all MaxChatFileIDs files.
|
||||
parts := []codersdk.ChatInputPart{
|
||||
{Type: codersdk.ChatInputPartTypeText, Text: "max files"},
|
||||
}
|
||||
for _, fid := range fileIDs {
|
||||
parts = append(parts, codersdk.ChatInputPart{Type: codersdk.ChatInputPartTypeFile, FileID: fid})
|
||||
}
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{Content: parts})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, chat.Warnings, "creating a chat at exactly the cap should not warn")
|
||||
require.Len(t, chat.Files, codersdk.MaxChatFileIDs, "all files should be linked on creation")
|
||||
|
||||
// Upload one more file.
|
||||
extraResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "one-too-many.png", bytes.NewReader(pngData))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Sending a message with the extra file should succeed
|
||||
// (message goes through) but the file should NOT be linked
|
||||
// (cap enforced in SQL). The response includes a warning.
|
||||
msgResp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{Type: codersdk.ChatInputPartTypeText, Text: "one too many"},
|
||||
{Type: codersdk.ChatInputPartTypeFile, FileID: extraResp.ID},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, msgResp.Warnings, "response should warn about unlinked files")
|
||||
require.Contains(t, msgResp.Warnings[0], "file linking skipped")
|
||||
|
||||
// The extra file should NOT appear in the chat's files.
|
||||
chatResult, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, chatResult.Files, codersdk.MaxChatFileIDs,
|
||||
"file count should not exceed the cap")
|
||||
|
||||
// Sending a message referencing an already-linked file
|
||||
// should succeed with no warnings (dedup, no array growth).
|
||||
msgResp2, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{Type: codersdk.ChatInputPartTypeText, Text: "re-reference existing"},
|
||||
{Type: codersdk.ChatInputPartTypeFile, FileID: fileIDs[0]},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, msgResp2.Warnings, "re-referencing an existing file should not warn")
|
||||
})
|
||||
|
||||
t.Run("FileCapOnCreate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
|
||||
// Upload MaxChatFileIDs + 1 files.
|
||||
fileIDs := make([]uuid.UUID, 0, codersdk.MaxChatFileIDs+1)
|
||||
for i := range codersdk.MaxChatFileIDs + 1 {
|
||||
resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", fmt.Sprintf("create%d.png", i), bytes.NewReader(pngData))
|
||||
require.NoError(t, err)
|
||||
fileIDs = append(fileIDs, resp.ID)
|
||||
}
|
||||
|
||||
// Create a chat with all files (one over the cap).
|
||||
parts := []codersdk.ChatInputPart{
|
||||
{Type: codersdk.ChatInputPartTypeText, Text: "over cap on create"},
|
||||
}
|
||||
for _, fid := range fileIDs {
|
||||
parts = append(parts, codersdk.ChatInputPart{Type: codersdk.ChatInputPartTypeFile, FileID: fid})
|
||||
}
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{Content: parts})
|
||||
require.NoError(t, err, "chat creation should succeed even when cap is exceeded")
|
||||
require.NotEmpty(t, chat.Warnings, "response should warn about unlinked files")
|
||||
require.Contains(t, chat.Warnings[0], "file linking skipped")
|
||||
|
||||
// Only MaxChatFileIDs files should actually be linked.
|
||||
// With SQL-level batch rejection, ALL files are rejected
|
||||
// when the result would exceed the cap. Since we're
|
||||
// sending MaxChatFileIDs+1 files, the deduped count is
|
||||
// 21 > 20, so 0 rows are affected and all files are
|
||||
// unlinked.
|
||||
chatResult, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, chatResult.Files, "no files should be linked when batch exceeds cap")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPatchChatMessage(t *testing.T) {
|
||||
@@ -4446,11 +4791,11 @@ func TestPatchChatMessage(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
// The edited message is soft-deleted and a new one is inserted,
|
||||
// so the returned ID will differ from the original.
|
||||
require.NotEqual(t, userMessageID, edited.ID)
|
||||
require.Equal(t, codersdk.ChatMessageRoleUser, edited.Role)
|
||||
require.NotEqual(t, userMessageID, edited.Message.ID)
|
||||
require.Equal(t, codersdk.ChatMessageRoleUser, edited.Message.Role)
|
||||
|
||||
foundEditedText := false
|
||||
for _, part := range edited.Content {
|
||||
for _, part := range edited.Message.Content {
|
||||
if part.Type == codersdk.ChatMessagePartTypeText && part.Text == "hello after edit" {
|
||||
foundEditedText = true
|
||||
}
|
||||
@@ -4538,11 +4883,11 @@ func TestPatchChatMessage(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
// The edited message is soft-deleted and a new one is inserted,
|
||||
// so the returned ID will differ from the original.
|
||||
require.NotEqual(t, userMessageID, edited.ID)
|
||||
require.NotEqual(t, userMessageID, edited.Message.ID)
|
||||
|
||||
// Assert the edit response preserves the file_id.
|
||||
var foundText, foundFile bool
|
||||
for _, part := range edited.Content {
|
||||
for _, part := range edited.Message.Content {
|
||||
if part.Type == codersdk.ChatMessagePartTypeText && part.Text == "after edit with file" {
|
||||
foundText = true
|
||||
}
|
||||
@@ -4685,6 +5030,112 @@ func TestPatchChatMessage(t *testing.T) {
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid chat message ID.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("FilesLinkedOnEdit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Create a text-only chat.
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{Type: codersdk.ChatInputPartTypeText, Text: "before file edit"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Upload a file.
|
||||
pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "edit-linked.png", bytes.NewReader(pngData))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Find the user message ID.
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
var userMessageID int64
|
||||
for _, msg := range messagesResult.Messages {
|
||||
if msg.Role == codersdk.ChatMessageRoleUser {
|
||||
userMessageID = msg.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotZero(t, userMessageID)
|
||||
|
||||
// Edit the message to include the file.
|
||||
_, err = client.EditChatMessage(ctx, chat.ID, userMessageID, codersdk.EditChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{Type: codersdk.ChatInputPartTypeText, Text: "after file edit"},
|
||||
{Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// GET the chat — file should be linked.
|
||||
chatResult, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, chatResult.Files, 1)
|
||||
f := chatResult.Files[0]
|
||||
require.Equal(t, uploadResp.ID, f.ID)
|
||||
require.Equal(t, "edit-linked.png", f.Name)
|
||||
require.Equal(t, "image/png", f.MimeType)
|
||||
})
|
||||
|
||||
t.Run("CapExceededOnEdit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Create a chat with MaxChatFileIDs files already linked.
|
||||
parts := []codersdk.ChatInputPart{
|
||||
{Type: codersdk.ChatInputPartTypeText, Text: "fill to cap"},
|
||||
}
|
||||
pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
for i := range codersdk.MaxChatFileIDs {
|
||||
up, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", fmt.Sprintf("cap-%d.png", i), bytes.NewReader(pngData))
|
||||
require.NoError(t, err)
|
||||
parts = append(parts, codersdk.ChatInputPart{Type: codersdk.ChatInputPartTypeFile, FileID: up.ID})
|
||||
}
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{Content: parts})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, chat.Warnings, "all files should link on create")
|
||||
|
||||
// Find the user message.
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
var userMessageID int64
|
||||
for _, msg := range messagesResult.Messages {
|
||||
if msg.Role == codersdk.ChatMessageRoleUser {
|
||||
userMessageID = msg.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotZero(t, userMessageID)
|
||||
|
||||
// Upload one more file and try to link via edit.
|
||||
extra, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "one-too-many.png", bytes.NewReader(pngData))
|
||||
require.NoError(t, err)
|
||||
edited, err := client.EditChatMessage(ctx, chat.ID, userMessageID, codersdk.EditChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{Type: codersdk.ChatInputPartTypeText, Text: "edit with extra file"},
|
||||
{Type: codersdk.ChatInputPartTypeFile, FileID: extra.ID},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, edited.Warnings, "edit should surface cap warning")
|
||||
require.Contains(t, edited.Warnings[0], "file linking skipped")
|
||||
|
||||
// Verify the cap is still enforced.
|
||||
chatResult, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, chatResult.Files, codersdk.MaxChatFileIDs,
|
||||
"file count should not exceed the cap")
|
||||
})
|
||||
}
|
||||
|
||||
func TestStreamChat(t *testing.T) {
|
||||
@@ -6177,7 +6628,7 @@ func seedChatCostFixture(t *testing.T) chatCostTestFixture {
|
||||
ContextLimit: []int64{0, 0},
|
||||
Compressed: []bool{false, false},
|
||||
TotalCostMicros: []int64{500, 500},
|
||||
RuntimeMs: []int64{0, 0},
|
||||
RuntimeMs: []int64{1500, 2500},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 2)
|
||||
@@ -6210,16 +6661,19 @@ func assertChatCostSummary(t *testing.T, summary codersdk.ChatCostSummary, model
|
||||
require.Equal(t, int64(0), summary.UnpricedMessageCount)
|
||||
require.Equal(t, int64(200), summary.TotalInputTokens)
|
||||
require.Equal(t, int64(100), summary.TotalOutputTokens)
|
||||
require.Equal(t, int64(4000), summary.TotalRuntimeMs)
|
||||
|
||||
require.Len(t, summary.ByModel, 1)
|
||||
require.Equal(t, modelConfigID, summary.ByModel[0].ModelConfigID)
|
||||
require.Equal(t, int64(1000), summary.ByModel[0].TotalCostMicros)
|
||||
require.Equal(t, int64(2), summary.ByModel[0].MessageCount)
|
||||
require.Equal(t, int64(4000), summary.ByModel[0].TotalRuntimeMs)
|
||||
|
||||
require.Len(t, summary.ByChat, 1)
|
||||
require.Equal(t, chatID, summary.ByChat[0].RootChatID)
|
||||
require.Equal(t, int64(1000), summary.ByChat[0].TotalCostMicros)
|
||||
require.Equal(t, int64(2), summary.ByChat[0].MessageCount)
|
||||
require.Equal(t, int64(4000), summary.ByChat[0].TotalRuntimeMs)
|
||||
}
|
||||
|
||||
func TestChatCostSummary(t *testing.T) {
|
||||
|
||||
@@ -298,6 +298,40 @@ neq(input.object.owner, "");
|
||||
ExpectedSQL: p("'' = 'org-id'"),
|
||||
VariableConverter: regosql.ChatConverter(),
|
||||
},
|
||||
{
|
||||
Name: "AuditLogUUID",
|
||||
Queries: []string{
|
||||
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
|
||||
`input.object.org_owner != ""`,
|
||||
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
|
||||
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708", "05f58202-4bfc-43ce-9ba4-5ff6e0174a71"}`,
|
||||
`"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
},
|
||||
ExpectedSQL: p(
|
||||
p("audit_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("audit_logs.organization_id IS NOT NULL") + " OR " +
|
||||
p("audit_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("audit_logs.organization_id = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'::uuid,'8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
|
||||
"(false)"),
|
||||
VariableConverter: regosql.AuditLogConverter(),
|
||||
},
|
||||
{
|
||||
Name: "ConnectionLogUUID",
|
||||
Queries: []string{
|
||||
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
|
||||
`input.object.org_owner != ""`,
|
||||
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
|
||||
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708"}`,
|
||||
`"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
},
|
||||
ExpectedSQL: p(
|
||||
p("connection_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("connection_logs.organization_id IS NOT NULL") + " OR " +
|
||||
p("connection_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("connection_logs.organization_id = ANY(ARRAY ['8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
|
||||
"(false)"),
|
||||
VariableConverter: regosql.ConnectionLogConverter(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -53,7 +53,7 @@ func WorkspaceConverter() *sqltypes.VariableConverter {
|
||||
func AuditLogConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
resourceIDMatcher(),
|
||||
sqltypes.StringVarMatcher("COALESCE(audit_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
|
||||
sqltypes.UUIDVarMatcher("audit_logs.organization_id", []string{"input", "object", "org_owner"}),
|
||||
// Audit logs have no user owner, only owner by an organization.
|
||||
sqltypes.AlwaysFalse(userOwnerMatcher()),
|
||||
)
|
||||
@@ -67,7 +67,7 @@ func AuditLogConverter() *sqltypes.VariableConverter {
|
||||
func ConnectionLogConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
resourceIDMatcher(),
|
||||
sqltypes.StringVarMatcher("COALESCE(connection_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
|
||||
sqltypes.UUIDVarMatcher("connection_logs.organization_id", []string{"input", "object", "org_owner"}),
|
||||
// Connection logs have no user owner, only owner by an organization.
|
||||
sqltypes.AlwaysFalse(userOwnerMatcher()),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
package sqltypes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
var (
|
||||
_ VariableMatcher = astUUIDVar{}
|
||||
_ Node = astUUIDVar{}
|
||||
_ SupportsEquality = astUUIDVar{}
|
||||
)
|
||||
|
||||
// astUUIDVar is a variable that represents a UUID column. Unlike
|
||||
// astStringVar it emits native UUID comparisons (column = 'val'::uuid)
|
||||
// instead of text-based ones (COALESCE(column::text, ”) = 'val').
|
||||
// This allows PostgreSQL to use indexes on UUID columns.
|
||||
type astUUIDVar struct {
|
||||
Source RegoSource
|
||||
FieldPath []string
|
||||
ColumnString string
|
||||
}
|
||||
|
||||
func UUIDVarMatcher(sqlColumn string, regoPath []string) VariableMatcher {
|
||||
return astUUIDVar{FieldPath: regoPath, ColumnString: sqlColumn}
|
||||
}
|
||||
|
||||
func (astUUIDVar) UseAs() Node { return astUUIDVar{} }
|
||||
|
||||
func (u astUUIDVar) ConvertVariable(rego ast.Ref) (Node, bool) {
|
||||
left, err := RegoVarPath(u.FieldPath, rego)
|
||||
if err == nil && len(left) == 0 {
|
||||
return astUUIDVar{
|
||||
Source: RegoSource(rego.String()),
|
||||
FieldPath: u.FieldPath,
|
||||
ColumnString: u.ColumnString,
|
||||
}, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (u astUUIDVar) SQLString(_ *SQLGenerator) string {
|
||||
return u.ColumnString
|
||||
}
|
||||
|
||||
// EqualsSQLString handles equality comparisons for UUID columns.
|
||||
// Rego always produces string literals, so we accept AstString and
|
||||
// cast the literal to ::uuid in the output SQL. This lets PG use
|
||||
// native UUID indexes instead of falling back to text comparisons.
|
||||
// nolint:revive
|
||||
func (u astUUIDVar) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||
switch other.UseAs().(type) {
|
||||
case AstString:
|
||||
// The other side is a rego string literal like
|
||||
// "8c0b9bdc-a013-4b14-a49b-5747bc335708". Emit a comparison
|
||||
// that casts the literal to uuid so PG can use indexes:
|
||||
// column = 'val'::uuid
|
||||
// instead of the text-based:
|
||||
// 'val' = COALESCE(column::text, '')
|
||||
s, ok := other.(AstString)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("expected AstString, got %T", other)
|
||||
}
|
||||
if s.Value == "" {
|
||||
// Empty string in rego means "no value". Compare the
|
||||
// column against NULL since UUID columns represent
|
||||
// absent values as NULL, not empty strings.
|
||||
op := "IS NULL"
|
||||
if not {
|
||||
op = "IS NOT NULL"
|
||||
}
|
||||
return fmt.Sprintf("%s %s", u.ColumnString, op), nil
|
||||
}
|
||||
return fmt.Sprintf("%s %s '%s'::uuid",
|
||||
u.ColumnString, equalsOp(not), s.Value), nil
|
||||
case astUUIDVar:
|
||||
return basicSQLEquality(cfg, not, u, other), nil
|
||||
default:
|
||||
return "", xerrors.Errorf("unsupported equality: %T %s %T",
|
||||
u, equalsOp(not), other)
|
||||
}
|
||||
}
|
||||
|
||||
// ContainedInSQL implements SupportsContainedIn so that a UUID column
|
||||
// can appear in membership checks like `col = ANY(ARRAY[...])`. The
|
||||
// array elements are rego strings, so we cast each to ::uuid.
|
||||
func (u astUUIDVar) ContainedInSQL(_ *SQLGenerator, haystack Node) (string, error) {
|
||||
arr, ok := haystack.(ASTArray)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("unsupported containedIn: %T in %T", u, haystack)
|
||||
}
|
||||
|
||||
if len(arr.Value) == 0 {
|
||||
return "false", nil
|
||||
}
|
||||
|
||||
// Build ARRAY['uuid1'::uuid, 'uuid2'::uuid, ...]
|
||||
values := make([]string, 0, len(arr.Value))
|
||||
for _, v := range arr.Value {
|
||||
s, ok := v.(AstString)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("expected AstString array element, got %T", v)
|
||||
}
|
||||
values = append(values, fmt.Sprintf("'%s'::uuid", s.Value))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s = ANY(ARRAY [%s])",
|
||||
u.ColumnString,
|
||||
strings.Join(values, ",")), nil
|
||||
}
|
||||
@@ -66,7 +66,7 @@ func AuditLogs(ctx context.Context, db database.Store, query string) (database.G
|
||||
}
|
||||
|
||||
// Prepare the count filter, which uses the same parameters as the GetAuditLogsOffsetParams.
|
||||
// nolint:exhaustruct // UserID is not obtained from the query parameters.
|
||||
// nolint:exhaustruct // UserID and CountCap are not obtained from the query parameters.
|
||||
countFilter := database.CountAuditLogsParams{
|
||||
RequestID: filter.RequestID,
|
||||
ResourceID: filter.ResourceID,
|
||||
@@ -123,6 +123,7 @@ func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey
|
||||
}
|
||||
|
||||
// This MUST be kept in sync with the above
|
||||
// nolint:exhaustruct // CountCap is not obtained from the query parameters.
|
||||
countFilter := database.CountConnectionLogsParams{
|
||||
OrganizationID: filter.OrganizationID,
|
||||
WorkspaceOwner: filter.WorkspaceOwner,
|
||||
|
||||
+264
-62
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
@@ -90,6 +91,12 @@ const (
|
||||
// goroutines and lifecycle management.
|
||||
streamDropWarnInterval = 10 * time.Second
|
||||
|
||||
// bufferRetainGracePeriod is how long the message_part
|
||||
// buffer is kept after processing completes. This gives
|
||||
// cross-replica relay subscribers time to connect and
|
||||
// snapshot the buffer before it is garbage-collected.
|
||||
bufferRetainGracePeriod = 5 * time.Second
|
||||
|
||||
// DefaultMaxChatsPerAcquire is the maximum number of chats to
|
||||
// acquire in a single processOnce call. Batching avoids
|
||||
// waiting a full polling interval between acquisitions
|
||||
@@ -145,6 +152,12 @@ type Server struct {
|
||||
inFlightChatStaleAfter time.Duration
|
||||
chatHeartbeatInterval time.Duration
|
||||
|
||||
// heartbeatMu guards heartbeatRegistry.
|
||||
heartbeatMu sync.Mutex
|
||||
// heartbeatRegistry maps chat IDs to their cancel functions
|
||||
// and workspace state for the centralized heartbeat loop.
|
||||
heartbeatRegistry map[uuid.UUID]*heartbeatEntry
|
||||
|
||||
// wakeCh is signaled by SendMessage, EditMessage, CreateChat,
|
||||
// and PromoteQueued so the run loop calls processOnce
|
||||
// immediately instead of waiting for the next ticker.
|
||||
@@ -691,6 +704,24 @@ type chatStreamState struct {
|
||||
bufferLastWarnAt time.Time
|
||||
subscriberDropCount int64
|
||||
subscriberLastWarnAt time.Time
|
||||
// bufferRetainedAt records when processing completed and
|
||||
// the buffer was retained for late-connecting relay
|
||||
// subscribers. Zero while buffering is active. When
|
||||
// non-zero, cleanupStreamIfIdle skips GC until the grace
|
||||
// period expires so cross-replica relays can still
|
||||
// snapshot the buffer.
|
||||
bufferRetainedAt time.Time
|
||||
}
|
||||
|
||||
// heartbeatEntry tracks a single chat's cancel function and workspace
|
||||
// state for the centralized heartbeat loop. Instead of spawning a
|
||||
// per-chat goroutine, processChat registers an entry here and the
|
||||
// single heartbeatLoop goroutine handles all chats.
|
||||
type heartbeatEntry struct {
|
||||
cancelWithCause context.CancelCauseFunc
|
||||
chatID uuid.UUID
|
||||
workspaceID uuid.NullUUID
|
||||
logger slog.Logger
|
||||
}
|
||||
|
||||
// resetDropCounters zeroes the rate-limiting state for both buffer
|
||||
@@ -873,7 +904,8 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
return xerrors.Errorf("insert chat: %w", err)
|
||||
}
|
||||
|
||||
systemPrompt := strings.TrimSpace(opts.SystemPrompt)
|
||||
deploymentPrompt := p.resolveDeploymentSystemPrompt(ctx)
|
||||
userPrompt := SanitizePromptText(opts.SystemPrompt)
|
||||
var workspaceAwareness string
|
||||
if opts.WorkspaceID.Valid {
|
||||
workspaceAwareness = "This chat is attached to a workspace. You can use workspace tools like execute, read_file, write_file, etc."
|
||||
@@ -895,16 +927,32 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
ChatID: insertedChat.ID,
|
||||
}
|
||||
|
||||
if systemPrompt != "" {
|
||||
systemContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText(systemPrompt),
|
||||
if deploymentPrompt != "" {
|
||||
deploymentContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText(deploymentPrompt),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal system prompt: %w", err)
|
||||
return xerrors.Errorf("marshal deployment system prompt: %w", err)
|
||||
}
|
||||
appendChatMessage(&msgParams, newChatMessage(
|
||||
database.ChatMessageRoleSystem,
|
||||
systemContent,
|
||||
deploymentContent,
|
||||
database.ChatMessageVisibilityModel,
|
||||
opts.ModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
))
|
||||
}
|
||||
|
||||
if userPrompt != "" {
|
||||
userPromptContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText(userPrompt),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal user system prompt: %w", err)
|
||||
}
|
||||
appendChatMessage(&msgParams, newChatMessage(
|
||||
database.ChatMessageRoleSystem,
|
||||
userPromptContent,
|
||||
database.ChatMessageVisibilityModel,
|
||||
opts.ModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
@@ -2390,8 +2438,8 @@ func New(cfg Config) *Server {
|
||||
clock: clk,
|
||||
recordingSem: make(chan struct{}, maxConcurrentRecordingUploads),
|
||||
wakeCh: make(chan struct{}, 1),
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
//nolint:gocritic // The chat processor uses a scoped chatd context.
|
||||
ctx = dbauthz.AsChatd(ctx)
|
||||
|
||||
@@ -2431,6 +2479,9 @@ func (p *Server) start(ctx context.Context) {
|
||||
// to handle chats orphaned by crashed or redeployed workers.
|
||||
p.recoverStaleChats(ctx)
|
||||
|
||||
// Single heartbeat loop for all chats on this replica.
|
||||
go p.heartbeatLoop(ctx)
|
||||
|
||||
acquireTicker := p.clock.NewTicker(
|
||||
p.pendingChatAcquireInterval,
|
||||
"chatd",
|
||||
@@ -2681,11 +2732,113 @@ func (p *Server) getOrCreateStreamState(chatID uuid.UUID) *chatStreamState {
|
||||
|
||||
// cleanupStreamIfIdle removes the chat entry from the sync.Map
|
||||
// when there are no subscribers and the stream is not buffering.
|
||||
// When bufferRetainedAt is set, cleanup is deferred until the
|
||||
// grace period expires so cross-replica relay subscribers can
|
||||
// still snapshot the buffer.
|
||||
// The caller must hold state.mu.
|
||||
func (p *Server) cleanupStreamIfIdle(chatID uuid.UUID, state *chatStreamState) {
|
||||
if !state.buffering && len(state.subscribers) == 0 {
|
||||
p.chatStreams.Delete(chatID)
|
||||
p.workspaceMCPToolsCache.Delete(chatID)
|
||||
if state.buffering || len(state.subscribers) > 0 {
|
||||
return
|
||||
}
|
||||
// Keep stream state alive during the grace period so
|
||||
// late-connecting relay subscribers can snapshot the
|
||||
// buffer after the worker finishes processing.
|
||||
if !state.bufferRetainedAt.IsZero() &&
|
||||
p.clock.Now().Before(state.bufferRetainedAt.Add(bufferRetainGracePeriod)) {
|
||||
return
|
||||
}
|
||||
p.chatStreams.Delete(chatID)
|
||||
p.workspaceMCPToolsCache.Delete(chatID)
|
||||
}
|
||||
|
||||
// registerHeartbeat enrolls a chat in the centralized batch
|
||||
// heartbeat loop. Must be called after chatCtx is created.
|
||||
func (p *Server) registerHeartbeat(entry *heartbeatEntry) {
|
||||
p.heartbeatMu.Lock()
|
||||
defer p.heartbeatMu.Unlock()
|
||||
if _, exists := p.heartbeatRegistry[entry.chatID]; exists {
|
||||
p.logger.Warn(context.Background(),
|
||||
"duplicate heartbeat registration, skipping",
|
||||
slog.F("chat_id", entry.chatID))
|
||||
return
|
||||
}
|
||||
p.heartbeatRegistry[entry.chatID] = entry
|
||||
}
|
||||
|
||||
// unregisterHeartbeat removes a chat from the centralized
|
||||
// heartbeat loop when chat processing finishes.
|
||||
func (p *Server) unregisterHeartbeat(chatID uuid.UUID) {
|
||||
p.heartbeatMu.Lock()
|
||||
defer p.heartbeatMu.Unlock()
|
||||
delete(p.heartbeatRegistry, chatID)
|
||||
}
|
||||
|
||||
// heartbeatLoop runs in a single goroutine, issuing one batch
|
||||
// heartbeat query per interval for all registered chats.
|
||||
func (p *Server) heartbeatLoop(ctx context.Context) {
|
||||
ticker := p.clock.NewTicker(p.chatHeartbeatInterval, "chatd", "batch-heartbeat")
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.heartbeatTick(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeatTick issues a single batch UPDATE for all running chats
|
||||
// owned by this worker. Chats missing from the result set are
|
||||
// interrupted (stolen by another replica or already completed).
|
||||
func (p *Server) heartbeatTick(ctx context.Context) {
|
||||
// Snapshot the registry under the lock.
|
||||
p.heartbeatMu.Lock()
|
||||
snapshot := maps.Clone(p.heartbeatRegistry)
|
||||
p.heartbeatMu.Unlock()
|
||||
|
||||
if len(snapshot) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect the IDs we believe we own.
|
||||
ids := slices.Collect(maps.Keys(snapshot))
|
||||
|
||||
//nolint:gocritic // AsChatd provides narrowly-scoped daemon
|
||||
// access for batch-updating heartbeats.
|
||||
chatdCtx := dbauthz.AsChatd(ctx)
|
||||
updatedIDs, err := p.db.UpdateChatHeartbeats(chatdCtx, database.UpdateChatHeartbeatsParams{
|
||||
IDs: ids,
|
||||
WorkerID: p.workerID,
|
||||
Now: p.clock.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
p.logger.Error(ctx, "batch heartbeat failed", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Build a set of IDs that were successfully updated.
|
||||
updated := make(map[uuid.UUID]struct{}, len(updatedIDs))
|
||||
for _, id := range updatedIDs {
|
||||
updated[id] = struct{}{}
|
||||
}
|
||||
|
||||
// Interrupt registered chats that were not in the result
|
||||
// (stolen by another replica or already completed).
|
||||
for id, entry := range snapshot {
|
||||
if _, ok := updated[id]; !ok {
|
||||
entry.logger.Warn(ctx, "chat not in batch heartbeat result, interrupting")
|
||||
entry.cancelWithCause(chatloop.ErrInterrupted)
|
||||
continue
|
||||
}
|
||||
// Bump workspace usage for surviving chats.
|
||||
newWsID := p.trackWorkspaceUsage(ctx, entry.chatID, entry.workspaceID, entry.logger)
|
||||
// Update workspace ID in the registry for next tick.
|
||||
p.heartbeatMu.Lock()
|
||||
if current, exists := p.heartbeatRegistry[id]; exists {
|
||||
current.workspaceID = newWsID
|
||||
}
|
||||
p.heartbeatMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3163,7 +3316,11 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.Ch
|
||||
if p.pubsub == nil {
|
||||
return
|
||||
}
|
||||
sdkChat := db2sdk.Chat(chat, nil) // we have diffStatus already converted
|
||||
// diffStatus is applied below. File metadata is intentionally
|
||||
// omitted from pubsub events to avoid an extra DB query per
|
||||
// publish. Clients must merge pubsub updates, not replace
|
||||
// cached file metadata.
|
||||
sdkChat := db2sdk.Chat(chat, nil, nil)
|
||||
if diffStatus != nil {
|
||||
sdkChat.DiffStatus = diffStatus
|
||||
}
|
||||
@@ -3530,33 +3687,17 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
}
|
||||
}()
|
||||
|
||||
// Periodically update the heartbeat so other replicas know this
|
||||
// worker is still alive. The goroutine stops when chatCtx is
|
||||
// canceled (either by completion or interruption).
|
||||
go func() {
|
||||
ticker := p.clock.NewTicker(p.chatHeartbeatInterval, "chatd", "heartbeat")
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-chatCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
rows, err := p.db.UpdateChatHeartbeat(chatCtx, database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
WorkerID: p.workerID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn(chatCtx, "failed to update chat heartbeat", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
if rows == 0 {
|
||||
cancel(chatloop.ErrInterrupted)
|
||||
return
|
||||
}
|
||||
chat.WorkspaceID = p.trackWorkspaceUsage(chatCtx, chat.ID, chat.WorkspaceID, logger)
|
||||
}
|
||||
}
|
||||
}()
|
||||
// Register with the centralized heartbeat loop instead of
|
||||
// running a per-chat goroutine. The loop issues a single batch
|
||||
// UPDATE for all chats on this worker and detects stolen chats
|
||||
// via set-difference.
|
||||
p.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel,
|
||||
chatID: chat.ID,
|
||||
workspaceID: chat.WorkspaceID,
|
||||
logger: logger,
|
||||
})
|
||||
defer p.unregisterHeartbeat(chat.ID)
|
||||
|
||||
// Start buffering stream events BEFORE publishing the running
|
||||
// status. This closes a race where a subscriber sees
|
||||
@@ -3567,15 +3708,20 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
streamState := p.getOrCreateStreamState(chat.ID)
|
||||
streamState.mu.Lock()
|
||||
streamState.buffer = nil
|
||||
streamState.bufferRetainedAt = time.Time{}
|
||||
streamState.resetDropCounters()
|
||||
streamState.buffering = true
|
||||
streamState.mu.Unlock()
|
||||
defer func() {
|
||||
streamState.mu.Lock()
|
||||
streamState.buffer = nil
|
||||
streamState.resetDropCounters()
|
||||
streamState.buffering = false
|
||||
p.cleanupStreamIfIdle(chat.ID, streamState)
|
||||
// Retain the buffer for a grace period so
|
||||
// cross-replica relay subscribers can still snapshot
|
||||
// it after processing completes. The buffer is
|
||||
// cleared when the next processChat starts or when
|
||||
// cleanupStreamIfIdle runs after the grace period.
|
||||
streamState.bufferRetainedAt = p.clock.Now()
|
||||
streamState.mu.Unlock()
|
||||
}()
|
||||
|
||||
@@ -3905,14 +4051,6 @@ func (p *Server) runChat(
|
||||
)
|
||||
}()
|
||||
|
||||
prompt, err := chatprompt.ConvertMessagesWithFiles(ctx, messages, p.chatFileResolver(), logger)
|
||||
if err != nil {
|
||||
return result, xerrors.Errorf("build chat prompt: %w", err)
|
||||
}
|
||||
if chat.ParentChatID.Valid {
|
||||
prompt = chatprompt.InsertSystem(prompt, defaultSubagentInstruction)
|
||||
}
|
||||
|
||||
// Detect computer-use subagent via the mode column.
|
||||
isComputerUse := chat.Mode.Valid && chat.Mode.ChatMode == database.ChatModeComputerUse
|
||||
|
||||
@@ -3969,7 +4107,20 @@ func (p *Server) runChat(
|
||||
needsInstructionPersist = true
|
||||
}
|
||||
}
|
||||
// Convert messages to prompt format in parallel with g2 work.
|
||||
// ConvertMessagesWithFiles only reads `messages` (available
|
||||
// after g.Wait()) and resolves file references via the DB.
|
||||
// No g2 task reads or writes `prompt`, so this is safe.
|
||||
var prompt []fantasy.Message
|
||||
var g2 errgroup.Group
|
||||
g2.Go(func() error {
|
||||
var err error
|
||||
prompt, err = chatprompt.ConvertMessagesWithFiles(ctx, messages, p.chatFileResolver(), logger)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("build chat prompt: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if needsInstructionPersist {
|
||||
g2.Go(func() error {
|
||||
var persistErr error
|
||||
@@ -4083,8 +4234,12 @@ func (p *Server) runChat(
|
||||
return nil
|
||||
})
|
||||
}
|
||||
// All g2 goroutines return nil; error is discarded.
|
||||
_ = g2.Wait()
|
||||
if err := g2.Wait(); err != nil {
|
||||
return result, err
|
||||
}
|
||||
if chat.ParentChatID.Valid {
|
||||
prompt = chatprompt.InsertSystem(prompt, defaultSubagentInstruction)
|
||||
}
|
||||
if mcpCleanup != nil {
|
||||
defer mcpCleanup()
|
||||
}
|
||||
@@ -4302,17 +4457,12 @@ func (p *Server) runChat(
|
||||
p.publishMessage(chat.ID, msg)
|
||||
}
|
||||
|
||||
// Clear the stream buffer now that the step is
|
||||
// persisted. Late-joining subscribers will load
|
||||
// these messages from the database instead.
|
||||
if val, ok := p.chatStreams.Load(chat.ID); ok {
|
||||
if ss, ok := val.(*chatStreamState); ok {
|
||||
ss.mu.Lock()
|
||||
ss.buffer = nil
|
||||
ss.resetDropCounters()
|
||||
ss.mu.Unlock()
|
||||
}
|
||||
}
|
||||
// Do NOT clear the stream buffer here. Cross-replica
|
||||
// relay subscribers may still need to snapshot buffered
|
||||
// message_parts after processing completes. The buffer
|
||||
// is bounded by maxStreamBufferSize and is cleared when
|
||||
// the next processChat starts or when the stream state
|
||||
// is garbage-collected after the retention grace period.
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -4475,6 +4625,27 @@ func (p *Server) runChat(
|
||||
return uuid.Nil, xerrors.Errorf("insert chat file: %w", err)
|
||||
}
|
||||
|
||||
// Cap enforcement and dedup are handled atomically
|
||||
// in SQL. rejected > 0 = cap exceeded.
|
||||
rejected, err := p.db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: chatSnapshot.ID,
|
||||
MaxFileLinks: int32(codersdk.MaxChatFileIDs),
|
||||
FileIds: []uuid.UUID{row.ID},
|
||||
})
|
||||
switch {
|
||||
case err != nil:
|
||||
p.logger.Error(ctx, "failed to link file to chat",
|
||||
slog.F("chat_id", chatSnapshot.ID),
|
||||
slog.F("file_id", row.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
case rejected > 0:
|
||||
p.logger.Warn(ctx, "file cap reached, file not linked to chat",
|
||||
slog.F("chat_id", chatSnapshot.ID),
|
||||
slog.F("file_id", row.ID),
|
||||
slog.F("max_file_links", codersdk.MaxChatFileIDs),
|
||||
)
|
||||
}
|
||||
return row.ID, nil
|
||||
},
|
||||
}))
|
||||
@@ -4547,7 +4718,7 @@ func (p *Server) runChat(
|
||||
prompt = filterPromptForChainMode(prompt, chainInfo.trailingUserCount)
|
||||
}
|
||||
|
||||
err = chatloop.Run(ctx, chatloop.RunOptions{
|
||||
err := chatloop.Run(ctx, chatloop.RunOptions{
|
||||
Model: model,
|
||||
Messages: prompt,
|
||||
Tools: tools, MaxSteps: maxChatSteps,
|
||||
@@ -5196,6 +5367,37 @@ func (p *Server) resolveUserCompactionThreshold(ctx context.Context, userID uuid
|
||||
return int32(val), true
|
||||
}
|
||||
|
||||
// resolveDeploymentSystemPrompt builds the deployment-level system
|
||||
// prompt from the built-in default and the admin-configured custom
|
||||
// prompt stored in site_configs.
|
||||
func (p *Server) resolveDeploymentSystemPrompt(ctx context.Context) string {
|
||||
config, err := p.db.GetChatSystemPromptConfig(ctx)
|
||||
if err != nil {
|
||||
// Fail open: use the built-in default so chats always have
|
||||
// some system guidance.
|
||||
p.logger.Error(ctx, "failed to fetch chat system prompt configuration, using default", slog.Error(err))
|
||||
return DefaultSystemPrompt
|
||||
}
|
||||
|
||||
sanitizedCustom := SanitizePromptText(config.ChatSystemPrompt)
|
||||
if sanitizedCustom == "" && strings.TrimSpace(config.ChatSystemPrompt) != "" {
|
||||
p.logger.Warn(ctx, "custom system prompt became empty after sanitization, omitting custom portion")
|
||||
}
|
||||
|
||||
var parts []string
|
||||
if config.IncludeDefaultSystemPrompt {
|
||||
parts = append(parts, DefaultSystemPrompt)
|
||||
}
|
||||
if sanitizedCustom != "" {
|
||||
parts = append(parts, sanitizedCustom)
|
||||
}
|
||||
result := strings.Join(parts, "\n\n")
|
||||
if result == "" {
|
||||
p.logger.Warn(ctx, "resolved system prompt is empty, no system prompt will be injected into chats")
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// resolveUserPrompt fetches the user's custom chat prompt from the
|
||||
// database and wraps it in <user-instructions> tags. Returns empty
|
||||
// string if no prompt is set.
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
@@ -2071,6 +2072,7 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
|
||||
workerID: workerID,
|
||||
chatHeartbeatInterval: time.Minute,
|
||||
configCache: newChatConfigCache(ctx, db, clock),
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
// Publish a stale "pending" notification on the control channel
|
||||
@@ -2133,3 +2135,130 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
|
||||
require.Equal(t, database.ChatStatusError, finalStatus,
|
||||
"processChat should have reached runChat (error), not been interrupted (waiting)")
|
||||
}
|
||||
|
||||
// TestHeartbeatTick_StolenChatIsInterrupted verifies that when the
|
||||
// batch heartbeat UPDATE does not return a registered chat's ID
|
||||
// (because another replica stole it or it was completed), the
|
||||
// heartbeat tick cancels that chat's context with ErrInterrupted
|
||||
// while leaving surviving chats untouched.
|
||||
func TestHeartbeatTick_StolenChatIsInterrupted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
workerID := uuid.New()
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
clock: clock,
|
||||
workerID: workerID,
|
||||
chatHeartbeatInterval: time.Minute,
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
// Create three chats with independent cancel functions.
|
||||
chat1 := uuid.New()
|
||||
chat2 := uuid.New()
|
||||
chat3 := uuid.New()
|
||||
|
||||
_, cancel1 := context.WithCancelCause(ctx)
|
||||
_, cancel2 := context.WithCancelCause(ctx)
|
||||
ctx3, cancel3 := context.WithCancelCause(ctx)
|
||||
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel1,
|
||||
chatID: chat1,
|
||||
logger: logger,
|
||||
})
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel2,
|
||||
chatID: chat2,
|
||||
logger: logger,
|
||||
})
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel3,
|
||||
chatID: chat3,
|
||||
logger: logger,
|
||||
})
|
||||
|
||||
// The batch UPDATE returns only chat1 and chat2 —
|
||||
// chat3 was "stolen" by another replica.
|
||||
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, params database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
require.Equal(t, workerID, params.WorkerID)
|
||||
require.Len(t, params.IDs, 3)
|
||||
// Return only chat1 and chat2 as surviving.
|
||||
return []uuid.UUID{chat1, chat2}, nil
|
||||
},
|
||||
)
|
||||
|
||||
server.heartbeatTick(ctx)
|
||||
|
||||
// chat3's context should be canceled with ErrInterrupted.
|
||||
require.ErrorIs(t, context.Cause(ctx3), chatloop.ErrInterrupted,
|
||||
"stolen chat should be interrupted")
|
||||
|
||||
// chat3 should have been removed from the registry by
|
||||
// unregister (in production this happens via defer in
|
||||
// processChat). The heartbeat tick itself does not
|
||||
// unregister — it only cancels. Verify the entry is
|
||||
// still present (processChat's defer would clean it up).
|
||||
server.heartbeatMu.Lock()
|
||||
_, chat1Exists := server.heartbeatRegistry[chat1]
|
||||
_, chat2Exists := server.heartbeatRegistry[chat2]
|
||||
_, chat3Exists := server.heartbeatRegistry[chat3]
|
||||
server.heartbeatMu.Unlock()
|
||||
|
||||
require.True(t, chat1Exists, "surviving chat1 should remain registered")
|
||||
require.True(t, chat2Exists, "surviving chat2 should remain registered")
|
||||
require.True(t, chat3Exists,
|
||||
"stolen chat3 should still be in registry (processChat defer removes it)")
|
||||
}
|
||||
|
||||
// TestHeartbeatTick_DBErrorDoesNotInterruptChats verifies that a
|
||||
// transient database failure causes the tick to log and return
|
||||
// without canceling any registered chats.
|
||||
func TestHeartbeatTick_DBErrorDoesNotInterruptChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
clock: clock,
|
||||
workerID: uuid.New(),
|
||||
chatHeartbeatInterval: time.Minute,
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
chatID := uuid.New()
|
||||
chatCtx, cancel := context.WithCancelCause(ctx)
|
||||
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel,
|
||||
chatID: chatID,
|
||||
logger: logger,
|
||||
})
|
||||
|
||||
// Simulate a transient DB error.
|
||||
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).Return(
|
||||
nil, xerrors.New("connection reset"),
|
||||
)
|
||||
|
||||
server.heartbeatTick(ctx)
|
||||
|
||||
// Chat should NOT be interrupted — the tick logged and
|
||||
// returned early.
|
||||
require.NoError(t, chatCtx.Err(),
|
||||
"chat context should not be canceled on transient DB error")
|
||||
}
|
||||
|
||||
@@ -474,7 +474,7 @@ func TestArchiveChatInterruptsActiveProcessing(t *testing.T) {
|
||||
require.Equal(t, 1, userMessages, "expected queued message to stay queued after archive")
|
||||
}
|
||||
|
||||
func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
|
||||
func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
@@ -501,19 +501,24 @@ func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
// Wrong worker_id should return no IDs.
|
||||
ids, err := db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{chat.ID},
|
||||
WorkerID: uuid.New(),
|
||||
Now: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), rows)
|
||||
require.Empty(t, ids)
|
||||
|
||||
rows, err = db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
// Correct worker_id should return the chat's ID.
|
||||
ids, err = db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{chat.ID},
|
||||
WorkerID: workerID,
|
||||
Now: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), rows)
|
||||
require.Len(t, ids, 1)
|
||||
require.Equal(t, chat.ID, ids[0])
|
||||
}
|
||||
|
||||
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
||||
|
||||
@@ -70,8 +70,8 @@ func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
|
||||
require.Equal(t, 1, persistStepCalls)
|
||||
require.True(t, persistedStep.ContextLimit.Valid)
|
||||
require.Equal(t, int64(4096), persistedStep.ContextLimit.Int64)
|
||||
require.Greater(t, persistedStep.Runtime, time.Duration(0),
|
||||
"step runtime should be positive")
|
||||
require.GreaterOrEqual(t, persistedStep.Runtime, time.Duration(0),
|
||||
"step runtime should be non-negative")
|
||||
|
||||
require.NotEmpty(t, capturedCall.Prompt)
|
||||
require.False(t, containsPromptSentinel(capturedCall.Prompt))
|
||||
|
||||
@@ -1192,23 +1192,6 @@ func fileReferencePartToText(part codersdk.ChatMessagePart) string {
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// skillPartToText formats a skill SDK part as plain text for
|
||||
// LLM consumption. The user explicitly attached this skill chip
|
||||
// to their message, so the text makes clear they want the agent
|
||||
// to use the named skill (listed in <available-skills>).
|
||||
func skillPartToText(part codersdk.ChatMessagePart) string {
|
||||
if part.SkillDescription != "" {
|
||||
return fmt.Sprintf(
|
||||
"Use the %q skill (%s). Read it with read_skill before following its instructions.",
|
||||
part.SkillName, part.SkillDescription,
|
||||
)
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"Use the %q skill. Read it with read_skill before following its instructions.",
|
||||
part.SkillName,
|
||||
)
|
||||
}
|
||||
|
||||
// toolResultPartToMessagePart converts an SDK tool-result part
|
||||
// into a fantasy ToolResultPart for LLM dispatch.
|
||||
func toolResultPartToMessagePart(logger slog.Logger, part codersdk.ChatMessagePart) fantasy.ToolResultPart {
|
||||
@@ -1377,12 +1360,6 @@ func partsToMessageParts(
|
||||
result = append(result, fantasy.TextPart{
|
||||
Text: fileReferencePartToText(part),
|
||||
})
|
||||
case codersdk.ChatMessagePartTypeSkill:
|
||||
// Skill parts from user input are converted to text
|
||||
// so the LLM knows which skill the user requested.
|
||||
result = append(result, fantasy.TextPart{
|
||||
Text: skillPartToText(part),
|
||||
})
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
if part.ContextFileContent == "" {
|
||||
continue
|
||||
|
||||
@@ -1088,49 +1088,6 @@ func TestFileReferencePreservation(t *testing.T) {
|
||||
assert.Contains(t, textPart.Text, "func main() {}")
|
||||
}
|
||||
|
||||
// TestSkillPartPreservation verifies skill parts survive the
|
||||
// storage round-trip and convert to text for LLMs.
|
||||
func TestSkillPartPreservation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
raw, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "deep-review",
|
||||
SkillDescription: "Multi-reviewer code review",
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Storage round-trip: all fields intact.
|
||||
parts, err := chatprompt.ParseContent(testMsg(codersdk.ChatMessageRoleUser, raw))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, codersdk.ChatMessagePartTypeSkill, parts[0].Type)
|
||||
assert.Equal(t, "deep-review", parts[0].SkillName)
|
||||
assert.Equal(t, "Multi-reviewer code review", parts[0].SkillDescription)
|
||||
|
||||
// LLM dispatch: skill becomes a TextPart.
|
||||
prompt, err := chatprompt.ConvertMessagesWithFiles(
|
||||
context.Background(),
|
||||
[]database.ChatMessage{{
|
||||
Role: database.ChatMessageRoleUser,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
Content: raw,
|
||||
}},
|
||||
nil,
|
||||
slogtest.Make(t, nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, prompt, 1)
|
||||
require.Len(t, prompt[0].Content, 1)
|
||||
|
||||
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](prompt[0].Content[0])
|
||||
require.True(t, ok, "skill should become TextPart for LLM")
|
||||
assert.Contains(t, textPart.Text, "deep-review")
|
||||
assert.Contains(t, textPart.Text, "read_skill")
|
||||
assert.Contains(t, textPart.Text, "Multi-reviewer code review")
|
||||
assert.Contains(t, textPart.Text, "Multi-reviewer code review")
|
||||
}
|
||||
|
||||
// TestAssistantWriteRoundTrip verifies the Stage 4 write path:
|
||||
// fantasy.Content (with ProviderMetadata) → PartFromContent →
|
||||
// MarshalParts → DB → ParseContent (SDK path) →
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -49,10 +50,11 @@ const connectTimeout = 10 * time.Second
|
||||
const toolCallTimeout = 60 * time.Second
|
||||
|
||||
// ConnectAll connects to all configured MCP servers, discovers
|
||||
// their tools, and returns them as fantasy.AgentTool values. It
|
||||
// skips servers that fail to connect and logs warnings. The
|
||||
// returned cleanup function must be called to close all
|
||||
// connections.
|
||||
// their tools, and returns them as fantasy.AgentTool values.
|
||||
// Tools are sorted by their prefixed name so callers
|
||||
// receive a deterministic order. It skips servers that fail to
|
||||
// connect and logs warnings. The returned cleanup function
|
||||
// must be called to close all connections.
|
||||
func ConnectAll(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
@@ -108,7 +110,9 @@ func ConnectAll(
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
clients = append(clients, mcpClient)
|
||||
if mcpClient != nil {
|
||||
clients = append(clients, mcpClient)
|
||||
}
|
||||
tools = append(tools, serverTools...)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
@@ -119,6 +123,31 @@ func ConnectAll(
|
||||
// discarded.
|
||||
_ = eg.Wait()
|
||||
|
||||
// Sort tools by prefixed name for deterministic ordering
|
||||
// regardless of goroutine completion order. Ties, possible
|
||||
// when the __ separator produces ambiguous prefixed names,
|
||||
// are broken by config ID. Stable prompt construction
|
||||
// depends on consistent tool ordering.
|
||||
slices.SortFunc(tools, func(a, b fantasy.AgentTool) int {
|
||||
// All tools in this slice are mcpToolWrapper values
|
||||
// created by connectOne above, so these checked
|
||||
// assertions should always succeed. The config ID
|
||||
// tiebreaker resolves the __ separator ambiguity
|
||||
// documented at the top of this file.
|
||||
aTool, ok := a.(MCPToolIdentifier)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("unexpected tool type %T", a))
|
||||
}
|
||||
bTool, ok := b.(MCPToolIdentifier)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("unexpected tool type %T", b))
|
||||
}
|
||||
return cmp.Or(
|
||||
cmp.Compare(a.Info().Name, b.Info().Name),
|
||||
cmp.Compare(aTool.MCPServerConfigID().String(), bTool.MCPServerConfigID().String()),
|
||||
)
|
||||
})
|
||||
|
||||
return tools, cleanup
|
||||
}
|
||||
|
||||
|
||||
@@ -63,6 +63,17 @@ func greetTool() mcpserver.ServerTool {
|
||||
}
|
||||
}
|
||||
|
||||
// makeTool returns a ServerTool with the given name and a
|
||||
// no-op handler that always returns "ok".
|
||||
func makeTool(name string) mcpserver.ServerTool {
|
||||
return mcpserver.ServerTool{
|
||||
Tool: mcp.NewTool(name),
|
||||
Handler: func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return mcp.NewToolResultText("ok"), nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// makeConfig builds a database.MCPServerConfig suitable for tests.
|
||||
func makeConfig(slug, url string) database.MCPServerConfig {
|
||||
return database.MCPServerConfig{
|
||||
@@ -198,6 +209,121 @@ func TestConnectAll_MultipleServers(t *testing.T) {
|
||||
assert.Contains(t, names, "beta__greet")
|
||||
}
|
||||
|
||||
func TestConnectAll_NoToolsAfterFiltering(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
ts := newTestMCPServer(t, echoTool())
|
||||
|
||||
cfg := makeConfig("filtered", ts.URL)
|
||||
cfg.ToolAllowList = []string{"greet"}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx,
|
||||
logger,
|
||||
[]database.MCPServerConfig{cfg},
|
||||
nil,
|
||||
)
|
||||
|
||||
require.Empty(t, tools)
|
||||
assert.NotPanics(t, cleanup)
|
||||
}
|
||||
|
||||
func TestConnectAll_DeterministicOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("AcrossServers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
ts1 := newTestMCPServer(t, makeTool("zebra"))
|
||||
ts2 := newTestMCPServer(t, makeTool("alpha"))
|
||||
ts3 := newTestMCPServer(t, makeTool("middle"))
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx,
|
||||
logger,
|
||||
[]database.MCPServerConfig{
|
||||
makeConfig("srv3", ts3.URL),
|
||||
makeConfig("srv1", ts1.URL),
|
||||
makeConfig("srv2", ts2.URL),
|
||||
},
|
||||
nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 3)
|
||||
// Sorted by full prefixed name (slug__tool), so slug
|
||||
// order determines the sequence, not the tool name.
|
||||
assert.Equal(t,
|
||||
[]string{"srv1__zebra", "srv2__alpha", "srv3__middle"},
|
||||
toolNames(tools),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("WithMultiToolServer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
multi := newTestMCPServer(t, makeTool("zeta"), makeTool("beta"))
|
||||
other := newTestMCPServer(t, makeTool("gamma"))
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx,
|
||||
logger,
|
||||
[]database.MCPServerConfig{
|
||||
makeConfig("zzz", multi.URL),
|
||||
makeConfig("aaa", other.URL),
|
||||
},
|
||||
nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 3)
|
||||
assert.Equal(t,
|
||||
[]string{"aaa__gamma", "zzz__beta", "zzz__zeta"},
|
||||
toolNames(tools),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("TiebreakByConfigID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
ts1 := newTestMCPServer(t, makeTool("b__z"))
|
||||
ts2 := newTestMCPServer(t, makeTool("z"))
|
||||
|
||||
// Use fixed UUIDs so the tiebreaker order is
|
||||
// predictable. Both servers produce the same prefixed
|
||||
// name, a__b__z, due to the __ separator ambiguity.
|
||||
cfg1 := makeConfig("a", ts1.URL)
|
||||
cfg1.ID = uuid.MustParse("00000000-0000-0000-0000-000000000002")
|
||||
|
||||
cfg2 := makeConfig("a__b", ts2.URL)
|
||||
cfg2.ID = uuid.MustParse("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx,
|
||||
logger,
|
||||
[]database.MCPServerConfig{cfg1, cfg2},
|
||||
nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 2)
|
||||
assert.Equal(t, []string{"a__b__z", "a__b__z"}, toolNames(tools))
|
||||
|
||||
id0 := tools[0].(mcpclient.MCPToolIdentifier).MCPServerConfigID()
|
||||
id1 := tools[1].(mcpclient.MCPToolIdentifier).MCPServerConfigID()
|
||||
assert.Equal(t, cfg2.ID, id0, "lower config ID should sort first")
|
||||
assert.Equal(t, cfg1.ID, id1, "higher config ID should sort second")
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnectAll_AuthHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -1055,12 +1055,17 @@ func TestAwaitSubagentCompletion(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer cancelProbe()
|
||||
|
||||
// Transition the child first, then publish once the
|
||||
// durable completion state is observable. Pubsub only
|
||||
// wakes the waiter; it does not guarantee the report is
|
||||
// visible in the same instant as the notification.
|
||||
setChatStatus(ctx, t, db, child.ID, database.ChatStatusWaiting, "")
|
||||
// Insert the message BEFORE transitioning to Waiting.
|
||||
// Stale PG LISTEN/NOTIFY notifications from the
|
||||
// processor's earlier run can still be buffered in the
|
||||
// pgListener after drainInflight returns. If such a
|
||||
// notification is dispatched between setChatStatus and
|
||||
// insertAssistantMessage, checkSubagentCompletion would
|
||||
// see done=true (Waiting) with an empty report. By
|
||||
// inserting the message first, the report is guaranteed
|
||||
// to be committed before the status makes it visible.
|
||||
insertAssistantMessage(ctx, t, db, child.ID, model.ID, "pubsub result")
|
||||
setChatStatus(ctx, t, db, child.ID, database.ChatStatusWaiting, "")
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
chat, report, done, err := server.checkSubagentCompletion(ctx, child.ID)
|
||||
require.NoError(c, err)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package chatd_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -110,19 +111,19 @@ func TestSpawnComputerUseAgent_SystemPromptFormat(t *testing.T) {
|
||||
|
||||
// The system message raw content is a JSON-encoded string.
|
||||
// It should contain the system prompt with the user prompt.
|
||||
var rawSystemContent string
|
||||
var foundPrompt bool
|
||||
for _, msg := range messages {
|
||||
if msg.Role != "system" {
|
||||
continue
|
||||
}
|
||||
if msg.Content.Valid {
|
||||
rawSystemContent = string(msg.Content.RawMessage)
|
||||
if msg.Content.Valid && strings.Contains(string(msg.Content.RawMessage), prompt) {
|
||||
foundPrompt = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assert.Contains(t, rawSystemContent, prompt,
|
||||
"system prompt raw content should contain the user prompt")
|
||||
assert.True(t, foundPrompt,
|
||||
"at least one system message should contain the user prompt")
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_ChildIsListedUnderParent(t *testing.T) {
|
||||
|
||||
@@ -212,6 +212,7 @@ type AuditLogsRequest struct {
|
||||
type AuditLogResponse struct {
|
||||
AuditLogs []AuditLog `json:"audit_logs"`
|
||||
Count int64 `json:"count"`
|
||||
CountCap int64 `json:"count_cap"`
|
||||
}
|
||||
|
||||
type CreateTestAuditLogRequest struct {
|
||||
|
||||
+81
-37
@@ -26,6 +26,12 @@ import (
|
||||
// threshold settings.
|
||||
const ChatCompactionThresholdKeyPrefix = "chat_compaction_threshold_pct:"
|
||||
|
||||
// MaxChatFileIDs is the maximum number of file IDs that can be
|
||||
// associated with a single chat. This limit prevents unbounded
|
||||
// growth in the chat_file_links table. It is easier to raise
|
||||
// this limit than to lower it.
|
||||
const MaxChatFileIDs = 20
|
||||
|
||||
// CompactionThresholdKey returns the user-config key for a specific
|
||||
// model configuration's compaction threshold.
|
||||
func CompactionThresholdKey(modelConfigID uuid.UUID) string {
|
||||
@@ -46,24 +52,25 @@ const (
|
||||
|
||||
// Chat represents a chat session with an AI agent.
|
||||
type Chat struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
OwnerID uuid.UUID `json:"owner_id" format:"uuid"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"`
|
||||
BuildID *uuid.UUID `json:"build_id,omitempty" format:"uuid"`
|
||||
AgentID *uuid.UUID `json:"agent_id,omitempty" format:"uuid"`
|
||||
ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"`
|
||||
RootChatID *uuid.UUID `json:"root_chat_id,omitempty" format:"uuid"`
|
||||
LastModelConfigID uuid.UUID `json:"last_model_config_id" format:"uuid"`
|
||||
Title string `json:"title"`
|
||||
Status ChatStatus `json:"status"`
|
||||
LastError *string `json:"last_error"`
|
||||
DiffStatus *ChatDiffStatus `json:"diff_status,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
|
||||
Archived bool `json:"archived"`
|
||||
PinOrder int32 `json:"pin_order"`
|
||||
MCPServerIDs []uuid.UUID `json:"mcp_server_ids" format:"uuid"`
|
||||
Labels map[string]string `json:"labels"`
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
OwnerID uuid.UUID `json:"owner_id" format:"uuid"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"`
|
||||
BuildID *uuid.UUID `json:"build_id,omitempty" format:"uuid"`
|
||||
AgentID *uuid.UUID `json:"agent_id,omitempty" format:"uuid"`
|
||||
ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"`
|
||||
RootChatID *uuid.UUID `json:"root_chat_id,omitempty" format:"uuid"`
|
||||
LastModelConfigID uuid.UUID `json:"last_model_config_id" format:"uuid"`
|
||||
Title string `json:"title"`
|
||||
Status ChatStatus `json:"status"`
|
||||
LastError *string `json:"last_error"`
|
||||
DiffStatus *ChatDiffStatus `json:"diff_status,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
|
||||
Archived bool `json:"archived"`
|
||||
PinOrder int32 `json:"pin_order"`
|
||||
MCPServerIDs []uuid.UUID `json:"mcp_server_ids" format:"uuid"`
|
||||
Labels map[string]string `json:"labels"`
|
||||
Files []ChatFileMetadata `json:"files,omitempty"`
|
||||
// HasUnread is true when assistant messages exist beyond
|
||||
// the owner's read cursor, which updates on stream
|
||||
// connect and disconnect.
|
||||
@@ -73,6 +80,18 @@ type Chat struct {
|
||||
// is updated only when context changes — first workspace
|
||||
// attach or agent change.
|
||||
LastInjectedContext []ChatMessagePart `json:"last_injected_context,omitempty"`
|
||||
Warnings []string `json:"warnings,omitempty"`
|
||||
}
|
||||
|
||||
// ChatFileMetadata contains lightweight metadata about a file
|
||||
// associated with a chat, excluding the file content itself.
|
||||
type ChatFileMetadata struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
OwnerID uuid.UUID `json:"owner_id" format:"uuid"`
|
||||
OrganizationID uuid.UUID `json:"organization_id" format:"uuid"`
|
||||
Name string `json:"name"`
|
||||
MimeType string `json:"mime_type"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
}
|
||||
|
||||
// ChatMessage represents a single message in a chat.
|
||||
@@ -326,7 +345,6 @@ const (
|
||||
ChatInputPartTypeText ChatInputPartType = "text"
|
||||
ChatInputPartTypeFile ChatInputPartType = "file"
|
||||
ChatInputPartTypeFileReference ChatInputPartType = "file-reference"
|
||||
ChatInputPartTypeSkill ChatInputPartType = "skill"
|
||||
)
|
||||
|
||||
// ChatInputPart is a single user input part for creating a chat.
|
||||
@@ -341,15 +359,12 @@ type ChatInputPart struct {
|
||||
EndLine int `json:"end_line,omitempty"`
|
||||
// The code content from the diff that was commented on.
|
||||
Content string `json:"content,omitempty"`
|
||||
// The following fields are only set when Type is
|
||||
// ChatInputPartTypeSkill.
|
||||
SkillName string `json:"skill_name,omitempty"`
|
||||
SkillDescription string `json:"skill_description,omitempty"`
|
||||
}
|
||||
|
||||
// CreateChatRequest is the request to create a new chat.
|
||||
type CreateChatRequest struct {
|
||||
Content []ChatInputPart `json:"content"`
|
||||
SystemPrompt string `json:"system_prompt,omitempty"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"`
|
||||
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
|
||||
MCPServerIDs []uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"`
|
||||
@@ -373,11 +388,27 @@ type UpdateChatRequest struct {
|
||||
Labels *map[string]string `json:"labels,omitempty"`
|
||||
}
|
||||
|
||||
// ChatBusyBehavior controls what happens when a user sends a message
|
||||
// while the chat is already processing.
|
||||
type ChatBusyBehavior string
|
||||
|
||||
const (
|
||||
// ChatBusyBehaviorQueue queues the message for processing after
|
||||
// the current run finishes.
|
||||
ChatBusyBehaviorQueue ChatBusyBehavior = "queue"
|
||||
// ChatBusyBehaviorInterrupt queues the message and interrupts
|
||||
// the active run. The partial assistant response is persisted
|
||||
// before the queued message is promoted, preserving correct
|
||||
// conversation order.
|
||||
ChatBusyBehaviorInterrupt ChatBusyBehavior = "interrupt"
|
||||
)
|
||||
|
||||
// CreateChatMessageRequest is the request to add a message to a chat.
|
||||
type CreateChatMessageRequest struct {
|
||||
Content []ChatInputPart `json:"content"`
|
||||
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
|
||||
MCPServerIDs *[]uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"`
|
||||
Content []ChatInputPart `json:"content"`
|
||||
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
|
||||
MCPServerIDs *[]uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"`
|
||||
BusyBehavior ChatBusyBehavior `json:"busy_behavior,omitempty" enums:"queue,interrupt"`
|
||||
}
|
||||
|
||||
// EditChatMessageRequest is the request to edit a user message in a chat.
|
||||
@@ -390,6 +421,15 @@ type CreateChatMessageResponse struct {
|
||||
Message *ChatMessage `json:"message,omitempty"`
|
||||
QueuedMessage *ChatQueuedMessage `json:"queued_message,omitempty"`
|
||||
Queued bool `json:"queued"`
|
||||
Warnings []string `json:"warnings,omitempty"`
|
||||
}
|
||||
|
||||
// EditChatMessageResponse is the response from editing a message in a chat.
|
||||
// Edits are always synchronous (no queueing), so the message is returned
|
||||
// directly.
|
||||
type EditChatMessageResponse struct {
|
||||
Message ChatMessage `json:"message"`
|
||||
Warnings []string `json:"warnings,omitempty"`
|
||||
}
|
||||
|
||||
// UploadChatFileResponse is the response from uploading a chat file.
|
||||
@@ -631,7 +671,7 @@ type ChatModelOpenAIProviderOptions struct {
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty" description:"Whether the model may make multiple tool calls in parallel"`
|
||||
User *string `json:"user,omitempty" description:"Unique identifier for the end user for abuse monitoring" hidden:"true"`
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty" description:"Controls the level of reasoning effort" enum:"none,minimal,low,medium,high,xhigh"`
|
||||
ReasoningSummary *string `json:"reasoning_summary,omitempty" description:"Controls whether reasoning tokens are summarized in the response"`
|
||||
ReasoningSummary *string `json:"reasoning_summary,omitempty" description:"Controls whether reasoning tokens are summarized in the response" enum:"auto,concise,detailed"`
|
||||
MaxCompletionTokens *int64 `json:"max_completion_tokens,omitempty" description:"Upper bound on tokens the model may generate"`
|
||||
TextVerbosity *string `json:"text_verbosity,omitempty" description:"Controls the verbosity of the text response" enum:"low,medium,high"`
|
||||
Prediction map[string]any `json:"prediction,omitempty" description:"Predicted output content to speed up responses" hidden:"true"`
|
||||
@@ -639,12 +679,12 @@ type ChatModelOpenAIProviderOptions struct {
|
||||
Metadata map[string]any `json:"metadata,omitempty" description:"Arbitrary metadata to attach to the request" hidden:"true"`
|
||||
PromptCacheKey *string `json:"prompt_cache_key,omitempty" description:"Key for enabling cross-request prompt caching"`
|
||||
SafetyIdentifier *string `json:"safety_identifier,omitempty" description:"Developer-specific safety identifier for the request" hidden:"true"`
|
||||
ServiceTier *string `json:"service_tier,omitempty" description:"Latency tier to use for processing the request"`
|
||||
ServiceTier *string `json:"service_tier,omitempty" description:"Latency tier to use for processing the request" enum:"auto,default,flex,scale,priority"`
|
||||
StructuredOutputs *bool `json:"structured_outputs,omitempty" description:"Whether to enable structured JSON output mode" hidden:"true"`
|
||||
StrictJSONSchema *bool `json:"strict_json_schema,omitempty" description:"Whether to enforce strict adherence to the JSON schema" hidden:"true"`
|
||||
WebSearchEnabled *bool `json:"web_search_enabled,omitempty" description:"Enable OpenAI web search tool for grounding responses with real-time information"`
|
||||
SearchContextSize *string `json:"search_context_size,omitempty" description:"Amount of search context to use" enum:"low,medium,high"`
|
||||
AllowedDomains []string `json:"allowed_domains,omitempty" description:"Restrict web search to these domains"`
|
||||
AllowedDomains []string `json:"allowed_domains,omitempty" label:"Web Search: Allowed Domains" description:"Restrict web search to these domains"`
|
||||
}
|
||||
|
||||
// ChatModelAnthropicThinkingOptions configures Anthropic thinking budget.
|
||||
@@ -656,11 +696,11 @@ type ChatModelAnthropicThinkingOptions struct {
|
||||
type ChatModelAnthropicProviderOptions struct {
|
||||
SendReasoning *bool `json:"send_reasoning,omitempty" description:"Whether to include reasoning content in the response"`
|
||||
Thinking *ChatModelAnthropicThinkingOptions `json:"thinking,omitempty" description:"Configuration for extended thinking"`
|
||||
Effort *string `json:"effort,omitempty" description:"Controls the level of reasoning effort" enum:"low,medium,high,max"`
|
||||
Effort *string `json:"effort,omitempty" label:"Reasoning Effort" description:"Controls the level of reasoning effort" enum:"low,medium,high,max"`
|
||||
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty" description:"Whether to disable parallel tool execution"`
|
||||
WebSearchEnabled *bool `json:"web_search_enabled,omitempty" description:"Enable Anthropic web search tool for grounding responses with real-time information"`
|
||||
AllowedDomains []string `json:"allowed_domains,omitempty" description:"Restrict web search to these domains (cannot be used with blocked_domains)"`
|
||||
BlockedDomains []string `json:"blocked_domains,omitempty" description:"Block web search on these domains (cannot be used with allowed_domains)"`
|
||||
AllowedDomains []string `json:"allowed_domains,omitempty" label:"Web Search: Allowed Domains" description:"Restrict web search to these domains (cannot be used with blocked_domains)"`
|
||||
BlockedDomains []string `json:"blocked_domains,omitempty" label:"Web Search: Blocked Domains" description:"Block web search on these domains (cannot be used with allowed_domains)"`
|
||||
}
|
||||
|
||||
// ChatModelGoogleThinkingConfig configures Google thinking behavior.
|
||||
@@ -979,6 +1019,7 @@ type ChatCostSummary struct {
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||
TotalRuntimeMs int64 `json:"total_runtime_ms"`
|
||||
ByModel []ChatCostModelBreakdown `json:"by_model"`
|
||||
ByChat []ChatCostChatBreakdown `json:"by_chat"`
|
||||
UsageLimit *ChatUsageLimitStatus `json:"usage_limit,omitempty"`
|
||||
@@ -996,6 +1037,7 @@ type ChatCostModelBreakdown struct {
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||
TotalRuntimeMs int64 `json:"total_runtime_ms"`
|
||||
}
|
||||
|
||||
// ChatCostChatBreakdown contains per-root-chat cost aggregation.
|
||||
@@ -1008,6 +1050,7 @@ type ChatCostChatBreakdown struct {
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||
TotalRuntimeMs int64 `json:"total_runtime_ms"`
|
||||
}
|
||||
|
||||
// ChatCostUserRollup contains per-user cost aggregation for admin views.
|
||||
@@ -1023,6 +1066,7 @@ type ChatCostUserRollup struct {
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||
TotalRuntimeMs int64 `json:"total_runtime_ms"`
|
||||
}
|
||||
|
||||
// ChatCostUsersResponse is the response from the admin chat cost users endpoint.
|
||||
@@ -1940,7 +1984,7 @@ func (c *ExperimentalClient) EditChatMessage(
|
||||
chatID uuid.UUID,
|
||||
messageID int64,
|
||||
req EditChatMessageRequest,
|
||||
) (ChatMessage, error) {
|
||||
) (EditChatMessageResponse, error) {
|
||||
res, err := c.Request(
|
||||
ctx,
|
||||
http.MethodPatch,
|
||||
@@ -1948,14 +1992,14 @@ func (c *ExperimentalClient) EditChatMessage(
|
||||
req,
|
||||
)
|
||||
if err != nil {
|
||||
return ChatMessage{}, err
|
||||
return EditChatMessageResponse{}, err
|
||||
}
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatMessage{}, readBodyAsChatUsageLimitError(res)
|
||||
return EditChatMessageResponse{}, readBodyAsChatUsageLimitError(res)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var message ChatMessage
|
||||
return message, json.NewDecoder(res.Body).Decode(&message)
|
||||
var resp EditChatMessageResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// InterruptChat cancels an in-flight chat run and leaves it waiting.
|
||||
|
||||
@@ -96,6 +96,7 @@ type ConnectionLogsRequest struct {
|
||||
type ConnectionLogResponse struct {
|
||||
ConnectionLogs []ConnectionLog `json:"connection_logs"`
|
||||
Count int64 `json:"count"`
|
||||
CountCap int64 `json:"count_cap"`
|
||||
}
|
||||
|
||||
func (c *Client) ConnectionLogs(ctx context.Context, req ConnectionLogsRequest) (ConnectionLogResponse, error) {
|
||||
|
||||
@@ -7,6 +7,14 @@ features, you can [request a trial](https://coder.com/trial) or
|
||||
|
||||

|
||||
|
||||
## Offline license validation
|
||||
|
||||
Coder license keys are signed JWTs that are validated locally using cryptographic
|
||||
signatures. No outbound connection to Coder's servers is required for license
|
||||
validation. This means licenses work in
|
||||
[air-gapped and offline deployments](../../install/airgap.md) without any
|
||||
additional configuration.
|
||||
|
||||
## Adding your license key
|
||||
|
||||
There are two ways to add a license to a Coder deployment:
|
||||
|
||||
@@ -134,6 +134,12 @@ edited message onward, truncating any messages that followed it.
|
||||
|-----------|-------------------|----------|----------------------------------|
|
||||
| `content` | `ChatInputPart[]` | yes | The replacement message content. |
|
||||
|
||||
The response is an `EditChatMessageResponse` with the edited `message`
|
||||
and an optional `warnings` array. When file references in the edited
|
||||
content cannot be linked (e.g. the per-chat file cap is reached), the
|
||||
edit still succeeds and the `warnings` array describes which files
|
||||
were not linked.
|
||||
|
||||
### Stream updates
|
||||
|
||||
`GET /api/experimental/chats/{chat}/stream`
|
||||
@@ -201,7 +207,9 @@ Each event is a JSON object with `kind` and `chat` fields:
|
||||
|
||||
`GET /api/experimental/chats`
|
||||
|
||||
Returns all chats owned by the authenticated user.
|
||||
Returns all chats owned by the authenticated user. The `files` field is
|
||||
populated on `POST /chats` and `GET /chats/{id}`. Other endpoints that
|
||||
return a `Chat` object omit it.
|
||||
|
||||
| Query parameter | Type | Required | Description |
|
||||
|-----------------|----------|----------|------------------------------------------------------------------|
|
||||
@@ -212,7 +220,17 @@ Returns all chats owned by the authenticated user.
|
||||
|
||||
`GET /api/experimental/chats/{chat}`
|
||||
|
||||
Returns the `Chat` object (metadata only, no messages).
|
||||
Returns the `Chat` object (metadata only, no messages). The response
|
||||
includes a `files` field (`ChatFileMetadata[]`) containing metadata for
|
||||
files that have been successfully linked to the chat. File linking is
|
||||
best-effort; if linking fails, the file remains in message content but
|
||||
will be absent from this field.
|
||||
|
||||
When file linking is skipped (e.g. the per-chat file cap is reached),
|
||||
`POST /chats` includes a `warnings` array on the `Chat` response and
|
||||
`POST /chats/{chat}/messages` includes a `warnings` array on the
|
||||
`CreateChatMessageResponse`. The `warnings` field is `omitempty` and
|
||||
absent when all files are linked successfully.
|
||||
|
||||
### Get chat messages
|
||||
|
||||
@@ -295,6 +313,10 @@ file, use `GET /api/experimental/chats/files/{file}`.
|
||||
Supported formats: PNG, JPEG, GIF, WebP (up to 10 MB). The server
|
||||
validates actual file content regardless of the declared `Content-Type`.
|
||||
|
||||
Files referenced in messages are automatically linked to the chat and
|
||||
appear in the `files` field on subsequent
|
||||
`GET /api/experimental/chats/{chat}` responses.
|
||||
|
||||
## Chat statuses
|
||||
|
||||
| Status | Meaning |
|
||||
|
||||
@@ -13,6 +13,7 @@ air-gapped with Kubernetes or Docker.
|
||||
| PostgreSQL | If no [PostgreSQL connection URL](../reference/cli/server.md#--postgres-url) is specified, Coder will download Postgres from [repo1.maven.org](https://repo1.maven.org) | An external database is required, you must specify a [PostgreSQL connection URL](../reference/cli/server.md#--postgres-url) |
|
||||
| Telemetry | Telemetry is on by default, and [can be disabled](../reference/cli/server.md#--telemetry) | Telemetry [can be disabled](../reference/cli/server.md#--telemetry) |
|
||||
| Update check | By default, Coder checks for updates from [GitHub releases](https://github.com/coder/coder/releases) | Update checks [can be disabled](../reference/cli/server.md#--update-check) |
|
||||
| License validation | License keys are validated locally using cryptographic signatures. No outbound connection to Coder is required | No changes needed. See [offline license validation](../admin/licensing/index.md#offline-license-validation) |
|
||||
| AI Governance Usage Count | By default, deployments with the [AI Governance Add On](../ai-coder/ai-governance.md) report usage data | [Contact us](https://coder.com/contact) to request a license with usage reporting off. |
|
||||
|
||||
## Air-gapped container images
|
||||
|
||||
@@ -1299,6 +1299,12 @@
|
||||
"description": "Custom claims/scopes with Okta for group/role sync",
|
||||
"path": "./tutorials/configuring-okta.md"
|
||||
},
|
||||
{
|
||||
"title": "Persistent Shared Workspaces",
|
||||
"description": "Set up long-lived shared workspaces with service accounts and workspace sharing",
|
||||
"path": "./tutorials/persistent-shared-workspaces.md",
|
||||
"state": ["premium"]
|
||||
},
|
||||
{
|
||||
"title": "Google to AWS Federation",
|
||||
"description": "Federating a Google Cloud service account to AWS",
|
||||
|
||||
Generated
+2
-1
@@ -90,7 +90,8 @@ curl -X GET http://coder-server:8080/api/v2/audit?limit=0 \
|
||||
"user_agent": "string"
|
||||
}
|
||||
],
|
||||
"count": 0
|
||||
"count": 0,
|
||||
"count_cap": 0
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
Generated
+2
-1
@@ -291,7 +291,8 @@ curl -X GET http://coder-server:8080/api/v2/connectionlog?limit=0 \
|
||||
"workspace_owner_username": "string"
|
||||
}
|
||||
],
|
||||
"count": 0
|
||||
"count": 0,
|
||||
"count_cap": 0
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
Generated
+6
-2
@@ -1740,7 +1740,8 @@
|
||||
"user_agent": "string"
|
||||
}
|
||||
],
|
||||
"count": 0
|
||||
"count": 0,
|
||||
"count_cap": 0
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1750,6 +1751,7 @@
|
||||
|--------------|-------------------------------------------------|----------|--------------|-------------|
|
||||
| `audit_logs` | array of [codersdk.AuditLog](#codersdkauditlog) | false | | |
|
||||
| `count` | integer | false | | |
|
||||
| `count_cap` | integer | false | | |
|
||||
|
||||
## codersdk.AuthMethod
|
||||
|
||||
@@ -2173,7 +2175,8 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
|
||||
"workspace_owner_username": "string"
|
||||
}
|
||||
],
|
||||
"count": 0
|
||||
"count": 0,
|
||||
"count_cap": 0
|
||||
}
|
||||
```
|
||||
|
||||
@@ -2183,6 +2186,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
|
||||
|-------------------|-----------------------------------------------------------|----------|--------------|-------------|
|
||||
| `connection_logs` | array of [codersdk.ConnectionLog](#codersdkconnectionlog) | false | | |
|
||||
| `count` | integer | false | | |
|
||||
| `count_cap` | integer | false | | |
|
||||
|
||||
## codersdk.ConnectionLogSSHInfo
|
||||
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
# Persistent Shared Workspaces with Service Accounts
|
||||
|
||||
> [!NOTE]
|
||||
> This guide requires a
|
||||
> [Premium license](https://coder.com/pricing#compare-plans) because service
|
||||
> accounts are a Premium feature. For more details,
|
||||
> [contact your account team](https://coder.com/contact).
|
||||
|
||||
This guide walks through setting up a long-lived workspace that is owned by a
|
||||
service account and shared with a rotating set of users. Because no single
|
||||
person owns the workspace, it persists across team changes and every user
|
||||
authenticates as themselves.
|
||||
|
||||
This pattern is useful for any scenario where a workspace outlives the people
|
||||
who use it:
|
||||
|
||||
- **On-call rotations** — Engineers share a workspace pre-loaded with runbooks,
|
||||
dashboards, and monitoring tools. Access rotates with the shift schedule.
|
||||
- **Shared staging or QA** — A team workspace hosts a persistent staging
|
||||
environment. Testers and reviewers are added and removed as sprints change.
|
||||
- **Pair programming** — A service-account-owned workspace gives two or more
|
||||
developers a shared environment without either one owning (and accidentally
|
||||
deleting) it.
|
||||
- **Contractor onboarding** — An external team gets scoped access to a workspace
|
||||
for the duration of an engagement, then access is revoked.
|
||||
|
||||
The steps below use an **on-call SRE workspace** as a running example, but the
|
||||
same commands apply to any of the scenarios above. Substitute the usernames,
|
||||
group names, and template to match your use case.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- A running Coder deployment (v2.32+) with workspace sharing enabled. Sharing
|
||||
is on by default for OSS; Premium deployments may require
|
||||
[admin configuration](../user-guides/shared-workspaces.md#policies).
|
||||
- The [Coder CLI](../install/index.md) installed and authenticated.
|
||||
- An account with the `Owner` or `User Admin` role.
|
||||
- [OIDC authentication](../admin/users/oidc-auth/index.md) configured so
|
||||
shared users log in with their corporate SSO identity. Configure
|
||||
[refresh tokens](../admin/users/oidc-auth/refresh-tokens.md) to prevent
|
||||
session timeouts during long work sessions.
|
||||
- A [wildcard access URL](../admin/networking/wildcard-access-url.md) configured
|
||||
(e.g. `*.coder.example.com`) so that shared users can access workspace apps
|
||||
without a 404.
|
||||
- (Recommended) [IdP Group Sync](../admin/users/idp-sync.md#group-sync)
|
||||
configured if your identity provider manages group membership for the teams
|
||||
that will share the workspace.
|
||||
|
||||
## 1. Create a service account
|
||||
|
||||
Create a dedicated service account that will own the shared workspace. Service
|
||||
accounts are non-human accounts intended for automation and shared ownership.
|
||||
Because no individual user owns the workspace, there are no personal
|
||||
credentials to expose and the shared environment is not affected when any user
|
||||
leaves the team or the organization.
|
||||
|
||||
```shell
|
||||
# On-call example — substitute a name that fits your use case
|
||||
coder users create \
|
||||
--username oncall-sre \
|
||||
--service-account
|
||||
```
|
||||
|
||||
## 2. Generate an API token for the service account
|
||||
|
||||
Generate a long-lived API token so you can create and manage workspaces on
|
||||
behalf of the service account:
|
||||
|
||||
```shell
|
||||
coder tokens create \
|
||||
--user oncall-sre \
|
||||
--name oncall-automation \
|
||||
--lifetime 8760h
|
||||
```
|
||||
|
||||
Store this token securely (e.g. in a secrets manager like Vault or AWS Secrets
|
||||
Manager).
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Never distribute this token to end users. The token is for workspace
|
||||
> administration only. Shared users authenticate as themselves and reach the
|
||||
> workspace through sharing.
|
||||
|
||||
## 3. Create the workspace
|
||||
|
||||
Authenticate as the service account and create the workspace:
|
||||
|
||||
```shell
|
||||
export CODER_SESSION_TOKEN="<token-from-step-2>"
|
||||
|
||||
coder create oncall-sre/oncall-workspace \
|
||||
--template your-oncall-template \
|
||||
--use-parameter-defaults \
|
||||
--yes
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Design a dedicated template for the workspace with the tools your team
|
||||
> needs pre-installed (e.g. monitoring dashboards for on-call, test runners
|
||||
> for QA). Set `subdomain = true` on workspace apps so that shared users can
|
||||
> access web-based tools without a 404. See
|
||||
> [Accessing workspace apps in shared workspaces](../user-guides/shared-workspaces.md#accessing-workspace-apps-in-shared-workspaces).
|
||||
|
||||
## 4. Share the workspace
|
||||
|
||||
Use `coder sharing share` to grant access to users who need the workspace:
|
||||
|
||||
```shell
|
||||
coder sharing share oncall-sre/oncall-workspace --user alice
|
||||
```
|
||||
|
||||
This gives `alice` the default `use` role, which allows connection via SSH and
|
||||
workspace apps, starting and stopping the workspace, and viewing logs and stats.
|
||||
|
||||
To grant `admin` permissions (which includes all `use` permissions as well as renaming, updating, and inviting
|
||||
others to join with the `use` role):
|
||||
|
||||
```shell
|
||||
coder sharing share oncall-sre/oncall-workspace --user alice:admin
|
||||
```
|
||||
|
||||
To share with multiple users at once:
|
||||
|
||||
```shell
|
||||
coder sharing share oncall-sre/oncall-workspace --user alice:admin,bob
|
||||
```
|
||||
|
||||
To share with an entire Coder group:
|
||||
|
||||
```shell
|
||||
coder sharing share oncall-sre/oncall-workspace --group sre-oncall
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Groups can be synced from your identity provider using
|
||||
> [IdP Sync](../admin/users/idp-sync.md#group-sync). If your IdP already
|
||||
> manages team membership, sharing with a group is the simplest approach.
|
||||
|
||||
## 5. Rotate access
|
||||
|
||||
When team membership changes, remove outgoing users and add incoming ones:
|
||||
|
||||
```shell
|
||||
# Remove outgoing user
|
||||
coder sharing remove oncall-sre/oncall-workspace --user alice
|
||||
|
||||
# Add incoming user
|
||||
coder sharing share oncall-sre/oncall-workspace --user carol
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The workspace must be restarted for user removal to take effect.
|
||||
|
||||
Verify current sharing status at any time:
|
||||
|
||||
```shell
|
||||
coder sharing status oncall-sre/oncall-workspace
|
||||
```
|
||||
|
||||
## 6. Automate access changes (optional)
|
||||
|
||||
For use cases with frequent rotation (such as on-call shifts), you can integrate
|
||||
the share/remove commands into external tooling like PagerDuty, Opsgenie, or a
|
||||
cron job.
|
||||
|
||||
### Rotation script
|
||||
|
||||
```shell
|
||||
#!/bin/bash
|
||||
# rotate-access.sh
|
||||
# Usage: ./rotate-access.sh <outgoing-user> <incoming-user>
|
||||
|
||||
WORKSPACE="oncall-sre/oncall-workspace"
|
||||
OUTGOING="$1"
|
||||
INCOMING="$2"
|
||||
|
||||
if [ -n "$OUTGOING" ]; then
|
||||
echo "Removing access for $OUTGOING..."
|
||||
coder sharing remove "$WORKSPACE" --user "$OUTGOING"
|
||||
fi
|
||||
|
||||
echo "Granting access to $INCOMING..."
|
||||
coder sharing share "$WORKSPACE" --user "$INCOMING"
|
||||
|
||||
echo "Restarting workspace to apply changes..."
|
||||
coder restart "$WORKSPACE" --yes
|
||||
|
||||
echo "Current sharing status:"
|
||||
coder sharing status "$WORKSPACE"
|
||||
```
|
||||
|
||||
### Group-based rotation with IdP Sync
|
||||
|
||||
If your identity provider manages group membership (e.g. an `sre-oncall` group
|
||||
in Okta or Azure AD), you can skip manual share/remove commands entirely:
|
||||
|
||||
1. Configure [Group Sync](../admin/users/idp-sync.md#group-sync) to
|
||||
synchronize the group from your IdP to Coder.
|
||||
|
||||
1. Share the workspace with the group once:
|
||||
|
||||
```shell
|
||||
coder sharing share oncall-sre/oncall-workspace --group sre-oncall
|
||||
```
|
||||
|
||||
1. When your IdP rotates group membership, Coder group membership updates on
|
||||
next login. All current members have access; removed members lose access
|
||||
after a workspace restart.
|
||||
|
||||
## Finding shared workspaces
|
||||
|
||||
Shared users can find workspaces shared with them:
|
||||
|
||||
```shell
|
||||
# List all workspaces shared with you
|
||||
coder list --search shared:true
|
||||
|
||||
# List workspaces shared with a specific user
|
||||
coder list --search shared_with_user:alice
|
||||
|
||||
# List workspaces shared with a specific group
|
||||
coder list --search shared_with_group:sre-oncall
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Shared user sees 404 on workspace apps
|
||||
|
||||
Workspace apps using path-based routing block non-owners by default. Configure a
|
||||
[wildcard access URL](../admin/networking/wildcard-access-url.md) and set
|
||||
`subdomain = true` on the workspace app in your template.
|
||||
|
||||
### Removed user still has access
|
||||
|
||||
Access removal requires a workspace restart. Run
|
||||
`coder restart <workspace>` after removing a user or group.
|
||||
|
||||
### Group sync not updating membership
|
||||
|
||||
Group membership changes in your IdP are not reflected until the user logs out
|
||||
and back in. Group sync runs at login time, not on a polling schedule. Check the
|
||||
Coder server logs with
|
||||
`CODER_LOG_FILTER=".*userauth.*|.*groups returned.*"` for details. See
|
||||
[Troubleshooting group sync](../admin/users/idp-sync.md#troubleshooting-grouproleorganization-sync)
|
||||
for more information.
|
||||
|
||||
## Next steps
|
||||
|
||||
- [Shared Workspaces](../user-guides/shared-workspaces.md) — full reference
|
||||
for workspace sharing features and UI
|
||||
- [IdP Sync](../admin/users/idp-sync.md) — group, role, and organization
|
||||
sync configuration
|
||||
- [Configuring Okta](./configuring-okta.md) — Okta-specific OIDC setup with
|
||||
custom claims and scopes
|
||||
- [Security Best Practices](./best-practices/security-best-practices.md) —
|
||||
deployment-wide security hardening
|
||||
- [Sessions and Tokens](../admin/users/sessions-tokens.md) — API token
|
||||
management and scoping
|
||||
@@ -5,7 +5,7 @@ terraform {
|
||||
}
|
||||
docker = {
|
||||
source = "kreuzwerker/docker"
|
||||
version = "~> 3.0"
|
||||
version = "~> 4.0"
|
||||
}
|
||||
envbuilder = {
|
||||
source = "coder/envbuilder"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# 1.93.1
|
||||
FROM rust:slim@sha256:1d0000a49fb62f4fde24455f49d59c6c088af46202d65d8f455b722f7263e8f8 AS rust-utils
|
||||
FROM rust:slim@sha256:a08d20a404f947ed358dfb63d1ee7e0b88ecad3c45ba9682ccbf2cb09c98acca AS rust-utils
|
||||
# Install rust helper programs
|
||||
ENV CARGO_INSTALL_ROOT=/tmp/
|
||||
# Use more reliable mirrors for Debian packages
|
||||
@@ -8,7 +8,7 @@ RUN sed -i 's|http://deb.debian.org/debian|http://mirrors.edge.kernel.org/debian
|
||||
RUN apt-get update && apt-get install -y libssl-dev openssl pkg-config build-essential
|
||||
RUN cargo install jj-cli typos-cli watchexec-cli
|
||||
|
||||
FROM ubuntu:jammy@sha256:5e5b128eb4ff35ee52687c20d081dcc15b8cb55e113247683f435224fc58b956 AS go
|
||||
FROM ubuntu:jammy@sha256:eb29ed27b0821dca09c2e28b39135e185fc1302036427d5f4d70a41ce8fd7659 AS go
|
||||
|
||||
# Install Go manually, so that we can control the version
|
||||
ARG GO_VERSION=1.25.8
|
||||
@@ -83,7 +83,7 @@ RUN curl -L -o protoc.zip https://github.com/protocolbuffers/protobuf/releases/d
|
||||
unzip protoc.zip && \
|
||||
rm protoc.zip
|
||||
|
||||
FROM ubuntu:jammy@sha256:5e5b128eb4ff35ee52687c20d081dcc15b8cb55e113247683f435224fc58b956
|
||||
FROM ubuntu:jammy@sha256:eb29ed27b0821dca09c2e28b39135e185fc1302036427d5f4d70a41ce8fd7659
|
||||
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ terraform {
|
||||
}
|
||||
docker = {
|
||||
source = "kreuzwerker/docker"
|
||||
version = "~> 3.6"
|
||||
version = "~> 4.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -922,7 +922,7 @@ resource "coder_script" "boundary_config_setup" {
|
||||
module "claude-code" {
|
||||
count = data.coder_task.me.enabled ? data.coder_workspace.me.start_count : 0
|
||||
source = "dev.registry.coder.com/coder/claude-code/coder"
|
||||
version = "4.8.2"
|
||||
version = "4.9.1"
|
||||
enable_boundary = true
|
||||
agent_id = coder_agent.dev.id
|
||||
workdir = local.repo_dir
|
||||
|
||||
@@ -16,6 +16,9 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// NOTE: See the auditLogCountCap note.
|
||||
const connectionLogCountCap = 2000
|
||||
|
||||
// @Summary Get connection logs
|
||||
// @ID get-connection-logs
|
||||
// @Security CoderSessionToken
|
||||
@@ -49,6 +52,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
// #nosec G115 - Safe conversion as pagination limit is expected to be within int32 range
|
||||
filter.LimitOpt = int32(page.Limit)
|
||||
|
||||
countFilter.CountCap = connectionLogCountCap
|
||||
count, err := api.Database.CountConnectionLogs(ctx, countFilter)
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
@@ -63,6 +67,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{
|
||||
ConnectionLogs: []codersdk.ConnectionLog{},
|
||||
Count: 0,
|
||||
CountCap: connectionLogCountCap,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -80,6 +85,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{
|
||||
ConnectionLogs: convertConnectionLogs(dblogs),
|
||||
Count: count,
|
||||
CountCap: connectionLogCountCap,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,12 @@ const RelaySourceHeader = "X-Coder-Relay-Source-Replica"
|
||||
const (
|
||||
authorizationHeader = "Authorization"
|
||||
cookieHeader = "Cookie"
|
||||
|
||||
// relayDrainTimeout is how long an established relay is
|
||||
// kept open after the chat leaves running state, giving
|
||||
// buffered snapshot events time to be forwarded before
|
||||
// the relay is torn down.
|
||||
relayDrainTimeout = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
// MultiReplicaSubscribeConfig holds the dependencies for multi-replica chat
|
||||
@@ -169,6 +175,21 @@ func NewMultiReplicaSubscribeFn(
|
||||
var reconnectTimer *quartz.Timer
|
||||
var reconnectCh <-chan time.Time
|
||||
|
||||
// drainAndClose is set when the chat transitions away
|
||||
// from running while a relay dial is still in progress.
|
||||
// Instead of canceling the dial immediately, we let it
|
||||
// complete so the snapshot of buffered message_parts
|
||||
// can be forwarded to the subscriber.
|
||||
var drainAndClose bool
|
||||
|
||||
// Drain timer state. When the relay connects in
|
||||
// drain-and-close mode, a short timer is started.
|
||||
// During this window the normal relayPartsCh case
|
||||
// forwards buffered snapshot events. When the timer
|
||||
// fires the relay is torn down.
|
||||
var drainTimer *quartz.Timer
|
||||
var drainTimerCh <-chan time.Time
|
||||
|
||||
// Helper to close relay and stop any pending reconnect
|
||||
// timer.
|
||||
closeRelay := func() {
|
||||
@@ -200,6 +221,12 @@ func NewMultiReplicaSubscribeFn(
|
||||
reconnectTimer = nil
|
||||
reconnectCh = nil
|
||||
}
|
||||
if drainTimer != nil {
|
||||
drainTimer.Stop()
|
||||
drainTimer = nil
|
||||
drainTimerCh = nil
|
||||
}
|
||||
drainAndClose = false
|
||||
}
|
||||
|
||||
// openRelayAsync dials the remote replica in a background
|
||||
@@ -335,16 +362,52 @@ func NewMultiReplicaSubscribeFn(
|
||||
// A nil parts channel signals the dial
|
||||
// failed — schedule a retry.
|
||||
if result.parts == nil {
|
||||
scheduleRelayReconnect()
|
||||
if drainAndClose {
|
||||
// Dial failed and we were only
|
||||
// waiting to drain — nothing to do.
|
||||
drainAndClose = false
|
||||
} else {
|
||||
scheduleRelayReconnect()
|
||||
}
|
||||
continue
|
||||
}
|
||||
// An async relay dial completed; swap
|
||||
} // An async relay dial completed; swap
|
||||
// in the new relay channel.
|
||||
if relayCancel != nil {
|
||||
relayCancel()
|
||||
}
|
||||
relayParts = result.parts
|
||||
relayCancel = result.cancel
|
||||
if drainAndClose {
|
||||
// The chat is no longer running on
|
||||
// the remote worker, but the dial
|
||||
// completed. Verify no new worker
|
||||
// has claimed the chat before we
|
||||
// drain stale parts.
|
||||
currentChat, dbErr := params.DB.GetChatByID(ctx, chatID)
|
||||
if dbErr != nil {
|
||||
logger.Warn(ctx, "failed to check chat status for relay drain",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(dbErr),
|
||||
)
|
||||
}
|
||||
if dbErr == nil && currentChat.Status == database.ChatStatusRunning &&
|
||||
currentChat.WorkerID.Valid &&
|
||||
currentChat.WorkerID.UUID != params.WorkerID {
|
||||
// A new worker picked up the chat;
|
||||
// discard the stale relay and let
|
||||
// openRelayAsync handle the new one.
|
||||
closeRelay()
|
||||
} else {
|
||||
// Chat is still idle — drain the
|
||||
// buffered snapshot before closing.
|
||||
if drainTimer != nil {
|
||||
drainTimer.Stop()
|
||||
}
|
||||
drainTimer = cfg.clock().NewTimer(relayDrainTimeout, "drain")
|
||||
drainTimerCh = drainTimer.C
|
||||
drainAndClose = false
|
||||
}
|
||||
}
|
||||
case <-reconnectCh:
|
||||
reconnectCh = nil
|
||||
// Re-check whether the chat is still
|
||||
@@ -374,8 +437,31 @@ func NewMultiReplicaSubscribeFn(
|
||||
if sn.Status == database.ChatStatusRunning && sn.WorkerID != uuid.Nil && sn.WorkerID != params.WorkerID {
|
||||
openRelayAsync(sn.WorkerID)
|
||||
} else {
|
||||
closeRelay()
|
||||
switch {
|
||||
case dialCancel != nil && relayParts == nil:
|
||||
// In-progress dial: let it complete
|
||||
// so its snapshot can be forwarded.
|
||||
drainAndClose = true
|
||||
case relayParts != nil:
|
||||
// Active relay: give it a short
|
||||
// window to deliver any remaining
|
||||
// buffered parts before closing.
|
||||
if drainTimer != nil {
|
||||
drainTimer.Stop()
|
||||
}
|
||||
drainTimer = cfg.clock().NewTimer(relayDrainTimeout, "drain")
|
||||
drainTimerCh = drainTimer.C
|
||||
default:
|
||||
closeRelay()
|
||||
}
|
||||
}
|
||||
case <-drainTimerCh:
|
||||
drainTimerCh = nil
|
||||
drainTimer = nil
|
||||
closeRelay()
|
||||
drainTimerCh = nil
|
||||
drainTimer = nil
|
||||
closeRelay()
|
||||
case event, ok := <-relayPartsCh:
|
||||
if !ok {
|
||||
if relayCancel != nil {
|
||||
|
||||
@@ -1245,3 +1245,334 @@ func TestSubscribeRelayMultipleReconnects(t *testing.T) {
|
||||
consumePart("relay-3")
|
||||
require.GreaterOrEqual(t, int(callCount.Load()), 3)
|
||||
}
|
||||
|
||||
// TestSubscribeRelayDialCanceledOnFastCompletion demonstrates a race
|
||||
// condition in multi-replica chat streaming where the relay connection
|
||||
// from the subscriber replica to the worker replica is canceled before
|
||||
// it can be established because the worker completes processing before
|
||||
// the async relay dial finishes.
|
||||
//
|
||||
// Scenario:
|
||||
// 1. Subscriber subscribes to a chat while it's in waiting state (no relay).
|
||||
// 2. User sends a message → chat becomes pending → worker picks it up.
|
||||
// 3. Subscriber receives status=running via pubsub → enterprise opens relay async.
|
||||
// 4. Worker completes quickly → publishes committed message + status=waiting.
|
||||
// 5. Subscriber receives status=waiting → enterprise cancels the in-progress relay dial.
|
||||
// 6. The relay was never established, so no message_part events were delivered.
|
||||
// 7. The committed message arrives via pubsub (durable path), but streaming is lost.
|
||||
//
|
||||
// This reproduces the user-facing issue where refreshing the page is needed
|
||||
// to see a response: the streaming tokens never arrive via the relay, and
|
||||
// the response only appears after the full committed message is delivered.
|
||||
func TestSubscribeRelayDialCanceledOnFastCompletion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
workerID := uuid.New()
|
||||
subscriberID := uuid.New()
|
||||
|
||||
var dialAttempted atomic.Bool
|
||||
|
||||
// Gate: closed when the worker finishes processing.
|
||||
workerDone := make(chan struct{})
|
||||
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("fast-completion-relay-race")
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("hello ", "world ", "from ", "the ", "worker")...,
|
||||
)
|
||||
})
|
||||
|
||||
// Worker server with a 1-hour acquire interval so it only processes
|
||||
// when explicitly woken by SendMessage's signalWake.
|
||||
workerLogger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
worker := osschatd.New(osschatd.Config{
|
||||
Logger: workerLogger,
|
||||
Database: db,
|
||||
ReplicaID: workerID,
|
||||
Pubsub: ps,
|
||||
PendingChatAcquireInterval: time.Hour,
|
||||
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, worker.Close())
|
||||
})
|
||||
|
||||
// Subscriber's relay dialer blocks until the worker finishes,
|
||||
// simulating a slow relay dial (network latency between replicas).
|
||||
// After the worker completes, the dialer connects to the worker
|
||||
// to retrieve buffered parts from the retained buffer.
|
||||
subscriber := newTestServer(t, db, ps, subscriberID, func(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
targetWorkerID uuid.UUID,
|
||||
requestHeader http.Header,
|
||||
) (
|
||||
[]codersdk.ChatStreamEvent,
|
||||
<-chan codersdk.ChatStreamEvent,
|
||||
func(),
|
||||
error,
|
||||
) {
|
||||
dialAttempted.Store(true)
|
||||
// Block until the worker finishes processing, simulating
|
||||
// a slow relay dial.
|
||||
select {
|
||||
case <-workerDone:
|
||||
case <-ctx.Done():
|
||||
return nil, nil, nil, ctx.Err()
|
||||
}
|
||||
// Connect to the worker. The buffer is retained for a
|
||||
// grace period after processing, so the relay still gets
|
||||
// the message_part snapshot.
|
||||
snapshot, relayEvents, cancel, ok := worker.Subscribe(ctx, chatID, requestHeader, math.MaxInt64)
|
||||
if !ok {
|
||||
return nil, nil, nil, xerrors.New("worker subscribe failed")
|
||||
}
|
||||
return snapshot, relayEvents, cancel, nil
|
||||
}, nil)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
||||
|
||||
// Create the chat in waiting state so the subscriber sees it
|
||||
// before the worker picks it up (avoids the synchronous relay
|
||||
// path in Subscribe).
|
||||
chat := seedWaitingChat(ctx, t, db, user, model, "fast-completion-relay-race")
|
||||
|
||||
// Subscribe from the subscriber replica while the chat is idle.
|
||||
// No relay is opened because the chat is in waiting state.
|
||||
_, events, subCancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
defer subCancel()
|
||||
|
||||
// Send a message via the worker server to transition the chat to
|
||||
// pending and wake the worker's processing loop.
|
||||
_, err := worker.SendMessage(ctx, osschatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: user.ID,
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the worker to fully process the chat.
|
||||
require.Eventually(t, func() bool {
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
return fromDB.Status == database.ChatStatusWaiting
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
// Release the relay dial now that the worker is done.
|
||||
close(workerDone)
|
||||
|
||||
// Collect all events that arrived at the subscriber.
|
||||
var messageParts []string
|
||||
var committedAssistantMsgs int
|
||||
|
||||
// Drain events until we see both the committed message (via
|
||||
// pubsub) and at least one streaming part (via relay
|
||||
// drain-and-close).
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case event := <-events:
|
||||
switch event.Type {
|
||||
case codersdk.ChatStreamEventTypeMessagePart:
|
||||
if event.MessagePart != nil {
|
||||
messageParts = append(messageParts, event.MessagePart.Part.Text)
|
||||
}
|
||||
case codersdk.ChatStreamEventTypeMessage:
|
||||
if event.Message != nil && event.Message.Role == codersdk.ChatMessageRoleAssistant {
|
||||
committedAssistantMsgs++
|
||||
}
|
||||
}
|
||||
return committedAssistantMsgs > 0 && len(messageParts) > 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
// The committed assistant message arrives via pubsub → DB query
|
||||
// (durable path).
|
||||
require.Equal(t, 1, committedAssistantMsgs,
|
||||
"committed assistant message should arrive via pubsub durable path")
|
||||
|
||||
// The relay dial was attempted when status=running arrived.
|
||||
require.True(t, dialAttempted.Load(),
|
||||
"relay dial should have been attempted when status changed to running")
|
||||
|
||||
// Streaming parts are now received even though the relay was
|
||||
// slower than the worker: the OSS buffer retention grace period
|
||||
// keeps parts available, and the enterprise relay completes the
|
||||
// dial (drain-and-close) instead of canceling it immediately.
|
||||
require.NotEmpty(t, messageParts,
|
||||
"streaming parts should be received via the relay even when the "+
|
||||
"worker completes before the relay is established")
|
||||
}
|
||||
|
||||
// TestSubscribeRelayEstablishedMidStream demonstrates that when the
|
||||
// relay is established while the worker is still streaming, the
|
||||
// subscriber receives buffered parts via the relay snapshot and live
|
||||
// parts through the relay channel.
|
||||
//
|
||||
// This is the complementary test to TestSubscribeRelayDialCanceledOnFastCompletion:
|
||||
// it shows the relay mechanism works correctly when timing is favorable
|
||||
// (relay connects before the worker finishes), contrasting with the race
|
||||
// condition where the relay is too slow.
|
||||
func TestSubscribeRelayEstablishedMidStream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
workerID := uuid.New()
|
||||
subscriberID := uuid.New()
|
||||
|
||||
// Gate: worker blocks after first streaming request until we
|
||||
// release it. This gives the relay time to establish.
|
||||
firstChunkEmitted := make(chan struct{})
|
||||
continueStreaming := make(chan struct{})
|
||||
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("mid-stream-relay")
|
||||
}
|
||||
// Signal that the first streaming request was received,
|
||||
// then block until released.
|
||||
select {
|
||||
case <-firstChunkEmitted:
|
||||
default:
|
||||
close(firstChunkEmitted)
|
||||
}
|
||||
<-continueStreaming
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("continued ", "response")...,
|
||||
)
|
||||
})
|
||||
|
||||
// Worker with a 1-hour acquire interval; only processes when
|
||||
// explicitly woken.
|
||||
workerLogger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
worker := osschatd.New(osschatd.Config{
|
||||
Logger: workerLogger,
|
||||
Database: db,
|
||||
ReplicaID: workerID,
|
||||
Pubsub: ps,
|
||||
PendingChatAcquireInterval: time.Hour,
|
||||
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, worker.Close())
|
||||
})
|
||||
|
||||
// Subscriber's dialer connects to the worker with no delay.
|
||||
// This simulates a relay that succeeds promptly.
|
||||
subscriber := newTestServer(t, db, ps, subscriberID, func(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
targetWorkerID uuid.UUID,
|
||||
requestHeader http.Header,
|
||||
) (
|
||||
[]codersdk.ChatStreamEvent,
|
||||
<-chan codersdk.ChatStreamEvent,
|
||||
func(),
|
||||
error,
|
||||
) {
|
||||
if targetWorkerID != workerID {
|
||||
return nil, nil, nil, xerrors.Errorf("unexpected relay target %s", targetWorkerID)
|
||||
}
|
||||
snapshot, relayEvents, cancel, ok := worker.Subscribe(ctx, chatID, requestHeader, math.MaxInt64)
|
||||
if !ok {
|
||||
return nil, nil, nil, xerrors.New("worker subscribe failed")
|
||||
}
|
||||
return snapshot, relayEvents, cancel, nil
|
||||
}, nil)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
||||
|
||||
// Create the chat in waiting state.
|
||||
chat := seedWaitingChat(ctx, t, db, user, model, "mid-stream-relay")
|
||||
|
||||
// Subscribe from the subscriber replica while the chat is idle.
|
||||
_, events, subCancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
defer subCancel()
|
||||
|
||||
// Send a message to make the chat pending and wake the worker.
|
||||
_, err := worker.SendMessage(ctx, osschatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: user.ID,
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the worker to reach the LLM (first streaming request).
|
||||
select {
|
||||
case <-firstChunkEmitted:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for worker to start streaming")
|
||||
}
|
||||
|
||||
// Wait for the subscriber to receive the running status, which
|
||||
// triggers the relay. Because the dialer is non-blocking, the
|
||||
// relay establishes promptly.
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case event := <-events:
|
||||
return event.Type == codersdk.ChatStreamEventTypeStatus &&
|
||||
event.Status != nil &&
|
||||
event.Status.Status == codersdk.ChatStatusRunning
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
// Now release the worker to continue streaming.
|
||||
close(continueStreaming)
|
||||
|
||||
// Wait for the worker to complete.
|
||||
require.Eventually(t, func() bool {
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
return fromDB.Status == database.ChatStatusWaiting
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
// Collect remaining events.
|
||||
var messageParts []string
|
||||
var hasCommittedMsg bool
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case event := <-events:
|
||||
switch event.Type {
|
||||
case codersdk.ChatStreamEventTypeMessagePart:
|
||||
if event.MessagePart != nil {
|
||||
messageParts = append(messageParts, event.MessagePart.Part.Text)
|
||||
}
|
||||
case codersdk.ChatStreamEventTypeMessage:
|
||||
if event.Message != nil && event.Message.Role == codersdk.ChatMessageRoleAssistant {
|
||||
hasCommittedMsg = true
|
||||
}
|
||||
}
|
||||
return hasCommittedMsg
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
// The committed message arrives via pubsub.
|
||||
require.True(t, hasCommittedMsg,
|
||||
"committed assistant message should arrive")
|
||||
|
||||
// When the relay is established mid-stream, streaming parts
|
||||
// SHOULD be received through the relay. This contrasts with
|
||||
// TestSubscribeRelayDialCanceledOnFastCompletion where no parts
|
||||
// arrive because the relay is never established.
|
||||
require.NotEmpty(t, messageParts,
|
||||
"streaming parts should be received when relay establishes while worker is still streaming")
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ data "coder_task" "me" {}
|
||||
module "claude-code" {
|
||||
count = data.coder_workspace.me.start_count
|
||||
source = "registry.coder.com/coder/claude-code/coder"
|
||||
version = "4.8.2"
|
||||
version = "4.9.1"
|
||||
agent_id = coder_agent.main.id
|
||||
workdir = "/home/coder/projects"
|
||||
order = 999
|
||||
|
||||
@@ -91,7 +91,7 @@ require (
|
||||
github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d
|
||||
github.com/adrg/xdg v0.5.0
|
||||
github.com/ammario/tlru v0.4.0
|
||||
github.com/andybalholm/brotli v1.2.0
|
||||
github.com/andybalholm/brotli v1.2.1
|
||||
github.com/aquasecurity/trivy-iac v0.8.0
|
||||
github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2
|
||||
github.com/awalterschulze/gographviz v2.0.3+incompatible
|
||||
@@ -136,11 +136,11 @@ require (
|
||||
github.com/go-chi/chi/v5 v5.2.4
|
||||
github.com/go-chi/cors v1.2.1
|
||||
github.com/go-chi/httprate v0.15.0
|
||||
github.com/go-jose/go-jose/v4 v4.1.3
|
||||
github.com/go-jose/go-jose/v4 v4.1.4
|
||||
github.com/go-logr/logr v1.4.3
|
||||
github.com/go-playground/validator/v10 v10.30.0
|
||||
github.com/gofrs/flock v0.13.0
|
||||
github.com/gohugoio/hugo v0.159.2
|
||||
github.com/gohugoio/hugo v0.160.0
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2
|
||||
github.com/golang-migrate/migrate/v4 v4.19.0
|
||||
github.com/gomarkdown/markdown v0.0.0-20240930133441-72d49d9543d8
|
||||
@@ -162,7 +162,7 @@ require (
|
||||
github.com/justinas/nosurf v1.2.0
|
||||
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51
|
||||
github.com/kirsle/configdir v0.0.0-20170128060238-e45d2f54772f
|
||||
github.com/klauspost/compress v1.18.4
|
||||
github.com/klauspost/compress v1.18.5
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/mattn/go-isatty v0.0.20
|
||||
github.com/mitchellh/go-wordwrap v1.0.1
|
||||
@@ -193,7 +193,7 @@ require (
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/u-root/u-root v0.14.0
|
||||
github.com/unrolled/secure v1.17.0
|
||||
github.com/valyala/fasthttp v1.69.0
|
||||
github.com/valyala/fasthttp v1.70.0
|
||||
github.com/wagslane/go-password-validator v0.3.0
|
||||
github.com/zclconf/go-cty-yaml v1.2.0
|
||||
go.mozilla.org/pkcs7 v0.9.0
|
||||
@@ -218,7 +218,7 @@ require (
|
||||
golang.org/x/text v0.35.0
|
||||
golang.org/x/tools v0.43.0
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da
|
||||
google.golang.org/api v0.273.0
|
||||
google.golang.org/api v0.274.0
|
||||
google.golang.org/grpc v1.80.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
gopkg.in/DataDog/dd-trace-go.v1 v1.74.0
|
||||
@@ -434,7 +434,7 @@ require (
|
||||
github.com/xeipuuv/gojsonschema v1.2.0 // indirect
|
||||
github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect
|
||||
github.com/yashtewari/glob-intersection v0.2.0 // indirect
|
||||
github.com/yuin/goldmark v1.7.17 // indirect
|
||||
github.com/yuin/goldmark v1.8.2 // indirect
|
||||
github.com/yuin/goldmark-emoji v1.0.6 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
github.com/zclconf/go-cty v1.17.0
|
||||
|
||||
@@ -128,8 +128,8 @@ github.com/ammario/tlru v0.4.0 h1:sJ80I0swN3KOX2YxC6w8FbCqpQucWdbb+J36C05FPuU=
|
||||
github.com/ammario/tlru v0.4.0/go.mod h1:aYzRFu0XLo4KavE9W8Lx7tzjkX+pAApz+NgcKYIFUBQ=
|
||||
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ=
|
||||
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8=
|
||||
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/andybalholm/brotli v1.2.1 h1:R+f5xP285VArJDRgowrfb9DqL18yVK0gKAW/F+eTWro=
|
||||
github.com/andybalholm/brotli v1.2.1/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
||||
github.com/apparentlymart/go-cidr v1.1.0 h1:2mAhrMoF+nhXqxTzSZMUzDHkLjmIHC+Zzn4tdgBZjnU=
|
||||
@@ -518,8 +518,8 @@ github.com/go-git/go-git/v5 v5.17.1 h1:WnljyxIzSj9BRRUlnmAU35ohDsjRK0EKmL0evDqi5
|
||||
github.com/go-git/go-git/v5 v5.17.1/go.mod h1:pW/VmeqkanRFqR6AljLcs7EA7FbZaN5MQqO7oZADXpo=
|
||||
github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A=
|
||||
github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8=
|
||||
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
|
||||
github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
|
||||
github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA=
|
||||
github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
|
||||
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e h1:Lf/gRkoycfOBPa42vU2bbgPurFong6zXeFtPoxholzU=
|
||||
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e/go.mod h1:uNVvRXArCGbZ508SxYYTC5v1JWoz2voff5pm25jU1Ok=
|
||||
github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
@@ -611,12 +611,12 @@ github.com/gohugoio/hashstructure v0.6.0 h1:7wMB/2CfXoThFYhdWRGv3u3rUM761Cq29CxU
|
||||
github.com/gohugoio/hashstructure v0.6.0/go.mod h1:lapVLk9XidheHG1IQ4ZSbyYrXcaILU1ZEP/+vno5rBQ=
|
||||
github.com/gohugoio/httpcache v0.8.0 h1:hNdsmGSELztetYCsPVgjA960zSa4dfEqqF/SficorCU=
|
||||
github.com/gohugoio/httpcache v0.8.0/go.mod h1:fMlPrdY/vVJhAriLZnrF5QpN3BNAcoBClgAyQd+lGFI=
|
||||
github.com/gohugoio/hugo v0.159.2 h1:tpS6pcShcP3Khl8WA1NAxVHi2D3/ib9BbM8+m7NECUA=
|
||||
github.com/gohugoio/hugo v0.159.2/go.mod h1:vKww5V9i8MYzFD8JVvhRN+AKdDfKV0UvbFpmCDtTr/A=
|
||||
github.com/gohugoio/hugo-goldmark-extensions/extras v0.6.0 h1:c16engMi6zyOGeCrP73RWC9fom94wXGpVzncu3GXBjI=
|
||||
github.com/gohugoio/hugo-goldmark-extensions/extras v0.6.0/go.mod h1:e3+TRCT4Uz6NkZOAVMOMgPeJ+7KEtQMX8hdB+WG4qRs=
|
||||
github.com/gohugoio/hugo-goldmark-extensions/passthrough v0.4.0 h1:awFlqaCQ0N/RS9ndIBpDYNms101I1sGbDRG1bksa5Js=
|
||||
github.com/gohugoio/hugo-goldmark-extensions/passthrough v0.4.0/go.mod h1:lK1CjqrueCd3OBnsLLQJGrQ+uodWfT9M9Cq2zfDWJCE=
|
||||
github.com/gohugoio/hugo v0.160.0 h1:WmmygLg2ahijM4w2VHFn/DdBR+OpJ9H9pH3d8OApNDY=
|
||||
github.com/gohugoio/hugo v0.160.0/go.mod h1:+VA5jOO3iGELh+6cig098PT2Cd9iNhwUPRqCUe3Ce7w=
|
||||
github.com/gohugoio/hugo-goldmark-extensions/extras v0.7.0 h1:I/n6v7VImJ3aISLnn73JAHXyjcQsMVvbguQPTk9Ehus=
|
||||
github.com/gohugoio/hugo-goldmark-extensions/extras v0.7.0/go.mod h1:9LJNfKWFmhEJ7HW0in5znezMwH+FYMBIhNZ3VWtRcRs=
|
||||
github.com/gohugoio/hugo-goldmark-extensions/passthrough v0.5.0 h1:p13Q0DBCrBRpJGtbtlgkYNCs4TnIlZJh8vHgnAiofrI=
|
||||
github.com/gohugoio/hugo-goldmark-extensions/passthrough v0.5.0/go.mod h1:ob9PCHy/ocsQhTz68uxhyInaYCbbVNpOOrJkIoSeD+8=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
@@ -794,8 +794,8 @@ github.com/kirsle/configdir v0.0.0-20170128060238-e45d2f54772f h1:dKccXx7xA56UNq
|
||||
github.com/kirsle/configdir v0.0.0-20170128060238-e45d2f54772f/go.mod h1:4rEELDSfUAlBSyUjPG0JnaNGjf13JySHFeRdD/3dLP0=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
|
||||
github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
|
||||
github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
|
||||
github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE=
|
||||
github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a h1:+RR6SqnTkDLWyICxS1xpjCi/3dhyV+TgZwA6Ww3KncQ=
|
||||
@@ -1188,8 +1188,8 @@ github.com/urfave/cli/v2 v2.27.5 h1:WoHEJLdsXr6dDWoJgMq/CboDmyY/8HMMH1fTECbih+w=
|
||||
github.com/urfave/cli/v2 v2.27.5/go.mod h1:3Sevf16NykTbInEnD0yKkjDAeZDS0A6bzhBH5hrMvTQ=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.69.0 h1:fNLLESD2SooWeh2cidsuFtOcrEi4uB4m1mPrkJMZyVI=
|
||||
github.com/valyala/fasthttp v1.69.0/go.mod h1:4wA4PfAraPlAsJ5jMSqCE2ug5tqUPwKXxVj8oNECGcw=
|
||||
github.com/valyala/fasthttp v1.70.0 h1:LAhMGcWk13QZWm85+eg8ZBNbrq5mnkWFGbHMUJHIdXA=
|
||||
github.com/valyala/fasthttp v1.70.0/go.mod h1:oDZEHHkJ/Buyklg6uURmYs19442zFSnCIfX3j1FY3pE=
|
||||
github.com/valyala/fastjson v1.6.4 h1:uAUNq9Z6ymTgGhcm0UynUAB6tlbakBrz6CQFax3BXVQ=
|
||||
github.com/valyala/fastjson v1.6.4/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY=
|
||||
github.com/vbatts/tar-split v0.12.2 h1:w/Y6tjxpeiFMR47yzZPlPj/FcPLpXbTUi/9H7d3CPa4=
|
||||
@@ -1252,8 +1252,8 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
github.com/yuin/goldmark v1.7.17 h1:p36OVWwRb246iHxA/U4p8OPEpOTESm4n+g+8t0EE5uA=
|
||||
github.com/yuin/goldmark v1.7.17/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
|
||||
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
github.com/yuin/goldmark-emoji v1.0.6 h1:QWfF2FYaXwL74tfGOW5izeiZepUDroDJfWubQI9HTHs=
|
||||
github.com/yuin/goldmark-emoji v1.0.6/go.mod h1:ukxJDKFpdFb5x0a5HqbdlcKtebh086iJpI31LTKmWuA=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
@@ -1375,8 +1375,8 @@ golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||
golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0=
|
||||
golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA=
|
||||
golang.org/x/image v0.37.0 h1:ZiRjArKI8GwxZOoEtUfhrBtaCN+4b/7709dlT6SSnQA=
|
||||
golang.org/x/image v0.37.0/go.mod h1:/3f6vaXC+6CEanU4KJxbcUZyEePbyKbaLoDOe4ehFYY=
|
||||
golang.org/x/image v0.38.0 h1:5l+q+Y9JDC7mBOMjo4/aPhMDcxEptsX+Tt3GgRQRPuE=
|
||||
golang.org/x/image v0.38.0/go.mod h1:/3f6vaXC+6CEanU4KJxbcUZyEePbyKbaLoDOe4ehFYY=
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
@@ -1514,8 +1514,8 @@ golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/api v0.273.0 h1:r/Bcv36Xa/te1ugaN1kdJ5LoA5Wj/cL+a4gj6FiPBjQ=
|
||||
google.golang.org/api v0.273.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
|
||||
google.golang.org/api v0.274.0 h1:aYhycS5QQCwxHLwfEHRRLf9yNsfvp1JadKKWBE54RFA=
|
||||
google.golang.org/api v0.274.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM=
|
||||
google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds=
|
||||
|
||||
+1
-1
@@ -9,7 +9,7 @@
|
||||
"storybook": "pnpm run -C site/ storybook"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@biomejs/biome": "2.2.0",
|
||||
"@biomejs/biome": "2.4.10",
|
||||
"markdown-table-formatter": "^1.6.1",
|
||||
"markdownlint-cli2": "^0.16.0",
|
||||
"quicktype": "^23.0.0"
|
||||
|
||||
Generated
+37
-37
@@ -12,8 +12,8 @@ importers:
|
||||
.:
|
||||
devDependencies:
|
||||
'@biomejs/biome':
|
||||
specifier: 2.2.0
|
||||
version: 2.2.0
|
||||
specifier: 2.4.10
|
||||
version: 2.4.10
|
||||
markdown-table-formatter:
|
||||
specifier: ^1.6.1
|
||||
version: 1.6.1
|
||||
@@ -26,55 +26,55 @@ importers:
|
||||
|
||||
packages:
|
||||
|
||||
'@biomejs/biome@2.2.0':
|
||||
resolution: {integrity: sha512-3On3RSYLsX+n9KnoSgfoYlckYBoU6VRM22cw1gB4Y0OuUVSYd/O/2saOJMrA4HFfA1Ff0eacOvMN1yAAvHtzIw==}
|
||||
'@biomejs/biome@2.4.10':
|
||||
resolution: {integrity: sha512-xxA3AphFQ1geij4JTHXv4EeSTda1IFn22ye9LdyVPoJU19fNVl0uzfEuhsfQ4Yue/0FaLs2/ccVi4UDiE7R30w==}
|
||||
engines: {node: '>=14.21.3'}
|
||||
hasBin: true
|
||||
|
||||
'@biomejs/cli-darwin-arm64@2.2.0':
|
||||
resolution: {integrity: sha512-zKbwUUh+9uFmWfS8IFxmVD6XwqFcENjZvEyfOxHs1epjdH3wyyMQG80FGDsmauPwS2r5kXdEM0v/+dTIA9FXAg==}
|
||||
'@biomejs/cli-darwin-arm64@2.4.10':
|
||||
resolution: {integrity: sha512-vuzzI1cWqDVzOMIkYyHbKqp+AkQq4K7k+UCXWpkYcY/HDn1UxdsbsfgtVpa40shem8Kax4TLDLlx8kMAecgqiw==}
|
||||
engines: {node: '>=14.21.3'}
|
||||
cpu: [arm64]
|
||||
os: [darwin]
|
||||
|
||||
'@biomejs/cli-darwin-x64@2.2.0':
|
||||
resolution: {integrity: sha512-+OmT4dsX2eTfhD5crUOPw3RPhaR+SKVspvGVmSdZ9y9O/AgL8pla6T4hOn1q+VAFBHuHhsdxDRJgFCSC7RaMOw==}
|
||||
'@biomejs/cli-darwin-x64@2.4.10':
|
||||
resolution: {integrity: sha512-14fzASRo+BPotwp7nWULy2W5xeUyFnTaq1V13Etrrxkrih+ez/2QfgFm5Ehtf5vSjtgx/IJycMMpn5kPd5ZNaA==}
|
||||
engines: {node: '>=14.21.3'}
|
||||
cpu: [x64]
|
||||
os: [darwin]
|
||||
|
||||
'@biomejs/cli-linux-arm64-musl@2.2.0':
|
||||
resolution: {integrity: sha512-egKpOa+4FL9YO+SMUMLUvf543cprjevNc3CAgDNFLcjknuNMcZ0GLJYa3EGTCR2xIkIUJDVneBV3O9OcIlCEZQ==}
|
||||
'@biomejs/cli-linux-arm64-musl@2.4.10':
|
||||
resolution: {integrity: sha512-WrJY6UuiSD/Dh+nwK2qOTu8kdMDlLV3dLMmychIghHPAysWFq1/DGC1pVZx8POE3ZkzKR3PUUnVrtZfMfaJjyQ==}
|
||||
engines: {node: '>=14.21.3'}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
|
||||
'@biomejs/cli-linux-arm64@2.2.0':
|
||||
resolution: {integrity: sha512-6eoRdF2yW5FnW9Lpeivh7Mayhq0KDdaDMYOJnH9aT02KuSIX5V1HmWJCQQPwIQbhDh68Zrcpl8inRlTEan0SXw==}
|
||||
'@biomejs/cli-linux-arm64@2.4.10':
|
||||
resolution: {integrity: sha512-7MH1CMW5uuxQ/s7FLST63qF8B3Hgu2HRdZ7tA1X1+mk+St4JOuIrqdhIBnnyqeyWJNI+Bww7Es5QZ0wIc1Cmkw==}
|
||||
engines: {node: '>=14.21.3'}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
|
||||
'@biomejs/cli-linux-x64-musl@2.2.0':
|
||||
resolution: {integrity: sha512-I5J85yWwUWpgJyC1CcytNSGusu2p9HjDnOPAFG4Y515hwRD0jpR9sT9/T1cKHtuCvEQ/sBvx+6zhz9l9wEJGAg==}
|
||||
'@biomejs/cli-linux-x64-musl@2.4.10':
|
||||
resolution: {integrity: sha512-kDTi3pI6PBN6CiczsWYOyP2zk0IJI08EWEQyDMQWW221rPaaEz6FvjLhnU07KMzLv8q3qSuoB93ua6inSQ55Tw==}
|
||||
engines: {node: '>=14.21.3'}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
|
||||
'@biomejs/cli-linux-x64@2.2.0':
|
||||
resolution: {integrity: sha512-5UmQx/OZAfJfi25zAnAGHUMuOd+LOsliIt119x2soA2gLggQYrVPA+2kMUxR6Mw5M1deUF/AWWP2qpxgH7Nyfw==}
|
||||
'@biomejs/cli-linux-x64@2.4.10':
|
||||
resolution: {integrity: sha512-tZLvEEi2u9Xu1zAqRjTcpIDGVtldigVvzug2fTuPG0ME/g8/mXpRPcNgLB22bGn6FvLJpHHnqLnwliOu8xjYrg==}
|
||||
engines: {node: '>=14.21.3'}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
|
||||
'@biomejs/cli-win32-arm64@2.2.0':
|
||||
resolution: {integrity: sha512-n9a1/f2CwIDmNMNkFs+JI0ZjFnMO0jdOyGNtihgUNFnlmd84yIYY2KMTBmMV58ZlVHjgmY5Y6E1hVTnSRieggA==}
|
||||
'@biomejs/cli-win32-arm64@2.4.10':
|
||||
resolution: {integrity: sha512-umwQU6qPzH+ISTf/eHyJ/QoQnJs3V9Vpjz2OjZXe9MVBZ7prgGafMy7yYeRGnlmDAn87AKTF3Q6weLoMGpeqdQ==}
|
||||
engines: {node: '>=14.21.3'}
|
||||
cpu: [arm64]
|
||||
os: [win32]
|
||||
|
||||
'@biomejs/cli-win32-x64@2.2.0':
|
||||
resolution: {integrity: sha512-Nawu5nHjP/zPKTIryh2AavzTc/KEg4um/MxWdXW0A6P/RZOyIpa7+QSjeXwAwX/utJGaCoXRPWtF3m5U/bB3Ww==}
|
||||
'@biomejs/cli-win32-x64@2.4.10':
|
||||
resolution: {integrity: sha512-aW/JU5GuyH4uxMrNYpoC2kjaHlyJGLgIa3XkhPEZI0uKhZhJZU8BuEyJmvgzSPQNGozBwWjC972RaNdcJ9KyJg==}
|
||||
engines: {node: '>=14.21.3'}
|
||||
cpu: [x64]
|
||||
os: [win32]
|
||||
@@ -778,39 +778,39 @@ packages:
|
||||
|
||||
snapshots:
|
||||
|
||||
'@biomejs/biome@2.2.0':
|
||||
'@biomejs/biome@2.4.10':
|
||||
optionalDependencies:
|
||||
'@biomejs/cli-darwin-arm64': 2.2.0
|
||||
'@biomejs/cli-darwin-x64': 2.2.0
|
||||
'@biomejs/cli-linux-arm64': 2.2.0
|
||||
'@biomejs/cli-linux-arm64-musl': 2.2.0
|
||||
'@biomejs/cli-linux-x64': 2.2.0
|
||||
'@biomejs/cli-linux-x64-musl': 2.2.0
|
||||
'@biomejs/cli-win32-arm64': 2.2.0
|
||||
'@biomejs/cli-win32-x64': 2.2.0
|
||||
'@biomejs/cli-darwin-arm64': 2.4.10
|
||||
'@biomejs/cli-darwin-x64': 2.4.10
|
||||
'@biomejs/cli-linux-arm64': 2.4.10
|
||||
'@biomejs/cli-linux-arm64-musl': 2.4.10
|
||||
'@biomejs/cli-linux-x64': 2.4.10
|
||||
'@biomejs/cli-linux-x64-musl': 2.4.10
|
||||
'@biomejs/cli-win32-arm64': 2.4.10
|
||||
'@biomejs/cli-win32-x64': 2.4.10
|
||||
|
||||
'@biomejs/cli-darwin-arm64@2.2.0':
|
||||
'@biomejs/cli-darwin-arm64@2.4.10':
|
||||
optional: true
|
||||
|
||||
'@biomejs/cli-darwin-x64@2.2.0':
|
||||
'@biomejs/cli-darwin-x64@2.4.10':
|
||||
optional: true
|
||||
|
||||
'@biomejs/cli-linux-arm64-musl@2.2.0':
|
||||
'@biomejs/cli-linux-arm64-musl@2.4.10':
|
||||
optional: true
|
||||
|
||||
'@biomejs/cli-linux-arm64@2.2.0':
|
||||
'@biomejs/cli-linux-arm64@2.4.10':
|
||||
optional: true
|
||||
|
||||
'@biomejs/cli-linux-x64-musl@2.2.0':
|
||||
'@biomejs/cli-linux-x64-musl@2.4.10':
|
||||
optional: true
|
||||
|
||||
'@biomejs/cli-linux-x64@2.2.0':
|
||||
'@biomejs/cli-linux-x64@2.4.10':
|
||||
optional: true
|
||||
|
||||
'@biomejs/cli-win32-arm64@2.2.0':
|
||||
'@biomejs/cli-win32-arm64@2.4.10':
|
||||
optional: true
|
||||
|
||||
'@biomejs/cli-win32-x64@2.2.0':
|
||||
'@biomejs/cli-win32-x64@2.4.10':
|
||||
optional: true
|
||||
|
||||
'@cspotcode/source-map-support@0.8.1':
|
||||
|
||||
@@ -381,6 +381,7 @@ func provisionEnv(
|
||||
"CODER_WORKSPACE_BUILD_ID="+metadata.GetWorkspaceBuildId(),
|
||||
"CODER_TASK_ID="+metadata.GetTaskId(),
|
||||
"CODER_TASK_PROMPT="+metadata.GetTaskPrompt(),
|
||||
"AWS_SDK_UA_APP_ID=APN_1.1/pc_cdfmjwn8i6u8l9fwz8h82e4w3$",
|
||||
)
|
||||
if metadata.GetPrebuiltWorkspaceBuildStage().IsPrebuild() {
|
||||
env = append(env, provider.IsPrebuildEnvironmentVariable()+"=true")
|
||||
|
||||
@@ -1298,6 +1298,7 @@ func TestProvision_SafeEnv(t *testing.T) {
|
||||
require.Contains(t, log, passedValue)
|
||||
require.NotContains(t, log, secretValue)
|
||||
require.Contains(t, log, "CODER_")
|
||||
require.Contains(t, log, "AWS_SDK_UA_APP_ID=APN_1.1/pc_cdfmjwn8i6u8l9fwz8h82e4w3$")
|
||||
|
||||
apply := applyComplete.Type.(*proto.Response_Apply)
|
||||
require.NotEmpty(t, apply.Apply.State, "state exists")
|
||||
|
||||
@@ -18,6 +18,7 @@ type SchemaField struct {
|
||||
GoName string `json:"go_name"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Label string `json:"label,omitempty"`
|
||||
Required bool `json:"required"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
InputType string `json:"input_type"`
|
||||
@@ -135,6 +136,7 @@ func extractFields(t reflect.Type, prefix string, skip map[string]bool) FieldGro
|
||||
|
||||
typeName := goTypeToSchemaType(f.Type)
|
||||
description := f.Tag.Get("description")
|
||||
label := f.Tag.Get("label")
|
||||
enumTag := f.Tag.Get("enum")
|
||||
|
||||
var enumValues []string
|
||||
@@ -150,6 +152,7 @@ func extractFields(t reflect.Type, prefix string, skip map[string]bool) FieldGro
|
||||
GoName: goFieldPath(prefix, f.Name, t, fullJSONName),
|
||||
Type: typeName,
|
||||
Description: description,
|
||||
Label: label,
|
||||
Required: required,
|
||||
Enum: enumValues,
|
||||
InputType: inputType,
|
||||
|
||||
+152
-150
@@ -12,162 +12,164 @@ import {
|
||||
} from "../helpers";
|
||||
import { beforeCoderTest, resetExternalAuthKey } from "../hooks";
|
||||
|
||||
test.describe.skip("externalAuth", () => {
|
||||
test.beforeAll(async ({ baseURL }) => {
|
||||
const srv = await createServer(gitAuth.webPort);
|
||||
test.describe
|
||||
.skip("externalAuth", () => {
|
||||
test.beforeAll(async ({ baseURL }) => {
|
||||
const srv = await createServer(gitAuth.webPort);
|
||||
|
||||
// The GitHub validate endpoint returns the currently authenticated user!
|
||||
srv.use(gitAuth.validatePath, (_req, res) => {
|
||||
res.write(JSON.stringify(ghUser));
|
||||
res.end();
|
||||
// The GitHub validate endpoint returns the currently authenticated user!
|
||||
srv.use(gitAuth.validatePath, (_req, res) => {
|
||||
res.write(JSON.stringify(ghUser));
|
||||
res.end();
|
||||
});
|
||||
srv.use(gitAuth.tokenPath, (_req, res) => {
|
||||
const r = (Math.random() + 1).toString(36).substring(7);
|
||||
res.write(JSON.stringify({ access_token: r }));
|
||||
res.end();
|
||||
});
|
||||
srv.use(gitAuth.authPath, (req, res) => {
|
||||
res.redirect(
|
||||
`${baseURL}/external-auth/${gitAuth.webProvider}/callback?code=1234&state=${req.query.state}`,
|
||||
);
|
||||
});
|
||||
});
|
||||
srv.use(gitAuth.tokenPath, (_req, res) => {
|
||||
const r = (Math.random() + 1).toString(36).substring(7);
|
||||
res.write(JSON.stringify({ access_token: r }));
|
||||
res.end();
|
||||
|
||||
test.beforeEach(async ({ context, page }) => {
|
||||
beforeCoderTest(page);
|
||||
await login(page);
|
||||
await resetExternalAuthKey(context);
|
||||
});
|
||||
srv.use(gitAuth.authPath, (req, res) => {
|
||||
res.redirect(
|
||||
`${baseURL}/external-auth/${gitAuth.webProvider}/callback?code=1234&state=${req.query.state}`,
|
||||
|
||||
// Ensures that a Git auth provider with the device flow functions and completes!
|
||||
test("external auth device", async ({ page }) => {
|
||||
const device: ExternalAuthDevice = {
|
||||
device_code: "1234",
|
||||
user_code: "1234-5678",
|
||||
expires_in: 900,
|
||||
interval: 1,
|
||||
verification_uri: "",
|
||||
};
|
||||
|
||||
// Start a server to mock the GitHub API.
|
||||
const srv = await createServer(gitAuth.devicePort);
|
||||
srv.use(gitAuth.validatePath, (_req, res) => {
|
||||
res.write(JSON.stringify(ghUser));
|
||||
res.end();
|
||||
});
|
||||
srv.use(gitAuth.codePath, (_req, res) => {
|
||||
res.write(JSON.stringify(device));
|
||||
res.end();
|
||||
});
|
||||
srv.use(gitAuth.installationsPath, (_req, res) => {
|
||||
res.write(JSON.stringify(ghInstall));
|
||||
res.end();
|
||||
});
|
||||
|
||||
const token = {
|
||||
access_token: "",
|
||||
error: "authorization_pending",
|
||||
error_description: "",
|
||||
};
|
||||
// First we send a result from the API that the token hasn't been
|
||||
// authorized yet to ensure the UI reacts properly.
|
||||
const sentPending = new Awaiter();
|
||||
srv.use(gitAuth.tokenPath, (_req, res) => {
|
||||
res.write(JSON.stringify(token));
|
||||
res.end();
|
||||
sentPending.done();
|
||||
});
|
||||
|
||||
await page.goto(`/external-auth/${gitAuth.deviceProvider}`, {
|
||||
waitUntil: "domcontentloaded",
|
||||
});
|
||||
await page.getByText(device.user_code).isVisible();
|
||||
await sentPending.wait();
|
||||
// Update the token to be valid and ensure the UI updates!
|
||||
token.error = "";
|
||||
token.access_token = "hello-world";
|
||||
await page.waitForSelector("text=1 organization authorized");
|
||||
});
|
||||
|
||||
test("external auth web", async ({ page }) => {
|
||||
await page.goto(`/external-auth/${gitAuth.webProvider}`, {
|
||||
waitUntil: "domcontentloaded",
|
||||
});
|
||||
// This endpoint doesn't have the installations URL set intentionally!
|
||||
await page.waitForSelector("text=You've authenticated with GitHub!");
|
||||
});
|
||||
|
||||
test("successful external auth from workspace", async ({ page }) => {
|
||||
const templateName = await createTemplate(
|
||||
page,
|
||||
echoResponsesWithExternalAuth([
|
||||
{ id: gitAuth.webProvider, optional: false },
|
||||
]),
|
||||
);
|
||||
|
||||
await createWorkspace(page, templateName, { useExternalAuth: true });
|
||||
});
|
||||
});
|
||||
|
||||
test.beforeEach(async ({ context, page }) => {
|
||||
beforeCoderTest(page);
|
||||
await login(page);
|
||||
await resetExternalAuthKey(context);
|
||||
});
|
||||
|
||||
// Ensures that a Git auth provider with the device flow functions and completes!
|
||||
test("external auth device", async ({ page }) => {
|
||||
const device: ExternalAuthDevice = {
|
||||
device_code: "1234",
|
||||
user_code: "1234-5678",
|
||||
expires_in: 900,
|
||||
interval: 1,
|
||||
verification_uri: "",
|
||||
const ghUser: Endpoints["GET /user"]["response"]["data"] = {
|
||||
login: "kylecarbs",
|
||||
id: 7122116,
|
||||
node_id: "MDQ6VXNlcjcxMjIxMTY=",
|
||||
avatar_url: "https://avatars.githubusercontent.com/u/7122116?v=4",
|
||||
gravatar_id: "",
|
||||
url: "https://api.github.com/users/kylecarbs",
|
||||
html_url: "https://github.com/kylecarbs",
|
||||
followers_url: "https://api.github.com/users/kylecarbs/followers",
|
||||
following_url:
|
||||
"https://api.github.com/users/kylecarbs/following{/other_user}",
|
||||
gists_url: "https://api.github.com/users/kylecarbs/gists{/gist_id}",
|
||||
starred_url:
|
||||
"https://api.github.com/users/kylecarbs/starred{/owner}{/repo}",
|
||||
subscriptions_url: "https://api.github.com/users/kylecarbs/subscriptions",
|
||||
organizations_url: "https://api.github.com/users/kylecarbs/orgs",
|
||||
repos_url: "https://api.github.com/users/kylecarbs/repos",
|
||||
events_url: "https://api.github.com/users/kylecarbs/events{/privacy}",
|
||||
received_events_url:
|
||||
"https://api.github.com/users/kylecarbs/received_events",
|
||||
type: "User",
|
||||
site_admin: false,
|
||||
name: "Kyle Carberry",
|
||||
company: "@coder",
|
||||
blog: "https://carberry.com",
|
||||
location: "Austin, TX",
|
||||
email: "kyle@carberry.com",
|
||||
hireable: null,
|
||||
bio: "hey there",
|
||||
twitter_username: "kylecarbs",
|
||||
public_repos: 52,
|
||||
public_gists: 9,
|
||||
followers: 208,
|
||||
following: 31,
|
||||
created_at: "2014-04-01T02:24:41Z",
|
||||
updated_at: "2023-06-26T13:03:09Z",
|
||||
};
|
||||
|
||||
// Start a server to mock the GitHub API.
|
||||
const srv = await createServer(gitAuth.devicePort);
|
||||
srv.use(gitAuth.validatePath, (_req, res) => {
|
||||
res.write(JSON.stringify(ghUser));
|
||||
res.end();
|
||||
});
|
||||
srv.use(gitAuth.codePath, (_req, res) => {
|
||||
res.write(JSON.stringify(device));
|
||||
res.end();
|
||||
});
|
||||
srv.use(gitAuth.installationsPath, (_req, res) => {
|
||||
res.write(JSON.stringify(ghInstall));
|
||||
res.end();
|
||||
});
|
||||
|
||||
const token = {
|
||||
access_token: "",
|
||||
error: "authorization_pending",
|
||||
error_description: "",
|
||||
};
|
||||
// First we send a result from the API that the token hasn't been
|
||||
// authorized yet to ensure the UI reacts properly.
|
||||
const sentPending = new Awaiter();
|
||||
srv.use(gitAuth.tokenPath, (_req, res) => {
|
||||
res.write(JSON.stringify(token));
|
||||
res.end();
|
||||
sentPending.done();
|
||||
});
|
||||
|
||||
await page.goto(`/external-auth/${gitAuth.deviceProvider}`, {
|
||||
waitUntil: "domcontentloaded",
|
||||
});
|
||||
await page.getByText(device.user_code).isVisible();
|
||||
await sentPending.wait();
|
||||
// Update the token to be valid and ensure the UI updates!
|
||||
token.error = "";
|
||||
token.access_token = "hello-world";
|
||||
await page.waitForSelector("text=1 organization authorized");
|
||||
});
|
||||
|
||||
test("external auth web", async ({ page }) => {
|
||||
await page.goto(`/external-auth/${gitAuth.webProvider}`, {
|
||||
waitUntil: "domcontentloaded",
|
||||
});
|
||||
// This endpoint doesn't have the installations URL set intentionally!
|
||||
await page.waitForSelector("text=You've authenticated with GitHub!");
|
||||
});
|
||||
|
||||
test("successful external auth from workspace", async ({ page }) => {
|
||||
const templateName = await createTemplate(
|
||||
page,
|
||||
echoResponsesWithExternalAuth([
|
||||
{ id: gitAuth.webProvider, optional: false },
|
||||
]),
|
||||
);
|
||||
|
||||
await createWorkspace(page, templateName, { useExternalAuth: true });
|
||||
});
|
||||
|
||||
const ghUser: Endpoints["GET /user"]["response"]["data"] = {
|
||||
login: "kylecarbs",
|
||||
id: 7122116,
|
||||
node_id: "MDQ6VXNlcjcxMjIxMTY=",
|
||||
avatar_url: "https://avatars.githubusercontent.com/u/7122116?v=4",
|
||||
gravatar_id: "",
|
||||
url: "https://api.github.com/users/kylecarbs",
|
||||
html_url: "https://github.com/kylecarbs",
|
||||
followers_url: "https://api.github.com/users/kylecarbs/followers",
|
||||
following_url:
|
||||
"https://api.github.com/users/kylecarbs/following{/other_user}",
|
||||
gists_url: "https://api.github.com/users/kylecarbs/gists{/gist_id}",
|
||||
starred_url:
|
||||
"https://api.github.com/users/kylecarbs/starred{/owner}{/repo}",
|
||||
subscriptions_url: "https://api.github.com/users/kylecarbs/subscriptions",
|
||||
organizations_url: "https://api.github.com/users/kylecarbs/orgs",
|
||||
repos_url: "https://api.github.com/users/kylecarbs/repos",
|
||||
events_url: "https://api.github.com/users/kylecarbs/events{/privacy}",
|
||||
received_events_url:
|
||||
"https://api.github.com/users/kylecarbs/received_events",
|
||||
type: "User",
|
||||
site_admin: false,
|
||||
name: "Kyle Carberry",
|
||||
company: "@coder",
|
||||
blog: "https://carberry.com",
|
||||
location: "Austin, TX",
|
||||
email: "kyle@carberry.com",
|
||||
hireable: null,
|
||||
bio: "hey there",
|
||||
twitter_username: "kylecarbs",
|
||||
public_repos: 52,
|
||||
public_gists: 9,
|
||||
followers: 208,
|
||||
following: 31,
|
||||
created_at: "2014-04-01T02:24:41Z",
|
||||
updated_at: "2023-06-26T13:03:09Z",
|
||||
};
|
||||
|
||||
const ghInstall: Endpoints["GET /user/installations"]["response"]["data"] = {
|
||||
installations: [
|
||||
const ghInstall: Endpoints["GET /user/installations"]["response"]["data"] =
|
||||
{
|
||||
id: 1,
|
||||
access_tokens_url: "",
|
||||
account: ghUser,
|
||||
app_id: 1,
|
||||
app_slug: "coder",
|
||||
created_at: "2014-04-01T02:24:41Z",
|
||||
events: [],
|
||||
html_url: "",
|
||||
permissions: {},
|
||||
repositories_url: "",
|
||||
repository_selection: "all",
|
||||
single_file_name: "",
|
||||
suspended_at: null,
|
||||
suspended_by: null,
|
||||
target_id: 1,
|
||||
target_type: "",
|
||||
updated_at: "2023-06-26T13:03:09Z",
|
||||
},
|
||||
],
|
||||
total_count: 1,
|
||||
};
|
||||
});
|
||||
installations: [
|
||||
{
|
||||
id: 1,
|
||||
access_tokens_url: "",
|
||||
account: ghUser,
|
||||
app_id: 1,
|
||||
app_slug: "coder",
|
||||
created_at: "2014-04-01T02:24:41Z",
|
||||
events: [],
|
||||
html_url: "",
|
||||
permissions: {},
|
||||
repositories_url: "",
|
||||
repository_selection: "all",
|
||||
single_file_name: "",
|
||||
suspended_at: null,
|
||||
suspended_by: null,
|
||||
target_id: 1,
|
||||
target_type: "",
|
||||
updated_at: "2023-06-26T13:03:09Z",
|
||||
},
|
||||
],
|
||||
total_count: 1,
|
||||
};
|
||||
});
|
||||
|
||||
+1
-1
@@ -127,7 +127,7 @@
|
||||
"devDependencies": {
|
||||
"@babel/core": "7.29.0",
|
||||
"@babel/plugin-syntax-typescript": "7.28.6",
|
||||
"@biomejs/biome": "2.2.4",
|
||||
"@biomejs/biome": "2.4.10",
|
||||
"@chromatic-com/storybook": "5.0.1",
|
||||
"@octokit/types": "12.6.0",
|
||||
"@playwright/test": "1.50.1",
|
||||
|
||||
Generated
+40
-40
@@ -276,8 +276,8 @@ importers:
|
||||
specifier: 7.28.6
|
||||
version: 7.28.6(@babel/core@7.29.0)
|
||||
'@biomejs/biome':
|
||||
specifier: 2.2.4
|
||||
version: 2.2.4
|
||||
specifier: 2.4.10
|
||||
version: 2.4.10
|
||||
'@chromatic-com/storybook':
|
||||
specifier: 5.0.1
|
||||
version: 5.0.1(storybook@10.3.3(@testing-library/dom@10.4.0)(prettier@3.4.1)(react-dom@19.2.2(react@19.2.2))(react@19.2.2))
|
||||
@@ -469,7 +469,7 @@ importers:
|
||||
version: 8.0.2(@emnapi/core@1.7.1)(@emnapi/runtime@1.7.1)(@types/node@20.19.25)(esbuild@0.25.12)(jiti@1.21.7)(yaml@2.7.0)
|
||||
vite-plugin-checker:
|
||||
specifier: 0.12.0
|
||||
version: 0.12.0(@biomejs/biome@2.2.4)(optionator@0.9.3)(typescript@6.0.2)(vite@8.0.2(@emnapi/core@1.7.1)(@emnapi/runtime@1.7.1)(@types/node@20.19.25)(esbuild@0.25.12)(jiti@1.21.7)(yaml@2.7.0))
|
||||
version: 0.12.0(@biomejs/biome@2.4.10)(optionator@0.9.3)(typescript@6.0.2)(vite@8.0.2(@emnapi/core@1.7.1)(@emnapi/runtime@1.7.1)(@types/node@20.19.25)(esbuild@0.25.12)(jiti@1.21.7)(yaml@2.7.0))
|
||||
vitest:
|
||||
specifier: 4.1.1
|
||||
version: 4.1.1(@types/node@20.19.25)(@vitest/browser-playwright@4.1.1)(jsdom@27.2.0)(msw@2.4.8(typescript@6.0.2))(vite@8.0.2(@emnapi/core@1.7.1)(@emnapi/runtime@1.7.1)(@types/node@20.19.25)(esbuild@0.25.12)(jiti@1.21.7)(yaml@2.7.0))
|
||||
@@ -708,55 +708,55 @@ packages:
|
||||
'@bcoe/v8-coverage@0.2.3':
|
||||
resolution: {integrity: sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==, tarball: https://registry.npmjs.org/@bcoe/v8-coverage/-/v8-coverage-0.2.3.tgz}
|
||||
|
||||
'@biomejs/biome@2.2.4':
|
||||
resolution: {integrity: sha512-TBHU5bUy/Ok6m8c0y3pZiuO/BZoY/OcGxoLlrfQof5s8ISVwbVBdFINPQZyFfKwil8XibYWb7JMwnT8wT4WVPg==, tarball: https://registry.npmjs.org/@biomejs/biome/-/biome-2.2.4.tgz}
|
||||
'@biomejs/biome@2.4.10':
|
||||
resolution: {integrity: sha512-xxA3AphFQ1geij4JTHXv4EeSTda1IFn22ye9LdyVPoJU19fNVl0uzfEuhsfQ4Yue/0FaLs2/ccVi4UDiE7R30w==, tarball: https://registry.npmjs.org/@biomejs/biome/-/biome-2.4.10.tgz}
|
||||
engines: {node: '>=14.21.3'}
|
||||
hasBin: true
|
||||
|
||||
'@biomejs/cli-darwin-arm64@2.2.4':
|
||||
resolution: {integrity: sha512-RJe2uiyaloN4hne4d2+qVj3d3gFJFbmrr5PYtkkjei1O9c+BjGXgpUPVbi8Pl8syumhzJjFsSIYkcLt2VlVLMA==, tarball: https://registry.npmjs.org/@biomejs/cli-darwin-arm64/-/cli-darwin-arm64-2.2.4.tgz}
|
||||
'@biomejs/cli-darwin-arm64@2.4.10':
|
||||
resolution: {integrity: sha512-vuzzI1cWqDVzOMIkYyHbKqp+AkQq4K7k+UCXWpkYcY/HDn1UxdsbsfgtVpa40shem8Kax4TLDLlx8kMAecgqiw==, tarball: https://registry.npmjs.org/@biomejs/cli-darwin-arm64/-/cli-darwin-arm64-2.4.10.tgz}
|
||||
engines: {node: '>=14.21.3'}
|
||||
cpu: [arm64]
|
||||
os: [darwin]
|
||||
|
||||
'@biomejs/cli-darwin-x64@2.2.4':
|
||||
resolution: {integrity: sha512-cFsdB4ePanVWfTnPVaUX+yr8qV8ifxjBKMkZwN7gKb20qXPxd/PmwqUH8mY5wnM9+U0QwM76CxFyBRJhC9tQwg==, tarball: https://registry.npmjs.org/@biomejs/cli-darwin-x64/-/cli-darwin-x64-2.2.4.tgz}
|
||||
'@biomejs/cli-darwin-x64@2.4.10':
|
||||
resolution: {integrity: sha512-14fzASRo+BPotwp7nWULy2W5xeUyFnTaq1V13Etrrxkrih+ez/2QfgFm5Ehtf5vSjtgx/IJycMMpn5kPd5ZNaA==, tarball: https://registry.npmjs.org/@biomejs/cli-darwin-x64/-/cli-darwin-x64-2.4.10.tgz}
|
||||
engines: {node: '>=14.21.3'}
|
||||
cpu: [x64]
|
||||
os: [darwin]
|
||||
|
||||
'@biomejs/cli-linux-arm64-musl@2.2.4':
|
||||
resolution: {integrity: sha512-7TNPkMQEWfjvJDaZRSkDCPT/2r5ESFPKx+TEev+I2BXDGIjfCZk2+b88FOhnJNHtksbOZv8ZWnxrA5gyTYhSsQ==, tarball: https://registry.npmjs.org/@biomejs/cli-linux-arm64-musl/-/cli-linux-arm64-musl-2.2.4.tgz}
|
||||
'@biomejs/cli-linux-arm64-musl@2.4.10':
|
||||
resolution: {integrity: sha512-WrJY6UuiSD/Dh+nwK2qOTu8kdMDlLV3dLMmychIghHPAysWFq1/DGC1pVZx8POE3ZkzKR3PUUnVrtZfMfaJjyQ==, tarball: https://registry.npmjs.org/@biomejs/cli-linux-arm64-musl/-/cli-linux-arm64-musl-2.4.10.tgz}
|
||||
engines: {node: '>=14.21.3'}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
|
||||
'@biomejs/cli-linux-arm64@2.2.4':
|
||||
resolution: {integrity: sha512-M/Iz48p4NAzMXOuH+tsn5BvG/Jb07KOMTdSVwJpicmhN309BeEyRyQX+n1XDF0JVSlu28+hiTQ2L4rZPvu7nMw==, tarball: https://registry.npmjs.org/@biomejs/cli-linux-arm64/-/cli-linux-arm64-2.2.4.tgz}
|
||||
'@biomejs/cli-linux-arm64@2.4.10':
|
||||
resolution: {integrity: sha512-7MH1CMW5uuxQ/s7FLST63qF8B3Hgu2HRdZ7tA1X1+mk+St4JOuIrqdhIBnnyqeyWJNI+Bww7Es5QZ0wIc1Cmkw==, tarball: https://registry.npmjs.org/@biomejs/cli-linux-arm64/-/cli-linux-arm64-2.4.10.tgz}
|
||||
engines: {node: '>=14.21.3'}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
|
||||
'@biomejs/cli-linux-x64-musl@2.2.4':
|
||||
resolution: {integrity: sha512-m41nFDS0ksXK2gwXL6W6yZTYPMH0LughqbsxInSKetoH6morVj43szqKx79Iudkp8WRT5SxSh7qVb8KCUiewGg==, tarball: https://registry.npmjs.org/@biomejs/cli-linux-x64-musl/-/cli-linux-x64-musl-2.2.4.tgz}
|
||||
'@biomejs/cli-linux-x64-musl@2.4.10':
|
||||
resolution: {integrity: sha512-kDTi3pI6PBN6CiczsWYOyP2zk0IJI08EWEQyDMQWW221rPaaEz6FvjLhnU07KMzLv8q3qSuoB93ua6inSQ55Tw==, tarball: https://registry.npmjs.org/@biomejs/cli-linux-x64-musl/-/cli-linux-x64-musl-2.4.10.tgz}
|
||||
engines: {node: '>=14.21.3'}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
|
||||
'@biomejs/cli-linux-x64@2.2.4':
|
||||
resolution: {integrity: sha512-orr3nnf2Dpb2ssl6aihQtvcKtLySLta4E2UcXdp7+RTa7mfJjBgIsbS0B9GC8gVu0hjOu021aU8b3/I1tn+pVQ==, tarball: https://registry.npmjs.org/@biomejs/cli-linux-x64/-/cli-linux-x64-2.2.4.tgz}
|
||||
'@biomejs/cli-linux-x64@2.4.10':
|
||||
resolution: {integrity: sha512-tZLvEEi2u9Xu1zAqRjTcpIDGVtldigVvzug2fTuPG0ME/g8/mXpRPcNgLB22bGn6FvLJpHHnqLnwliOu8xjYrg==, tarball: https://registry.npmjs.org/@biomejs/cli-linux-x64/-/cli-linux-x64-2.4.10.tgz}
|
||||
engines: {node: '>=14.21.3'}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
|
||||
'@biomejs/cli-win32-arm64@2.2.4':
|
||||
resolution: {integrity: sha512-NXnfTeKHDFUWfxAefa57DiGmu9VyKi0cDqFpdI+1hJWQjGJhJutHPX0b5m+eXvTKOaf+brU+P0JrQAZMb5yYaQ==, tarball: https://registry.npmjs.org/@biomejs/cli-win32-arm64/-/cli-win32-arm64-2.2.4.tgz}
|
||||
'@biomejs/cli-win32-arm64@2.4.10':
|
||||
resolution: {integrity: sha512-umwQU6qPzH+ISTf/eHyJ/QoQnJs3V9Vpjz2OjZXe9MVBZ7prgGafMy7yYeRGnlmDAn87AKTF3Q6weLoMGpeqdQ==, tarball: https://registry.npmjs.org/@biomejs/cli-win32-arm64/-/cli-win32-arm64-2.4.10.tgz}
|
||||
engines: {node: '>=14.21.3'}
|
||||
cpu: [arm64]
|
||||
os: [win32]
|
||||
|
||||
'@biomejs/cli-win32-x64@2.2.4':
|
||||
resolution: {integrity: sha512-3Y4V4zVRarVh/B/eSHczR4LYoSVyv3Dfuvm3cWs5w/HScccS0+Wt/lHOcDTRYeHjQmMYVC3rIRWqyN2EI52+zg==, tarball: https://registry.npmjs.org/@biomejs/cli-win32-x64/-/cli-win32-x64-2.2.4.tgz}
|
||||
'@biomejs/cli-win32-x64@2.4.10':
|
||||
resolution: {integrity: sha512-aW/JU5GuyH4uxMrNYpoC2kjaHlyJGLgIa3XkhPEZI0uKhZhJZU8BuEyJmvgzSPQNGozBwWjC972RaNdcJ9KyJg==, tarball: https://registry.npmjs.org/@biomejs/cli-win32-x64/-/cli-win32-x64-2.4.10.tgz}
|
||||
engines: {node: '>=14.21.3'}
|
||||
cpu: [x64]
|
||||
os: [win32]
|
||||
@@ -7697,39 +7697,39 @@ snapshots:
|
||||
|
||||
'@bcoe/v8-coverage@0.2.3': {}
|
||||
|
||||
'@biomejs/biome@2.2.4':
|
||||
'@biomejs/biome@2.4.10':
|
||||
optionalDependencies:
|
||||
'@biomejs/cli-darwin-arm64': 2.2.4
|
||||
'@biomejs/cli-darwin-x64': 2.2.4
|
||||
'@biomejs/cli-linux-arm64': 2.2.4
|
||||
'@biomejs/cli-linux-arm64-musl': 2.2.4
|
||||
'@biomejs/cli-linux-x64': 2.2.4
|
||||
'@biomejs/cli-linux-x64-musl': 2.2.4
|
||||
'@biomejs/cli-win32-arm64': 2.2.4
|
||||
'@biomejs/cli-win32-x64': 2.2.4
|
||||
'@biomejs/cli-darwin-arm64': 2.4.10
|
||||
'@biomejs/cli-darwin-x64': 2.4.10
|
||||
'@biomejs/cli-linux-arm64': 2.4.10
|
||||
'@biomejs/cli-linux-arm64-musl': 2.4.10
|
||||
'@biomejs/cli-linux-x64': 2.4.10
|
||||
'@biomejs/cli-linux-x64-musl': 2.4.10
|
||||
'@biomejs/cli-win32-arm64': 2.4.10
|
||||
'@biomejs/cli-win32-x64': 2.4.10
|
||||
|
||||
'@biomejs/cli-darwin-arm64@2.2.4':
|
||||
'@biomejs/cli-darwin-arm64@2.4.10':
|
||||
optional: true
|
||||
|
||||
'@biomejs/cli-darwin-x64@2.2.4':
|
||||
'@biomejs/cli-darwin-x64@2.4.10':
|
||||
optional: true
|
||||
|
||||
'@biomejs/cli-linux-arm64-musl@2.2.4':
|
||||
'@biomejs/cli-linux-arm64-musl@2.4.10':
|
||||
optional: true
|
||||
|
||||
'@biomejs/cli-linux-arm64@2.2.4':
|
||||
'@biomejs/cli-linux-arm64@2.4.10':
|
||||
optional: true
|
||||
|
||||
'@biomejs/cli-linux-x64-musl@2.2.4':
|
||||
'@biomejs/cli-linux-x64-musl@2.4.10':
|
||||
optional: true
|
||||
|
||||
'@biomejs/cli-linux-x64@2.2.4':
|
||||
'@biomejs/cli-linux-x64@2.4.10':
|
||||
optional: true
|
||||
|
||||
'@biomejs/cli-win32-arm64@2.2.4':
|
||||
'@biomejs/cli-win32-arm64@2.4.10':
|
||||
optional: true
|
||||
|
||||
'@biomejs/cli-win32-x64@2.2.4':
|
||||
'@biomejs/cli-win32-x64@2.4.10':
|
||||
optional: true
|
||||
|
||||
'@blazediff/core@1.9.1': {}
|
||||
@@ -15034,7 +15034,7 @@ snapshots:
|
||||
d3-time: 3.1.0
|
||||
d3-timer: 3.0.1
|
||||
|
||||
vite-plugin-checker@0.12.0(@biomejs/biome@2.2.4)(optionator@0.9.3)(typescript@6.0.2)(vite@8.0.2(@emnapi/core@1.7.1)(@emnapi/runtime@1.7.1)(@types/node@20.19.25)(esbuild@0.25.12)(jiti@1.21.7)(yaml@2.7.0)):
|
||||
vite-plugin-checker@0.12.0(@biomejs/biome@2.4.10)(optionator@0.9.3)(typescript@6.0.2)(vite@8.0.2(@emnapi/core@1.7.1)(@emnapi/runtime@1.7.1)(@types/node@20.19.25)(esbuild@0.25.12)(jiti@1.21.7)(yaml@2.7.0)):
|
||||
dependencies:
|
||||
'@babel/code-frame': 7.29.0
|
||||
chokidar: 4.0.3
|
||||
@@ -15046,7 +15046,7 @@ snapshots:
|
||||
vite: 8.0.2(@emnapi/core@1.7.1)(@emnapi/runtime@1.7.1)(@types/node@20.19.25)(esbuild@0.25.12)(jiti@1.21.7)(yaml@2.7.0)
|
||||
vscode-uri: 3.1.0
|
||||
optionalDependencies:
|
||||
'@biomejs/biome': 2.2.4
|
||||
'@biomejs/biome': 2.4.10
|
||||
optionator: 0.9.3
|
||||
typescript: 6.0.2
|
||||
|
||||
|
||||
@@ -147,12 +147,9 @@ describe("api.ts", () => {
|
||||
{ q: "owner:me" },
|
||||
"/api/v2/workspaces?q=owner%3Ame",
|
||||
],
|
||||
])(
|
||||
"Workspaces - getURLWithSearchParams(%p, %p) returns %p",
|
||||
(basePath, filter, expected) => {
|
||||
expect(getURLWithSearchParams(basePath, filter)).toBe(expected);
|
||||
},
|
||||
);
|
||||
])("Workspaces - getURLWithSearchParams(%p, %p) returns %p", (basePath, filter, expected) => {
|
||||
expect(getURLWithSearchParams(basePath, filter)).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getURLWithSearchParams - users", () => {
|
||||
@@ -164,12 +161,9 @@ describe("api.ts", () => {
|
||||
"/api/v2/users?q=status%3Aactive",
|
||||
],
|
||||
["/api/v2/users", { q: "" }, "/api/v2/users"],
|
||||
])(
|
||||
"Users - getURLWithSearchParams(%p, %p) returns %p",
|
||||
(basePath, filter, expected) => {
|
||||
expect(getURLWithSearchParams(basePath, filter)).toBe(expected);
|
||||
},
|
||||
);
|
||||
])("Users - getURLWithSearchParams(%p, %p) returns %p", (basePath, filter, expected) => {
|
||||
expect(getURLWithSearchParams(basePath, filter)).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
describe("update", () => {
|
||||
|
||||
+2
-3
@@ -3170,14 +3170,13 @@ class ExperimentalApiMethods {
|
||||
chatId: string,
|
||||
messageId: number,
|
||||
req: TypesGen.EditChatMessageRequest,
|
||||
): Promise<TypesGen.ChatMessage> => {
|
||||
const response = await this.axios.patch<TypesGen.ChatMessage>(
|
||||
): Promise<TypesGen.EditChatMessageResponse> => {
|
||||
const response = await this.axios.patch<TypesGen.EditChatMessageResponse>(
|
||||
`/api/experimental/chats/${chatId}/messages/${messageId}`,
|
||||
req,
|
||||
);
|
||||
return response.data;
|
||||
};
|
||||
|
||||
interruptChat = async (chatId: string): Promise<TypesGen.Chat> => {
|
||||
const response = await this.axios.post<TypesGen.Chat>(
|
||||
`/api/experimental/chats/${chatId}/interrupt`,
|
||||
|
||||
@@ -13,6 +13,8 @@ export interface FieldSchema {
|
||||
type: "string" | "integer" | "number" | "boolean" | "array" | "object";
|
||||
/** Human-readable description of the field. May be absent for some fields. */
|
||||
description?: string;
|
||||
/** Optional display label override. When absent, derive from json_name. */
|
||||
label?: string;
|
||||
/** Whether this field is required when configuring the provider. */
|
||||
required: boolean;
|
||||
/** Hint for how the frontend should render the input control. */
|
||||
|
||||
@@ -107,6 +107,7 @@
|
||||
"go_name": "Effort",
|
||||
"type": "string",
|
||||
"description": "Controls the level of reasoning effort",
|
||||
"label": "Reasoning Effort",
|
||||
"required": false,
|
||||
"enum": ["low", "medium", "high", "max"],
|
||||
"input_type": "select"
|
||||
@@ -132,6 +133,7 @@
|
||||
"go_name": "AllowedDomains",
|
||||
"type": "array",
|
||||
"description": "Restrict web search to these domains (cannot be used with blocked_domains)",
|
||||
"label": "Web Search: Allowed Domains",
|
||||
"required": false,
|
||||
"input_type": "json"
|
||||
},
|
||||
@@ -140,6 +142,7 @@
|
||||
"go_name": "BlockedDomains",
|
||||
"type": "array",
|
||||
"description": "Block web search on these domains (cannot be used with allowed_domains)",
|
||||
"label": "Web Search: Blocked Domains",
|
||||
"required": false,
|
||||
"input_type": "json"
|
||||
}
|
||||
@@ -286,7 +289,8 @@
|
||||
"type": "string",
|
||||
"description": "Controls whether reasoning tokens are summarized in the response",
|
||||
"required": false,
|
||||
"input_type": "input"
|
||||
"enum": ["auto", "concise", "detailed"],
|
||||
"input_type": "select"
|
||||
},
|
||||
{
|
||||
"json_name": "max_completion_tokens",
|
||||
@@ -354,7 +358,8 @@
|
||||
"type": "string",
|
||||
"description": "Latency tier to use for processing the request",
|
||||
"required": false,
|
||||
"input_type": "input"
|
||||
"enum": ["auto", "default", "flex", "scale", "priority"],
|
||||
"input_type": "select"
|
||||
},
|
||||
{
|
||||
"json_name": "structured_outputs",
|
||||
@@ -396,6 +401,7 @@
|
||||
"go_name": "AllowedDomains",
|
||||
"type": "array",
|
||||
"description": "Restrict web search to these domains",
|
||||
"label": "Web Search: Allowed Domains",
|
||||
"required": false,
|
||||
"input_type": "json"
|
||||
}
|
||||
|
||||
Generated
+51
-8
@@ -913,6 +913,7 @@ export interface AuditLog {
|
||||
export interface AuditLogResponse {
|
||||
readonly audit_logs: readonly AuditLog[];
|
||||
readonly count: number;
|
||||
readonly count_cap: number;
|
||||
}
|
||||
|
||||
// From codersdk/audit.go
|
||||
@@ -1200,6 +1201,7 @@ export interface Chat {
|
||||
readonly pin_order: number;
|
||||
readonly mcp_server_ids: readonly string[];
|
||||
readonly labels: Record<string, string>;
|
||||
readonly files?: readonly ChatFileMetadata[];
|
||||
/**
|
||||
* HasUnread is true when assistant messages exist beyond
|
||||
* the owner's read cursor, which updates on stream
|
||||
@@ -1213,8 +1215,14 @@ export interface Chat {
|
||||
* attach or agent change.
|
||||
*/
|
||||
readonly last_injected_context?: readonly ChatMessagePart[];
|
||||
readonly warnings?: readonly string[];
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
export type ChatBusyBehavior = "interrupt" | "queue";
|
||||
|
||||
export const ChatBusyBehaviors: ChatBusyBehavior[] = ["interrupt", "queue"];
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatCompactionThresholdKeyPrefix scopes per-model chat compaction
|
||||
@@ -1263,6 +1271,7 @@ export interface ChatCostChatBreakdown {
|
||||
readonly total_output_tokens: number;
|
||||
readonly total_cache_read_tokens: number;
|
||||
readonly total_cache_creation_tokens: number;
|
||||
readonly total_runtime_ms: number;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
@@ -1280,6 +1289,7 @@ export interface ChatCostModelBreakdown {
|
||||
readonly total_output_tokens: number;
|
||||
readonly total_cache_read_tokens: number;
|
||||
readonly total_cache_creation_tokens: number;
|
||||
readonly total_runtime_ms: number;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
@@ -1296,6 +1306,7 @@ export interface ChatCostSummary {
|
||||
readonly total_output_tokens: number;
|
||||
readonly total_cache_read_tokens: number;
|
||||
readonly total_cache_creation_tokens: number;
|
||||
readonly total_runtime_ms: number;
|
||||
readonly by_model: readonly ChatCostModelBreakdown[];
|
||||
readonly by_chat: readonly ChatCostChatBreakdown[];
|
||||
readonly usage_limit?: ChatUsageLimitStatus;
|
||||
@@ -1326,6 +1337,7 @@ export interface ChatCostUserRollup {
|
||||
readonly total_output_tokens: number;
|
||||
readonly total_cache_read_tokens: number;
|
||||
readonly total_cache_creation_tokens: number;
|
||||
readonly total_runtime_ms: number;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
@@ -1398,6 +1410,20 @@ export interface ChatDiffStatus {
|
||||
readonly stale_at?: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatFileMetadata contains lightweight metadata about a file
|
||||
* associated with a chat, excluding the file content itself.
|
||||
*/
|
||||
export interface ChatFileMetadata {
|
||||
readonly id: string;
|
||||
readonly owner_id: string;
|
||||
readonly organization_id: string;
|
||||
readonly name: string;
|
||||
readonly mime_type: string;
|
||||
readonly created_at: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
export interface ChatFilePart {
|
||||
readonly type: "file";
|
||||
@@ -1451,21 +1477,14 @@ export interface ChatInputPart {
|
||||
* The code content from the diff that was commented on.
|
||||
*/
|
||||
readonly content?: string;
|
||||
/**
|
||||
* The following fields are only set when Type is
|
||||
* ChatInputPartTypeSkill.
|
||||
*/
|
||||
readonly skill_name?: string;
|
||||
readonly skill_description?: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
export type ChatInputPartType = "file" | "file-reference" | "skill" | "text";
|
||||
export type ChatInputPartType = "file" | "file-reference" | "text";
|
||||
|
||||
export const ChatInputPartTypes: ChatInputPartType[] = [
|
||||
"file",
|
||||
"file-reference",
|
||||
"skill",
|
||||
"text",
|
||||
];
|
||||
|
||||
@@ -2251,6 +2270,7 @@ export interface ConnectionLog {
|
||||
export interface ConnectionLogResponse {
|
||||
readonly connection_logs: readonly ConnectionLog[];
|
||||
readonly count: number;
|
||||
readonly count_cap: number;
|
||||
}
|
||||
|
||||
// From codersdk/connectionlog.go
|
||||
@@ -2341,6 +2361,7 @@ export interface CreateChatMessageRequest {
|
||||
readonly content: readonly ChatInputPart[];
|
||||
readonly model_config_id?: string;
|
||||
readonly mcp_server_ids?: string[];
|
||||
readonly busy_behavior?: ChatBusyBehavior;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
@@ -2351,6 +2372,7 @@ export interface CreateChatMessageResponse {
|
||||
readonly message?: ChatMessage;
|
||||
readonly queued_message?: ChatQueuedMessage;
|
||||
readonly queued: boolean;
|
||||
readonly warnings?: readonly string[];
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
@@ -2389,6 +2411,7 @@ export interface CreateChatProviderConfigRequest {
|
||||
*/
|
||||
export interface CreateChatRequest {
|
||||
readonly content: readonly ChatInputPart[];
|
||||
readonly system_prompt?: string;
|
||||
readonly workspace_id?: string;
|
||||
readonly model_config_id?: string;
|
||||
readonly mcp_server_ids?: readonly string[];
|
||||
@@ -3201,6 +3224,17 @@ export interface EditChatMessageRequest {
|
||||
readonly content: readonly ChatInputPart[];
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* EditChatMessageResponse is the response from editing a message in a chat.
|
||||
* Edits are always synchronous (no queueing), so the message is returned
|
||||
* directly.
|
||||
*/
|
||||
export interface EditChatMessageResponse {
|
||||
readonly message: ChatMessage;
|
||||
readonly warnings?: readonly string[];
|
||||
}
|
||||
|
||||
// From codersdk/externalauth.go
|
||||
export type EnhancedExternalAuthProvider =
|
||||
| "azure-devops"
|
||||
@@ -4108,6 +4142,15 @@ export interface MatchedProvisioners {
|
||||
readonly most_recently_seen?: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* MaxChatFileIDs is the maximum number of file IDs that can be
|
||||
* associated with a single chat. This limit prevents unbounded
|
||||
* growth in the chat_file_links table. It is easier to raise
|
||||
* this limit than to lower it.
|
||||
*/
|
||||
export const MaxChatFileIDs = 20;
|
||||
|
||||
// From codersdk/organizations.go
|
||||
export interface MinimalOrganization {
|
||||
readonly id: string;
|
||||
|
||||
@@ -10,4 +10,4 @@ const CollapsibleTrigger = CollapsiblePrimitive.CollapsibleTrigger;
|
||||
|
||||
const CollapsibleContent = CollapsiblePrimitive.CollapsibleContent;
|
||||
|
||||
export { Collapsible, CollapsibleTrigger, CollapsibleContent };
|
||||
export { Collapsible, CollapsibleContent, CollapsibleTrigger };
|
||||
|
||||
@@ -12,11 +12,13 @@ import { useClipboard } from "#/hooks/useClipboard";
|
||||
type CopyButtonProps = ButtonProps & {
|
||||
text: string;
|
||||
label: string;
|
||||
tooltipSide?: "top" | "bottom" | "left" | "right";
|
||||
};
|
||||
|
||||
export const CopyButton: FC<CopyButtonProps> = ({
|
||||
text,
|
||||
label,
|
||||
tooltipSide,
|
||||
...buttonProps
|
||||
}) => {
|
||||
const { showCopiedSuccess, copyToClipboard } = useClipboard();
|
||||
@@ -34,7 +36,7 @@ export const CopyButton: FC<CopyButtonProps> = ({
|
||||
<span className="sr-only">{label}</span>
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>{label}</TooltipContent>
|
||||
<TooltipContent side={tooltipSide}>{label}</TooltipContent>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -67,4 +67,4 @@ export const DialogActionButtons: FC<DialogActionButtonsProps> = ({
|
||||
* Re-export of MUI's Dialog component, for convenience.
|
||||
* @link See original documentation here: https://mui.com/material-ui/react-dialog/
|
||||
*/
|
||||
export { MuiDialog as Dialog, type DialogProps };
|
||||
export { type DialogProps, MuiDialog as Dialog };
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { css, type Interpolation, type Theme } from "@emotion/react";
|
||||
import CircularProgress from "@mui/material/CircularProgress";
|
||||
import IconButton from "@mui/material/IconButton";
|
||||
import { CloudUploadIcon, FolderIcon, TrashIcon } from "lucide-react";
|
||||
import { type DragEvent, type FC, type ReactNode, useRef } from "react";
|
||||
import { Button } from "#/components/Button/Button";
|
||||
import { Stack } from "#/components/Stack/Stack";
|
||||
import { useClickable } from "#/hooks/useClickable";
|
||||
|
||||
@@ -46,9 +46,14 @@ export const FileUpload: FC<FileUploadProps> = ({
|
||||
<span>{file.name}</span>
|
||||
</Stack>
|
||||
|
||||
<IconButton title={removeLabel} size="small" onClick={onRemove}>
|
||||
<Button
|
||||
variant="subtle"
|
||||
size="icon-lg"
|
||||
onClick={onRemove}
|
||||
title={removeLabel}
|
||||
>
|
||||
<TrashIcon className="size-icon-sm" />
|
||||
</IconButton>
|
||||
</Button>
|
||||
</Stack>
|
||||
);
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user