Compare commits
84 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cfd7730194 | |||
| 1937ada0cd | |||
| d64cd6415d | |||
| c1851d9453 | |||
| 8f73453681 | |||
| 165db3d31c | |||
| 1bd1516fd1 | |||
| 81ba35a987 | |||
| 53d63cf8e9 | |||
| 4213a43b53 | |||
| 5453a6c6d6 | |||
| 21c08a37d7 | |||
| 2bd261fbbf | |||
| cffc68df58 | |||
| 6e5335df1e | |||
| 16265e834e | |||
| 565a15bc9b | |||
| 76a2cb1af5 | |||
| 684f21740d | |||
| 86ca61d6ca | |||
| f0521cfa3c | |||
| 0c5d189aff | |||
| d7c8213eee | |||
| 63924ac687 | |||
| 6c47e9ea23 | |||
| aede045549 | |||
| 2ea08aa168 | |||
| d4b9248202 | |||
| fd6c623560 | |||
| 99da498679 | |||
| a20b817c28 | |||
| d5a1792f07 | |||
| beb99c17de | |||
| 8913f9f5c1 | |||
| acd5f01b4b | |||
| 6c62d8f5e6 | |||
| 5000f15021 | |||
| 44be5a0d1e | |||
| 3ca2aae9ca | |||
| 01080302a5 | |||
| 61d6c728b9 | |||
| 648787e739 | |||
| d2950e7615 | |||
| df8f695e84 | |||
| 8bb48ffdda | |||
| 4cfbf544a0 | |||
| a2ce74f398 | |||
| 0060dee222 | |||
| 5ff1058f30 | |||
| 500fc5e2a4 | |||
| baba9e6ede | |||
| b36619b905 | |||
| 937f50f0ae | |||
| a16755dd66 | |||
| 8bdc35f91f | |||
| 5b32c4d79d | |||
| 8625543413 | |||
| e18094825a | |||
| 919dc299fc | |||
| 7e63fe68f7 | |||
| a1d51f0dab | |||
| 333503f74e | |||
| 01b8cdb00d | |||
| ec83065b59 | |||
| 2a1bef18e0 | |||
| 8369fa88fd | |||
| da3c46b557 | |||
| 53482adc2d | |||
| aa0e288b88 | |||
| 1c4a9ed745 | |||
| 5b6b7719df | |||
| 990c006f28 | |||
| 0cb942aab2 | |||
| b0a6802d12 | |||
| 8d08885792 | |||
| f68161350a | |||
| 9ac67a5253 | |||
| 7d0a0c6495 | |||
| 17dec2a70f | |||
| f796f3645f | |||
| d5ed51a190 | |||
| 796e8e4e18 | |||
| 5b28548d1c | |||
| b5da77ff55 |
+17
-17
@@ -35,7 +35,7 @@ jobs:
|
||||
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -157,7 +157,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -247,7 +247,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -272,7 +272,7 @@ jobs:
|
||||
if: ${{ !cancelled() }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -327,7 +327,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -379,7 +379,7 @@ jobs:
|
||||
- windows-2022
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -575,7 +575,7 @@ jobs:
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -637,7 +637,7 @@ jobs:
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -709,7 +709,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -736,7 +736,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -769,7 +769,7 @@ jobs:
|
||||
name: ${{ matrix.variant.name }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -849,7 +849,7 @@ jobs:
|
||||
if: needs.changes.outputs.site == 'true' || needs.changes.outputs.ci == 'true'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -930,7 +930,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1005,7 +1005,7 @@ jobs:
|
||||
if: always()
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1043,7 +1043,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1097,7 +1097,7 @@ jobs:
|
||||
IMAGE: ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1479,7 +1479,7 @@ jobs:
|
||||
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ jobs:
|
||||
steps:
|
||||
- name: Dependabot metadata
|
||||
id: metadata
|
||||
uses: dependabot/fetch-metadata@21025c705c08248db411dc16f3619e6b5f9ea21a # v2.5.0
|
||||
uses: dependabot/fetch-metadata@ffa630c65fa7e0ecfa0625b5ceda64399aea1b36 # v3.0.0
|
||||
with:
|
||||
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
verdict: ${{ steps.check.outputs.verdict }} # DEPLOY or NOOP
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -65,7 +65,7 @@ jobs:
|
||||
packages: write # to retag image as dogfood
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -142,7 +142,7 @@ jobs:
|
||||
needs: deploy
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
if: github.repository_owner == 'coder'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -125,7 +125,7 @@ jobs:
|
||||
id-token: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
- windows-2022
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ jobs:
|
||||
packages: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ jobs:
|
||||
PR_OPEN: ${{ steps.check_pr.outputs.pr_open }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -76,7 +76,7 @@ jobs:
|
||||
runs-on: "ubuntu-latest"
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -184,7 +184,7 @@ jobs:
|
||||
pull-requests: write # needed for commenting on PRs
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -228,7 +228,7 @@ jobs:
|
||||
CODER_IMAGE_TAG: ${{ needs.get_info.outputs.CODER_IMAGE_TAG }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -288,7 +288,7 @@ jobs:
|
||||
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ jobs:
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -673,7 +673,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -749,7 +749,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -47,6 +47,6 @@ jobs:
|
||||
|
||||
# Upload the results to GitHub's code scanning dashboard.
|
||||
- name: "Upload to code-scanning"
|
||||
uses: github/codeql-action/upload-sarif@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
|
||||
uses: github/codeql-action/upload-sarif@c10b8064de6f491fea524254123dbe5e09572f13 # v3.29.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
@@ -27,7 +27,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-go
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
|
||||
uses: github/codeql-action/init@c10b8064de6f491fea524254123dbe5e09572f13 # v3.29.5
|
||||
with:
|
||||
languages: go, javascript
|
||||
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
rm Makefile
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
|
||||
uses: github/codeql-action/analyze@c10b8064de6f491fea524254123dbe5e09572f13 # v3.29.5
|
||||
|
||||
- name: Send Slack notification on failure
|
||||
if: ${{ failure() }}
|
||||
|
||||
@@ -18,7 +18,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -96,7 +96,7 @@ jobs:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -120,7 +120,7 @@ jobs:
|
||||
actions: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ jobs:
|
||||
pull-requests: write # required to post PR review comments by the action
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -110,6 +110,9 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
|
||||
- For experimental or unstable API paths, skip public doc generation with
|
||||
`// @x-apidocgen {"skip": true}` after the `@Router` annotation. This
|
||||
keeps them out of the published API reference until they stabilize.
|
||||
- Experimental chat endpoints in `coderd/exp_chats.go` omit swagger
|
||||
annotations entirely. Do not add `@Summary`, `@Router`, or other
|
||||
swagger comments to handlers in that file.
|
||||
|
||||
### Database Query Naming
|
||||
|
||||
|
||||
+2
-2
@@ -398,7 +398,7 @@ func (a *agent) init() {
|
||||
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
|
||||
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
|
||||
desktop := agentdesktop.NewPortableDesktop(
|
||||
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(),
|
||||
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(), nil,
|
||||
)
|
||||
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
|
||||
a.mcpManager = agentmcp.NewManager(a.logger.Named("mcp"))
|
||||
@@ -1366,7 +1366,7 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context,
|
||||
// lifecycle transition to avoid delaying Ready.
|
||||
// This runs inside the tracked goroutine so it
|
||||
// is properly awaited on shutdown.
|
||||
if mcpErr := a.mcpManager.Connect(a.gracefulCtx, a.contextConfigAPI.Config().MCPConfigFiles); mcpErr != nil {
|
||||
if mcpErr := a.mcpManager.Connect(a.gracefulCtx, a.contextConfigAPI.MCPConfigFiles()); mcpErr != nil {
|
||||
a.logger.Warn(ctx, "failed to connect to workspace MCP servers", slog.Error(mcpErr))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -83,14 +83,14 @@ func TestContextConfigAPI_InitOnce(t *testing.T) {
|
||||
return ""
|
||||
})
|
||||
|
||||
cfg1 := a.contextConfigAPI.Config()
|
||||
require.NotEmpty(t, cfg1.MCPConfigFiles)
|
||||
require.Contains(t, cfg1.MCPConfigFiles[0], dir1)
|
||||
mcpFiles1 := a.contextConfigAPI.MCPConfigFiles()
|
||||
require.NotEmpty(t, mcpFiles1)
|
||||
require.Contains(t, mcpFiles1[0], dir1)
|
||||
|
||||
// Simulate manifest update on reconnection — no field
|
||||
// Simulate manifest update on reconnection -- no field
|
||||
// reassignment needed, the lazy closure picks it up.
|
||||
a.manifest.Store(&agentsdk.Manifest{Directory: dir2})
|
||||
cfg2 := a.contextConfigAPI.Config()
|
||||
require.NotEmpty(t, cfg2.MCPConfigFiles)
|
||||
require.Contains(t, cfg2.MCPConfigFiles[0], dir2)
|
||||
mcpFiles2 := a.contextConfigAPI.MCPConfigFiles()
|
||||
require.NotEmpty(t, mcpFiles2)
|
||||
require.Contains(t, mcpFiles2[0], dir2)
|
||||
}
|
||||
|
||||
+251
-22
@@ -2,13 +2,17 @@ package agentcontextconfig
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
@@ -22,9 +26,47 @@ const (
|
||||
EnvMCPConfigFiles = "CODER_AGENT_EXP_MCP_CONFIG_FILES"
|
||||
)
|
||||
|
||||
// Defaults are defined in codersdk/workspacesdk so both
|
||||
// the agent and server can reference them without a
|
||||
// cross-layer import.
|
||||
const (
|
||||
maxInstructionFileBytes = 64 * 1024
|
||||
maxSkillMetaBytes = 64 * 1024
|
||||
)
|
||||
|
||||
// markdownCommentPattern strips HTML comments from instruction
|
||||
// file content for security (prevents hidden prompt injection).
|
||||
var markdownCommentPattern = regexp.MustCompile(`<!--[\s\S]*?-->`)
|
||||
|
||||
// invisibleRunePattern strips invisible Unicode characters that
|
||||
// could be used for prompt injection.
|
||||
//
|
||||
//nolint:gocritic // Non-ASCII char ranges are intentional for invisible Unicode stripping.
|
||||
var invisibleRunePattern = regexp.MustCompile(
|
||||
"[\u00ad\u034f\u061c\u070f" +
|
||||
"\u115f\u1160\u17b4\u17b5" +
|
||||
"\u180b-\u180f" +
|
||||
"\u200b\u200d\u200e\u200f" +
|
||||
"\u202a-\u202e" +
|
||||
"\u2060-\u206f" +
|
||||
"\u3164" +
|
||||
"\ufe00-\ufe0f" +
|
||||
"\ufeff" +
|
||||
"\uffa0" +
|
||||
"\ufff0-\ufff8]",
|
||||
)
|
||||
|
||||
// skillNamePattern validates kebab-case skill names.
|
||||
var skillNamePattern = regexp.MustCompile(
|
||||
`^[a-z0-9]+(-[a-z0-9]+)*$`,
|
||||
)
|
||||
|
||||
// Default values for agent-internal configuration. These are
|
||||
// used when the corresponding env vars are unset.
|
||||
const (
|
||||
DefaultInstructionsDir = "~/.coder"
|
||||
DefaultInstructionsFile = "AGENTS.md"
|
||||
DefaultSkillsDir = ".agents/skills"
|
||||
DefaultSkillMetaFile = "SKILL.md"
|
||||
DefaultMCPConfigFile = ".mcp.json"
|
||||
)
|
||||
|
||||
// API exposes the resolved context configuration through the
|
||||
// agent's HTTP API.
|
||||
@@ -42,33 +84,61 @@ func NewAPI(workingDir func() string) *API {
|
||||
return &API{workingDir: workingDir}
|
||||
}
|
||||
|
||||
// Config reads env vars and resolves paths. Exported for use
|
||||
// by the MCP manager and tests.
|
||||
func Config(workingDir string) workspacesdk.ContextConfigResponse {
|
||||
// Config reads env vars, resolves paths, reads instruction files,
|
||||
// and discovers skills. Returns the HTTP response and the resolved
|
||||
// MCP config file paths (used only agent-internally). Exported
|
||||
// for use by tests.
|
||||
func Config(workingDir string) (workspacesdk.ContextConfigResponse, []string) {
|
||||
// TrimSpace all env vars before cmp.Or so that a
|
||||
// whitespace-only value falls through to the default
|
||||
// consistently. ResolvePaths also trims each comma-
|
||||
// separated entry, but without pre-trimming here a
|
||||
// bare " " would bypass cmp.Or and produce nil.
|
||||
instructionsDir := cmp.Or(strings.TrimSpace(os.Getenv(EnvInstructionsDirs)), workspacesdk.DefaultInstructionsDir)
|
||||
instructionsFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvInstructionsFile)), workspacesdk.DefaultInstructionsFile)
|
||||
skillsDir := cmp.Or(strings.TrimSpace(os.Getenv(EnvSkillsDirs)), workspacesdk.DefaultSkillsDir)
|
||||
skillMetaFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvSkillMetaFile)), workspacesdk.DefaultSkillMetaFile)
|
||||
mcpConfigFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvMCPConfigFiles)), workspacesdk.DefaultMCPConfigFile)
|
||||
instructionsDir := cmp.Or(strings.TrimSpace(os.Getenv(EnvInstructionsDirs)), DefaultInstructionsDir)
|
||||
instructionsFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvInstructionsFile)), DefaultInstructionsFile)
|
||||
skillsDir := cmp.Or(strings.TrimSpace(os.Getenv(EnvSkillsDirs)), DefaultSkillsDir)
|
||||
skillMetaFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvSkillMetaFile)), DefaultSkillMetaFile)
|
||||
mcpConfigFile := cmp.Or(strings.TrimSpace(os.Getenv(EnvMCPConfigFiles)), DefaultMCPConfigFile)
|
||||
|
||||
resolvedInstructionsDirs := ResolvePaths(instructionsDir, workingDir)
|
||||
resolvedSkillsDirs := ResolvePaths(skillsDir, workingDir)
|
||||
|
||||
// Read instruction files from each configured directory.
|
||||
parts := readInstructionFiles(resolvedInstructionsDirs, instructionsFile)
|
||||
|
||||
// Also check the working directory for the instruction file,
|
||||
// unless it was already covered by InstructionsDirs.
|
||||
if workingDir != "" {
|
||||
seenDirs := make(map[string]struct{}, len(resolvedInstructionsDirs))
|
||||
for _, d := range resolvedInstructionsDirs {
|
||||
seenDirs[d] = struct{}{}
|
||||
}
|
||||
if _, ok := seenDirs[workingDir]; !ok {
|
||||
if entry, found := readInstructionFileFromDir(workingDir, instructionsFile); found {
|
||||
parts = append(parts, entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Discover skills from each configured skills directory.
|
||||
skillParts := discoverSkills(resolvedSkillsDirs, skillMetaFile)
|
||||
parts = append(parts, skillParts...)
|
||||
|
||||
// Guarantee non-nil slice to signal agent support.
|
||||
if parts == nil {
|
||||
parts = []codersdk.ChatMessagePart{}
|
||||
}
|
||||
|
||||
return workspacesdk.ContextConfigResponse{
|
||||
InstructionsDirs: ResolvePaths(instructionsDir, workingDir),
|
||||
InstructionsFile: instructionsFile,
|
||||
SkillsDirs: ResolvePaths(skillsDir, workingDir),
|
||||
SkillMetaFile: skillMetaFile,
|
||||
MCPConfigFiles: ResolvePaths(mcpConfigFile, workingDir),
|
||||
}
|
||||
Parts: parts,
|
||||
}, ResolvePaths(mcpConfigFile, workingDir)
|
||||
}
|
||||
|
||||
// Config returns the resolved config for use by other agent
|
||||
// components (e.g. MCP manager).
|
||||
func (api *API) Config() workspacesdk.ContextConfigResponse {
|
||||
return Config(api.workingDir())
|
||||
// MCPConfigFiles returns the resolved MCP configuration file
|
||||
// paths for the agent's MCP manager.
|
||||
func (api *API) MCPConfigFiles() []string {
|
||||
_, mcpFiles := Config(api.workingDir())
|
||||
return mcpFiles
|
||||
}
|
||||
|
||||
// Routes returns the HTTP handler for the context config
|
||||
@@ -80,5 +150,164 @@ func (api *API) Routes() http.Handler {
|
||||
}
|
||||
|
||||
func (api *API) handleGet(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, api.Config())
|
||||
response, _ := Config(api.workingDir())
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// readInstructionFiles reads instruction files from each given
|
||||
// directory. Missing directories are silently skipped. Duplicate
|
||||
// directories are deduplicated.
|
||||
func readInstructionFiles(dirs []string, fileName string) []codersdk.ChatMessagePart {
|
||||
var parts []codersdk.ChatMessagePart
|
||||
seen := make(map[string]struct{}, len(dirs))
|
||||
for _, dir := range dirs {
|
||||
if _, ok := seen[dir]; ok {
|
||||
continue
|
||||
}
|
||||
seen[dir] = struct{}{}
|
||||
if part, found := readInstructionFileFromDir(dir, fileName); found {
|
||||
parts = append(parts, part)
|
||||
}
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// readInstructionFileFromDir scans a directory for a file matching
|
||||
// fileName (case-insensitive) and reads its contents.
|
||||
func readInstructionFileFromDir(dir, fileName string) (codersdk.ChatMessagePart, bool) {
|
||||
dirEntries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return codersdk.ChatMessagePart{}, false
|
||||
}
|
||||
|
||||
for _, e := range dirEntries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(e.Name()), fileName) {
|
||||
filePath := filepath.Join(dir, e.Name())
|
||||
content, truncated, ok := readAndSanitizeFile(filePath, maxInstructionFileBytes)
|
||||
if !ok {
|
||||
return codersdk.ChatMessagePart{}, false
|
||||
}
|
||||
if content == "" {
|
||||
return codersdk.ChatMessagePart{}, false
|
||||
}
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: filePath,
|
||||
ContextFileContent: content,
|
||||
ContextFileTruncated: truncated,
|
||||
}, true
|
||||
}
|
||||
}
|
||||
return codersdk.ChatMessagePart{}, false
|
||||
}
|
||||
|
||||
// readAndSanitizeFile reads the file at path, capping the read
|
||||
// at maxBytes to avoid unbounded memory allocation. It sanitizes
|
||||
// the content (strips HTML comments and invisible Unicode) and
|
||||
// returns the result. Returns false if the file cannot be read.
|
||||
func readAndSanitizeFile(path string, maxBytes int64) (content string, truncated bool, ok bool) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", false, false
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Read at most maxBytes+1 to detect truncation without
|
||||
// allocating the entire file into memory.
|
||||
raw, err := io.ReadAll(io.LimitReader(f, maxBytes+1))
|
||||
if err != nil {
|
||||
return "", false, false
|
||||
}
|
||||
|
||||
truncated = int64(len(raw)) > maxBytes
|
||||
if truncated {
|
||||
raw = raw[:maxBytes]
|
||||
}
|
||||
|
||||
s := sanitizeInstructionMarkdown(string(raw))
|
||||
if s == "" {
|
||||
return "", truncated, true
|
||||
}
|
||||
return s, truncated, true
|
||||
}
|
||||
|
||||
// sanitizeInstructionMarkdown strips HTML comments, invisible
|
||||
// Unicode characters, and CRLF line endings from instruction
|
||||
// file content.
|
||||
func sanitizeInstructionMarkdown(content string) string {
|
||||
content = strings.ReplaceAll(content, "\r\n", "\n")
|
||||
content = strings.ReplaceAll(content, "\r", "\n")
|
||||
content = markdownCommentPattern.ReplaceAllString(content, "")
|
||||
content = invisibleRunePattern.ReplaceAllString(content, "")
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
// discoverSkills walks the given skills directories and returns
|
||||
// metadata for every valid skill it finds. Body and supporting
|
||||
// file lists are NOT included; chatd fetches those on demand
|
||||
// via read_skill. Missing directories or individual errors are
|
||||
// silently skipped.
|
||||
func discoverSkills(skillsDirs []string, metaFile string) []codersdk.ChatMessagePart {
|
||||
seen := make(map[string]struct{})
|
||||
var parts []codersdk.ChatMessagePart
|
||||
|
||||
for _, skillsDir := range skillsDirs {
|
||||
entries, err := os.ReadDir(skillsDir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
metaPath := filepath.Join(skillsDir, entry.Name(), metaFile)
|
||||
f, err := os.Open(metaPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
raw, err := io.ReadAll(io.LimitReader(f, maxSkillMetaBytes+1))
|
||||
_ = f.Close()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if int64(len(raw)) > maxSkillMetaBytes {
|
||||
raw = raw[:maxSkillMetaBytes]
|
||||
}
|
||||
|
||||
name, description, _, err := workspacesdk.ParseSkillFrontmatter(string(raw))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// The directory name must match the declared name.
|
||||
if name != entry.Name() {
|
||||
continue
|
||||
}
|
||||
if !skillNamePattern.MatchString(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
// First occurrence wins across directories.
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
|
||||
skillDir := filepath.Join(skillsDir, entry.Name())
|
||||
parts = append(parts, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: name,
|
||||
SkillDescription: description,
|
||||
SkillDir: skillDir,
|
||||
ContextFileSkillMetaFile: metaFile,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
@@ -1,15 +1,28 @@
|
||||
package agentcontextconfig_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontextconfig"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// filterParts returns only the parts matching the given type.
|
||||
func filterParts(parts []codersdk.ChatMessagePart, t codersdk.ChatMessagePartType) []codersdk.ChatMessagePart {
|
||||
var out []codersdk.ChatMessagePart
|
||||
for _, p := range parts {
|
||||
if p.Type == t {
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
t.Run("Defaults", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
@@ -24,19 +37,13 @@ func TestConfig(t *testing.T) {
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg := agentcontextconfig.Config(workDir)
|
||||
cfg, mcpFiles := agentcontextconfig.Config(workDir)
|
||||
|
||||
require.Equal(t, workspacesdk.DefaultInstructionsFile, cfg.InstructionsFile)
|
||||
require.Equal(t, workspacesdk.DefaultSkillMetaFile, cfg.SkillMetaFile)
|
||||
// Default instructions dir is "~/.coder" which resolves
|
||||
// to the home directory.
|
||||
require.Equal(t, []string{filepath.Join(fakeHome, ".coder")}, cfg.InstructionsDirs)
|
||||
// Default skills dir is ".agents/skills" (relative),
|
||||
// resolved against the working directory.
|
||||
require.Equal(t, []string{filepath.Join(workDir, ".agents", "skills")}, cfg.SkillsDirs)
|
||||
// Parts is always non-nil.
|
||||
require.NotNil(t, cfg.Parts)
|
||||
// Default MCP config file is ".mcp.json" (relative),
|
||||
// resolved against the working directory.
|
||||
require.Equal(t, []string{filepath.Join(workDir, ".mcp.json")}, cfg.MCPConfigFiles)
|
||||
require.Equal(t, []string{filepath.Join(workDir, ".mcp.json")}, mcpFiles)
|
||||
})
|
||||
|
||||
t.Run("CustomEnvVars", func(t *testing.T) {
|
||||
@@ -44,8 +51,8 @@ func TestConfig(t *testing.T) {
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
optInstructions := platformAbsPath("opt", "instructions")
|
||||
optSkills := platformAbsPath("opt", "skills")
|
||||
optInstructions := t.TempDir()
|
||||
optSkills := t.TempDir()
|
||||
optMCP := platformAbsPath("opt", "mcp.json")
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, optInstructions)
|
||||
@@ -54,32 +61,58 @@ func TestConfig(t *testing.T) {
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "META.yaml")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg := agentcontextconfig.Config(workDir)
|
||||
// Create files matching the custom names so we can
|
||||
// verify the env vars actually change lookup behavior.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(optInstructions, "CUSTOM.md"), []byte("custom instructions"), 0o600))
|
||||
skillDir := filepath.Join(optSkills, "my-skill")
|
||||
require.NoError(t, os.MkdirAll(skillDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(skillDir, "META.yaml"),
|
||||
[]byte("---\nname: my-skill\ndescription: custom meta\n---\n"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
require.Equal(t, "CUSTOM.md", cfg.InstructionsFile)
|
||||
require.Equal(t, "META.yaml", cfg.SkillMetaFile)
|
||||
require.Equal(t, []string{optInstructions}, cfg.InstructionsDirs)
|
||||
require.Equal(t, []string{optSkills}, cfg.SkillsDirs)
|
||||
require.Equal(t, []string{optMCP}, cfg.MCPConfigFiles)
|
||||
workDir := platformAbsPath("work")
|
||||
cfg, mcpFiles := agentcontextconfig.Config(workDir)
|
||||
|
||||
require.Equal(t, []string{optMCP}, mcpFiles)
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "custom instructions", ctxFiles[0].ContextFileContent)
|
||||
skillParts := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeSkill)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, "my-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, "META.yaml", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
t.Run("WhitespaceInFileNames", func(t *testing.T) {
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, " CLAUDE.md ")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg := agentcontextconfig.Config(workDir)
|
||||
workDir := t.TempDir()
|
||||
// Create a file matching the trimmed name.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(fakeHome, "CLAUDE.md"), []byte("hello"), 0o600))
|
||||
|
||||
require.Equal(t, "CLAUDE.md", cfg.InstructionsFile)
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "hello", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
t.Run("CommaSeparatedDirs", func(t *testing.T) {
|
||||
a := platformAbsPath("opt", "a")
|
||||
b := platformAbsPath("opt", "b")
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
a := t.TempDir()
|
||||
b := t.TempDir()
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, a+","+b)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
@@ -87,10 +120,300 @@ func TestConfig(t *testing.T) {
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg := agentcontextconfig.Config(workDir)
|
||||
// Put instruction files in both dirs.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(a, "AGENTS.md"), []byte("from a"), 0o600))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(b, "AGENTS.md"), []byte("from b"), 0o600))
|
||||
|
||||
require.Equal(t, []string{a, b}, cfg.InstructionsDirs)
|
||||
workDir := t.TempDir()
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 2)
|
||||
require.Equal(t, "from a", ctxFiles[0].ContextFileContent)
|
||||
require.Equal(t, "from b", ctxFiles[1].ContextFileContent)
|
||||
})
|
||||
|
||||
t.Run("ReadsInstructionFiles", func(t *testing.T) {
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
// Create ~/.coder/AGENTS.md
|
||||
coderDir := filepath.Join(fakeHome, ".coder")
|
||||
require.NoError(t, os.MkdirAll(coderDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(coderDir, "AGENTS.md"),
|
||||
[]byte("home instructions"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.NotNil(t, cfg.Parts)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "home instructions", ctxFiles[0].ContextFileContent)
|
||||
require.Equal(t, filepath.Join(coderDir, "AGENTS.md"), ctxFiles[0].ContextFilePath)
|
||||
require.False(t, ctxFiles[0].ContextFileTruncated)
|
||||
})
|
||||
|
||||
t.Run("ReadsWorkingDirInstructionFile", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
|
||||
// Create AGENTS.md in the working directory.
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("project instructions"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
// Should find the working dir file (not in instruction dirs).
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.NotNil(t, cfg.Parts)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "project instructions", ctxFiles[0].ContextFileContent)
|
||||
require.Equal(t, filepath.Join(workDir, "AGENTS.md"), ctxFiles[0].ContextFilePath)
|
||||
})
|
||||
|
||||
t.Run("TruncatesLargeInstructionFile", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
largeContent := strings.Repeat("a", 64*1024+100)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(workDir, "AGENTS.md"), []byte(largeContent), 0o600))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.True(t, ctxFiles[0].ContextFileTruncated)
|
||||
require.Len(t, ctxFiles[0].ContextFileContent, 64*1024)
|
||||
})
|
||||
|
||||
t.Run("SanitizesHTMLComments", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("visible\n<!-- hidden -->content"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "visible\ncontent", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
t.Run("SanitizesInvisibleUnicode", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
// U+200B (zero-width space) should be stripped.
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("before\u200bafter"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "beforeafter", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
t.Run("NormalizesCRLF", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("line1\r\nline2\rline3"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "line1\nline2\nline3", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
t.Run("DiscoversSkills", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
skillsDir := filepath.Join(workDir, ".agents", "skills")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, skillsDir)
|
||||
|
||||
// Create a valid skill.
|
||||
skillDir := filepath.Join(skillsDir, "my-skill")
|
||||
require.NoError(t, os.MkdirAll(skillDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(skillDir, "SKILL.md"),
|
||||
[]byte("---\nname: my-skill\ndescription: A test skill\n---\nSkill body"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
skillParts := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeSkill)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, "my-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
t.Run("SkipsMissingDirs", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
nonExistent := filepath.Join(t.TempDir(), "does-not-exist")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, nonExistent)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, nonExistent)
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
// Non-nil empty slice (signals agent supports new format).
|
||||
require.NotNil(t, cfg.Parts)
|
||||
require.Empty(t, cfg.Parts)
|
||||
})
|
||||
|
||||
t.Run("MCPConfigFilesResolvedSeparately", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
|
||||
optMCP := platformAbsPath("opt", "custom.json")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
|
||||
|
||||
workDir := t.TempDir()
|
||||
_, mcpFiles := agentcontextconfig.Config(workDir)
|
||||
|
||||
require.Equal(t, []string{optMCP}, mcpFiles)
|
||||
})
|
||||
|
||||
t.Run("SkillNameMustMatchDir", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
skillsDir := filepath.Join(workDir, "skills")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, skillsDir)
|
||||
|
||||
// Skill name in frontmatter doesn't match directory name.
|
||||
skillDir := filepath.Join(skillsDir, "wrong-dir-name")
|
||||
require.NoError(t, os.MkdirAll(skillDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(skillDir, "SKILL.md"),
|
||||
[]byte("---\nname: actual-name\ndescription: mismatch\n---\n"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
skillParts := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeSkill)
|
||||
require.Empty(t, skillParts)
|
||||
})
|
||||
|
||||
t.Run("DuplicateSkillsFirstWins", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
skillsDir1 := filepath.Join(workDir, "skills1")
|
||||
skillsDir2 := filepath.Join(workDir, "skills2")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, skillsDir1+","+skillsDir2)
|
||||
|
||||
// Same skill name in both directories.
|
||||
for _, dir := range []string{skillsDir1, skillsDir2} {
|
||||
skillDir := filepath.Join(dir, "dup-skill")
|
||||
require.NoError(t, os.MkdirAll(skillDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(skillDir, "SKILL.md"),
|
||||
[]byte("---\nname: dup-skill\ndescription: from "+filepath.Base(dir)+"\n---\n"),
|
||||
0o600,
|
||||
))
|
||||
}
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
skillParts := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeSkill)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, "from skills1", skillParts[0].SkillDescription)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -104,14 +427,13 @@ func TestNewAPI_LazyDirectory(t *testing.T) {
|
||||
dir := ""
|
||||
api := agentcontextconfig.NewAPI(func() string { return dir })
|
||||
|
||||
// Before directory is set, relative paths resolve to nothing.
|
||||
cfg := api.Config()
|
||||
require.Empty(t, cfg.SkillsDirs)
|
||||
require.Empty(t, cfg.MCPConfigFiles)
|
||||
// Before directory is set, MCP paths resolve to nothing.
|
||||
mcpFiles := api.MCPConfigFiles()
|
||||
require.Empty(t, mcpFiles)
|
||||
|
||||
// After setting the directory, Config() picks it up lazily.
|
||||
// After setting the directory, MCPConfigFiles() picks it up.
|
||||
dir = platformAbsPath("work")
|
||||
cfg = api.Config()
|
||||
require.NotEmpty(t, cfg.SkillsDirs)
|
||||
require.Equal(t, []string{filepath.Join(dir, ".agents", "skills")}, cfg.SkillsDirs)
|
||||
mcpFiles = api.MCPConfigFiles()
|
||||
require.NotEmpty(t, mcpFiles)
|
||||
require.Equal(t, []string{filepath.Join(dir, ".mcp.json")}, mcpFiles)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
@@ -47,6 +52,9 @@ type API struct {
|
||||
logger slog.Logger
|
||||
desktop Desktop
|
||||
clock quartz.Clock
|
||||
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewAPI creates a new desktop streaming API.
|
||||
@@ -66,6 +74,10 @@ func (a *API) Routes() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/vnc", a.handleDesktopVNC)
|
||||
r.Post("/action", a.handleAction)
|
||||
r.Route("/recording", func(r chi.Router) {
|
||||
r.Post("/start", a.handleRecordingStart)
|
||||
r.Post("/stop", a.handleRecordingStop)
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -116,6 +128,9 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
handlerStart := a.clock.Now()
|
||||
|
||||
// Update last desktop action timestamp for idle recording monitor.
|
||||
a.desktop.RecordActivity()
|
||||
|
||||
// Ensure the desktop is running and grab native dimensions.
|
||||
cfg, err := a.desktop.Start(ctx)
|
||||
if err != nil {
|
||||
@@ -480,9 +495,150 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Close shuts down the desktop session if one is running.
|
||||
func (a *API) Close() error {
|
||||
a.closeMu.Lock()
|
||||
if a.closed {
|
||||
a.closeMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
a.closed = true
|
||||
a.closeMu.Unlock()
|
||||
|
||||
return a.desktop.Close()
|
||||
}
|
||||
|
||||
// decodeRecordingRequest decodes and validates a recording request
|
||||
// from the HTTP body, returning the recording ID. Returns false if
|
||||
// the request was invalid and an error response was already written.
|
||||
func (*API) decodeRecordingRequest(rw http.ResponseWriter, r *http.Request) (string, bool) {
|
||||
ctx := r.Context()
|
||||
var req struct {
|
||||
RecordingID string `json:"recording_id"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to decode request body.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return "", false
|
||||
}
|
||||
if req.RecordingID == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing recording_id.",
|
||||
})
|
||||
return "", false
|
||||
}
|
||||
if _, err := uuid.Parse(req.RecordingID); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid recording_id format.",
|
||||
Detail: "recording_id must be a valid UUID.",
|
||||
})
|
||||
return "", false
|
||||
}
|
||||
return req.RecordingID, true
|
||||
}
|
||||
|
||||
func (a *API) handleRecordingStart(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
recordingID, ok := a.decodeRecordingRequest(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
a.closeMu.Lock()
|
||||
if a.closed {
|
||||
a.closeMu.Unlock()
|
||||
httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{
|
||||
Message: "Desktop API is shutting down.",
|
||||
})
|
||||
return
|
||||
}
|
||||
a.closeMu.Unlock()
|
||||
|
||||
if err := a.desktop.StartRecording(ctx, recordingID); err != nil {
|
||||
if errors.Is(err, ErrDesktopClosed) {
|
||||
httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{
|
||||
Message: "Desktop API is shutting down.",
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to start recording.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
|
||||
Message: "Recording started.",
|
||||
})
|
||||
}
|
||||
|
||||
func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
recordingID, ok := a.decodeRecordingRequest(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
a.closeMu.Lock()
|
||||
if a.closed {
|
||||
a.closeMu.Unlock()
|
||||
httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{
|
||||
Message: "Desktop API is shutting down.",
|
||||
})
|
||||
return
|
||||
}
|
||||
a.closeMu.Unlock()
|
||||
|
||||
// Stop recording (idempotent).
|
||||
// Use a context detached from the HTTP request so that if the
|
||||
// connection drops, the recording process can still shut down
|
||||
// gracefully. WithoutCancel preserves request-scoped values.
|
||||
stopCtx, stopCancel := context.WithTimeout(context.WithoutCancel(r.Context()), 30*time.Second)
|
||||
defer stopCancel()
|
||||
artifact, err := a.desktop.StopRecording(stopCtx, recordingID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUnknownRecording) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Recording not found.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, ErrRecordingCorrupted) {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Recording is corrupted.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to stop recording.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer artifact.Reader.Close()
|
||||
|
||||
if artifact.Size > workspacesdk.MaxRecordingSize {
|
||||
a.logger.Warn(ctx, "recording file exceeds maximum size",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("size", artifact.Size),
|
||||
slog.F("max_size", workspacesdk.MaxRecordingSize),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{
|
||||
Message: "Recording file exceeds maximum allowed size.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "video/mp4")
|
||||
rw.Header().Set("Content-Length", strconv.FormatInt(artifact.Size, 10))
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = io.Copy(rw, artifact.Reader)
|
||||
}
|
||||
|
||||
// coordFromAction extracts the coordinate pair from a DesktopAction,
|
||||
// returning an error if the coordinate field is missing.
|
||||
func coordFromAction(action DesktopAction) (x, y int, err error) {
|
||||
|
||||
@@ -4,12 +4,17 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
@@ -21,6 +26,16 @@ import (
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// Test recording UUIDs used across tests.
|
||||
const (
|
||||
testRecIDDefault = "870e1f02-8118-4300-a37e-4adb0117baf3"
|
||||
testRecIDStartIdempotent = "250a2ffb-a5e5-4c94-9754-4d6a4ab7ba20"
|
||||
testRecIDStopIdempotent = "38f8a378-f98f-4758-a4ae-950b44cf989a"
|
||||
testRecIDConcurrentA = "8dc173eb-23c6-4601-a485-b6dfb2a42c3a"
|
||||
testRecIDConcurrentB = "fea490d4-70f0-4798-a181-29d65ce25ae1"
|
||||
testRecIDRestart = "75173a0d-b018-4e2e-a771-defa3fc6af69"
|
||||
)
|
||||
|
||||
// Ensure fakeDesktop satisfies the Desktop interface at compile time.
|
||||
var _ agentdesktop.Desktop = (*fakeDesktop)(nil)
|
||||
|
||||
@@ -43,6 +58,14 @@ type fakeDesktop struct {
|
||||
lastTyped string
|
||||
lastKeyDown string
|
||||
lastKeyUp string
|
||||
|
||||
// Recording tracking (guarded by recMu).
|
||||
recMu sync.Mutex
|
||||
recordings map[string]string // ID → file path
|
||||
stopCalls []string // recording IDs passed to StopRecording
|
||||
recStopCh chan string // optional: signaled when StopRecording is called
|
||||
startCount int // incremented on each new recording start
|
||||
activityCount int // incremented by RecordActivity
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Start(context.Context) (agentdesktop.DisplayConfig, error) {
|
||||
@@ -107,11 +130,140 @@ func (f *fakeDesktop) CursorPosition(context.Context) (x int, y int, err error)
|
||||
return f.cursorPos[0], f.cursorPos[1], nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) StartRecording(_ context.Context, recordingID string) error {
|
||||
f.recMu.Lock()
|
||||
defer f.recMu.Unlock()
|
||||
if f.recordings == nil {
|
||||
f.recordings = make(map[string]string)
|
||||
}
|
||||
if path, ok := f.recordings[recordingID]; ok {
|
||||
// Check if already stopped (file still exists but stop was
|
||||
// called). For the fake, a stopped recording means its ID
|
||||
// appears in stopCalls. In that case, remove the old file
|
||||
// and start fresh.
|
||||
stopped := slices.Contains(f.stopCalls, recordingID)
|
||||
if !stopped {
|
||||
// Active recording - no-op.
|
||||
return nil
|
||||
}
|
||||
// Completed recording - discard old file, start fresh.
|
||||
_ = os.Remove(path)
|
||||
delete(f.recordings, recordingID)
|
||||
}
|
||||
f.startCount++
|
||||
tmpFile, err := os.CreateTemp("", "fake-recording-*.mp4")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, _ = tmpFile.Write([]byte(fmt.Sprintf("fake-mp4-data-%s-%d", recordingID, f.startCount)))
|
||||
_ = tmpFile.Close()
|
||||
f.recordings[recordingID] = tmpFile.Name()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) StopRecording(_ context.Context, recordingID string) (*agentdesktop.RecordingArtifact, error) {
|
||||
f.recMu.Lock()
|
||||
defer f.recMu.Unlock()
|
||||
if f.recordings == nil {
|
||||
return nil, agentdesktop.ErrUnknownRecording
|
||||
}
|
||||
path, ok := f.recordings[recordingID]
|
||||
if !ok {
|
||||
return nil, agentdesktop.ErrUnknownRecording
|
||||
}
|
||||
f.stopCalls = append(f.stopCalls, recordingID)
|
||||
if f.recStopCh != nil {
|
||||
select {
|
||||
case f.recStopCh <- recordingID:
|
||||
default:
|
||||
}
|
||||
}
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &agentdesktop.RecordingArtifact{
|
||||
Reader: file,
|
||||
Size: info.Size(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) RecordActivity() {
|
||||
f.recMu.Lock()
|
||||
f.activityCount++
|
||||
f.recMu.Unlock()
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Close() error {
|
||||
f.closed = true
|
||||
f.recMu.Lock()
|
||||
defer f.recMu.Unlock()
|
||||
for _, path := range f.recordings {
|
||||
_ = os.Remove(path)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// failStartRecordingDesktop wraps fakeDesktop and overrides
|
||||
// StartRecording to always return an error.
|
||||
type failStartRecordingDesktop struct {
|
||||
fakeDesktop
|
||||
startRecordingErr error
|
||||
}
|
||||
|
||||
func (f *failStartRecordingDesktop) StartRecording(_ context.Context, _ string) error {
|
||||
return f.startRecordingErr
|
||||
}
|
||||
|
||||
// corruptedStopDesktop wraps fakeDesktop and overrides
|
||||
// StopRecording to always return ErrRecordingCorrupted.
|
||||
type corruptedStopDesktop struct {
|
||||
fakeDesktop
|
||||
}
|
||||
|
||||
func (*corruptedStopDesktop) StopRecording(_ context.Context, _ string) (*agentdesktop.RecordingArtifact, error) {
|
||||
return nil, agentdesktop.ErrRecordingCorrupted
|
||||
}
|
||||
|
||||
// oversizedFakeDesktop wraps fakeDesktop and expands recording files
|
||||
// beyond MaxRecordingSize when StopRecording is called.
|
||||
type oversizedFakeDesktop struct {
|
||||
fakeDesktop
|
||||
}
|
||||
|
||||
func (f *oversizedFakeDesktop) StopRecording(ctx context.Context, recordingID string) (*agentdesktop.RecordingArtifact, error) {
|
||||
artifact, err := f.fakeDesktop.StopRecording(ctx, recordingID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Close the original reader since we're going to re-open after truncation.
|
||||
artifact.Reader.Close()
|
||||
|
||||
// Look up the path from the fakeDesktop recordings.
|
||||
f.fakeDesktop.recMu.Lock()
|
||||
path := f.fakeDesktop.recordings[recordingID]
|
||||
f.fakeDesktop.recMu.Unlock()
|
||||
|
||||
// Expand the file to exceed the maximum recording size.
|
||||
if err := os.Truncate(path, workspacesdk.MaxRecordingSize+1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Re-open the truncated file.
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &agentdesktop.RecordingArtifact{
|
||||
Reader: file,
|
||||
Size: workspacesdk.MaxRecordingSize + 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestHandleDesktopVNC_StartError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -134,6 +286,37 @@ func TestHandleDesktopVNC_StartError(t *testing.T) {
|
||||
assert.Equal(t, "Failed to start desktop session.", resp.Message)
|
||||
}
|
||||
|
||||
func TestHandleAction_CallsRecordActivity(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "left_click",
|
||||
Coordinate: &[2]int{100, 200},
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
fake.recMu.Lock()
|
||||
count := fake.activityCount
|
||||
fake.recMu.Unlock()
|
||||
assert.Equal(t, 1, count, "handleAction should call RecordActivity exactly once")
|
||||
}
|
||||
|
||||
func TestHandleAction_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -574,3 +757,481 @@ func TestHandleAction_CursorPositionReturnsDeclaredCoordinates(t *testing.T) {
|
||||
// Native (960,540) in 1920x1080 should map to declared space in 1280x720.
|
||||
assert.Equal(t, "x=640,y=360", resp.Output)
|
||||
}
|
||||
|
||||
func TestRecordingStartStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": testRecIDDefault})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": testRecIDDefault})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), rr.Body.Bytes())
|
||||
}
|
||||
|
||||
func TestRecordingStartFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &failStartRecordingDesktop{
|
||||
fakeDesktop: fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
},
|
||||
startRecordingErr: xerrors.New("start recording error"),
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
body, err := json.Marshal(map[string]string{"recording_id": uuid.New().String()})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Failed to start recording.", resp.Message)
|
||||
}
|
||||
|
||||
func TestRecordingStartIdempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start same recording twice - both should succeed.
|
||||
for range 2 {
|
||||
body, err := json.Marshal(map[string]string{"recording_id": testRecIDStartIdempotent})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
// Stop once, verify normal response.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": testRecIDStartIdempotent})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), rr.Body.Bytes())
|
||||
}
|
||||
|
||||
func TestRecordingStopIdempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": testRecIDStopIdempotent})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop twice - both should succeed with identical data.
|
||||
var bodies [2][]byte
|
||||
for i := range 2 {
|
||||
body, err := json.Marshal(map[string]string{"recording_id": testRecIDStopIdempotent})
|
||||
require.NoError(t, err)
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(recorder, request)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "video/mp4", recorder.Header().Get("Content-Type"))
|
||||
bodies[i] = recorder.Body.Bytes()
|
||||
}
|
||||
assert.Equal(t, bodies[0], bodies[1])
|
||||
}
|
||||
|
||||
func TestRecordingStopInvalidIDFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
body, err := json.Marshal(map[string]string{"recording_id": "not-a-uuid"})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
}
|
||||
|
||||
func TestRecordingStopUnknownRecording(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Send a valid UUID that was never started - should reach
|
||||
// StopRecording, get ErrUnknownRecording, and return 404.
|
||||
body, err := json.Marshal(map[string]string{"recording_id": uuid.New().String()})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusNotFound, rr.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Recording not found.", resp.Message)
|
||||
}
|
||||
|
||||
func TestRecordingStopOversizedFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &oversizedFakeDesktop{
|
||||
fakeDesktop: fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording - file exceeds max size, expect 413.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusRequestEntityTooLarge, rr.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Recording file exceeds maximum allowed size.", resp.Message)
|
||||
}
|
||||
|
||||
func TestRecordingMultipleSimultaneous(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start two recordings with different IDs.
|
||||
for _, id := range []string{testRecIDConcurrentA, testRecIDConcurrentB} {
|
||||
body, err := json.Marshal(map[string]string{"recording_id": id})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
// Stop both and verify each returns its own data.
|
||||
expected := map[string][]byte{
|
||||
testRecIDConcurrentA: []byte("fake-mp4-data-" + testRecIDConcurrentA + "-1"),
|
||||
testRecIDConcurrentB: []byte("fake-mp4-data-" + testRecIDConcurrentB + "-2"),
|
||||
}
|
||||
for _, id := range []string{testRecIDConcurrentA, testRecIDConcurrentB} {
|
||||
body, err := json.Marshal(map[string]string{"recording_id": id})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, expected[id], rr.Body.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordingStartMalformedBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader([]byte("not json")))
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
}
|
||||
|
||||
func TestRecordingStartEmptyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
body, err := json.Marshal(map[string]string{"recording_id": ""})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
}
|
||||
|
||||
func TestRecordingStopEmptyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
body, err := json.Marshal(map[string]string{"recording_id": ""})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
}
|
||||
|
||||
func TestRecordingStopMalformedBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader([]byte("not json")))
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
}
|
||||
|
||||
func TestRecordingStartAfterCompleted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Step 1: Start recording.
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": testRecIDRestart})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Step 2: Stop recording (gets first MP4 data).
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": testRecIDRestart})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
firstData := rr.Body.Bytes()
|
||||
require.NotEmpty(t, firstData)
|
||||
|
||||
// Step 3: Start again with the same ID - should succeed
|
||||
// (old file discarded, new recording started).
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Step 4: Stop again - should return NEW MP4 data.
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
secondData := rr.Body.Bytes()
|
||||
require.NotEmpty(t, secondData)
|
||||
|
||||
// The two recordings should have different data because the
|
||||
// fake increments a counter on each fresh start.
|
||||
assert.NotEqual(t, firstData, secondData,
|
||||
"restarted recording should produce different data")
|
||||
}
|
||||
|
||||
func TestRecordingStartAfterClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Close the API before sending the request.
|
||||
api.Close()
|
||||
|
||||
body, err := json.Marshal(map[string]string{"recording_id": uuid.New().String()})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, rr.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Desktop API is shutting down.", resp.Message)
|
||||
}
|
||||
|
||||
func TestRecordingStartDesktopClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// StartRecording returns ErrDesktopClosed to simulate a race
|
||||
// where the desktop is closed between the API-level check and
|
||||
// the desktop-level StartRecording call.
|
||||
fake := &failStartRecordingDesktop{
|
||||
fakeDesktop: fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
},
|
||||
startRecordingErr: agentdesktop.ErrDesktopClosed,
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
body, err := json.Marshal(map[string]string{"recording_id": uuid.New().String()})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, rr.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Desktop API is shutting down.", resp.Message)
|
||||
}
|
||||
|
||||
func TestRecordingStopCorrupted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &corruptedStopDesktop{
|
||||
fakeDesktop: fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start a recording so the stop has something to find.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop returns ErrRecordingCorrupted.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
|
||||
var respStop codersdk.Response
|
||||
err = json.NewDecoder(rr.Body).Decode(&respStop)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Recording is corrupted.", respStop.Message)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,10 @@ package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// Desktop abstracts a virtual desktop session running inside a workspace.
|
||||
@@ -58,10 +61,52 @@ type Desktop interface {
|
||||
// CursorPosition returns the current cursor coordinates.
|
||||
CursorPosition(ctx context.Context) (x, y int, err error)
|
||||
|
||||
// RecordActivity marks the desktop as having received user
|
||||
// interaction, resetting the idle-recording timer.
|
||||
RecordActivity()
|
||||
|
||||
// StartRecording begins recording the desktop to an MP4 file
|
||||
// using the caller-provided recording ID. Safe to call
|
||||
// repeatedly - active recordings continue unchanged, stopped
|
||||
// recordings are discarded and restarted. Concurrent recordings
|
||||
// are supported.
|
||||
StartRecording(ctx context.Context, recordingID string) error
|
||||
|
||||
// StopRecording finalizes the recording identified by the given
|
||||
// ID. Idempotent - safe to call on an already-stopped recording.
|
||||
// Returns a RecordingArtifact that the caller can stream. The
|
||||
// caller must close the artifact when done. Returns an error if
|
||||
// the recording ID is unknown.
|
||||
StopRecording(ctx context.Context, recordingID string) (*RecordingArtifact, error)
|
||||
|
||||
// Close shuts down the desktop session and cleans up resources.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// ErrUnknownRecording is returned by StopRecording when the
|
||||
// recording ID is not recognized.
|
||||
var ErrUnknownRecording = xerrors.New("unknown recording ID")
|
||||
|
||||
// ErrDesktopClosed is returned when an operation is attempted on a
|
||||
// closed desktop session.
|
||||
var ErrDesktopClosed = xerrors.New("desktop closed")
|
||||
|
||||
// ErrRecordingCorrupted is returned by StopRecording when the
|
||||
// recording process was force-killed and the artifact is likely
|
||||
// incomplete or corrupt.
|
||||
var ErrRecordingCorrupted = xerrors.New("recording corrupted: process was force-killed")
|
||||
|
||||
// RecordingArtifact is a finalized recording returned by StopRecording.
|
||||
// The caller streams the artifact and must call Close when done. The
|
||||
// artifact remains valid even if the same recording ID is restarted
|
||||
// or the desktop is closed while the caller is reading.
|
||||
type RecordingArtifact struct {
|
||||
// Reader is the MP4 content. Callers must close it when done.
|
||||
Reader io.ReadCloser
|
||||
// Size is the byte length of the MP4 content.
|
||||
Size int64
|
||||
}
|
||||
|
||||
// DisplayConfig describes a running desktop session.
|
||||
type DisplayConfig struct {
|
||||
Width int // native width in pixels
|
||||
|
||||
@@ -3,6 +3,7 @@ package agentdesktop
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
@@ -18,6 +20,7 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// portableDesktopOutput is the JSON output from
|
||||
@@ -49,32 +52,65 @@ type screenshotOutput struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// recordingProcess tracks a single desktop recording subprocess.
|
||||
type recordingProcess struct {
|
||||
cmd *exec.Cmd
|
||||
filePath string
|
||||
stopped bool
|
||||
killed bool // true when the process was SIGKILLed
|
||||
done chan struct{} // closed when cmd.Wait() returns
|
||||
waitErr error // set before done is closed
|
||||
stopOnce sync.Once
|
||||
idleCancel context.CancelFunc // cancels the per-recording idle goroutine
|
||||
idleDone chan struct{} // closed when idle goroutine exits
|
||||
}
|
||||
|
||||
// maxConcurrentRecordings is the maximum number of active (non-stopped)
|
||||
// recordings allowed at once. This prevents resource exhaustion.
|
||||
const maxConcurrentRecordings = 5
|
||||
|
||||
// idleTimeout is the duration of desktop inactivity after which all
|
||||
// active recordings are automatically stopped.
|
||||
const idleTimeout = 10 * time.Minute
|
||||
|
||||
// portableDesktop implements Desktop by shelling out to the
|
||||
// portabledesktop CLI via agentexec.Execer.
|
||||
type portableDesktop struct {
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
scriptBinDir string // coder script bin directory
|
||||
clock quartz.Clock
|
||||
|
||||
mu sync.Mutex
|
||||
session *desktopSession // nil until started
|
||||
binPath string // resolved path to binary, cached
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
session *desktopSession // nil until started
|
||||
binPath string // resolved path to binary, cached
|
||||
closed bool
|
||||
recordings map[string]*recordingProcess // guarded by mu
|
||||
lastDesktopActionAt atomic.Int64
|
||||
}
|
||||
|
||||
// NewPortableDesktop creates a Desktop backed by the portabledesktop
|
||||
// CLI binary, using execer to spawn child processes. scriptBinDir is
|
||||
// the coder script bin directory checked for the binary.
|
||||
// the coder script bin directory checked for the binary. If clk is
|
||||
// nil, a real clock is used.
|
||||
func NewPortableDesktop(
|
||||
logger slog.Logger,
|
||||
execer agentexec.Execer,
|
||||
scriptBinDir string,
|
||||
clk quartz.Clock,
|
||||
) Desktop {
|
||||
return &portableDesktop{
|
||||
if clk == nil {
|
||||
clk = quartz.NewReal()
|
||||
}
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
scriptBinDir: scriptBinDir,
|
||||
clock: clk,
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
return pd
|
||||
}
|
||||
|
||||
// Start launches the desktop session (idempotent).
|
||||
@@ -83,7 +119,7 @@ func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return DisplayConfig{}, xerrors.New("desktop is closed")
|
||||
return DisplayConfig{}, ErrDesktopClosed
|
||||
}
|
||||
|
||||
if err := p.ensureBinary(ctx); err != nil {
|
||||
@@ -313,23 +349,328 @@ func (p *portableDesktop) CursorPosition(ctx context.Context) (x int, y int, err
|
||||
return result.X, result.Y, nil
|
||||
}
|
||||
|
||||
// Close shuts down the desktop session and cleans up resources.
|
||||
func (p *portableDesktop) Close() error {
|
||||
// StartRecording begins recording the desktop to an MP4 file.
|
||||
// Three-state idempotency: active recordings are no-ops,
|
||||
// completed recordings are discarded and restarted.
|
||||
func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string) error {
|
||||
// Ensure the desktop session is running before acquiring the
|
||||
// recording lock. Start is independently locked and idempotent.
|
||||
if _, err := p.Start(ctx); err != nil {
|
||||
return xerrors.Errorf("ensure desktop session: %w", err)
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return ErrDesktopClosed
|
||||
}
|
||||
|
||||
// Three-state idempotency:
|
||||
// - Active recording → no-op, continue recording.
|
||||
// - Completed recording → discard old file, start fresh.
|
||||
// - Unknown ID → fall through to start a new recording.
|
||||
if rec, ok := p.recordings[recordingID]; ok {
|
||||
if !rec.stopped {
|
||||
select {
|
||||
case <-rec.done:
|
||||
// Process exited unexpectedly; treat as completed
|
||||
// so we fall through to discard the old file and
|
||||
// restart.
|
||||
default:
|
||||
// Active recording - no-op, continue recording.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// Completed recording - discard old file, start fresh.
|
||||
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
|
||||
p.logger.Warn(ctx, "failed to remove old recording file",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("file_path", rec.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, recordingID)
|
||||
}
|
||||
|
||||
// Check concurrent recording limit.
|
||||
if p.lockedActiveRecordingCount() >= maxConcurrentRecordings {
|
||||
return xerrors.Errorf("too many concurrent recordings (max %d)", maxConcurrentRecordings)
|
||||
}
|
||||
|
||||
// GC sweep: remove stopped recordings with stale files.
|
||||
p.lockedCleanStaleRecordings(ctx)
|
||||
|
||||
if err := p.ensureBinary(ctx); err != nil {
|
||||
return xerrors.Errorf("ensure portabledesktop binary: %w", err)
|
||||
}
|
||||
|
||||
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".mp4")
|
||||
|
||||
// Use a background context so the process outlives the HTTP
|
||||
// request that triggered it.
|
||||
procCtx, procCancel := context.WithCancel(context.Background())
|
||||
|
||||
//nolint:gosec // portabledesktop is a trusted binary resolved via ensureBinary.
|
||||
cmd := p.execer.CommandContext(procCtx, p.binPath, "record",
|
||||
// The following options are used to speed up the recording when the desktop is idle.
|
||||
// They were taken out of an example in the portabledesktop repo.
|
||||
// There's likely room for improvement to optimize the values.
|
||||
"--idle-speedup", "20",
|
||||
"--idle-min-duration", "0.35",
|
||||
"--idle-noise-tolerance", "-38dB",
|
||||
filePath)
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
procCancel()
|
||||
return xerrors.Errorf("start recording process: %w", err)
|
||||
}
|
||||
|
||||
rec := &recordingProcess{
|
||||
cmd: cmd,
|
||||
filePath: filePath,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go func() {
|
||||
rec.waitErr = cmd.Wait()
|
||||
close(rec.done)
|
||||
// avoid a context resource leak by canceling the context
|
||||
procCancel()
|
||||
}()
|
||||
|
||||
p.recordings[recordingID] = rec
|
||||
|
||||
p.logger.Info(ctx, "started desktop recording",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("file_path", filePath),
|
||||
slog.F("pid", cmd.Process.Pid),
|
||||
)
|
||||
|
||||
// Record activity so a recording started on an already-idle
|
||||
// desktop does not stop immediately.
|
||||
p.lastDesktopActionAt.Store(p.clock.Now().UnixNano())
|
||||
|
||||
// Spawn a per-recording idle goroutine.
|
||||
idleCtx, idleCancel := context.WithCancel(context.Background())
|
||||
rec.idleCancel = idleCancel
|
||||
rec.idleDone = make(chan struct{})
|
||||
go func() {
|
||||
defer close(rec.idleDone)
|
||||
p.monitorRecordingIdle(idleCtx, rec)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopRecording finalizes the recording. Idempotent - safe to call
|
||||
// on an already-stopped recording. Returns a RecordingArtifact
|
||||
// that the caller can stream. The caller must close the Reader
|
||||
// on the returned artifact to avoid leaking file descriptors.
|
||||
func (p *portableDesktop) StopRecording(ctx context.Context, recordingID string) (*RecordingArtifact, error) {
|
||||
p.mu.Lock()
|
||||
rec, ok := p.recordings[recordingID]
|
||||
if !ok {
|
||||
p.mu.Unlock()
|
||||
return nil, ErrUnknownRecording
|
||||
}
|
||||
|
||||
p.lockedStopRecordingProcess(ctx, rec, false)
|
||||
killed := rec.killed
|
||||
p.mu.Unlock()
|
||||
|
||||
p.logger.Info(ctx, "stopped desktop recording",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("file_path", rec.filePath),
|
||||
)
|
||||
|
||||
if killed {
|
||||
return nil, ErrRecordingCorrupted
|
||||
}
|
||||
|
||||
// Open the file and return an artifact. Each call opens a fresh
|
||||
// file descriptor so the caller is insulated from restarts and
|
||||
// desktop close.
|
||||
f, err := os.Open(rec.filePath)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("open recording artifact: %w", err)
|
||||
}
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
_ = f.Close()
|
||||
return nil, xerrors.Errorf("stat recording artifact: %w", err)
|
||||
}
|
||||
return &RecordingArtifact{
|
||||
Reader: f,
|
||||
Size: info.Size(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// lockedStopRecordingProcess stops a single recording via stopOnce.
|
||||
// It sends SIGINT, waits up to 15 seconds for graceful exit, then
|
||||
// SIGKILLs. When force is true the process is SIGKILLed immediately
|
||||
// without attempting a graceful shutdown. Must be called while p.mu
|
||||
// is held; the lock is held for the full duration so that no
|
||||
// concurrent StopRecording caller can read rec.stopped = true
|
||||
// before the process has finished writing the MP4 file.
|
||||
//
|
||||
//nolint:revive // force flag keeps shared stopOnce/cleanup logic in one place.
|
||||
func (p *portableDesktop) lockedStopRecordingProcess(ctx context.Context, rec *recordingProcess, force bool) {
|
||||
rec.stopOnce.Do(func() {
|
||||
if force {
|
||||
_ = rec.cmd.Process.Kill()
|
||||
rec.killed = true
|
||||
} else {
|
||||
_ = interruptRecordingProcess(rec.cmd.Process)
|
||||
timer := p.clock.NewTimer(15*time.Second, "agentdesktop", "stop_timeout")
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-rec.done:
|
||||
case <-ctx.Done():
|
||||
_ = rec.cmd.Process.Kill()
|
||||
rec.killed = true
|
||||
case <-timer.C:
|
||||
_ = rec.cmd.Process.Kill()
|
||||
rec.killed = true
|
||||
}
|
||||
}
|
||||
rec.stopped = true
|
||||
if rec.idleCancel != nil {
|
||||
rec.idleCancel()
|
||||
}
|
||||
})
|
||||
// NOTE: We intentionally do not wait on rec.done here.
|
||||
// If goleak is added to this package's tests, this may
|
||||
// need revisiting to avoid flakes.
|
||||
}
|
||||
|
||||
// lockedActiveRecordingCount returns the number of recordings that
|
||||
// are still actively running. Must be called while p.mu is held.
|
||||
// The max concurrency is low (maxConcurrentRecordings = 5), so a
|
||||
// full scan is cheap and avoids maintaining a separate counter.
|
||||
func (p *portableDesktop) lockedActiveRecordingCount() int {
|
||||
active := 0
|
||||
for _, rec := range p.recordings {
|
||||
if rec.stopped {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-rec.done:
|
||||
default:
|
||||
active++
|
||||
}
|
||||
}
|
||||
return active
|
||||
}
|
||||
|
||||
// lockedCleanStaleRecordings removes stopped recordings whose temp
|
||||
// files are older than one hour. Must be called while p.mu is held.
|
||||
func (p *portableDesktop) lockedCleanStaleRecordings(ctx context.Context) {
|
||||
for id, rec := range p.recordings {
|
||||
if !rec.stopped {
|
||||
continue
|
||||
}
|
||||
info, err := os.Stat(rec.filePath)
|
||||
if err != nil {
|
||||
// File already removed or inaccessible; drop entry.
|
||||
delete(p.recordings, id)
|
||||
continue
|
||||
}
|
||||
if p.clock.Since(info.ModTime()) > time.Hour {
|
||||
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
|
||||
p.logger.Warn(ctx, "failed to remove stale recording file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("file_path", rec.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the desktop session and cleans up resources.
|
||||
func (p *portableDesktop) Close() error {
|
||||
p.mu.Lock()
|
||||
p.closed = true
|
||||
if p.session != nil {
|
||||
p.session.cancel()
|
||||
// Xvnc is a child process — killing it cleans up the X
|
||||
// session.
|
||||
_ = p.session.cmd.Process.Kill()
|
||||
_ = p.session.cmd.Wait()
|
||||
p.session = nil
|
||||
|
||||
// Force-kill all active recordings. The stopOnce inside
|
||||
// lockedStopRecordingProcess makes this safe for
|
||||
// already-stopped recordings.
|
||||
for _, rec := range p.recordings {
|
||||
p.lockedStopRecordingProcess(context.Background(), rec, true)
|
||||
}
|
||||
|
||||
// Snapshot recording file paths and idle goroutine channels
|
||||
// for cleanup, then clear the map.
|
||||
type recEntry struct {
|
||||
id string
|
||||
filePath string
|
||||
idleDone chan struct{}
|
||||
}
|
||||
var allRecs []recEntry
|
||||
for id, rec := range p.recordings {
|
||||
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, idleDone: rec.idleDone})
|
||||
delete(p.recordings, id)
|
||||
}
|
||||
session := p.session
|
||||
p.session = nil
|
||||
p.mu.Unlock()
|
||||
|
||||
// Wait for all per-recording idle goroutines to exit.
|
||||
for _, entry := range allRecs {
|
||||
if entry.idleDone != nil {
|
||||
<-entry.idleDone
|
||||
}
|
||||
}
|
||||
|
||||
// Remove all recording files and wait for the session to
|
||||
// exit with a timeout so a slow filesystem or hung process
|
||||
// cannot block agent shutdown indefinitely.
|
||||
cleanupDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(cleanupDone)
|
||||
for _, entry := range allRecs {
|
||||
if err := os.Remove(entry.filePath); err != nil && !os.IsNotExist(err) {
|
||||
p.logger.Warn(context.Background(), "failed to remove recording file on close",
|
||||
slog.F("recording_id", entry.id),
|
||||
slog.F("file_path", entry.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
if session != nil {
|
||||
session.cancel()
|
||||
if err := session.cmd.Process.Kill(); err != nil {
|
||||
p.logger.Warn(context.Background(), "failed to kill portabledesktop process",
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := session.cmd.Wait(); err != nil {
|
||||
var exitErr *exec.ExitError
|
||||
if !errors.As(err, &exitErr) {
|
||||
p.logger.Warn(context.Background(), "portabledesktop process exited with error",
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
timer := p.clock.NewTimer(15*time.Second, "agentdesktop", "close_cleanup_timeout")
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-cleanupDone:
|
||||
case <-timer.C:
|
||||
p.logger.Warn(context.Background(), "timed out waiting for close cleanup")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordActivity marks the desktop as having received user
|
||||
// interaction, resetting the idle-recording timer.
|
||||
func (p *portableDesktop) RecordActivity() {
|
||||
p.lastDesktopActionAt.Store(p.clock.Now().UnixNano())
|
||||
}
|
||||
|
||||
// runCmd executes a portabledesktop subcommand and returns combined
|
||||
// output. The caller must have previously called ensureBinary.
|
||||
func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, error) {
|
||||
@@ -397,3 +738,31 @@ func (p *portableDesktop) ensureBinary(ctx context.Context) error {
|
||||
|
||||
return xerrors.New("portabledesktop binary not found in PATH or script bin directory")
|
||||
}
|
||||
|
||||
// monitorRecordingIdle watches for desktop inactivity and stops the
|
||||
// given recording when the idle timeout is reached.
|
||||
func (p *portableDesktop) monitorRecordingIdle(ctx context.Context, rec *recordingProcess) {
|
||||
timer := p.clock.NewTimer(idleTimeout, "agentdesktop", "recording_idle")
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
lastNano := p.lastDesktopActionAt.Load()
|
||||
lastAction := time.Unix(0, lastNano)
|
||||
elapsed := p.clock.Since(lastAction)
|
||||
if elapsed >= idleTimeout {
|
||||
p.mu.Lock()
|
||||
p.lockedStopRecordingProcess(context.Background(), rec, false)
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
// Activity happened; reset with remaining budget.
|
||||
timer.Reset(idleTimeout-elapsed, "agentdesktop", "recording_idle")
|
||||
case <-rec.done:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,13 +9,17 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/pty"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// recordedExecer implements agentexec.Execer by recording every
|
||||
@@ -86,6 +90,7 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
|
||||
clock: quartz.NewReal(),
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
@@ -117,6 +122,7 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
clock: quartz.NewReal(),
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
@@ -159,6 +165,7 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
clock: quartz.NewReal(),
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
@@ -184,6 +191,7 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
clock: quartz.NewReal(),
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
@@ -282,6 +290,7 @@ func TestPortableDesktop_MouseMethods(t *testing.T) {
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
clock: quartz.NewReal(),
|
||||
}
|
||||
|
||||
err := tt.invoke(t.Context(), pd)
|
||||
@@ -289,7 +298,6 @@ func TestPortableDesktop_MouseMethods(t *testing.T) {
|
||||
|
||||
cmds := rec.allCommands()
|
||||
require.NotEmpty(t, cmds, "expected at least one command")
|
||||
|
||||
// Find at least one recorded command that contains
|
||||
// all expected argument substrings.
|
||||
found := false
|
||||
@@ -367,6 +375,7 @@ func TestPortableDesktop_KeyboardMethods(t *testing.T) {
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
clock: quartz.NewReal(),
|
||||
}
|
||||
|
||||
err := tt.invoke(t.Context(), pd)
|
||||
@@ -423,6 +432,7 @@ func TestPortableDesktop_Close(t *testing.T) {
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
clock: quartz.NewReal(),
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
@@ -445,7 +455,7 @@ func TestPortableDesktop_Close(t *testing.T) {
|
||||
// Subsequent Start must fail.
|
||||
_, err = pd.Start(ctx)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "desktop is closed")
|
||||
assert.Contains(t, err.Error(), "desktop closed")
|
||||
}
|
||||
|
||||
// --- ensureBinary tests ---
|
||||
@@ -539,7 +549,410 @@ func TestEnsureBinary_NotFound(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
}
|
||||
|
||||
func TestPortableDesktop_StartRecording(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewReal()
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
recID := uuid.New().String()
|
||||
err := pd.StartRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmds := rec.allCommands()
|
||||
require.NotEmpty(t, cmds)
|
||||
// Find the record command (not the up command).
|
||||
found := false
|
||||
for _, cmd := range cmds {
|
||||
joined := strings.Join(cmd, " ")
|
||||
if strings.Contains(joined, "record") && strings.Contains(joined, "coder-recording-"+recID) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected a record command with the recording ID")
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_StartRecording_ConcurrentLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewReal()
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
for i := range maxConcurrentRecordings {
|
||||
err := pd.StartRecording(ctx, uuid.New().String())
|
||||
require.NoError(t, err, "recording %d should succeed", i)
|
||||
}
|
||||
|
||||
err := pd.StartRecording(ctx, uuid.New().String())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "too many concurrent recordings")
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_StopRecording_ReturnsArtifact(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewReal()
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
recID := uuid.New().String()
|
||||
err := pd.StartRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write a dummy MP4 file at the expected path so StopRecording
|
||||
// can open it as an artifact.
|
||||
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".mp4")
|
||||
require.NoError(t, os.WriteFile(filePath, []byte("fake-mp4-data"), 0o600))
|
||||
t.Cleanup(func() { _ = os.Remove(filePath) })
|
||||
|
||||
artifact, err := pd.StopRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
defer artifact.Reader.Close()
|
||||
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_StopRecording_UnknownID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewReal()
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
_, err := pd.StopRecording(ctx, uuid.New().String())
|
||||
require.ErrorIs(t, err, ErrUnknownRecording)
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
// Ensure that portableDesktop satisfies the Desktop interface at
|
||||
// compile time. This uses the unexported type so it lives in the
|
||||
// internal test package.
|
||||
var _ Desktop = (*portableDesktop)(nil)
|
||||
|
||||
func TestPortableDesktop_IdleTimeout_StopsRecordings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewMock(t)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
recID := uuid.New().String()
|
||||
|
||||
// Install the trap before StartRecording so it is guaranteed
|
||||
// to catch the idle monitor's NewTimer call regardless of
|
||||
// goroutine scheduling.
|
||||
trap := clk.Trap().NewTimer("agentdesktop", "recording_idle")
|
||||
|
||||
err := pd.StartRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify recording is active.
|
||||
pd.mu.Lock()
|
||||
require.False(t, pd.recordings[recID].stopped)
|
||||
pd.mu.Unlock()
|
||||
|
||||
// Wait for the idle monitor timer to be created and release
|
||||
// it so the monitor enters its select loop.
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
trap.Close()
|
||||
|
||||
// The stop-all path calls lockedStopRecordingProcess which
|
||||
// creates a per-recording 15s stop_timeout timer.
|
||||
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
|
||||
|
||||
// Advance past idle timeout to trigger the stop-all.
|
||||
clk.Advance(idleTimeout)
|
||||
|
||||
// Wait for the stop timer to be created, then release it.
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
stopTrap.Close()
|
||||
|
||||
// The recording process should now be stopped.
|
||||
require.Eventually(t, func() bool {
|
||||
pd.mu.Lock()
|
||||
defer pd.mu.Unlock()
|
||||
rec, ok := pd.recordings[recID]
|
||||
return ok && rec.stopped
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_IdleTimeout_ActivityResetsTimer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewMock(t)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
recID := uuid.New().String()
|
||||
|
||||
// Install the trap before StartRecording so it is guaranteed
|
||||
// to catch the idle monitor's NewTimer call regardless of
|
||||
// goroutine scheduling.
|
||||
trap := clk.Trap().NewTimer("agentdesktop", "recording_idle")
|
||||
|
||||
err := pd.StartRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the idle monitor timer to be created.
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
trap.Close()
|
||||
|
||||
// Advance most of the way but not past the timeout.
|
||||
clk.Advance(idleTimeout - time.Minute)
|
||||
|
||||
// Record activity to reset the timer.
|
||||
pd.RecordActivity()
|
||||
|
||||
// Trap the Reset call that the idle monitor makes when it
|
||||
// sees recent activity.
|
||||
resetTrap := clk.Trap().TimerReset("agentdesktop", "recording_idle")
|
||||
|
||||
// Advance past the original idle timeout deadline. The
|
||||
// monitor should see the recent activity and reset instead
|
||||
// of stopping.
|
||||
clk.Advance(time.Minute)
|
||||
|
||||
resetTrap.MustWait(ctx).MustRelease(ctx)
|
||||
resetTrap.Close()
|
||||
|
||||
// Recording should still be active because activity was
|
||||
// recorded.
|
||||
pd.mu.Lock()
|
||||
require.False(t, pd.recordings[recID].stopped)
|
||||
pd.mu.Unlock()
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_IdleTimeout_MultipleRecordings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewMock(t)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
recID1 := uuid.New().String()
|
||||
recID2 := uuid.New().String()
|
||||
|
||||
// Trap idle timer creation for both recordings.
|
||||
trap := clk.Trap().NewTimer("agentdesktop", "recording_idle")
|
||||
|
||||
err := pd.StartRecording(ctx, recID1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for first recording's idle timer.
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
err = pd.StartRecording(ctx, recID2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for second recording's idle timer.
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
trap.Close()
|
||||
|
||||
// Trap the stop timers that will be created when idle fires.
|
||||
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
|
||||
|
||||
// Advance past idle timeout.
|
||||
clk.Advance(idleTimeout)
|
||||
|
||||
// Wait for both stop timers.
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
stopTrap.Close()
|
||||
|
||||
// Both recordings should be stopped.
|
||||
require.Eventually(t, func() bool {
|
||||
pd.mu.Lock()
|
||||
defer pd.mu.Unlock()
|
||||
r1, ok1 := pd.recordings[recID1]
|
||||
r2, ok2 := pd.recordings[recID2]
|
||||
return ok1 && r1.stopped && ok2 && r2.stopped
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_StartRecording_ReturnsErrDesktopClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewReal()
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
// Start and close the desktop so it's in the closed state.
|
||||
ctx := t.Context()
|
||||
_, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, pd.Close())
|
||||
|
||||
// StartRecording should now return ErrDesktopClosed.
|
||||
err = pd.StartRecording(ctx, uuid.New().String())
|
||||
require.ErrorIs(t, err, ErrDesktopClosed)
|
||||
}
|
||||
|
||||
func TestPortableDesktop_Start_ReturnsErrDesktopClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: quartz.NewReal(),
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(pd.clock.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
_, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, pd.Close())
|
||||
|
||||
_, err = pd.Start(ctx)
|
||||
require.ErrorIs(t, err, ErrDesktopClosed)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package agentdesktop
|
||||
|
||||
import "os"
|
||||
|
||||
// interruptRecordingProcess sends a SIGINT to the recording process
|
||||
// for graceful shutdown. On Unix, os.Interrupt is delivered as
|
||||
// SIGINT which lets the recorder finalize the MP4 container.
|
||||
func interruptRecordingProcess(p *os.Process) error {
|
||||
return p.Signal(os.Interrupt)
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package agentdesktop
|
||||
|
||||
import "os"
|
||||
|
||||
// interruptRecordingProcess kills the recording process directly
|
||||
// because os.Process.Signal(os.Interrupt) is not supported on
|
||||
// Windows and returns an error without delivering a signal.
|
||||
func interruptRecordingProcess(p *os.Process) error {
|
||||
return p.Kill()
|
||||
}
|
||||
@@ -187,7 +187,11 @@ func (*Manager) connectServer(ctx context.Context, cfg ServerConfig) (*client.Cl
|
||||
connectCtx, cancel := context.WithTimeout(ctx, connectTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := c.Start(connectCtx); err != nil {
|
||||
// Use the parent ctx (not connectCtx) so the subprocess outlives
|
||||
// the connect/initialize handshake. connectCtx bounds only the
|
||||
// Initialize call below. The subprocess is cleaned up when the
|
||||
// Manager is closed or ctx is canceled.
|
||||
if err := c.Start(ctx); err != nil {
|
||||
_ = c.Close()
|
||||
return nil, xerrors.Errorf("start %q: %w", cfg.Name, err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package agentmcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
@@ -8,6 +13,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestSplitToolName(t *testing.T) {
|
||||
@@ -193,3 +199,118 @@ func TestConvertResult(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectServer_StdioProcessSurvivesConnect verifies that a stdio MCP
|
||||
// server subprocess remains alive after connectServer returns. This is a
|
||||
// regression test for a bug where the subprocess was tied to a short-lived
|
||||
// connectCtx and killed as soon as the context was canceled.
|
||||
func TestConnectServer_StdioProcessSurvivesConnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" {
|
||||
// Child process: act as a minimal MCP server over stdio.
|
||||
runFakeMCPServer()
|
||||
return
|
||||
}
|
||||
|
||||
// Get the path to the test binary so we can re-exec ourselves
|
||||
// as a fake MCP server subprocess.
|
||||
testBin, err := os.Executable()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := ServerConfig{
|
||||
Name: "fake",
|
||||
Transport: "stdio",
|
||||
Command: testBin,
|
||||
Args: []string{"-test.run=^TestConnectServer_StdioProcessSurvivesConnect$"},
|
||||
Env: map[string]string{"TEST_MCP_FAKE_SERVER": "1"},
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
m := &Manager{}
|
||||
client, err := m.connectServer(ctx, cfg)
|
||||
require.NoError(t, err, "connectServer should succeed")
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
// At this point connectServer has returned and its internal
|
||||
// connectCtx has been canceled. The subprocess must still be
|
||||
// alive. Verify by listing tools (requires a live server).
|
||||
listCtx, listCancel := context.WithTimeout(ctx, testutil.WaitShort)
|
||||
defer listCancel()
|
||||
result, err := client.ListTools(listCtx, mcp.ListToolsRequest{})
|
||||
require.NoError(t, err, "ListTools should succeed — server must be alive after connect")
|
||||
require.Len(t, result.Tools, 1)
|
||||
assert.Equal(t, "echo", result.Tools[0].Name)
|
||||
}
|
||||
|
||||
// runFakeMCPServer implements a minimal JSON-RPC / MCP server over
|
||||
// stdin/stdout, just enough for initialize + tools/list.
|
||||
func runFakeMCPServer() {
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
|
||||
var req struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.RawMessage `json:"id"`
|
||||
Method string `json:"method"`
|
||||
}
|
||||
if err := json.Unmarshal(line, &req); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var resp any
|
||||
switch req.Method {
|
||||
case "initialize":
|
||||
resp = map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": req.ID,
|
||||
"result": map[string]any{
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": map[string]any{
|
||||
"tools": map[string]any{},
|
||||
},
|
||||
"serverInfo": map[string]any{
|
||||
"name": "fake-server",
|
||||
"version": "0.0.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
case "notifications/initialized":
|
||||
// No response needed for notifications.
|
||||
continue
|
||||
case "tools/list":
|
||||
resp = map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": req.ID,
|
||||
"result": map[string]any{
|
||||
"tools": []map[string]any{
|
||||
{
|
||||
"name": "echo",
|
||||
"description": "echoes input",
|
||||
"inputSchema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
default:
|
||||
resp = map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": req.ID,
|
||||
"error": map[string]any{
|
||||
"code": -32601,
|
||||
"message": "method not found",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
out, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
_, _ = fmt.Fprintf(os.Stdout, "%s\n", out)
|
||||
}
|
||||
}
|
||||
|
||||
+28
-19
@@ -3,11 +3,13 @@
|
||||
"enabled": true,
|
||||
"clientKind": "git",
|
||||
"useIgnoreFile": true,
|
||||
"defaultBranch": "main",
|
||||
"defaultBranch": "main"
|
||||
},
|
||||
"files": {
|
||||
"includes": ["**", "!**/pnpm-lock.yaml"],
|
||||
"ignoreUnknown": true,
|
||||
// static/*.html are Go templates with {{ }} directives that
|
||||
// Biome's HTML parser does not support.
|
||||
"includes": ["**", "!**/pnpm-lock.yaml", "!**/static/*.html"],
|
||||
"ignoreUnknown": true
|
||||
},
|
||||
"linter": {
|
||||
"rules": {
|
||||
@@ -15,7 +17,7 @@
|
||||
"noSvgWithoutTitle": "off",
|
||||
"useButtonType": "off",
|
||||
"useSemanticElements": "off",
|
||||
"noStaticElementInteractions": "off",
|
||||
"noStaticElementInteractions": "off"
|
||||
},
|
||||
"correctness": {
|
||||
"noUnusedImports": "warn",
|
||||
@@ -24,9 +26,9 @@
|
||||
"noUnusedVariables": {
|
||||
"level": "warn",
|
||||
"options": {
|
||||
"ignoreRestSiblings": true,
|
||||
},
|
||||
},
|
||||
"ignoreRestSiblings": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"style": {
|
||||
"noNonNullAssertion": "off",
|
||||
@@ -47,7 +49,7 @@
|
||||
"paths": {
|
||||
"react": {
|
||||
"message": "React 19 no longer requires forwardRef. Use ref as a prop instead.",
|
||||
"importNames": ["forwardRef"],
|
||||
"importNames": ["forwardRef"]
|
||||
},
|
||||
// "@mui/material/Alert": "Use components/Alert/Alert instead.",
|
||||
// "@mui/material/AlertTitle": "Use components/Alert/Alert instead.",
|
||||
@@ -115,10 +117,10 @@
|
||||
"@emotion/styled": "Use Tailwind CSS instead.",
|
||||
// "@emotion/cache": "Use Tailwind CSS instead.",
|
||||
// "components/Stack/Stack": "Use Tailwind flex utilities instead (e.g., <div className='flex flex-col gap-4'>).",
|
||||
"lodash": "Use lodash/<name> instead.",
|
||||
},
|
||||
},
|
||||
},
|
||||
"lodash": "Use lodash/<name> instead."
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"suspicious": {
|
||||
"noArrayIndexKey": "off",
|
||||
@@ -129,14 +131,21 @@
|
||||
"noConsole": {
|
||||
"level": "error",
|
||||
"options": {
|
||||
"allow": ["error", "info", "warn"],
|
||||
},
|
||||
},
|
||||
"allow": ["error", "info", "warn"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"complexity": {
|
||||
"noImportantStyles": "off", // TODO: check and fix !important styles
|
||||
},
|
||||
},
|
||||
"noImportantStyles": "off" // TODO: check and fix !important styles
|
||||
}
|
||||
}
|
||||
},
|
||||
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json",
|
||||
"css": {
|
||||
"parser": {
|
||||
// Biome 2.3+ requires opt-in for @apply and other
|
||||
// Tailwind directives.
|
||||
"tailwindDirectives": true
|
||||
}
|
||||
},
|
||||
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json"
|
||||
}
|
||||
|
||||
+33
-6
@@ -17,6 +17,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/xerrors"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
@@ -272,11 +273,14 @@ func workspaceAgent() *serpent.Command {
|
||||
logger.Info(ctx, "agent devcontainer detection not enabled")
|
||||
}
|
||||
|
||||
reinitEvents := agentsdk.WaitForReinitLoop(ctx, logger, client)
|
||||
reinitCtx, reinitCancel := context.WithCancel(ctx)
|
||||
defer reinitCancel()
|
||||
reinitEvents := agentsdk.WaitForReinitLoop(reinitCtx, logger, client)
|
||||
|
||||
var (
|
||||
lastErr error
|
||||
mustExit bool
|
||||
lastOwnerID uuid.UUID
|
||||
lastErr error
|
||||
mustExit bool
|
||||
)
|
||||
for {
|
||||
prometheusRegistry := prometheus.NewRegistry()
|
||||
@@ -343,9 +347,32 @@ func workspaceAgent() *serpent.Command {
|
||||
case <-ctx.Done():
|
||||
logger.Info(ctx, "agent shutting down", slog.Error(context.Cause(ctx)))
|
||||
mustExit = true
|
||||
case event := <-reinitEvents:
|
||||
logger.Info(ctx, "agent received instruction to reinitialize",
|
||||
slog.F("workspace_id", event.WorkspaceID), slog.F("reason", event.Reason))
|
||||
case event, ok := <-reinitEvents:
|
||||
switch {
|
||||
case !ok:
|
||||
// Channel closed — the reinit loop exited
|
||||
// (terminal 409 or context expired). Keep
|
||||
// running the current agent until the parent
|
||||
// context is canceled.
|
||||
logger.Info(ctx, "reinit channel closed, running without reinit capability")
|
||||
reinitEvents = nil
|
||||
<-ctx.Done()
|
||||
mustExit = true
|
||||
case event.OwnerID != uuid.Nil && event.OwnerID == lastOwnerID:
|
||||
// Duplicate reinit for same owner — already
|
||||
// reinitialized. Cancel the reinit loop
|
||||
// goroutine and keep the current agent.
|
||||
logger.Info(ctx, "skipping redundant reinit, owner unchanged",
|
||||
slog.F("owner_id", event.OwnerID))
|
||||
reinitCancel()
|
||||
reinitEvents = nil
|
||||
<-ctx.Done()
|
||||
mustExit = true
|
||||
default:
|
||||
lastOwnerID = event.OwnerID
|
||||
logger.Info(ctx, "agent received instruction to reinitialize",
|
||||
slog.F("workspace_id", event.WorkspaceID), slog.F("reason", event.Reason))
|
||||
}
|
||||
}
|
||||
|
||||
lastErr = agnt.Close()
|
||||
|
||||
@@ -104,7 +104,7 @@ func (b *Builder) Build(inv *serpent.Invocation) (log slog.Logger, closeLog func
|
||||
|
||||
addSinkIfProvided := func(sinkFn func(io.Writer) slog.Sink, loc string) error {
|
||||
switch loc {
|
||||
case "":
|
||||
case "", "/dev/null":
|
||||
case "/dev/stdout":
|
||||
sinks = append(sinks, sinkFn(inv.Stdout))
|
||||
|
||||
|
||||
@@ -1401,6 +1401,9 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {
|
||||
// Setup our workspace agent connection.
|
||||
config := workspacetraffic.Config{
|
||||
AgentID: agent.ID,
|
||||
WorkspaceID: ws.ID,
|
||||
WorkspaceName: ws.Name,
|
||||
AgentName: agent.Name,
|
||||
BytesPerTick: bytesPerTick,
|
||||
Duration: strategy.timeout,
|
||||
TickInterval: tickInterval,
|
||||
|
||||
@@ -85,7 +85,7 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor
|
||||
AgentName: a.AgentName,
|
||||
Type: connectionType,
|
||||
Code: code,
|
||||
Ip: logIP,
|
||||
IP: logIP,
|
||||
ConnectionID: uuid.NullUUID{
|
||||
UUID: connectionID,
|
||||
Valid: true,
|
||||
|
||||
@@ -152,7 +152,7 @@ func TestConnectionLog(t *testing.T) {
|
||||
Int32: tt.status,
|
||||
Valid: *tt.action == agentproto.Connection_DISCONNECT,
|
||||
},
|
||||
Ip: expectedIP,
|
||||
IP: expectedIP,
|
||||
Type: agentProtoConnectionTypeToConnectionLog(t, *tt.typ),
|
||||
DisconnectReason: sql.NullString{
|
||||
String: tt.reason,
|
||||
|
||||
Generated
+59
-2
@@ -10205,12 +10205,26 @@ const docTemplate = `{
|
||||
],
|
||||
"summary": "Get workspace agent reinitialization",
|
||||
"operationId": "get-workspace-agent-reinitialization",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "boolean",
|
||||
"description": "Opt in to durable reinit checks",
|
||||
"name": "wait",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/agentsdk.ReinitializationEvent"
|
||||
}
|
||||
},
|
||||
"409": {
|
||||
"description": "Conflict",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
@@ -12647,11 +12661,16 @@ const docTemplate = `{
|
||||
"agentsdk.ReinitializationEvent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"reason": {
|
||||
"$ref": "#/definitions/agentsdk.ReinitializationReason"
|
||||
},
|
||||
"workspaceID": {
|
||||
"type": "string"
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -13114,6 +13133,12 @@ const docTemplate = `{
|
||||
"codersdk.AIBridgeSessionThreadsTokenUsage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cache_read_input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"cache_write_input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
@@ -13129,6 +13154,12 @@ const docTemplate = `{
|
||||
"codersdk.AIBridgeSessionTokenUsageSummary": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cache_read_input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"cache_write_input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
@@ -13175,6 +13206,12 @@ const docTemplate = `{
|
||||
"codersdk.AIBridgeTokenUsage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cache_read_input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"cache_write_input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -14138,6 +14175,9 @@ const docTemplate = `{
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -14459,6 +14499,9 @@ const docTemplate = `{
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -14546,6 +14589,17 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateFirstUserOnboardingInfo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"newsletter_marketing": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"newsletter_releases": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateFirstUserRequest": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
@@ -14560,6 +14614,9 @@ const docTemplate = `{
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"onboarding_info": {
|
||||
"$ref": "#/definitions/codersdk.CreateFirstUserOnboardingInfo"
|
||||
},
|
||||
"password": {
|
||||
"type": "string"
|
||||
},
|
||||
|
||||
Generated
+59
-2
@@ -9038,12 +9038,26 @@
|
||||
"tags": ["Agents"],
|
||||
"summary": "Get workspace agent reinitialization",
|
||||
"operationId": "get-workspace-agent-reinitialization",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "boolean",
|
||||
"description": "Opt in to durable reinit checks",
|
||||
"name": "wait",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/agentsdk.ReinitializationEvent"
|
||||
}
|
||||
},
|
||||
"409": {
|
||||
"description": "Conflict",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Response"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
@@ -11229,11 +11243,16 @@
|
||||
"agentsdk.ReinitializationEvent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"reason": {
|
||||
"$ref": "#/definitions/agentsdk.ReinitializationReason"
|
||||
},
|
||||
"workspaceID": {
|
||||
"type": "string"
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -11692,6 +11711,12 @@
|
||||
"codersdk.AIBridgeSessionThreadsTokenUsage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cache_read_input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"cache_write_input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
@@ -11707,6 +11732,12 @@
|
||||
"codersdk.AIBridgeSessionTokenUsageSummary": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cache_read_input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"cache_write_input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
@@ -11753,6 +11784,12 @@
|
||||
"codersdk.AIBridgeTokenUsage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cache_read_input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"cache_write_input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -12702,6 +12739,9 @@
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -13002,6 +13042,9 @@
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -13086,6 +13129,17 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateFirstUserOnboardingInfo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"newsletter_marketing": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"newsletter_releases": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateFirstUserRequest": {
|
||||
"type": "object",
|
||||
"required": ["email", "password", "username"],
|
||||
@@ -13096,6 +13150,9 @@
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"onboarding_info": {
|
||||
"$ref": "#/definitions/codersdk.CreateFirstUserOnboardingInfo"
|
||||
},
|
||||
"password": {
|
||||
"type": "string"
|
||||
},
|
||||
|
||||
+8
-1
@@ -26,6 +26,11 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// Limit the count query to avoid a slow sequential scan due to joins
|
||||
// on a large table. Set to 0 to disable capping (but also see the note
|
||||
// in the SQL query).
|
||||
const auditLogCountCap = 2000
|
||||
|
||||
// @Summary Get audit logs
|
||||
// @ID get-audit-logs
|
||||
// @Security CoderSessionToken
|
||||
@@ -66,7 +71,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
countFilter.Username = ""
|
||||
}
|
||||
|
||||
// Use the same filters to count the number of audit logs
|
||||
countFilter.CountCap = auditLogCountCap
|
||||
count, err := api.Database.CountAuditLogs(ctx, countFilter)
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
@@ -81,6 +86,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
|
||||
AuditLogs: []codersdk.AuditLog{},
|
||||
Count: 0,
|
||||
CountCap: auditLogCountCap,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -98,6 +104,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
|
||||
AuditLogs: api.convertAuditLogs(ctx, dblogs),
|
||||
Count: count,
|
||||
CountCap: auditLogCountCap,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
+10
-1
@@ -168,6 +168,7 @@ type Options struct {
|
||||
ConnectionLogger connectionlog.ConnectionLogger
|
||||
AgentConnectionUpdateFrequency time.Duration
|
||||
AgentInactiveDisconnectTimeout time.Duration
|
||||
ChatdInstructionLookupTimeout time.Duration
|
||||
AWSCertificates awsidentity.Certificates
|
||||
Authorizer rbac.Authorizer
|
||||
AzureCertificates x509.VerifyOptions
|
||||
@@ -782,9 +783,10 @@ func New(options *Options) *API {
|
||||
ReplicaID: api.ID,
|
||||
SubscribeFn: options.ChatSubscribeFn,
|
||||
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
|
||||
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
ProviderAPIKeys: ChatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout,
|
||||
InstructionLookupTimeout: options.ChatdInstructionLookupTimeout,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
StartWorkspace: api.chatStartWorkspace,
|
||||
Pubsub: options.Pubsub,
|
||||
@@ -1221,6 +1223,13 @@ func New(options *Options) *API {
|
||||
r.Delete("/", api.deleteChatUsageLimitGroupOverride)
|
||||
})
|
||||
})
|
||||
r.Route("/user-provider-configs", func(r chi.Router) {
|
||||
r.Get("/", api.listUserChatProviderConfigs)
|
||||
r.Route("/{providerConfig}", func(r chi.Router) {
|
||||
r.Put("/", api.upsertUserChatProviderKey)
|
||||
r.Delete("/", api.deleteUserChatProviderKey)
|
||||
})
|
||||
})
|
||||
r.Route("/{chat}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractChatParam(options.Database))
|
||||
r.Get("/", api.getChat)
|
||||
|
||||
@@ -149,12 +149,13 @@ type Options struct {
|
||||
OneTimePasscodeValidityPeriod time.Duration
|
||||
|
||||
// IncludeProvisionerDaemon when true means to start an in-memory provisionerD
|
||||
IncludeProvisionerDaemon bool
|
||||
ProvisionerDaemonVersion string
|
||||
ProvisionerDaemonTags map[string]string
|
||||
MetricsCacheRefreshInterval time.Duration
|
||||
AgentStatsRefreshInterval time.Duration
|
||||
DeploymentValues *codersdk.DeploymentValues
|
||||
IncludeProvisionerDaemon bool
|
||||
ChatdInstructionLookupTimeout time.Duration
|
||||
ProvisionerDaemonVersion string
|
||||
ProvisionerDaemonTags map[string]string
|
||||
MetricsCacheRefreshInterval time.Duration
|
||||
AgentStatsRefreshInterval time.Duration
|
||||
DeploymentValues *codersdk.DeploymentValues
|
||||
|
||||
// Set update check options to enable update check.
|
||||
UpdateCheckOptions *updatecheck.Options
|
||||
@@ -575,6 +576,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
|
||||
// Force a long disconnection timeout to ensure
|
||||
// agents are not marked as disconnected during slow tests.
|
||||
AgentInactiveDisconnectTimeout: testutil.WaitShort,
|
||||
ChatdInstructionLookupTimeout: options.ChatdInstructionLookupTimeout,
|
||||
AccessURL: accessURL,
|
||||
AppHostname: options.AppHostname,
|
||||
AppHostnameRegex: appHostnameRegex,
|
||||
|
||||
@@ -90,8 +90,8 @@ func (m *FakeConnectionLogger) Contains(t testing.TB, expected database.UpsertCo
|
||||
t.Logf("connection log %d: expected Code %d, got %d", idx+1, expected.Code.Int32, cl.Code.Int32)
|
||||
continue
|
||||
}
|
||||
if expected.Ip.Valid && cl.Ip.IPNet.String() != expected.Ip.IPNet.String() {
|
||||
t.Logf("connection log %d: expected IP %s, got %s", idx+1, expected.Ip.IPNet, cl.Ip.IPNet)
|
||||
if expected.IP.Valid && cl.IP.IPNet.String() != expected.IP.IPNet.String() {
|
||||
t.Logf("connection log %d: expected IP %s, got %s", idx+1, expected.IP.IPNet, cl.IP.IPNet)
|
||||
continue
|
||||
}
|
||||
if expected.UserAgent.Valid && cl.UserAgent.String != expected.UserAgent.String {
|
||||
|
||||
@@ -10,6 +10,7 @@ const (
|
||||
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
|
||||
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
|
||||
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
|
||||
CheckValidCredentialPolicy CheckConstraint = "valid_credential_policy" // chat_providers
|
||||
CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config
|
||||
CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config
|
||||
CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config
|
||||
@@ -32,4 +33,5 @@ const (
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
CheckUserChatProviderKeysAPIKeyCheck CheckConstraint = "user_chat_provider_keys_api_key_check" // user_chat_provider_keys
|
||||
)
|
||||
|
||||
@@ -1037,8 +1037,10 @@ func AIBridgeSession(row database.ListAIBridgeSessionsRow) codersdk.AIBridgeSess
|
||||
StartedAt: row.StartedAt,
|
||||
Threads: row.Threads,
|
||||
TokenUsageSummary: codersdk.AIBridgeSessionTokenUsageSummary{
|
||||
InputTokens: row.InputTokens,
|
||||
OutputTokens: row.OutputTokens,
|
||||
InputTokens: row.InputTokens,
|
||||
OutputTokens: row.OutputTokens,
|
||||
CacheReadInputTokens: row.CacheReadInputTokens,
|
||||
CacheWriteInputTokens: row.CacheWriteInputTokens,
|
||||
},
|
||||
}
|
||||
// Ensure non-nil slices for JSON serialization.
|
||||
@@ -1062,13 +1064,15 @@ func AIBridgeSession(row database.ListAIBridgeSessionsRow) codersdk.AIBridgeSess
|
||||
|
||||
func AIBridgeTokenUsage(usage database.AIBridgeTokenUsage) codersdk.AIBridgeTokenUsage {
|
||||
return codersdk.AIBridgeTokenUsage{
|
||||
ID: usage.ID,
|
||||
InterceptionID: usage.InterceptionID,
|
||||
ProviderResponseID: usage.ProviderResponseID,
|
||||
InputTokens: usage.InputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
Metadata: jsonOrEmptyMap(usage.Metadata),
|
||||
CreatedAt: usage.CreatedAt,
|
||||
ID: usage.ID,
|
||||
InterceptionID: usage.InterceptionID,
|
||||
ProviderResponseID: usage.ProviderResponseID,
|
||||
InputTokens: usage.InputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
CacheReadInputTokens: usage.CacheReadInputTokens,
|
||||
CacheWriteInputTokens: usage.CacheWriteInputTokens,
|
||||
Metadata: jsonOrEmptyMap(usage.Metadata),
|
||||
CreatedAt: usage.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1179,9 +1183,11 @@ func AIBridgeSessionThreads(
|
||||
PageStartedAt: pageStartedAt,
|
||||
PageEndedAt: pageEndedAt,
|
||||
TokenUsageSummary: codersdk.AIBridgeSessionThreadsTokenUsage{
|
||||
InputTokens: session.InputTokens,
|
||||
OutputTokens: session.OutputTokens,
|
||||
Metadata: sessionTokenMeta,
|
||||
InputTokens: session.InputTokens,
|
||||
OutputTokens: session.OutputTokens,
|
||||
CacheReadInputTokens: session.CacheReadInputTokens,
|
||||
CacheWriteInputTokens: session.CacheWriteInputTokens,
|
||||
Metadata: sessionTokenMeta,
|
||||
},
|
||||
Threads: threads,
|
||||
}
|
||||
@@ -1314,17 +1320,19 @@ func buildAIBridgeThread(
|
||||
|
||||
// aggregateTokenUsage sums token usage rows and aggregates metadata.
|
||||
func aggregateTokenUsage(tokens []database.AIBridgeTokenUsage) codersdk.AIBridgeSessionThreadsTokenUsage {
|
||||
var inputTokens, outputTokens int64
|
||||
var inputTokens, outputTokens, cacheRead, cacheWrite int64
|
||||
for _, tu := range tokens {
|
||||
inputTokens += tu.InputTokens
|
||||
outputTokens += tu.OutputTokens
|
||||
// TODO: once https://github.com/coder/aibridge/issues/150 lands we
|
||||
// should aggregate the other token types.
|
||||
cacheRead += tu.CacheReadInputTokens
|
||||
cacheWrite += tu.CacheWriteInputTokens
|
||||
}
|
||||
return codersdk.AIBridgeSessionThreadsTokenUsage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
Metadata: aggregateTokenMetadata(tokens),
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
CacheReadInputTokens: cacheRead,
|
||||
CacheWriteInputTokens: cacheWrite,
|
||||
Metadata: aggregateTokenMetadata(tokens),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1520,7 +1528,10 @@ func nullInt64Ptr(v sql.NullInt64) *int64 {
|
||||
// Chat converts a database.Chat to a codersdk.Chat. It coalesces
|
||||
// nil slices and maps to empty values for JSON serialization and
|
||||
// derives RootChatID from the parent chain when not explicitly set.
|
||||
func Chat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat {
|
||||
// When diffStatus is non-nil the response includes diff metadata.
|
||||
// When files is non-empty the response includes file metadata;
|
||||
// pass nil to omit the files field (e.g. list endpoints).
|
||||
func Chat(c database.Chat, diffStatus *database.ChatDiffStatus, files []database.GetChatFileMetadataByChatIDRow) codersdk.Chat {
|
||||
mcpServerIDs := c.MCPServerIDs
|
||||
if mcpServerIDs == nil {
|
||||
mcpServerIDs = []uuid.UUID{}
|
||||
@@ -1573,6 +1584,19 @@ func Chat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat {
|
||||
convertedDiffStatus := ChatDiffStatus(c.ID, diffStatus)
|
||||
chat.DiffStatus = &convertedDiffStatus
|
||||
}
|
||||
if len(files) > 0 {
|
||||
chat.Files = make([]codersdk.ChatFileMetadata, 0, len(files))
|
||||
for _, row := range files {
|
||||
chat.Files = append(chat.Files, codersdk.ChatFileMetadata{
|
||||
ID: row.ID,
|
||||
OwnerID: row.OwnerID,
|
||||
OrganizationID: row.OrganizationID,
|
||||
Name: row.Name,
|
||||
MimeType: row.Mimetype,
|
||||
CreatedAt: row.CreatedAt,
|
||||
})
|
||||
}
|
||||
}
|
||||
if c.LastInjectedContext.Valid {
|
||||
var parts []codersdk.ChatMessagePart
|
||||
// Internal fields are stripped at write time in
|
||||
@@ -1596,9 +1620,9 @@ func ChatRows(rows []database.GetChatsRow, diffStatusesByChatID map[uuid.UUID]da
|
||||
for i, row := range rows {
|
||||
diffStatus, ok := diffStatusesByChatID[row.Chat.ID]
|
||||
if ok {
|
||||
result[i] = Chat(row.Chat, &diffStatus)
|
||||
result[i] = Chat(row.Chat, &diffStatus, nil)
|
||||
} else {
|
||||
result[i] = Chat(row.Chat, nil)
|
||||
result[i] = Chat(row.Chat, nil, nil)
|
||||
if diffStatusesByChatID != nil {
|
||||
emptyDiffStatus := ChatDiffStatus(row.Chat.ID, nil)
|
||||
result[i].DiffStatus = &emptyDiffStatus
|
||||
|
||||
@@ -259,11 +259,13 @@ func TestAIBridgeInterception(t *testing.T) {
|
||||
},
|
||||
tokenUsages: []database.AIBridgeTokenUsage{
|
||||
{
|
||||
ID: uuid.New(),
|
||||
InterceptionID: interceptionID,
|
||||
ProviderResponseID: "resp-123",
|
||||
InputTokens: 100,
|
||||
OutputTokens: 200,
|
||||
ID: uuid.New(),
|
||||
InterceptionID: interceptionID,
|
||||
ProviderResponseID: "resp-123",
|
||||
InputTokens: 100,
|
||||
OutputTokens: 200,
|
||||
CacheReadInputTokens: 50,
|
||||
CacheWriteInputTokens: 10,
|
||||
Metadata: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{"cache":"hit"}`),
|
||||
Valid: true,
|
||||
@@ -413,6 +415,8 @@ func TestAIBridgeInterception(t *testing.T) {
|
||||
require.Equal(t, tu.ProviderResponseID, result.TokenUsages[i].ProviderResponseID)
|
||||
require.Equal(t, tu.InputTokens, result.TokenUsages[i].InputTokens)
|
||||
require.Equal(t, tu.OutputTokens, result.TokenUsages[i].OutputTokens)
|
||||
require.Equal(t, tu.CacheReadInputTokens, result.TokenUsages[i].CacheReadInputTokens)
|
||||
require.Equal(t, tu.CacheWriteInputTokens, result.TokenUsages[i].CacheWriteInputTokens)
|
||||
}
|
||||
|
||||
// Verify user prompts are converted correctly.
|
||||
@@ -557,14 +561,26 @@ func TestChat_AllFieldsPopulated(t *testing.T) {
|
||||
ChatID: input.ID,
|
||||
}
|
||||
|
||||
got := db2sdk.Chat(input, diffStatus)
|
||||
fileRows := []database.GetChatFileMetadataByChatIDRow{
|
||||
{
|
||||
ID: uuid.New(),
|
||||
OwnerID: input.OwnerID,
|
||||
OrganizationID: uuid.New(),
|
||||
Name: "test.png",
|
||||
Mimetype: "image/png",
|
||||
CreatedAt: now,
|
||||
},
|
||||
}
|
||||
|
||||
got := db2sdk.Chat(input, diffStatus, fileRows)
|
||||
|
||||
v := reflect.ValueOf(got)
|
||||
typ := v.Type()
|
||||
// HasUnread is populated by ChatRows (which joins the
|
||||
// read-cursor query), not by Chat, so it is expected
|
||||
// to remain zero here.
|
||||
skip := map[string]bool{"HasUnread": true}
|
||||
// read-cursor query), not by Chat. Warnings is a transient
|
||||
// field populated by handlers, not the converter. Both are
|
||||
// expected to remain zero here.
|
||||
skip := map[string]bool{"HasUnread": true, "Warnings": true}
|
||||
for i := range typ.NumField() {
|
||||
field := typ.Field(i)
|
||||
if skip[field.Name] {
|
||||
@@ -577,6 +593,112 @@ func TestChat_AllFieldsPopulated(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_FileMetadataConversion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ownerID := uuid.New()
|
||||
orgID := uuid.New()
|
||||
fileID := uuid.New()
|
||||
now := dbtime.Now()
|
||||
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
OwnerID: ownerID,
|
||||
LastModelConfigID: uuid.New(),
|
||||
Title: "file metadata test",
|
||||
Status: database.ChatStatusWaiting,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
rows := []database.GetChatFileMetadataByChatIDRow{
|
||||
{
|
||||
ID: fileID,
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
Name: "screenshot.png",
|
||||
Mimetype: "image/png",
|
||||
CreatedAt: now,
|
||||
},
|
||||
}
|
||||
|
||||
result := db2sdk.Chat(chat, nil, rows)
|
||||
|
||||
require.Len(t, result.Files, 1)
|
||||
f := result.Files[0]
|
||||
require.Equal(t, fileID, f.ID)
|
||||
require.Equal(t, ownerID, f.OwnerID, "OwnerID must be mapped from DB row")
|
||||
require.Equal(t, orgID, f.OrganizationID, "OrganizationID must be mapped from DB row")
|
||||
require.Equal(t, "screenshot.png", f.Name)
|
||||
require.Equal(t, "image/png", f.MimeType)
|
||||
require.Equal(t, now, f.CreatedAt)
|
||||
|
||||
// Verify JSON serialization uses snake_case for mime_type.
|
||||
data, err := json.Marshal(f)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(data), `"mime_type"`)
|
||||
require.NotContains(t, string(data), `"mimetype"`)
|
||||
}
|
||||
|
||||
func TestChat_NilFilesOmitted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
OwnerID: uuid.New(),
|
||||
LastModelConfigID: uuid.New(),
|
||||
Title: "no files",
|
||||
Status: database.ChatStatusWaiting,
|
||||
CreatedAt: dbtime.Now(),
|
||||
UpdatedAt: dbtime.Now(),
|
||||
}
|
||||
|
||||
result := db2sdk.Chat(chat, nil, nil)
|
||||
require.Empty(t, result.Files)
|
||||
}
|
||||
|
||||
func TestChat_MultipleFiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := dbtime.Now()
|
||||
file1 := uuid.New()
|
||||
file2 := uuid.New()
|
||||
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
OwnerID: uuid.New(),
|
||||
LastModelConfigID: uuid.New(),
|
||||
Title: "multi file test",
|
||||
Status: database.ChatStatusWaiting,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
rows := []database.GetChatFileMetadataByChatIDRow{
|
||||
{
|
||||
ID: file1,
|
||||
OwnerID: chat.OwnerID,
|
||||
OrganizationID: uuid.New(),
|
||||
Name: "a.png",
|
||||
Mimetype: "image/png",
|
||||
CreatedAt: now,
|
||||
},
|
||||
{
|
||||
ID: file2,
|
||||
OwnerID: chat.OwnerID,
|
||||
OrganizationID: uuid.New(),
|
||||
Name: "b.txt",
|
||||
Mimetype: "text/plain",
|
||||
CreatedAt: now,
|
||||
},
|
||||
}
|
||||
|
||||
result := db2sdk.Chat(chat, nil, rows)
|
||||
require.Len(t, result.Files, 2)
|
||||
require.Equal(t, "a.png", result.Files[0].Name)
|
||||
require.Equal(t, "b.txt", result.Files[1].Name)
|
||||
}
|
||||
|
||||
func TestChatQueuedMessage_MalformedContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -1627,6 +1627,13 @@ func (q *querier) BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg datab
|
||||
return q.db.BatchUpdateWorkspaceNextStartAt(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.BatchUpsertConnectionLogs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceNotificationMessage); err != nil {
|
||||
return 0, err
|
||||
@@ -2137,17 +2144,23 @@ func (q *querier) DeleteUserChatCompactionThreshold(ctx context.Context, arg dat
|
||||
return q.db.DeleteUserChatCompactionThreshold(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.GetUserSecret(ctx, id)
|
||||
func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, secret); err != nil {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteUserSecret(ctx, id)
|
||||
return q.db.DeleteUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error {
|
||||
@@ -2565,6 +2578,10 @@ func (q *querier) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.C
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]database.GetChatFileMetadataByChatIDRow, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatFileMetadataByChatID)(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
|
||||
files, err := q.db.GetChatFilesByIDs(ctx, ids)
|
||||
if err != nil {
|
||||
@@ -3635,18 +3652,18 @@ func (q *querier) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]database
|
||||
return q.db.GetTailnetPeers(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsRow, error) {
|
||||
func (q *querier) GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsBatchRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetTailnetTunnelPeerBindings(ctx, srcID)
|
||||
return q.db.GetTailnetTunnelPeerBindingsBatch(ctx, ids)
|
||||
}
|
||||
|
||||
func (q *querier) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) {
|
||||
func (q *querier) GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerIDsBatchRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetTailnetTunnelPeerIDs(ctx, srcID)
|
||||
return q.db.GetTailnetTunnelPeerIDsBatch(ctx, ids)
|
||||
}
|
||||
|
||||
func (q *querier) GetTaskByID(ctx context.Context, id uuid.UUID) (database.Task, error) {
|
||||
@@ -4024,6 +4041,17 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID)
|
||||
return q.db.GetUserChatCustomPrompt(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
u, err := q.db.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetUserChatProviderKeys(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil {
|
||||
return 0, err
|
||||
@@ -4095,19 +4123,6 @@ func (q *querier) GetUserNotificationPreferences(ctx context.Context, userID uui
|
||||
return q.db.GetUserNotificationPreferences(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.db.GetUserSecret(ctx, id)
|
||||
if err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, secret); err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
|
||||
@@ -5364,6 +5379,17 @@ func (q *querier) InsertWorkspaceResourceMetadata(ctx context.Context, arg datab
|
||||
return q.db.InsertWorkspaceResourceMetadata(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) LinkChatFiles(ctx context.Context, arg database.LinkChatFilesParams) (int32, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.LinkChatFiles(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
@@ -5480,7 +5506,7 @@ func (q *querier) ListUserChatCompactionThresholds(ctx context.Context, userID u
|
||||
return q.db.ListUserChatCompactionThresholds(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(userID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
|
||||
return nil, err
|
||||
@@ -5488,6 +5514,16 @@ func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]data
|
||||
return q.db.ListUserSecrets(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
// This query returns decrypted secret values and must only be called
|
||||
// from system contexts (provisioner, agent manifest). REST API
|
||||
// handlers should use ListUserSecrets (metadata only).
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ListUserSecretsWithValues(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
workspace, err := q.db.GetWorkspaceByID(ctx, workspaceID)
|
||||
if err != nil {
|
||||
@@ -5738,15 +5774,15 @@ func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByI
|
||||
return q.db.UpdateChatByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
func (q *querier) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
// The batch heartbeat is a system-level operation filtered by
|
||||
// worker_id. Authorization is enforced by the AsChatd context
|
||||
// at the call site rather than per-row, because checking each
|
||||
// row individually would defeat the purpose of batching.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.UpdateChatHeartbeat(ctx, arg)
|
||||
return q.db.UpdateChatHeartbeats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateChatLabelsByIDParams) (database.Chat, error) {
|
||||
@@ -6454,6 +6490,17 @@ func (q *querier) UpdateUserChatCustomPrompt(ctx context.Context, arg database.U
|
||||
return q.db.UpdateUserChatCustomPrompt(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
return q.db.UpdateUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error {
|
||||
return deleteQ(q.log, q.auth, q.db.GetUserByID, q.db.UpdateUserDeletedByID)(ctx, id)
|
||||
}
|
||||
@@ -6577,17 +6624,12 @@ func (q *querier) UpdateUserRoles(ctx context.Context, arg database.UpdateUserRo
|
||||
return q.db.UpdateUserRoles(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.db.GetUserSecret(ctx, arg.ID)
|
||||
if err != nil {
|
||||
func (q *querier) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, obj); err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, secret); err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
return q.db.UpdateUserSecret(ctx, arg)
|
||||
return q.db.UpdateUserSecretByUserIDAndName(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserStatus(ctx context.Context, arg database.UpdateUserStatusParams) (database.User, error) {
|
||||
@@ -7032,13 +7074,6 @@ func (q *querier) UpsertChatWorkspaceTTL(ctx context.Context, workspaceTtl strin
|
||||
return q.db.UpsertChatWorkspaceTTL(ctx, workspaceTtl)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
|
||||
return database.ConnectionLog{}, err
|
||||
}
|
||||
return q.db.UpsertConnectionLog(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
@@ -7181,6 +7216,17 @@ func (q *querier) UpsertTemplateUsageStats(ctx context.Context) error {
|
||||
return q.db.UpsertTemplateUsageStats(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
return q.db.UpsertUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
|
||||
@@ -338,10 +338,9 @@ func (s *MethodTestSuite) TestAuditLogs() {
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestConnectionLogs() {
|
||||
s.Run("UpsertConnectionLog", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ws := testutil.Fake(s.T(), faker, database.WorkspaceTable{})
|
||||
arg := database.UpsertConnectionLogParams{Ip: defaultIPAddress(), Type: database.ConnectionTypeSsh, WorkspaceID: ws.ID, OrganizationID: ws.OrganizationID, ConnectionStatus: database.ConnectionStatusConnected, WorkspaceOwnerID: ws.OwnerID}
|
||||
dbm.EXPECT().UpsertConnectionLog(gomock.Any(), arg).Return(database.ConnectionLog{}, nil).AnyTimes()
|
||||
s.Run("BatchUpsertConnectionLogs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.BatchUpsertConnectionLogsParams{}
|
||||
dbm.EXPECT().BatchUpsertConnectionLogs(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceConnectionLog, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetConnectionLogsOffset", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
@@ -401,6 +400,17 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return([]database.Chat{chat}, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns([]database.Chat{chat})
|
||||
}))
|
||||
s.Run("LinkChatFiles", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.LinkChatFilesParams{
|
||||
ChatID: chat.ID,
|
||||
MaxFileLinks: int32(codersdk.MaxChatFileIDs),
|
||||
FileIds: []uuid.UUID{uuid.New()},
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().LinkChatFiles(gomock.Any(), arg).Return(int32(0), nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int32(0))
|
||||
}))
|
||||
s.Run("PinChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
@@ -577,6 +587,19 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatFilesByIDs(gomock.Any(), []uuid.UUID{file.ID}).Return([]database.ChatFile{file}, nil).AnyTimes()
|
||||
check.Args([]uuid.UUID{file.ID}).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns([]database.ChatFile{file})
|
||||
}))
|
||||
s.Run("GetChatFileMetadataByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
file := testutil.Fake(s.T(), faker, database.ChatFile{})
|
||||
rows := []database.GetChatFileMetadataByChatIDRow{{
|
||||
ID: file.ID,
|
||||
Name: file.Name,
|
||||
Mimetype: file.Mimetype,
|
||||
CreatedAt: file.CreatedAt,
|
||||
OwnerID: file.OwnerID,
|
||||
OrganizationID: file.OrganizationID,
|
||||
}}
|
||||
dbm.EXPECT().GetChatFileMetadataByChatID(gomock.Any(), file.ID).Return(rows, nil).AnyTimes()
|
||||
check.Args(file.ID).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns(rows)
|
||||
}))
|
||||
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
@@ -819,15 +842,15 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
s.Run("UpdateChatHeartbeats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
resultID := uuid.New()
|
||||
arg := database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{resultID},
|
||||
WorkerID: uuid.New(),
|
||||
Now: time.Now(),
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatHeartbeat(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
|
||||
dbm.EXPECT().UpdateChatHeartbeats(gomock.Any(), arg).Return([]uuid.UUID{resultID}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]uuid.UUID{resultID})
|
||||
}))
|
||||
s.Run("UpdateChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
@@ -2407,6 +2430,36 @@ func (s *MethodTestSuite) TestUser() {
|
||||
dbm.EXPECT().GetUserChatCustomPrompt(gomock.Any(), u.ID).Return("my custom prompt", nil).AnyTimes()
|
||||
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns("my custom prompt")
|
||||
}))
|
||||
s.Run("GetUserChatProviderKeys", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID})
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().GetUserChatProviderKeys(gomock.Any(), u.ID).Return([]database.UserChatProviderKey{key}, nil).AnyTimes()
|
||||
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns([]database.UserChatProviderKey{key})
|
||||
}))
|
||||
s.Run("DeleteUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.DeleteUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New()}
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteUserChatProviderKey(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns()
|
||||
}))
|
||||
s.Run("UpdateUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.UpdateUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "updated-api-key"}
|
||||
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey})
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
|
||||
}))
|
||||
s.Run("UpsertUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.UpsertUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "upserted-api-key"}
|
||||
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey})
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().UpsertUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
|
||||
}))
|
||||
s.Run("UpdateUserChatCustomPrompt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
uc := database.UserConfig{UserID: u.ID, Key: "chat_custom_prompt", Value: "my custom prompt"}
|
||||
@@ -3720,13 +3773,11 @@ func (s *MethodTestSuite) TestTailnetFunctions() {
|
||||
check.Args(uuid.New()).
|
||||
Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetTailnetTunnelPeerBindings", s.Subtest(func(_ database.Store, check *expects) {
|
||||
check.Args(uuid.New()).
|
||||
Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead)
|
||||
s.Run("GetTailnetTunnelPeerBindingsBatch", s.Subtest(func(_ database.Store, check *expects) {
|
||||
check.Args([]uuid.UUID{uuid.New()}).Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetTailnetTunnelPeerIDs", s.Subtest(func(_ database.Store, check *expects) {
|
||||
check.Args(uuid.New()).
|
||||
Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead)
|
||||
s.Run("GetTailnetTunnelPeerIDsBatch", s.Subtest(func(_ database.Store, check *expects) {
|
||||
check.Args([]uuid.UUID{uuid.New()}).Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetAllTailnetCoordinators", s.Subtest(func(_ database.Store, check *expects) {
|
||||
check.Args().
|
||||
@@ -5295,19 +5346,20 @@ func (s *MethodTestSuite) TestUserSecrets() {
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
|
||||
Returns(secret)
|
||||
}))
|
||||
s.Run("GetUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
check.Args(secret.ID).
|
||||
Asserts(secret, policy.ActionRead).
|
||||
Returns(secret)
|
||||
}))
|
||||
s.Run("ListUserSecrets", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
|
||||
row := testutil.Fake(s.T(), faker, database.ListUserSecretsRow{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.ListUserSecretsRow{row}, nil).AnyTimes()
|
||||
check.Args(user.ID).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
|
||||
Returns([]database.ListUserSecretsRow{row})
|
||||
}))
|
||||
s.Run("ListUserSecretsWithValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecretsWithValues(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
|
||||
check.Args(user.ID).
|
||||
Asserts(rbac.ResourceSystem, policy.ActionRead).
|
||||
Returns([]database.UserSecret{secret})
|
||||
}))
|
||||
s.Run("CreateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
@@ -5319,22 +5371,21 @@ func (s *MethodTestSuite) TestUserSecrets() {
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionCreate).
|
||||
Returns(ret)
|
||||
}))
|
||||
s.Run("UpdateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
updated := testutil.Fake(s.T(), faker, database.UserSecret{ID: secret.ID})
|
||||
arg := database.UpdateUserSecretParams{ID: secret.ID}
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateUserSecret(gomock.Any(), arg).Return(updated, nil).AnyTimes()
|
||||
s.Run("UpdateUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
updated := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
arg := database.UpdateUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
|
||||
dbm.EXPECT().UpdateUserSecretByUserIDAndName(gomock.Any(), arg).Return(updated, nil).AnyTimes()
|
||||
check.Args(arg).
|
||||
Asserts(secret, policy.ActionUpdate).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionUpdate).
|
||||
Returns(updated)
|
||||
}))
|
||||
s.Run("DeleteUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteUserSecret(gomock.Any(), secret.ID).Return(nil).AnyTimes()
|
||||
check.Args(secret.ID).
|
||||
Asserts(secret, policy.ActionRead, secret, policy.ActionDelete).
|
||||
s.Run("DeleteUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.DeleteUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
|
||||
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionDelete).
|
||||
Returns()
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func AuditLog(t testing.TB, db database.Store, seed database.AuditLog) database.
|
||||
}
|
||||
|
||||
func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnectionLogParams) database.ConnectionLog {
|
||||
log, err := db.UpsertConnectionLog(genCtx, database.UpsertConnectionLogParams{
|
||||
arg := database.UpsertConnectionLogParams{
|
||||
ID: takeFirst(seed.ID, uuid.New()),
|
||||
Time: takeFirst(seed.Time, dbtime.Now()),
|
||||
OrganizationID: takeFirst(seed.OrganizationID, uuid.New()),
|
||||
@@ -89,7 +89,7 @@ func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnecti
|
||||
Int32: takeFirst(seed.Code.Int32, 0),
|
||||
Valid: takeFirst(seed.Code.Valid, false),
|
||||
},
|
||||
Ip: pqtype.Inet{
|
||||
IP: pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 255),
|
||||
@@ -117,9 +117,53 @@ func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnecti
|
||||
Valid: takeFirst(seed.DisconnectReason.Valid, false),
|
||||
},
|
||||
ConnectionStatus: takeFirst(seed.ConnectionStatus, database.ConnectionStatusConnected),
|
||||
}
|
||||
|
||||
var disconnectTime sql.NullTime
|
||||
if arg.ConnectionStatus == database.ConnectionStatusDisconnected {
|
||||
disconnectTime = sql.NullTime{Time: arg.Time, Valid: true}
|
||||
}
|
||||
|
||||
err := db.BatchUpsertConnectionLogs(genCtx, database.BatchUpsertConnectionLogsParams{
|
||||
ID: []uuid.UUID{arg.ID},
|
||||
ConnectTime: []time.Time{arg.Time},
|
||||
OrganizationID: []uuid.UUID{arg.OrganizationID},
|
||||
WorkspaceOwnerID: []uuid.UUID{arg.WorkspaceOwnerID},
|
||||
WorkspaceID: []uuid.UUID{arg.WorkspaceID},
|
||||
WorkspaceName: []string{arg.WorkspaceName},
|
||||
AgentName: []string{arg.AgentName},
|
||||
Type: []database.ConnectionType{arg.Type},
|
||||
Code: []int32{arg.Code.Int32},
|
||||
CodeValid: []bool{arg.Code.Valid},
|
||||
Ip: []pqtype.Inet{arg.IP},
|
||||
UserAgent: []string{arg.UserAgent.String},
|
||||
UserID: []uuid.UUID{arg.UserID.UUID},
|
||||
SlugOrPort: []string{arg.SlugOrPort.String},
|
||||
ConnectionID: []uuid.UUID{arg.ConnectionID.UUID},
|
||||
DisconnectReason: []string{arg.DisconnectReason.String},
|
||||
DisconnectTime: []time.Time{disconnectTime.Time},
|
||||
})
|
||||
require.NoError(t, err, "insert connection log")
|
||||
return log
|
||||
|
||||
// Query back the actual row from the database. On upsert
|
||||
// conflict the DB keeps the original row's ID, so we can't
|
||||
// rely on arg.ID. Match on the conflict key for rows with a
|
||||
// connection_id, or by primary key for NULL connection_id.
|
||||
rows, err := db.GetConnectionLogsOffset(genCtx, database.GetConnectionLogsOffsetParams{})
|
||||
require.NoError(t, err, "query connection logs")
|
||||
for _, row := range rows {
|
||||
if arg.ConnectionID.Valid {
|
||||
if row.ConnectionLog.ConnectionID == arg.ConnectionID &&
|
||||
row.ConnectionLog.WorkspaceID == arg.WorkspaceID &&
|
||||
row.ConnectionLog.AgentName == arg.AgentName {
|
||||
return row.ConnectionLog
|
||||
}
|
||||
} else if row.ConnectionLog.ID == arg.ID {
|
||||
return row.ConnectionLog
|
||||
}
|
||||
}
|
||||
require.Failf(t, "connection log not found", "id=%s", arg.ID)
|
||||
return database.ConnectionLog{} // unreachable
|
||||
}
|
||||
|
||||
func Template(t testing.TB, db database.Store, seed database.Template) database.Template {
|
||||
@@ -1553,6 +1597,7 @@ func UserSecret(t testing.TB, db database.Store, seed database.UserSecret) datab
|
||||
Name: takeFirst(seed.Name, "secret-name"),
|
||||
Description: takeFirst(seed.Description, "secret description"),
|
||||
Value: takeFirst(seed.Value, "secret value"),
|
||||
ValueKeyID: seed.ValueKeyID,
|
||||
EnvName: takeFirst(seed.EnvName, "SECRET_ENV_NAME"),
|
||||
FilePath: takeFirst(seed.FilePath, "~/secret/file/path"),
|
||||
})
|
||||
@@ -1613,13 +1658,15 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
|
||||
|
||||
func AIBridgeTokenUsage(t testing.TB, db database.Store, seed database.InsertAIBridgeTokenUsageParams) database.AIBridgeTokenUsage {
|
||||
usage, err := db.InsertAIBridgeTokenUsage(genCtx, database.InsertAIBridgeTokenUsageParams{
|
||||
ID: takeFirst(seed.ID, uuid.New()),
|
||||
InterceptionID: takeFirst(seed.InterceptionID, uuid.New()),
|
||||
ProviderResponseID: takeFirst(seed.ProviderResponseID, "provider_response_id"),
|
||||
InputTokens: takeFirst(seed.InputTokens, 100),
|
||||
OutputTokens: takeFirst(seed.OutputTokens, 100),
|
||||
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
|
||||
CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()),
|
||||
ID: takeFirst(seed.ID, uuid.New()),
|
||||
InterceptionID: takeFirst(seed.InterceptionID, uuid.New()),
|
||||
ProviderResponseID: takeFirst(seed.ProviderResponseID, "provider_response_id"),
|
||||
InputTokens: takeFirst(seed.InputTokens, 100),
|
||||
OutputTokens: takeFirst(seed.OutputTokens, 100),
|
||||
CacheReadInputTokens: seed.CacheReadInputTokens,
|
||||
CacheWriteInputTokens: seed.CacheWriteInputTokens,
|
||||
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
|
||||
CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()),
|
||||
})
|
||||
require.NoError(t, err, "insert aibridge token usage")
|
||||
return usage
|
||||
|
||||
@@ -208,6 +208,14 @@ func (m queryMetricsStore) BatchUpdateWorkspaceNextStartAt(ctx context.Context,
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.BatchUpsertConnectionLogs(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("BatchUpsertConnectionLogs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BatchUpsertConnectionLogs").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.BulkMarkNotificationMessagesFailed(ctx, arg)
|
||||
@@ -696,11 +704,19 @@ func (m queryMetricsStore) DeleteUserChatCompactionThreshold(ctx context.Context
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserSecret(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecret").Inc()
|
||||
r0 := m.s.DeleteUserChatProviderKey(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserChatProviderKey").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserChatProviderKey").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecretByUserIDAndName").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -1112,6 +1128,14 @@ func (m queryMetricsStore) GetChatFileByID(ctx context.Context, id uuid.UUID) (d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]database.GetChatFileMetadataByChatIDRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatFileMetadataByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatFileMetadataByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatFileMetadataByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatFilesByIDs(ctx, ids)
|
||||
@@ -2208,19 +2232,19 @@ func (m queryMetricsStore) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsRow, error) {
|
||||
func (m queryMetricsStore) GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsBatchRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetTailnetTunnelPeerBindings(ctx, srcID)
|
||||
m.queryLatencies.WithLabelValues("GetTailnetTunnelPeerBindings").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTailnetTunnelPeerBindings").Inc()
|
||||
r0, r1 := m.s.GetTailnetTunnelPeerBindingsBatch(ctx, ids)
|
||||
m.queryLatencies.WithLabelValues("GetTailnetTunnelPeerBindingsBatch").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTailnetTunnelPeerBindingsBatch").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) {
|
||||
func (m queryMetricsStore) GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerIDsBatchRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetTailnetTunnelPeerIDs(ctx, srcID)
|
||||
m.queryLatencies.WithLabelValues("GetTailnetTunnelPeerIDs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTailnetTunnelPeerIDs").Inc()
|
||||
r0, r1 := m.s.GetTailnetTunnelPeerIDsBatch(ctx, ids)
|
||||
m.queryLatencies.WithLabelValues("GetTailnetTunnelPeerIDsBatch").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTailnetTunnelPeerIDsBatch").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -2528,6 +2552,14 @@ func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID u
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatProviderKeys(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("GetUserChatProviderKeys").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatProviderKeys").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatSpendInPeriod(ctx, arg)
|
||||
@@ -2592,14 +2624,6 @@ func (m queryMetricsStore) GetUserNotificationPreferences(ctx context.Context, u
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserSecret(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserSecret").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserSecretByUserIDAndName(ctx, arg)
|
||||
@@ -3752,6 +3776,14 @@ func (m queryMetricsStore) InsertWorkspaceResourceMetadata(ctx context.Context,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) LinkChatFiles(ctx context.Context, arg database.LinkChatFilesParams) (int32, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.LinkChatFiles(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("LinkChatFiles").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "LinkChatFiles").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeClients(ctx, arg)
|
||||
@@ -3880,7 +3912,7 @@ func (m queryMetricsStore) ListUserChatCompactionThresholds(ctx context.Context,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListUserSecrets(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("ListUserSecrets").Observe(time.Since(start).Seconds())
|
||||
@@ -3888,6 +3920,14 @@ func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListUserSecretsWithValues(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("ListUserSecretsWithValues").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListUserSecretsWithValues").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListWorkspaceAgentPortShares(ctx, workspaceID)
|
||||
@@ -4096,11 +4136,11 @@ func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.Upda
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
func (m queryMetricsStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatHeartbeat(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatHeartbeat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeat").Inc()
|
||||
r0, r1 := m.s.UpdateChatHeartbeats(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatHeartbeats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeats").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -4560,6 +4600,14 @@ func (m queryMetricsStore) UpdateUserChatCustomPrompt(ctx context.Context, arg d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateUserChatProviderKey(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserChatProviderKey").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserChatProviderKey").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpdateUserDeletedByID(ctx, id)
|
||||
@@ -4648,11 +4696,11 @@ func (m queryMetricsStore) UpdateUserRoles(ctx context.Context, arg database.Upd
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
func (m queryMetricsStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateUserSecret(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecret").Inc()
|
||||
r0, r1 := m.s.UpdateUserSecretByUserIDAndName(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecretByUserIDAndName").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -5000,14 +5048,6 @@ func (m queryMetricsStore) UpsertChatWorkspaceTTL(ctx context.Context, workspace
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertConnectionLog(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertConnectionLog").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertConnectionLog").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertDefaultProxy(ctx, arg)
|
||||
@@ -5152,6 +5192,14 @@ func (m queryMetricsStore) UpsertTemplateUsageStats(ctx context.Context) error {
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertUserChatProviderKey(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertUserChatProviderKey").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertUserChatProviderKey").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertWebpushVAPIDKeys(ctx, arg)
|
||||
|
||||
@@ -233,6 +233,20 @@ func (mr *MockStoreMockRecorder) BatchUpdateWorkspaceNextStartAt(ctx, arg any) *
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdateWorkspaceNextStartAt", reflect.TypeOf((*MockStore)(nil).BatchUpdateWorkspaceNextStartAt), ctx, arg)
|
||||
}
|
||||
|
||||
// BatchUpsertConnectionLogs mocks base method.
|
||||
func (m *MockStore) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BatchUpsertConnectionLogs", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BatchUpsertConnectionLogs indicates an expected call of BatchUpsertConnectionLogs.
|
||||
func (mr *MockStoreMockRecorder) BatchUpsertConnectionLogs(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpsertConnectionLogs", reflect.TypeOf((*MockStore)(nil).BatchUpsertConnectionLogs), ctx, arg)
|
||||
}
|
||||
|
||||
// BulkMarkNotificationMessagesFailed mocks base method.
|
||||
func (m *MockStore) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1171,18 +1185,32 @@ func (mr *MockStoreMockRecorder) DeleteUserChatCompactionThreshold(ctx, arg any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).DeleteUserChatCompactionThreshold), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserSecret mocks base method.
|
||||
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
// DeleteUserChatProviderKey mocks base method.
|
||||
func (m *MockStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteUserSecret", ctx, id)
|
||||
ret := m.ctrl.Call(m, "DeleteUserChatProviderKey", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteUserSecret indicates an expected call of DeleteUserSecret.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserSecret(ctx, id any) *gomock.Call {
|
||||
// DeleteUserChatProviderKey indicates an expected call of DeleteUserChatProviderKey.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecret", reflect.TypeOf((*MockStore)(nil).DeleteUserSecret), ctx, id)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteUserSecretByUserIDAndName", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteUserSecretByUserIDAndName indicates an expected call of DeleteUserSecretByUserIDAndName.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).DeleteUserSecretByUserIDAndName), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteWebpushSubscriptionByUserIDAndEndpoint mocks base method.
|
||||
@@ -2044,6 +2072,21 @@ func (mr *MockStoreMockRecorder) GetChatFileByID(ctx, id any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFileByID", reflect.TypeOf((*MockStore)(nil).GetChatFileByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatFileMetadataByChatID mocks base method.
|
||||
func (m *MockStore) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]database.GetChatFileMetadataByChatIDRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatFileMetadataByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].([]database.GetChatFileMetadataByChatIDRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatFileMetadataByChatID indicates an expected call of GetChatFileMetadataByChatID.
|
||||
func (mr *MockStoreMockRecorder) GetChatFileMetadataByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFileMetadataByChatID", reflect.TypeOf((*MockStore)(nil).GetChatFileMetadataByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// GetChatFilesByIDs mocks base method.
|
||||
func (m *MockStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4099,34 +4142,34 @@ func (mr *MockStoreMockRecorder) GetTailnetPeers(ctx, id any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetPeers", reflect.TypeOf((*MockStore)(nil).GetTailnetPeers), ctx, id)
|
||||
}
|
||||
|
||||
// GetTailnetTunnelPeerBindings mocks base method.
|
||||
func (m *MockStore) GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsRow, error) {
|
||||
// GetTailnetTunnelPeerBindingsBatch mocks base method.
|
||||
func (m *MockStore) GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsBatchRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTailnetTunnelPeerBindings", ctx, srcID)
|
||||
ret0, _ := ret[0].([]database.GetTailnetTunnelPeerBindingsRow)
|
||||
ret := m.ctrl.Call(m, "GetTailnetTunnelPeerBindingsBatch", ctx, ids)
|
||||
ret0, _ := ret[0].([]database.GetTailnetTunnelPeerBindingsBatchRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetTailnetTunnelPeerBindings indicates an expected call of GetTailnetTunnelPeerBindings.
|
||||
func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerBindings(ctx, srcID any) *gomock.Call {
|
||||
// GetTailnetTunnelPeerBindingsBatch indicates an expected call of GetTailnetTunnelPeerBindingsBatch.
|
||||
func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerBindingsBatch(ctx, ids any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerBindings", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerBindings), ctx, srcID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerBindingsBatch", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerBindingsBatch), ctx, ids)
|
||||
}
|
||||
|
||||
// GetTailnetTunnelPeerIDs mocks base method.
|
||||
func (m *MockStore) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) {
|
||||
// GetTailnetTunnelPeerIDsBatch mocks base method.
|
||||
func (m *MockStore) GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerIDsBatchRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTailnetTunnelPeerIDs", ctx, srcID)
|
||||
ret0, _ := ret[0].([]database.GetTailnetTunnelPeerIDsRow)
|
||||
ret := m.ctrl.Call(m, "GetTailnetTunnelPeerIDsBatch", ctx, ids)
|
||||
ret0, _ := ret[0].([]database.GetTailnetTunnelPeerIDsBatchRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetTailnetTunnelPeerIDs indicates an expected call of GetTailnetTunnelPeerIDs.
|
||||
func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerIDs(ctx, srcID any) *gomock.Call {
|
||||
// GetTailnetTunnelPeerIDsBatch indicates an expected call of GetTailnetTunnelPeerIDsBatch.
|
||||
func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerIDsBatch(ctx, ids any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerIDs", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerIDs), ctx, srcID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerIDsBatch", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerIDsBatch), ctx, ids)
|
||||
}
|
||||
|
||||
// GetTaskByID mocks base method.
|
||||
@@ -4729,6 +4772,21 @@ func (mr *MockStoreMockRecorder) GetUserChatCustomPrompt(ctx, userID any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).GetUserChatCustomPrompt), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserChatProviderKeys mocks base method.
|
||||
func (m *MockStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserChatProviderKeys", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.UserChatProviderKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserChatProviderKeys indicates an expected call of GetUserChatProviderKeys.
|
||||
func (mr *MockStoreMockRecorder) GetUserChatProviderKeys(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatProviderKeys", reflect.TypeOf((*MockStore)(nil).GetUserChatProviderKeys), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserChatSpendInPeriod mocks base method.
|
||||
func (m *MockStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4849,21 +4907,6 @@ func (mr *MockStoreMockRecorder) GetUserNotificationPreferences(ctx, userID any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserNotificationPreferences", reflect.TypeOf((*MockStore)(nil).GetUserNotificationPreferences), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserSecret mocks base method.
|
||||
func (m *MockStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserSecret", ctx, id)
|
||||
ret0, _ := ret[0].(database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserSecret indicates an expected call of GetUserSecret.
|
||||
func (mr *MockStoreMockRecorder) GetUserSecret(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserSecret", reflect.TypeOf((*MockStore)(nil).GetUserSecret), ctx, id)
|
||||
}
|
||||
|
||||
// GetUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7023,6 +7066,21 @@ func (mr *MockStoreMockRecorder) InsertWorkspaceResourceMetadata(ctx, arg any) *
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertWorkspaceResourceMetadata", reflect.TypeOf((*MockStore)(nil).InsertWorkspaceResourceMetadata), ctx, arg)
|
||||
}
|
||||
|
||||
// LinkChatFiles mocks base method.
|
||||
func (m *MockStore) LinkChatFiles(ctx context.Context, arg database.LinkChatFilesParams) (int32, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LinkChatFiles", ctx, arg)
|
||||
ret0, _ := ret[0].(int32)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LinkChatFiles indicates an expected call of LinkChatFiles.
|
||||
func (mr *MockStoreMockRecorder) LinkChatFiles(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkChatFiles", reflect.TypeOf((*MockStore)(nil).LinkChatFiles), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeClients mocks base method.
|
||||
func (m *MockStore) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7339,10 +7397,10 @@ func (mr *MockStoreMockRecorder) ListUserChatCompactionThresholds(ctx, userID an
|
||||
}
|
||||
|
||||
// ListUserSecrets mocks base method.
|
||||
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListUserSecrets", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.UserSecret)
|
||||
ret0, _ := ret[0].([]database.ListUserSecretsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -7353,6 +7411,21 @@ func (mr *MockStoreMockRecorder) ListUserSecrets(ctx, userID any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecrets", reflect.TypeOf((*MockStore)(nil).ListUserSecrets), ctx, userID)
|
||||
}
|
||||
|
||||
// ListUserSecretsWithValues mocks base method.
|
||||
func (m *MockStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListUserSecretsWithValues", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListUserSecretsWithValues indicates an expected call of ListUserSecretsWithValues.
|
||||
func (mr *MockStoreMockRecorder) ListUserSecretsWithValues(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecretsWithValues", reflect.TypeOf((*MockStore)(nil).ListUserSecretsWithValues), ctx, userID)
|
||||
}
|
||||
|
||||
// ListWorkspaceAgentPortShares mocks base method.
|
||||
func (m *MockStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7762,19 +7835,19 @@ func (mr *MockStoreMockRecorder) UpdateChatByID(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatByID", reflect.TypeOf((*MockStore)(nil).UpdateChatByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatHeartbeat mocks base method.
|
||||
func (m *MockStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
// UpdateChatHeartbeats mocks base method.
|
||||
func (m *MockStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatHeartbeat", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret := m.ctrl.Call(m, "UpdateChatHeartbeats", ctx, arg)
|
||||
ret0, _ := ret[0].([]uuid.UUID)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatHeartbeat indicates an expected call of UpdateChatHeartbeat.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call {
|
||||
// UpdateChatHeartbeats indicates an expected call of UpdateChatHeartbeats.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatHeartbeats(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeats", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeats), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatLabelsByID mocks base method.
|
||||
@@ -8605,6 +8678,21 @@ func (mr *MockStoreMockRecorder) UpdateUserChatCustomPrompt(ctx, arg any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).UpdateUserChatCustomPrompt), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserChatProviderKey mocks base method.
|
||||
func (m *MockStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateUserChatProviderKey", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UserChatProviderKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateUserChatProviderKey indicates an expected call of UpdateUserChatProviderKey.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserChatProviderKey(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpdateUserChatProviderKey), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserDeletedByID mocks base method.
|
||||
func (m *MockStore) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8766,19 +8854,19 @@ func (mr *MockStoreMockRecorder) UpdateUserRoles(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserRoles", reflect.TypeOf((*MockStore)(nil).UpdateUserRoles), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserSecret mocks base method.
|
||||
func (m *MockStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
// UpdateUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateUserSecret", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "UpdateUserSecretByUserIDAndName", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateUserSecret indicates an expected call of UpdateUserSecret.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserSecret(ctx, arg any) *gomock.Call {
|
||||
// UpdateUserSecretByUserIDAndName indicates an expected call of UpdateUserSecretByUserIDAndName.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecret", reflect.TypeOf((*MockStore)(nil).UpdateUserSecret), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).UpdateUserSecretByUserIDAndName), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserStatus mocks base method.
|
||||
@@ -9398,21 +9486,6 @@ func (mr *MockStoreMockRecorder) UpsertChatWorkspaceTTL(ctx, workspaceTtl any) *
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatWorkspaceTTL", reflect.TypeOf((*MockStore)(nil).UpsertChatWorkspaceTTL), ctx, workspaceTtl)
|
||||
}
|
||||
|
||||
// UpsertConnectionLog mocks base method.
|
||||
func (m *MockStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertConnectionLog", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ConnectionLog)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertConnectionLog indicates an expected call of UpsertConnectionLog.
|
||||
func (mr *MockStoreMockRecorder) UpsertConnectionLog(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertConnectionLog", reflect.TypeOf((*MockStore)(nil).UpsertConnectionLog), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertDefaultProxy mocks base method.
|
||||
func (m *MockStore) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -9671,6 +9744,21 @@ func (mr *MockStoreMockRecorder) UpsertTemplateUsageStats(ctx any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTemplateUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertTemplateUsageStats), ctx)
|
||||
}
|
||||
|
||||
// UpsertUserChatProviderKey mocks base method.
|
||||
func (m *MockStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertUserChatProviderKey", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UserChatProviderKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertUserChatProviderKey indicates an expected call of UpsertUserChatProviderKey.
|
||||
func (mr *MockStoreMockRecorder) UpsertUserChatProviderKey(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpsertUserChatProviderKey), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertWebpushVAPIDKeys mocks base method.
|
||||
func (m *MockStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+55
-3
@@ -1134,7 +1134,9 @@ CREATE TABLE aibridge_token_usages (
|
||||
input_tokens bigint NOT NULL,
|
||||
output_tokens bigint NOT NULL,
|
||||
metadata jsonb,
|
||||
created_at timestamp with time zone NOT NULL
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
cache_read_input_tokens bigint DEFAULT 0 NOT NULL,
|
||||
cache_write_input_tokens bigint DEFAULT 0 NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_token_usages IS 'Audit log of tokens used by intercepted requests in AI Bridge';
|
||||
@@ -1267,6 +1269,11 @@ CREATE TABLE chat_diff_statuses (
|
||||
head_branch text
|
||||
);
|
||||
|
||||
CREATE TABLE chat_file_links (
|
||||
chat_id uuid NOT NULL,
|
||||
file_id uuid NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE chat_files (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
owner_id uuid NOT NULL,
|
||||
@@ -1341,7 +1348,11 @@ CREATE TABLE chat_providers (
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
base_url text DEFAULT ''::text NOT NULL,
|
||||
CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text])))
|
||||
central_api_key_enabled boolean DEFAULT true NOT NULL,
|
||||
allow_user_api_key boolean DEFAULT false NOT NULL,
|
||||
allow_central_api_key_fallback boolean DEFAULT false NOT NULL,
|
||||
CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text]))),
|
||||
CONSTRAINT valid_credential_policy CHECK (((central_api_key_enabled OR allow_user_api_key) AND ((NOT allow_central_api_key_fallback) OR (central_api_key_enabled AND allow_user_api_key))))
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN chat_providers.api_key_key_id IS 'The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted';
|
||||
@@ -2752,6 +2763,17 @@ COMMENT ON TABLE usage_events_daily IS 'usage_events_daily is a daily rollup of
|
||||
|
||||
COMMENT ON COLUMN usage_events_daily.day IS 'The date of the summed usage events, always in UTC.';
|
||||
|
||||
CREATE TABLE user_chat_provider_keys (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
user_id uuid NOT NULL,
|
||||
chat_provider_id uuid NOT NULL,
|
||||
api_key text NOT NULL,
|
||||
api_key_key_id text,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
CONSTRAINT user_chat_provider_keys_api_key_check CHECK ((api_key <> ''::text))
|
||||
);
|
||||
|
||||
CREATE TABLE user_configs (
|
||||
user_id uuid NOT NULL,
|
||||
key character varying(256) NOT NULL,
|
||||
@@ -2793,7 +2815,8 @@ CREATE TABLE user_secrets (
|
||||
env_name text DEFAULT ''::text NOT NULL,
|
||||
file_path text DEFAULT ''::text NOT NULL,
|
||||
created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
updated_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
value_key_id text
|
||||
);
|
||||
|
||||
CREATE TABLE user_status_changes (
|
||||
@@ -3326,6 +3349,9 @@ ALTER TABLE ONLY boundary_usage_stats
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
|
||||
ALTER TABLE ONLY chat_file_links
|
||||
ADD CONSTRAINT chat_file_links_chat_id_file_id_key UNIQUE (chat_id, file_id);
|
||||
|
||||
ALTER TABLE ONLY chat_files
|
||||
ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
|
||||
|
||||
@@ -3548,6 +3574,12 @@ ALTER TABLE ONLY usage_events_daily
|
||||
ALTER TABLE ONLY usage_events
|
||||
ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id);
|
||||
|
||||
ALTER TABLE ONLY user_configs
|
||||
ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key);
|
||||
|
||||
@@ -3710,6 +3742,8 @@ CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC);
|
||||
|
||||
CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses USING btree (stale_at);
|
||||
|
||||
CREATE INDEX idx_chat_file_links_chat_id ON chat_file_links USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chat_files_org ON chat_files USING btree (organization_id);
|
||||
|
||||
CREATE INDEX idx_chat_files_owner ON chat_files USING btree (owner_id);
|
||||
@@ -4012,6 +4046,12 @@ ALTER TABLE ONLY api_keys
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_file_links
|
||||
ADD CONSTRAINT chat_file_links_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_file_links
|
||||
ADD CONSTRAINT chat_file_links_file_id_fkey FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_files
|
||||
ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
|
||||
@@ -4258,6 +4298,15 @@ ALTER TABLE ONLY templates
|
||||
ALTER TABLE ONLY templates
|
||||
ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY user_configs
|
||||
ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
@@ -4276,6 +4325,9 @@ ALTER TABLE ONLY user_links
|
||||
ALTER TABLE ONLY user_secrets
|
||||
ADD CONSTRAINT user_secrets_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY user_secrets
|
||||
ADD CONSTRAINT user_secrets_value_key_id_fkey FOREIGN KEY (value_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY user_status_changes
|
||||
ADD CONSTRAINT user_status_changes_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id);
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ const (
|
||||
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
|
||||
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFileLinksChatID ForeignKeyConstraint = "chat_file_links_chat_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFileLinksFileID ForeignKeyConstraint = "chat_file_links_file_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_file_id_fkey FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFilesOrganizationID ForeignKeyConstraint = "chat_files_organization_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFilesOwnerID ForeignKeyConstraint = "chat_files_owner_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatMessagesChatID ForeignKeyConstraint = "chat_messages_chat_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
@@ -92,12 +94,16 @@ const (
|
||||
ForeignKeyTemplateVersionsTemplateID ForeignKeyConstraint = "template_versions_template_id_fkey" // ALTER TABLE ONLY template_versions ADD CONSTRAINT template_versions_template_id_fkey FOREIGN KEY (template_id) REFERENCES templates(id) ON DELETE CASCADE;
|
||||
ForeignKeyTemplatesCreatedBy ForeignKeyConstraint = "templates_created_by_fkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE RESTRICT;
|
||||
ForeignKeyTemplatesOrganizationID ForeignKeyConstraint = "templates_organization_id_fkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserChatProviderKeysAPIKeyKeyID ForeignKeyConstraint = "user_chat_provider_keys_api_key_key_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyUserChatProviderKeysChatProviderID ForeignKeyConstraint = "user_chat_provider_keys_chat_provider_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserChatProviderKeysUserID ForeignKeyConstraint = "user_chat_provider_keys_user_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserConfigsUserID ForeignKeyConstraint = "user_configs_user_id_fkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserDeletedUserID ForeignKeyConstraint = "user_deleted_user_id_fkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id);
|
||||
ForeignKeyUserLinksOauthAccessTokenKeyID ForeignKeyConstraint = "user_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyUserLinksOauthRefreshTokenKeyID ForeignKeyConstraint = "user_links_oauth_refresh_token_key_id_fkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_oauth_refresh_token_key_id_fkey FOREIGN KEY (oauth_refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyUserLinksUserID ForeignKeyConstraint = "user_links_user_id_fkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserSecretsUserID ForeignKeyConstraint = "user_secrets_user_id_fkey" // ALTER TABLE ONLY user_secrets ADD CONSTRAINT user_secrets_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserSecretsValueKeyID ForeignKeyConstraint = "user_secrets_value_key_id_fkey" // ALTER TABLE ONLY user_secrets ADD CONSTRAINT user_secrets_value_key_id_fkey FOREIGN KEY (value_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyUserStatusChangesUserID ForeignKeyConstraint = "user_status_changes_user_id_fkey" // ALTER TABLE ONLY user_status_changes ADD CONSTRAINT user_status_changes_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id);
|
||||
ForeignKeyWebpushSubscriptionsUserID ForeignKeyConstraint = "webpush_subscriptions_user_id_fkey" // ALTER TABLE ONLY webpush_subscriptions ADD CONSTRAINT webpush_subscriptions_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyWorkspaceAgentDevcontainersSubagentID ForeignKeyConstraint = "workspace_agent_devcontainers_subagent_id_fkey" // ALTER TABLE ONLY workspace_agent_devcontainers ADD CONSTRAINT workspace_agent_devcontainers_subagent_id_fkey FOREIGN KEY (subagent_id) REFERENCES workspace_agents(id) ON DELETE CASCADE;
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
DROP TABLE IF EXISTS user_chat_provider_keys;
|
||||
|
||||
ALTER TABLE chat_providers DROP CONSTRAINT IF EXISTS valid_credential_policy;
|
||||
|
||||
ALTER TABLE chat_providers
|
||||
DROP COLUMN IF EXISTS central_api_key_enabled,
|
||||
DROP COLUMN IF EXISTS allow_user_api_key,
|
||||
DROP COLUMN IF EXISTS allow_central_api_key_fallback;
|
||||
@@ -0,0 +1,24 @@
|
||||
ALTER TABLE chat_providers
|
||||
ADD COLUMN central_api_key_enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
ADD COLUMN allow_user_api_key BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
ADD COLUMN allow_central_api_key_fallback BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
ALTER TABLE chat_providers
|
||||
ADD CONSTRAINT valid_credential_policy CHECK (
|
||||
(central_api_key_enabled OR allow_user_api_key) AND
|
||||
(
|
||||
NOT allow_central_api_key_fallback OR
|
||||
(central_api_key_enabled AND allow_user_api_key)
|
||||
)
|
||||
);
|
||||
|
||||
CREATE TABLE user_chat_provider_keys (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
chat_provider_id UUID NOT NULL REFERENCES chat_providers(id) ON DELETE CASCADE,
|
||||
api_key TEXT NOT NULL CHECK (api_key != ''),
|
||||
api_key_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE (user_id, chat_provider_id)
|
||||
);
|
||||
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE user_secrets
|
||||
DROP CONSTRAINT user_secrets_value_key_id_fkey,
|
||||
DROP COLUMN value_key_id;
|
||||
@@ -0,0 +1,5 @@
|
||||
ALTER TABLE user_secrets
|
||||
ADD COLUMN value_key_id TEXT;
|
||||
|
||||
ALTER TABLE ONLY user_secrets
|
||||
ADD CONSTRAINT user_secrets_value_key_id_fkey FOREIGN KEY (value_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE aibridge_token_usages
|
||||
DROP COLUMN cache_read_input_tokens,
|
||||
DROP COLUMN cache_write_input_tokens;
|
||||
@@ -0,0 +1,26 @@
|
||||
ALTER TABLE aibridge_token_usages
|
||||
ADD COLUMN cache_read_input_tokens BIGINT NOT NULL DEFAULT 0,
|
||||
ADD COLUMN cache_write_input_tokens BIGINT NOT NULL DEFAULT 0;
|
||||
|
||||
-- Backfill from metadata JSONB. Old rows stored cache tokens under
|
||||
-- provider-specific keys; new rows use the dedicated columns above.
|
||||
UPDATE aibridge_token_usages
|
||||
SET
|
||||
|
||||
-- Cache-read metadata keys by provider:
|
||||
-- Anthropic (/v1/messages): "cache_read_input"
|
||||
-- OpenAI (/v1/responses): "input_cached"
|
||||
-- OpenAI (/v1/chat/completions): "prompt_cached"
|
||||
cache_read_input_tokens = GREATEST(
|
||||
COALESCE((metadata->>'cache_read_input')::bigint, 0),
|
||||
COALESCE((metadata->>'input_cached')::bigint, 0),
|
||||
COALESCE((metadata->>'prompt_cached')::bigint, 0)
|
||||
),
|
||||
|
||||
-- Cache-write metadata keys by provider:
|
||||
-- Anthropic (/v1/messages): "cache_creation_input"
|
||||
-- OpenAI does not report cache-write tokens.
|
||||
cache_write_input_tokens = COALESCE((metadata->>'cache_creation_input')::bigint, 0)
|
||||
WHERE metadata IS NOT NULL
|
||||
AND cache_read_input_tokens = 0
|
||||
AND cache_write_input_tokens = 0;
|
||||
@@ -0,0 +1,9 @@
|
||||
ALTER TABLE chats ADD COLUMN file_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL;
|
||||
|
||||
UPDATE chats SET file_ids = (
|
||||
SELECT COALESCE(array_agg(cfl.file_id), '{}')
|
||||
FROM chat_file_links cfl
|
||||
WHERE cfl.chat_id = chats.id
|
||||
);
|
||||
|
||||
DROP TABLE chat_file_links;
|
||||
@@ -0,0 +1,17 @@
|
||||
CREATE TABLE chat_file_links (
|
||||
chat_id uuid NOT NULL,
|
||||
file_id uuid NOT NULL,
|
||||
UNIQUE (chat_id, file_id)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chat_file_links_chat_id ON chat_file_links (chat_id);
|
||||
|
||||
ALTER TABLE chat_file_links
|
||||
ADD CONSTRAINT chat_file_links_chat_id_fkey
|
||||
FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE chat_file_links
|
||||
ADD CONSTRAINT chat_file_links_file_id_fkey
|
||||
FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE chats DROP COLUMN IF EXISTS file_ids;
|
||||
+16
@@ -0,0 +1,16 @@
|
||||
INSERT INTO user_chat_provider_keys (
|
||||
user_id,
|
||||
chat_provider_id,
|
||||
api_key,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
SELECT
|
||||
id,
|
||||
'0a8b2f84-b5a8-4c44-8c9f-e58c44a534a7',
|
||||
'fixture-test-key',
|
||||
'2025-01-01 00:00:00+00',
|
||||
'2025-01-01 00:00:00+00'
|
||||
FROM users
|
||||
ORDER BY created_at, id
|
||||
LIMIT 1;
|
||||
@@ -0,0 +1,5 @@
|
||||
INSERT INTO chat_file_links (chat_id, file_id)
|
||||
VALUES (
|
||||
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
|
||||
'00000000-0000-0000-0000-000000000099'
|
||||
);
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
@@ -186,6 +187,10 @@ func (c ChatFile) RBACObject() rbac.Object {
|
||||
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()).InOrg(c.OrganizationID)
|
||||
}
|
||||
|
||||
func (c GetChatFileMetadataByChatIDRow) RBACObject() rbac.Object {
|
||||
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()).InOrg(c.OrganizationID)
|
||||
}
|
||||
|
||||
func (s APIKeyScope) ToRBAC() rbac.ScopeName {
|
||||
switch s {
|
||||
case ApiKeyScopeCoderAll:
|
||||
@@ -923,3 +928,28 @@ func WorkspaceIdentityFromWorkspace(w Workspace) WorkspaceIdentity {
|
||||
func (r GetWorkspaceAgentAndWorkspaceByIDRow) RBACObject() rbac.Object {
|
||||
return r.WorkspaceTable.RBACObject()
|
||||
}
|
||||
|
||||
// UpsertConnectionLogParams contains the parameters for upserting a
|
||||
// connection log entry. This struct is hand-maintained (not generated
|
||||
// by sqlc) because the single-row UpsertConnectionLog query was
|
||||
// removed in favor of BatchUpsertConnectionLogs, but the struct is
|
||||
// still used as the canonical connection log event type throughout
|
||||
// the codebase.
|
||||
type UpsertConnectionLogParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
WorkspaceOwnerID uuid.UUID `db:"workspace_owner_id" json:"workspace_owner_id"`
|
||||
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
|
||||
WorkspaceName string `db:"workspace_name" json:"workspace_name"`
|
||||
AgentName string `db:"agent_name" json:"agent_name"`
|
||||
Type ConnectionType `db:"type" json:"type"`
|
||||
Code sql.NullInt32 `db:"code" json:"code"`
|
||||
IP pqtype.Inet `db:"ip" json:"ip"`
|
||||
UserAgent sql.NullString `db:"user_agent" json:"user_agent"`
|
||||
UserID uuid.NullUUID `db:"user_id" json:"user_id"`
|
||||
SlugOrPort sql.NullString `db:"slug_or_port" json:"slug_or_port"`
|
||||
ConnectionID uuid.NullUUID `db:"connection_id" json:"connection_id"`
|
||||
DisconnectReason sql.NullString `db:"disconnect_reason" json:"disconnect_reason"`
|
||||
Time time.Time `db:"time" json:"time"`
|
||||
ConnectionStatus ConnectionStatus `db:"connection_status" json:"connection_status"`
|
||||
}
|
||||
|
||||
@@ -584,6 +584,7 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi
|
||||
arg.DateTo,
|
||||
arg.BuildReason,
|
||||
arg.RequestID,
|
||||
arg.CountCap,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -720,6 +721,7 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
|
||||
arg.WorkspaceID,
|
||||
arg.ConnectionID,
|
||||
arg.Status,
|
||||
arg.CountCap,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -1029,6 +1031,8 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg Lis
|
||||
&i.Threads,
|
||||
&i.InputTokens,
|
||||
&i.OutputTokens,
|
||||
&i.CacheReadInputTokens,
|
||||
&i.CacheWriteInputTokens,
|
||||
&i.LastPrompt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -145,5 +145,13 @@ func extractWhereClause(query string) string {
|
||||
// Remove SQL comments
|
||||
whereClause = regexp.MustCompile(`(?m)--.*$`).ReplaceAllString(whereClause, "")
|
||||
|
||||
// Normalize indentation so subquery wrapping doesn't cause
|
||||
// mismatches.
|
||||
lines := strings.Split(whereClause, "\n")
|
||||
for i, line := range lines {
|
||||
lines[i] = strings.TrimLeft(line, " \t")
|
||||
}
|
||||
whereClause = strings.Join(lines, "\n")
|
||||
|
||||
return strings.TrimSpace(whereClause)
|
||||
}
|
||||
|
||||
+41
-20
@@ -4055,11 +4055,13 @@ type AIBridgeTokenUsage struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
InterceptionID uuid.UUID `db:"interception_id" json:"interception_id"`
|
||||
// The ID for the response in which the tokens were used, produced by the provider.
|
||||
ProviderResponseID string `db:"provider_response_id" json:"provider_response_id"`
|
||||
InputTokens int64 `db:"input_tokens" json:"input_tokens"`
|
||||
OutputTokens int64 `db:"output_tokens" json:"output_tokens"`
|
||||
Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
ProviderResponseID string `db:"provider_response_id" json:"provider_response_id"`
|
||||
InputTokens int64 `db:"input_tokens" json:"input_tokens"`
|
||||
OutputTokens int64 `db:"output_tokens" json:"output_tokens"`
|
||||
Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
CacheReadInputTokens int64 `db:"cache_read_input_tokens" json:"cache_read_input_tokens"`
|
||||
CacheWriteInputTokens int64 `db:"cache_write_input_tokens" json:"cache_write_input_tokens"`
|
||||
}
|
||||
|
||||
// Audit log of tool calls in intercepted requests in AI Bridge
|
||||
@@ -4216,6 +4218,11 @@ type ChatFile struct {
|
||||
Data []byte `db:"data" json:"data"`
|
||||
}
|
||||
|
||||
type ChatFileLink struct {
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
FileID uuid.UUID `db:"file_id" json:"file_id"`
|
||||
}
|
||||
|
||||
type ChatMessage struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
@@ -4264,12 +4271,15 @@ type ChatProvider struct {
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
// The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
BaseUrl string `db:"base_url" json:"base_url"`
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
BaseUrl string `db:"base_url" json:"base_url"`
|
||||
CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"`
|
||||
AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"`
|
||||
AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"`
|
||||
}
|
||||
|
||||
type ChatQueuedMessage struct {
|
||||
@@ -5222,6 +5232,16 @@ type User struct {
|
||||
ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"`
|
||||
}
|
||||
|
||||
type UserChatProviderKey struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
type UserConfig struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Key string `db:"key" json:"key"`
|
||||
@@ -5251,15 +5271,16 @@ type UserLink struct {
|
||||
}
|
||||
|
||||
type UserSecret struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
Value string `db:"value" json:"value"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
Value string `db:"value" json:"value"`
|
||||
EnvName string `db:"env_name" json:"env_name"`
|
||||
FilePath string `db:"file_path" json:"file_path"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ValueKeyID sql.NullString `db:"value_key_id" json:"value_key_id"`
|
||||
}
|
||||
|
||||
// Tracks the history of user status changes
|
||||
|
||||
@@ -81,8 +81,8 @@ func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue
|
||||
}
|
||||
|
||||
func (q *msgQueue) run() {
|
||||
var batch [maxDrainBatch]msgOrErr
|
||||
for {
|
||||
// wait until there is something on the queue or we are closed
|
||||
q.cond.L.Lock()
|
||||
for q.size == 0 && !q.closed {
|
||||
q.cond.Wait()
|
||||
@@ -91,28 +91,32 @@ func (q *msgQueue) run() {
|
||||
q.cond.L.Unlock()
|
||||
return
|
||||
}
|
||||
item := q.q[q.front]
|
||||
q.front = (q.front + 1) % BufferSize
|
||||
q.size--
|
||||
// Drain up to maxDrainBatch items while holding the lock.
|
||||
n := min(q.size, maxDrainBatch)
|
||||
for i := range n {
|
||||
batch[i] = q.q[q.front]
|
||||
q.front = (q.front + 1) % BufferSize
|
||||
}
|
||||
q.size -= n
|
||||
q.cond.L.Unlock()
|
||||
|
||||
// process item without holding lock
|
||||
if item.err == nil {
|
||||
// real message
|
||||
if q.l != nil {
|
||||
q.l(q.ctx, item.msg)
|
||||
// Dispatch each message individually without holding the lock.
|
||||
for i := range n {
|
||||
item := batch[i]
|
||||
if item.err == nil {
|
||||
if q.l != nil {
|
||||
q.l(q.ctx, item.msg)
|
||||
continue
|
||||
}
|
||||
if q.le != nil {
|
||||
q.le(q.ctx, item.msg, nil)
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
if q.le != nil {
|
||||
q.le(q.ctx, item.msg, nil)
|
||||
continue
|
||||
q.le(q.ctx, nil, item.err)
|
||||
}
|
||||
// unhittable
|
||||
continue
|
||||
}
|
||||
// if the listener wants errors, send it.
|
||||
if q.le != nil {
|
||||
q.le(q.ctx, nil, item.err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -233,6 +237,12 @@ type PGPubsub struct {
|
||||
// for a subscriber before dropping messages.
|
||||
const BufferSize = 2048
|
||||
|
||||
// maxDrainBatch is the maximum number of messages to drain from the ring
|
||||
// buffer per iteration. Batching amortizes the cost of mutex
|
||||
// acquire/release and cond.Wait across many messages, improving drain
|
||||
// throughput during bursts.
|
||||
const maxDrainBatch = 256
|
||||
|
||||
// Subscribe calls the listener when an event matching the name is received.
|
||||
func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
|
||||
return p.subscribeQueue(event, newMsgQueue(context.Background(), listener, nil))
|
||||
|
||||
+34
-10
@@ -65,6 +65,7 @@ type sqlcQuerier interface {
|
||||
BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg BatchUpdateWorkspaceAgentMetadataParams) error
|
||||
BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg BatchUpdateWorkspaceLastUsedAtParams) error
|
||||
BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error
|
||||
BatchUpsertConnectionLogs(ctx context.Context, arg BatchUpsertConnectionLogsParams) error
|
||||
BulkMarkNotificationMessagesFailed(ctx context.Context, arg BulkMarkNotificationMessagesFailedParams) (int64, error)
|
||||
BulkMarkNotificationMessagesSent(ctx context.Context, arg BulkMarkNotificationMessagesSentParams) (int64, error)
|
||||
// Calculates the telemetry summary for a given provider, model, and client
|
||||
@@ -150,7 +151,8 @@ type sqlcQuerier interface {
|
||||
DeleteTailnetTunnel(ctx context.Context, arg DeleteTailnetTunnelParams) (DeleteTailnetTunnelRow, error)
|
||||
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
|
||||
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
|
||||
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
|
||||
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
|
||||
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
|
||||
@@ -242,6 +244,10 @@ type sqlcQuerier interface {
|
||||
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
|
||||
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
|
||||
GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error)
|
||||
// GetChatFileMetadataByChatID returns lightweight file metadata for
|
||||
// all files linked to a chat. The data column is excluded to avoid
|
||||
// loading file content.
|
||||
GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]GetChatFileMetadataByChatIDRow, error)
|
||||
GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error)
|
||||
// GetChatIncludeDefaultSystemPrompt preserves the legacy default
|
||||
// for deployments created before the explicit include-default toggle.
|
||||
@@ -476,8 +482,8 @@ type sqlcQuerier interface {
|
||||
// Used for recovery after coderd crashes or long hangs.
|
||||
GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error)
|
||||
GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]TailnetPeer, error)
|
||||
GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerBindingsRow, error)
|
||||
GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerIDsRow, error)
|
||||
GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]GetTailnetTunnelPeerBindingsBatchRow, error)
|
||||
GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uuid.UUID) ([]GetTailnetTunnelPeerIDsBatchRow, error)
|
||||
GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error)
|
||||
GetTaskByOwnerIDAndName(ctx context.Context, arg GetTaskByOwnerIDAndNameParams) (Task, error)
|
||||
GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (Task, error)
|
||||
@@ -577,6 +583,7 @@ type sqlcQuerier interface {
|
||||
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
|
||||
GetUserChatCompactionThreshold(ctx context.Context, arg GetUserChatCompactionThresholdParams) (string, error)
|
||||
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
|
||||
GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error)
|
||||
GetUserChatSpendInPeriod(ctx context.Context, arg GetUserChatSpendInPeriodParams) (int64, error)
|
||||
GetUserCount(ctx context.Context, includeSystem bool) (int64, error)
|
||||
// Returns the minimum (most restrictive) group limit for a user.
|
||||
@@ -591,7 +598,6 @@ type sqlcQuerier interface {
|
||||
GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error)
|
||||
GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error)
|
||||
GetUserNotificationPreferences(ctx context.Context, userID uuid.UUID) ([]NotificationPreference, error)
|
||||
GetUserSecret(ctx context.Context, id uuid.UUID) (UserSecret, error)
|
||||
GetUserSecretByUserIDAndName(ctx context.Context, arg GetUserSecretByUserIDAndNameParams) (UserSecret, error)
|
||||
// GetUserStatusCounts returns the count of users in each status over time.
|
||||
// The time range is inclusively defined by the start_time and end_time parameters.
|
||||
@@ -775,6 +781,15 @@ type sqlcQuerier interface {
|
||||
InsertWorkspaceProxy(ctx context.Context, arg InsertWorkspaceProxyParams) (WorkspaceProxy, error)
|
||||
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
|
||||
InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error)
|
||||
// LinkChatFiles inserts file associations into the chat_file_links
|
||||
// join table with deduplication (ON CONFLICT DO NOTHING). The INSERT
|
||||
// is conditional: it only proceeds when the total number of links
|
||||
// (existing + genuinely new) does not exceed max_file_links. Returns
|
||||
// the number of genuinely new file IDs that were NOT inserted due to
|
||||
// the cap. A return value of 0 means all files were linked (or were
|
||||
// already linked). A positive value means the cap blocked that many
|
||||
// new links.
|
||||
LinkChatFiles(ctx context.Context, arg LinkChatFilesParams) (int32, error)
|
||||
ListAIBridgeClients(ctx context.Context, arg ListAIBridgeClientsParams) ([]string, error)
|
||||
ListAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams) ([]ListAIBridgeInterceptionsRow, error)
|
||||
// Finds all unique AI Bridge interception telemetry summaries combinations
|
||||
@@ -802,7 +817,13 @@ type sqlcQuerier interface {
|
||||
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
|
||||
ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]UserConfig, error)
|
||||
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
|
||||
// Returns metadata only (no value or value_key_id) for the
|
||||
// REST API list and get endpoints.
|
||||
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]ListUserSecretsRow, error)
|
||||
// Returns all columns including the secret value. Used by the
|
||||
// provisioner (build-time injection) and the agent manifest
|
||||
// (runtime injection).
|
||||
ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
|
||||
ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error)
|
||||
MarkAllInboxNotificationsAsRead(ctx context.Context, arg MarkAllInboxNotificationsAsReadParams) error
|
||||
OIDCClaimFieldValues(ctx context.Context, arg OIDCClaimFieldValuesParams) ([]string, error)
|
||||
@@ -854,9 +875,11 @@ type sqlcQuerier interface {
|
||||
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
|
||||
UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error)
|
||||
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
|
||||
// Bumps the heartbeat timestamp for a running chat so that other
|
||||
// replicas know the worker is still alive.
|
||||
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
|
||||
// Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
// provided they are still running and owned by the specified
|
||||
// worker. Returns the IDs that were actually updated so the
|
||||
// caller can detect stolen or completed chats via set-difference.
|
||||
UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error)
|
||||
UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error)
|
||||
// Updates the cached injected context parts (AGENTS.md +
|
||||
// skills) on the chat row. Called only when context changes
|
||||
@@ -927,6 +950,7 @@ type sqlcQuerier interface {
|
||||
UpdateUsageEventsPostPublish(ctx context.Context, arg UpdateUsageEventsPostPublishParams) error
|
||||
UpdateUserChatCompactionThreshold(ctx context.Context, arg UpdateUserChatCompactionThresholdParams) (UserConfig, error)
|
||||
UpdateUserChatCustomPrompt(ctx context.Context, arg UpdateUserChatCustomPromptParams) (UserConfig, error)
|
||||
UpdateUserChatProviderKey(ctx context.Context, arg UpdateUserChatProviderKeyParams) (UserChatProviderKey, error)
|
||||
UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error
|
||||
UpdateUserGithubComUserID(ctx context.Context, arg UpdateUserGithubComUserIDParams) error
|
||||
UpdateUserHashedOneTimePasscode(ctx context.Context, arg UpdateUserHashedOneTimePasscodeParams) error
|
||||
@@ -938,7 +962,7 @@ type sqlcQuerier interface {
|
||||
UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error)
|
||||
UpdateUserQuietHoursSchedule(ctx context.Context, arg UpdateUserQuietHoursScheduleParams) (User, error)
|
||||
UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error)
|
||||
UpdateUserSecret(ctx context.Context, arg UpdateUserSecretParams) (UserSecret, error)
|
||||
UpdateUserSecretByUserIDAndName(ctx context.Context, arg UpdateUserSecretByUserIDAndNameParams) (UserSecret, error)
|
||||
UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error)
|
||||
UpdateUserTaskNotificationAlertDismissed(ctx context.Context, arg UpdateUserTaskNotificationAlertDismissedParams) (bool, error)
|
||||
UpdateUserTerminalFont(ctx context.Context, arg UpdateUserTerminalFontParams) (UserConfig, error)
|
||||
@@ -988,7 +1012,6 @@ type sqlcQuerier interface {
|
||||
UpsertChatUsageLimitGroupOverride(ctx context.Context, arg UpsertChatUsageLimitGroupOverrideParams) (UpsertChatUsageLimitGroupOverrideRow, error)
|
||||
UpsertChatUsageLimitUserOverride(ctx context.Context, arg UpsertChatUsageLimitUserOverrideParams) (UpsertChatUsageLimitUserOverrideRow, error)
|
||||
UpsertChatWorkspaceTTL(ctx context.Context, workspaceTtl string) error
|
||||
UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error)
|
||||
// The default proxy is implied and not actually stored in the database.
|
||||
// So we need to store it's configuration here for display purposes.
|
||||
// The functional values are immutable and controlled implicitly.
|
||||
@@ -1015,6 +1038,7 @@ type sqlcQuerier interface {
|
||||
// used to store the data, and the minutes are summed for each user and template
|
||||
// combination. The result is stored in the template_usage_stats table.
|
||||
UpsertTemplateUsageStats(ctx context.Context) error
|
||||
UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error)
|
||||
UpsertWebpushVAPIDKeys(ctx context.Context, arg UpsertWebpushVAPIDKeysParams) error
|
||||
UpsertWorkspaceAgentPortShare(ctx context.Context, arg UpsertWorkspaceAgentPortShareParams) (WorkspaceAgentPortShare, error)
|
||||
UpsertWorkspaceApp(ctx context.Context, arg UpsertWorkspaceAppParams) (WorkspaceApp, error)
|
||||
|
||||
+562
-255
File diff suppressed because it is too large
Load Diff
+874
-479
File diff suppressed because it is too large
Load Diff
@@ -31,9 +31,9 @@ WHERE aibridge_interceptions.id = (
|
||||
|
||||
-- name: InsertAIBridgeTokenUsage :one
|
||||
INSERT INTO aibridge_token_usages (
|
||||
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
|
||||
id, interception_id, provider_response_id, input_tokens, output_tokens, cache_read_input_tokens, cache_write_input_tokens, metadata, created_at
|
||||
) VALUES (
|
||||
@id, @interception_id, @provider_response_id, @input_tokens, @output_tokens, COALESCE(@metadata::jsonb, '{}'::jsonb), @created_at
|
||||
@id, @interception_id, @provider_response_id, @input_tokens, @output_tokens, @cache_read_input_tokens, @cache_write_input_tokens, COALESCE(@metadata::jsonb, '{}'::jsonb), @created_at
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
@@ -299,21 +299,8 @@ token_aggregates AS (
|
||||
SELECT
|
||||
COALESCE(SUM(tu.input_tokens), 0) AS token_count_input,
|
||||
COALESCE(SUM(tu.output_tokens), 0) AS token_count_output,
|
||||
-- Cached tokens are stored in metadata JSON, extract if available.
|
||||
-- Read tokens may be stored in:
|
||||
-- - cache_read_input (Anthropic)
|
||||
-- - prompt_cached (OpenAI)
|
||||
COALESCE(SUM(
|
||||
COALESCE((tu.metadata->>'cache_read_input')::bigint, 0) +
|
||||
COALESCE((tu.metadata->>'prompt_cached')::bigint, 0)
|
||||
), 0) AS token_count_cached_read,
|
||||
-- Written tokens may be stored in:
|
||||
-- - cache_creation_input (Anthropic)
|
||||
-- Note that cache_ephemeral_5m_input and cache_ephemeral_1h_input on
|
||||
-- Anthropic are included in the cache_creation_input field.
|
||||
COALESCE(SUM(
|
||||
COALESCE((tu.metadata->>'cache_creation_input')::bigint, 0)
|
||||
), 0) AS token_count_cached_written,
|
||||
COALESCE(SUM(tu.cache_read_input_tokens), 0) AS token_count_cached_read,
|
||||
COALESCE(SUM(tu.cache_write_input_tokens), 0) AS token_count_cached_written,
|
||||
COUNT(tu.id) AS token_usages_count
|
||||
FROM
|
||||
interceptions_in_range i
|
||||
@@ -552,6 +539,8 @@ SELECT
|
||||
sp.threads,
|
||||
COALESCE(st.input_tokens, 0)::bigint AS input_tokens,
|
||||
COALESCE(st.output_tokens, 0)::bigint AS output_tokens,
|
||||
COALESCE(st.cache_read_input_tokens, 0)::bigint AS cache_read_input_tokens,
|
||||
COALESCE(st.cache_write_input_tokens, 0)::bigint AS cache_write_input_tokens,
|
||||
COALESCE(slp.prompt, '') AS last_prompt
|
||||
FROM
|
||||
session_page sp
|
||||
@@ -573,7 +562,9 @@ LEFT JOIN LATERAL (
|
||||
-- Aggregate tokens only for this session's interceptions.
|
||||
SELECT
|
||||
COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens,
|
||||
COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens
|
||||
COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens,
|
||||
COALESCE(SUM(tu.cache_read_input_tokens), 0)::bigint AS cache_read_input_tokens,
|
||||
COALESCE(SUM(tu.cache_write_input_tokens), 0)::bigint AS cache_write_input_tokens
|
||||
FROM aibridge_token_usages tu
|
||||
WHERE tu.interception_id = ANY(sr.interception_ids)
|
||||
) st ON true
|
||||
|
||||
@@ -149,94 +149,105 @@ VALUES (
|
||||
RETURNING *;
|
||||
|
||||
-- name: CountAuditLogs :one
|
||||
SELECT COUNT(*)
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN @resource_target::text != '' THEN resource_target = @resource_target
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN @action::text != '' THEN action = @action::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower(@username)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @email::text != '' THEN users.email = @email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
;
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT 1
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN @resource_target::text != '' THEN resource_target = @resource_target
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN @action::text != '' THEN action = @action::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower(@username)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @email::text != '' THEN users.email = @email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
-- Avoid a slow scan on a large table with joins. The caller
|
||||
-- passes the count cap and we add 1 so the frontend can detect
|
||||
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
|
||||
-- -> NULL + 1 = NULL).
|
||||
-- NOTE: Parameterizing this so that we can easily change from,
|
||||
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
|
||||
-- here if disabling the capping on a large table permanently.
|
||||
-- This way the PG planner can plan parallel execution for
|
||||
-- potential large wins.
|
||||
LIMIT NULLIF(@count_cap::int, 0) + 1
|
||||
) AS limited_count;
|
||||
|
||||
-- name: DeleteOldAuditLogConnectionEvents :exec
|
||||
DELETE FROM audit_logs
|
||||
|
||||
@@ -8,3 +8,13 @@ SELECT * FROM chat_files WHERE id = @id::uuid;
|
||||
|
||||
-- name: GetChatFilesByIDs :many
|
||||
SELECT * FROM chat_files WHERE id = ANY(@ids::uuid[]);
|
||||
|
||||
-- name: GetChatFileMetadataByChatID :many
|
||||
-- GetChatFileMetadataByChatID returns lightweight file metadata for
|
||||
-- all files linked to a chat. The data column is excluded to avoid
|
||||
-- loading file content.
|
||||
SELECT cf.id, cf.owner_id, cf.organization_id, cf.name, cf.mimetype, cf.created_at
|
||||
FROM chat_files cf
|
||||
JOIN chat_file_links cfl ON cfl.file_id = cf.id
|
||||
WHERE cfl.chat_id = @chat_id::uuid
|
||||
ORDER BY cf.created_at ASC;
|
||||
|
||||
@@ -40,7 +40,10 @@ INSERT INTO chat_providers (
|
||||
base_url,
|
||||
api_key_key_id,
|
||||
created_by,
|
||||
enabled
|
||||
enabled,
|
||||
central_api_key_enabled,
|
||||
allow_user_api_key,
|
||||
allow_central_api_key_fallback
|
||||
) VALUES (
|
||||
@provider::text,
|
||||
@display_name::text,
|
||||
@@ -48,7 +51,10 @@ INSERT INTO chat_providers (
|
||||
@base_url::text,
|
||||
sqlc.narg('api_key_key_id')::text,
|
||||
sqlc.narg('created_by')::uuid,
|
||||
@enabled::boolean
|
||||
@enabled::boolean,
|
||||
@central_api_key_enabled::boolean,
|
||||
@allow_user_api_key::boolean,
|
||||
@allow_central_api_key_fallback::boolean
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
@@ -62,6 +68,9 @@ SET
|
||||
base_url = @base_url::text,
|
||||
api_key_key_id = sqlc.narg('api_key_key_id')::text,
|
||||
enabled = @enabled::boolean,
|
||||
central_api_key_enabled = @central_api_key_enabled::boolean,
|
||||
allow_user_api_key = @allow_user_api_key::boolean,
|
||||
allow_central_api_key_fallback = @allow_central_api_key_fallback::boolean,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
|
||||
@@ -567,6 +567,43 @@ WHERE
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: LinkChatFiles :one
|
||||
-- LinkChatFiles inserts file associations into the chat_file_links
|
||||
-- join table with deduplication (ON CONFLICT DO NOTHING). The INSERT
|
||||
-- is conditional: it only proceeds when the total number of links
|
||||
-- (existing + genuinely new) does not exceed max_file_links. Returns
|
||||
-- the number of genuinely new file IDs that were NOT inserted due to
|
||||
-- the cap. A return value of 0 means all files were linked (or were
|
||||
-- already linked). A positive value means the cap blocked that many
|
||||
-- new links.
|
||||
WITH current AS (
|
||||
SELECT COUNT(*) AS cnt
|
||||
FROM chat_file_links
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
),
|
||||
new_links AS (
|
||||
SELECT @chat_id::uuid AS chat_id, unnest(@file_ids::uuid[]) AS file_id
|
||||
),
|
||||
genuinely_new AS (
|
||||
SELECT nl.chat_id, nl.file_id
|
||||
FROM new_links nl
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM chat_file_links cfl
|
||||
WHERE cfl.chat_id = nl.chat_id AND cfl.file_id = nl.file_id
|
||||
)
|
||||
),
|
||||
inserted AS (
|
||||
INSERT INTO chat_file_links (chat_id, file_id)
|
||||
SELECT gn.chat_id, gn.file_id
|
||||
FROM genuinely_new gn, current c
|
||||
WHERE c.cnt + (SELECT COUNT(*) FROM genuinely_new) <= @max_file_links::int
|
||||
ON CONFLICT (chat_id, file_id) DO NOTHING
|
||||
RETURNING file_id
|
||||
)
|
||||
SELECT
|
||||
(SELECT COUNT(*)::int FROM genuinely_new) -
|
||||
(SELECT COUNT(*)::int FROM inserted) AS rejected_new_files;
|
||||
|
||||
-- name: AcquireChats :many
|
||||
-- Acquires up to @num_chats pending chats for processing. Uses SKIP LOCKED
|
||||
-- to prevent multiple replicas from acquiring the same chat.
|
||||
@@ -637,17 +674,20 @@ WHERE
|
||||
status = 'running'::chat_status
|
||||
AND heartbeat_at < @stale_threshold::timestamptz;
|
||||
|
||||
-- name: UpdateChatHeartbeat :execrows
|
||||
-- Bumps the heartbeat timestamp for a running chat so that other
|
||||
-- replicas know the worker is still alive.
|
||||
-- name: UpdateChatHeartbeats :many
|
||||
-- Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
-- provided they are still running and owned by the specified
|
||||
-- worker. Returns the IDs that were actually updated so the
|
||||
-- caller can detect stolen or completed chats via set-difference.
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
heartbeat_at = NOW()
|
||||
heartbeat_at = @now::timestamptz
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
id = ANY(@ids::uuid[])
|
||||
AND worker_id = @worker_id::uuid
|
||||
AND status = 'running'::chat_status;
|
||||
AND status = 'running'::chat_status
|
||||
RETURNING id;
|
||||
|
||||
-- name: GetChatDiffStatusByChatID :one
|
||||
SELECT
|
||||
@@ -883,7 +923,8 @@ SELECT
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
@@ -913,7 +954,8 @@ SELECT
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
@@ -948,7 +990,8 @@ WITH chat_costs AS (
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
|
||||
FROM chat_messages cm
|
||||
JOIN chats c ON c.id = cm.chat_id
|
||||
WHERE c.owner_id = @owner_id::uuid
|
||||
@@ -965,7 +1008,8 @@ SELECT
|
||||
cc.total_input_tokens,
|
||||
cc.total_output_tokens,
|
||||
cc.total_cache_read_tokens,
|
||||
cc.total_cache_creation_tokens
|
||||
cc.total_cache_creation_tokens,
|
||||
cc.total_runtime_ms
|
||||
FROM chat_costs cc
|
||||
LEFT JOIN chats rc ON rc.id = cc.root_chat_id
|
||||
ORDER BY cc.total_cost_micros DESC;
|
||||
@@ -991,7 +1035,8 @@ WITH chat_cost_users AS (
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
@@ -1025,6 +1070,7 @@ SELECT
|
||||
total_output_tokens,
|
||||
total_cache_read_tokens,
|
||||
total_cache_creation_tokens,
|
||||
total_runtime_ms,
|
||||
COUNT(*) OVER()::bigint AS total_count
|
||||
FROM
|
||||
chat_cost_users
|
||||
|
||||
@@ -133,111 +133,113 @@ OFFSET
|
||||
@offset_opt;
|
||||
|
||||
-- name: CountConnectionLogs :one
|
||||
SELECT
|
||||
COUNT(*) AS count
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN @workspace_owner :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = @workspace_owner_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN @workspace_owner_email :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = @workspace_owner_email AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN @type :: text != '' THEN
|
||||
type = @type :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@username) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @user_email :: text != '' THEN
|
||||
users.email = @user_email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= @connected_after
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= @connected_before
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = @workspace_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = @connection_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN @status :: text != '' THEN
|
||||
((@status = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
;
|
||||
SELECT COUNT(*) AS count FROM (
|
||||
SELECT 1
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN @workspace_owner :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = @workspace_owner_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN @workspace_owner_email :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = @workspace_owner_email AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN @type :: text != '' THEN
|
||||
type = @type :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@username) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @user_email :: text != '' THEN
|
||||
users.email = @user_email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= @connected_after
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= @connected_before
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = @workspace_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = @connection_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN @status :: text != '' THEN
|
||||
((@status = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
-- NOTE: See the CountAuditLogs LIMIT note.
|
||||
LIMIT NULLIF(@count_cap::int, 0) + 1
|
||||
) AS limited_count;
|
||||
|
||||
-- name: DeleteOldConnectionLogs :execrows
|
||||
WITH old_logs AS (
|
||||
@@ -251,55 +253,75 @@ DELETE FROM connection_logs
|
||||
USING old_logs
|
||||
WHERE connection_logs.id = old_logs.id;
|
||||
|
||||
-- name: UpsertConnectionLog :one
|
||||
-- name: BatchUpsertConnectionLogs :exec
|
||||
INSERT INTO connection_logs (
|
||||
id,
|
||||
connect_time,
|
||||
organization_id,
|
||||
workspace_owner_id,
|
||||
workspace_id,
|
||||
workspace_name,
|
||||
agent_name,
|
||||
type,
|
||||
code,
|
||||
ip,
|
||||
user_agent,
|
||||
user_id,
|
||||
slug_or_port,
|
||||
connection_id,
|
||||
disconnect_reason,
|
||||
disconnect_time
|
||||
) VALUES
|
||||
($1, @time, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14,
|
||||
-- If we've only received a disconnect event, mark the event as immediately
|
||||
-- closed.
|
||||
CASE
|
||||
WHEN @connection_status::connection_status = 'disconnected'
|
||||
THEN @time :: timestamp with time zone
|
||||
ELSE NULL
|
||||
END)
|
||||
id, connect_time, organization_id, workspace_owner_id, workspace_id,
|
||||
workspace_name, agent_name, type, code, ip, user_agent, user_id,
|
||||
slug_or_port, connection_id, disconnect_reason, disconnect_time
|
||||
)
|
||||
SELECT
|
||||
u.id,
|
||||
u.connect_time,
|
||||
u.organization_id,
|
||||
u.workspace_owner_id,
|
||||
u.workspace_id,
|
||||
u.workspace_name,
|
||||
u.agent_name,
|
||||
u.type,
|
||||
-- Use the validity flag to distinguish "no code" (NULL) from a
|
||||
-- legitimate zero exit code.
|
||||
CASE WHEN u.code_valid THEN u.code ELSE NULL END,
|
||||
u.ip,
|
||||
NULLIF(u.user_agent, ''),
|
||||
NULLIF(u.user_id, '00000000-0000-0000-0000-000000000000'::uuid),
|
||||
NULLIF(u.slug_or_port, ''),
|
||||
NULLIF(u.connection_id, '00000000-0000-0000-0000-000000000000'::uuid),
|
||||
NULLIF(u.disconnect_reason, ''),
|
||||
NULLIF(u.disconnect_time, '0001-01-01 00:00:00Z'::timestamptz)
|
||||
FROM (
|
||||
SELECT
|
||||
unnest(sqlc.arg('id')::uuid[]) AS id,
|
||||
unnest(sqlc.arg('connect_time')::timestamptz[]) AS connect_time,
|
||||
unnest(sqlc.arg('organization_id')::uuid[]) AS organization_id,
|
||||
unnest(sqlc.arg('workspace_owner_id')::uuid[]) AS workspace_owner_id,
|
||||
unnest(sqlc.arg('workspace_id')::uuid[]) AS workspace_id,
|
||||
unnest(sqlc.arg('workspace_name')::text[]) AS workspace_name,
|
||||
unnest(sqlc.arg('agent_name')::text[]) AS agent_name,
|
||||
unnest(sqlc.arg('type')::connection_type[]) AS type,
|
||||
unnest(sqlc.arg('code')::int4[]) AS code,
|
||||
unnest(sqlc.arg('code_valid')::bool[]) AS code_valid,
|
||||
unnest(sqlc.arg('ip')::inet[]) AS ip,
|
||||
unnest(sqlc.arg('user_agent')::text[]) AS user_agent,
|
||||
unnest(sqlc.arg('user_id')::uuid[]) AS user_id,
|
||||
unnest(sqlc.arg('slug_or_port')::text[]) AS slug_or_port,
|
||||
unnest(sqlc.arg('connection_id')::uuid[]) AS connection_id,
|
||||
unnest(sqlc.arg('disconnect_reason')::text[]) AS disconnect_reason,
|
||||
unnest(sqlc.arg('disconnect_time')::timestamptz[]) AS disconnect_time
|
||||
) AS u
|
||||
ON CONFLICT (connection_id, workspace_id, agent_name)
|
||||
DO UPDATE SET
|
||||
-- No-op if the connection is still open.
|
||||
disconnect_time = CASE
|
||||
WHEN @connection_status::connection_status = 'disconnected'
|
||||
-- Can only be set once
|
||||
AND connection_logs.disconnect_time IS NULL
|
||||
THEN EXCLUDED.connect_time
|
||||
ELSE connection_logs.disconnect_time
|
||||
END,
|
||||
disconnect_reason = CASE
|
||||
WHEN @connection_status::connection_status = 'disconnected'
|
||||
-- Can only be set once
|
||||
AND connection_logs.disconnect_reason IS NULL
|
||||
THEN EXCLUDED.disconnect_reason
|
||||
ELSE connection_logs.disconnect_reason
|
||||
END,
|
||||
code = CASE
|
||||
WHEN @connection_status::connection_status = 'disconnected'
|
||||
-- Can only be set once
|
||||
AND connection_logs.code IS NULL
|
||||
THEN EXCLUDED.code
|
||||
ELSE connection_logs.code
|
||||
END
|
||||
RETURNING *;
|
||||
-- Pick the earliest real connect_time. The zero sentinel
|
||||
-- ('0001-01-01') means the batch didn't know the connect_time
|
||||
-- (e.g. a pure disconnect event), so we keep the existing value.
|
||||
connect_time = CASE
|
||||
WHEN EXCLUDED.connect_time = '0001-01-01 00:00:00Z'::timestamptz
|
||||
THEN connection_logs.connect_time
|
||||
WHEN connection_logs.connect_time = '0001-01-01 00:00:00Z'::timestamptz
|
||||
THEN EXCLUDED.connect_time
|
||||
ELSE LEAST(connection_logs.connect_time, EXCLUDED.connect_time)
|
||||
END,
|
||||
disconnect_time = CASE
|
||||
WHEN connection_logs.disconnect_time IS NULL
|
||||
THEN EXCLUDED.disconnect_time
|
||||
ELSE connection_logs.disconnect_time
|
||||
END,
|
||||
disconnect_reason = CASE
|
||||
WHEN connection_logs.disconnect_reason IS NULL
|
||||
THEN EXCLUDED.disconnect_reason
|
||||
ELSE connection_logs.disconnect_reason
|
||||
END,
|
||||
code = CASE
|
||||
WHEN connection_logs.code IS NULL
|
||||
THEN EXCLUDED.code
|
||||
ELSE connection_logs.code
|
||||
END;
|
||||
|
||||
@@ -96,28 +96,6 @@ DELETE
|
||||
FROM tailnet_tunnels
|
||||
WHERE coordinator_id = $1 and src_id = $2;
|
||||
|
||||
-- name: GetTailnetTunnelPeerIDs :many
|
||||
SELECT dst_id as peer_id, coordinator_id, updated_at
|
||||
FROM tailnet_tunnels
|
||||
WHERE tailnet_tunnels.src_id = $1
|
||||
UNION
|
||||
SELECT src_id as peer_id, coordinator_id, updated_at
|
||||
FROM tailnet_tunnels
|
||||
WHERE tailnet_tunnels.dst_id = $1;
|
||||
|
||||
-- name: GetTailnetTunnelPeerBindings :many
|
||||
SELECT id AS peer_id, coordinator_id, updated_at, node, status
|
||||
FROM tailnet_peers
|
||||
WHERE id IN (
|
||||
SELECT dst_id as peer_id
|
||||
FROM tailnet_tunnels
|
||||
WHERE tailnet_tunnels.src_id = $1
|
||||
UNION
|
||||
SELECT src_id as peer_id
|
||||
FROM tailnet_tunnels
|
||||
WHERE tailnet_tunnels.dst_id = $1
|
||||
);
|
||||
|
||||
-- For PG Coordinator HTMLDebug
|
||||
|
||||
-- name: GetAllTailnetCoordinators :many
|
||||
@@ -128,3 +106,22 @@ SELECT * FROM tailnet_peers;
|
||||
|
||||
-- name: GetAllTailnetTunnels :many
|
||||
SELECT * FROM tailnet_tunnels;
|
||||
|
||||
-- name: GetTailnetTunnelPeerIDsBatch :many
|
||||
SELECT src_id AS lookup_id, dst_id AS peer_id, coordinator_id, updated_at
|
||||
FROM tailnet_tunnels WHERE src_id = ANY(@ids :: uuid[])
|
||||
UNION ALL
|
||||
SELECT dst_id AS lookup_id, src_id AS peer_id, coordinator_id, updated_at
|
||||
FROM tailnet_tunnels WHERE dst_id = ANY(@ids :: uuid[]);
|
||||
|
||||
-- name: GetTailnetTunnelPeerBindingsBatch :many
|
||||
SELECT tp.id AS peer_id, tp.coordinator_id, tp.updated_at, tp.node, tp.status,
|
||||
tunnels.lookup_id
|
||||
FROM (
|
||||
SELECT dst_id AS peer_id, src_id AS lookup_id
|
||||
FROM tailnet_tunnels WHERE src_id = ANY(@ids :: uuid[])
|
||||
UNION
|
||||
SELECT src_id AS peer_id, dst_id AS lookup_id
|
||||
FROM tailnet_tunnels WHERE dst_id = ANY(@ids :: uuid[])
|
||||
) tunnels
|
||||
INNER JOIN tailnet_peers tp ON tp.id = tunnels.peer_id;
|
||||
|
||||
@@ -1,14 +1,26 @@
|
||||
-- name: GetUserSecretByUserIDAndName :one
|
||||
SELECT * FROM user_secrets
|
||||
WHERE user_id = $1 AND name = $2;
|
||||
|
||||
-- name: GetUserSecret :one
|
||||
SELECT * FROM user_secrets
|
||||
WHERE id = $1;
|
||||
SELECT *
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id AND name = @name;
|
||||
|
||||
-- name: ListUserSecrets :many
|
||||
SELECT * FROM user_secrets
|
||||
WHERE user_id = $1
|
||||
-- Returns metadata only (no value or value_key_id) for the
|
||||
-- REST API list and get endpoints.
|
||||
SELECT
|
||||
id, user_id, name, description,
|
||||
env_name, file_path,
|
||||
created_at, updated_at
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id
|
||||
ORDER BY name ASC;
|
||||
|
||||
-- name: ListUserSecretsWithValues :many
|
||||
-- Returns all columns including the secret value. Used by the
|
||||
-- provisioner (build-time injection) and the agent manifest
|
||||
-- (runtime injection).
|
||||
SELECT *
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id
|
||||
ORDER BY name ASC;
|
||||
|
||||
-- name: CreateUserSecret :one
|
||||
@@ -18,23 +30,32 @@ INSERT INTO user_secrets (
|
||||
name,
|
||||
description,
|
||||
value,
|
||||
value_key_id,
|
||||
env_name,
|
||||
file_path
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7
|
||||
@id,
|
||||
@user_id,
|
||||
@name,
|
||||
@description,
|
||||
@value,
|
||||
@value_key_id,
|
||||
@env_name,
|
||||
@file_path
|
||||
) RETURNING *;
|
||||
|
||||
-- name: UpdateUserSecret :one
|
||||
-- name: UpdateUserSecretByUserIDAndName :one
|
||||
UPDATE user_secrets
|
||||
SET
|
||||
description = $2,
|
||||
value = $3,
|
||||
env_name = $4,
|
||||
file_path = $5,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = $1
|
||||
value = CASE WHEN @update_value::bool THEN @value ELSE value END,
|
||||
value_key_id = CASE WHEN @update_value::bool THEN @value_key_id ELSE value_key_id END,
|
||||
description = CASE WHEN @update_description::bool THEN @description ELSE description END,
|
||||
env_name = CASE WHEN @update_env_name::bool THEN @env_name ELSE env_name END,
|
||||
file_path = CASE WHEN @update_file_path::bool THEN @file_path ELSE file_path END,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE user_id = @user_id AND name = @name
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteUserSecret :exec
|
||||
-- name: DeleteUserSecretByUserIDAndName :exec
|
||||
DELETE FROM user_secrets
|
||||
WHERE id = $1;
|
||||
WHERE user_id = @user_id AND name = @name;
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
-- name: GetUserChatProviderKeys :many
|
||||
SELECT * FROM user_chat_provider_keys WHERE user_id = @user_id ORDER BY created_at ASC, id ASC;
|
||||
|
||||
-- name: UpsertUserChatProviderKey :one
|
||||
INSERT INTO user_chat_provider_keys (user_id, chat_provider_id, api_key, api_key_key_id)
|
||||
VALUES (@user_id, @chat_provider_id, @api_key, sqlc.narg('api_key_key_id')::text)
|
||||
ON CONFLICT (user_id, chat_provider_id) DO UPDATE SET
|
||||
api_key = @api_key,
|
||||
api_key_key_id = sqlc.narg('api_key_key_id')::text,
|
||||
updated_at = NOW()
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateUserChatProviderKey :one
|
||||
UPDATE user_chat_provider_keys
|
||||
SET api_key = @api_key, api_key_key_id = sqlc.narg('api_key_key_id')::text, updated_at = NOW()
|
||||
WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteUserChatProviderKey :exec
|
||||
DELETE FROM user_chat_provider_keys WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id;
|
||||
@@ -247,6 +247,7 @@ sql:
|
||||
mcp_server_tool_snapshots: MCPServerToolSnapshots
|
||||
mcp_server_config_id: MCPServerConfigID
|
||||
mcp_server_ids: MCPServerIDs
|
||||
max_file_links: MaxFileLinks
|
||||
icon_url: IconURL
|
||||
oauth2_client_id: OAuth2ClientID
|
||||
oauth2_client_secret: OAuth2ClientSecret
|
||||
|
||||
@@ -16,6 +16,7 @@ const (
|
||||
UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
|
||||
UniqueBoundaryUsageStatsPkey UniqueConstraint = "boundary_usage_stats_pkey" // ALTER TABLE ONLY boundary_usage_stats ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
|
||||
UniqueChatDiffStatusesPkey UniqueConstraint = "chat_diff_statuses_pkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
UniqueChatFileLinksChatIDFileIDKey UniqueConstraint = "chat_file_links_chat_id_file_id_key" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_file_id_key UNIQUE (chat_id, file_id);
|
||||
UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
|
||||
UniqueChatMessagesPkey UniqueConstraint = "chat_messages_pkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
|
||||
UniqueChatModelConfigsPkey UniqueConstraint = "chat_model_configs_pkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id);
|
||||
@@ -90,6 +91,8 @@ const (
|
||||
UniqueTemplatesPkey UniqueConstraint = "templates_pkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_pkey PRIMARY KEY (id);
|
||||
UniqueUsageEventsDailyPkey UniqueConstraint = "usage_events_daily_pkey" // ALTER TABLE ONLY usage_events_daily ADD CONSTRAINT usage_events_daily_pkey PRIMARY KEY (day, event_type);
|
||||
UniqueUsageEventsPkey UniqueConstraint = "usage_events_pkey" // ALTER TABLE ONLY usage_events ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id);
|
||||
UniqueUserChatProviderKeysPkey UniqueConstraint = "user_chat_provider_keys_pkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id);
|
||||
UniqueUserChatProviderKeysUserIDChatProviderIDKey UniqueConstraint = "user_chat_provider_keys_user_id_chat_provider_id_key" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id);
|
||||
UniqueUserConfigsPkey UniqueConstraint = "user_configs_pkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key);
|
||||
UniqueUserDeletedPkey UniqueConstraint = "user_deleted_pkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_pkey PRIMARY KEY (id);
|
||||
UniqueUserLinksPkey UniqueConstraint = "user_links_pkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_pkey PRIMARY KEY (user_id, login_type);
|
||||
|
||||
+616
-107
File diff suppressed because it is too large
Load Diff
+1552
-10
File diff suppressed because it is too large
Load Diff
@@ -39,13 +39,14 @@ func TestChatParam(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
_, err := db.InsertChatProvider(context.Background(), database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-api-key",
|
||||
BaseUrl: "https://api.openai.com/v1",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: ownerID, Valid: true},
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-api-key",
|
||||
BaseUrl: "https://api.openai.com/v1",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: ownerID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
+29
-16
@@ -2,6 +2,7 @@ package prebuilds
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -22,7 +23,11 @@ type PubsubWorkspaceClaimPublisher struct {
|
||||
|
||||
func (p PubsubWorkspaceClaimPublisher) PublishWorkspaceClaim(claim agentsdk.ReinitializationEvent) error {
|
||||
channel := agentsdk.PrebuildClaimedChannel(claim.WorkspaceID)
|
||||
if err := p.ps.Publish(channel, []byte(claim.Reason)); err != nil {
|
||||
payload, err := json.Marshal(claim)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal claim event: %w", err)
|
||||
}
|
||||
if err := p.ps.Publish(channel, payload); err != nil {
|
||||
return xerrors.Errorf("failed to trigger prebuilt workspace agent reinitialization: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -37,33 +42,41 @@ type PubsubWorkspaceClaimListener struct {
|
||||
ps pubsub.Pubsub
|
||||
}
|
||||
|
||||
// ListenForWorkspaceClaims subscribes to a pubsub channel and sends any received events on the chan that it returns.
|
||||
// pubsub.Pubsub does not communicate when its last callback has been called after it has been closed. As such the chan
|
||||
// returned by this method is never closed. Call the returned cancel() function to close the subscription when it is no longer needed.
|
||||
// cancel() will be called if ctx expires or is canceled.
|
||||
func (p PubsubWorkspaceClaimListener) ListenForWorkspaceClaims(ctx context.Context, workspaceID uuid.UUID, reinitEvents chan<- agentsdk.ReinitializationEvent) (func(), error) {
|
||||
// ListenForWorkspaceClaims subscribes to a pubsub channel and returns a
|
||||
// receive-only channel that emits claim events for the given workspace.
|
||||
// The returned channel is owned by this function and is never closed,
|
||||
// because pubsub.Pubsub does not guarantee that all in-flight callbacks
|
||||
// have returned after unsubscribe. Call the returned cancel function to
|
||||
// unsubscribe when events are no longer needed; cancel is also called
|
||||
// automatically if ctx expires or is canceled.
|
||||
func (p PubsubWorkspaceClaimListener) ListenForWorkspaceClaims(ctx context.Context, workspaceID uuid.UUID) (<-chan agentsdk.ReinitializationEvent, func(), error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return func() {}, ctx.Err()
|
||||
return nil, func() {}, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
cancelSub, err := p.ps.Subscribe(agentsdk.PrebuildClaimedChannel(workspaceID), func(inner context.Context, reason []byte) {
|
||||
claim := agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: workspaceID,
|
||||
Reason: agentsdk.ReinitializationReason(reason),
|
||||
reinitEvents := make(chan agentsdk.ReinitializationEvent, 1)
|
||||
|
||||
cancelSub, err := p.ps.Subscribe(agentsdk.PrebuildClaimedChannel(workspaceID), func(inner context.Context, payload []byte) {
|
||||
var event agentsdk.ReinitializationEvent
|
||||
if err := json.Unmarshal(payload, &event); err != nil {
|
||||
// Rolling upgrade: old publishers send the raw reason
|
||||
// string instead of JSON.
|
||||
event = agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: workspaceID,
|
||||
Reason: agentsdk.ReinitializationReason(payload),
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-inner.Done():
|
||||
return
|
||||
case reinitEvents <- claim:
|
||||
case reinitEvents <- event:
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return func() {}, xerrors.Errorf("failed to subscribe to prebuild claimed channel: %w", err)
|
||||
return nil, func() {}, xerrors.Errorf("failed to subscribe to prebuild claimed channel: %w", err)
|
||||
}
|
||||
|
||||
var once sync.Once
|
||||
@@ -78,5 +91,5 @@ func (p PubsubWorkspaceClaimListener) ListenForWorkspaceClaims(ctx context.Conte
|
||||
cancel()
|
||||
}()
|
||||
|
||||
return cancel, nil
|
||||
return reinitEvents, cancel, nil
|
||||
}
|
||||
|
||||
@@ -25,24 +25,26 @@ func TestPubsubWorkspaceClaimPublisher(t *testing.T) {
|
||||
logger := testutil.Logger(t)
|
||||
ps := pubsub.NewInMemory()
|
||||
workspaceID := uuid.New()
|
||||
reinitEvents := make(chan agentsdk.ReinitializationEvent, 1)
|
||||
publisher := prebuilds.NewPubsubWorkspaceClaimPublisher(ps)
|
||||
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, logger)
|
||||
|
||||
cancel, err := listener.ListenForWorkspaceClaims(ctx, workspaceID, reinitEvents)
|
||||
events, cancel, err := listener.ListenForWorkspaceClaims(ctx, workspaceID)
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
userID := uuid.New()
|
||||
claim := agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: workspaceID,
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
OwnerID: userID,
|
||||
}
|
||||
err = publisher.PublishWorkspaceClaim(claim)
|
||||
require.NoError(t, err)
|
||||
|
||||
gotEvent := testutil.RequireReceive(ctx, t, reinitEvents)
|
||||
gotEvent := testutil.RequireReceive(ctx, t, events)
|
||||
require.Equal(t, workspaceID, gotEvent.WorkspaceID)
|
||||
require.Equal(t, claim.Reason, gotEvent.Reason)
|
||||
require.Equal(t, userID, gotEvent.OwnerID)
|
||||
})
|
||||
|
||||
t.Run("fail to publish claim", func(t *testing.T) {
|
||||
@@ -69,10 +71,8 @@ func TestPubsubWorkspaceClaimListener(t *testing.T) {
|
||||
ps := pubsub.NewInMemory()
|
||||
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
|
||||
|
||||
claims := make(chan agentsdk.ReinitializationEvent, 1) // Buffer to avoid messing with goroutines in the rest of the test
|
||||
|
||||
workspaceID := uuid.New()
|
||||
cancelFunc, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID, claims)
|
||||
events, cancelFunc, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID)
|
||||
require.NoError(t, err)
|
||||
defer cancelFunc()
|
||||
|
||||
@@ -84,9 +84,10 @@ func TestPubsubWorkspaceClaimListener(t *testing.T) {
|
||||
|
||||
// Verify we receive the claim
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
claim := testutil.RequireReceive(ctx, t, claims)
|
||||
claim := testutil.RequireReceive(ctx, t, events)
|
||||
require.Equal(t, workspaceID, claim.WorkspaceID)
|
||||
require.Equal(t, reason, claim.Reason)
|
||||
require.Equal(t, uuid.Nil, claim.OwnerID)
|
||||
})
|
||||
|
||||
t.Run("ignores claim events for other workspaces", func(t *testing.T) {
|
||||
@@ -95,10 +96,9 @@ func TestPubsubWorkspaceClaimListener(t *testing.T) {
|
||||
ps := pubsub.NewInMemory()
|
||||
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
|
||||
|
||||
claims := make(chan agentsdk.ReinitializationEvent)
|
||||
workspaceID := uuid.New()
|
||||
otherWorkspaceID := uuid.New()
|
||||
cancelFunc, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID, claims)
|
||||
events, cancelFunc, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID)
|
||||
require.NoError(t, err)
|
||||
defer cancelFunc()
|
||||
|
||||
@@ -109,7 +109,7 @@ func TestPubsubWorkspaceClaimListener(t *testing.T) {
|
||||
|
||||
// Verify we don't receive the claim
|
||||
select {
|
||||
case <-claims:
|
||||
case <-events:
|
||||
t.Fatal("received claim for wrong workspace")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Expected - no claim received
|
||||
@@ -119,11 +119,10 @@ func TestPubsubWorkspaceClaimListener(t *testing.T) {
|
||||
t.Run("communicates the error if it can't subscribe", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
claims := make(chan agentsdk.ReinitializationEvent)
|
||||
ps := &brokenPubsub{}
|
||||
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
|
||||
|
||||
_, err := listener.ListenForWorkspaceClaims(context.Background(), uuid.New(), claims)
|
||||
_, _, err := listener.ListenForWorkspaceClaims(context.Background(), uuid.New())
|
||||
require.ErrorContains(t, err, "failed to subscribe to prebuild claimed channel")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2539,6 +2539,7 @@ func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.Pro
|
||||
err = prebuilds.NewPubsubWorkspaceClaimPublisher(s.Pubsub).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: workspace.ID,
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
OwnerID: workspace.OwnerID,
|
||||
})
|
||||
if err != nil {
|
||||
s.Logger.Error(ctx, "failed to publish workspace claim event", slog.Error(err))
|
||||
|
||||
@@ -51,7 +51,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/usage/usagetypes"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/provisionerd/proto"
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
|
||||
@@ -2787,8 +2786,7 @@ func TestCompleteJob(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// GIVEN something is listening to process workspace reinitialization:
|
||||
reinitChan := make(chan agentsdk.ReinitializationEvent, 1) // Buffered to simplify test structure
|
||||
cancel, err := agplprebuilds.NewPubsubWorkspaceClaimListener(ps, testutil.Logger(t)).ListenForWorkspaceClaims(ctx, workspace.ID, reinitChan)
|
||||
reinitChan, cancel, err := agplprebuilds.NewPubsubWorkspaceClaimListener(ps, testutil.Logger(t)).ListenForWorkspaceClaims(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
|
||||
@@ -298,6 +298,40 @@ neq(input.object.owner, "");
|
||||
ExpectedSQL: p("'' = 'org-id'"),
|
||||
VariableConverter: regosql.ChatConverter(),
|
||||
},
|
||||
{
|
||||
Name: "AuditLogUUID",
|
||||
Queries: []string{
|
||||
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
|
||||
`input.object.org_owner != ""`,
|
||||
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
|
||||
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708", "05f58202-4bfc-43ce-9ba4-5ff6e0174a71"}`,
|
||||
`"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
},
|
||||
ExpectedSQL: p(
|
||||
p("audit_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("audit_logs.organization_id IS NOT NULL") + " OR " +
|
||||
p("audit_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("audit_logs.organization_id = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'::uuid,'8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
|
||||
"(false)"),
|
||||
VariableConverter: regosql.AuditLogConverter(),
|
||||
},
|
||||
{
|
||||
Name: "ConnectionLogUUID",
|
||||
Queries: []string{
|
||||
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
|
||||
`input.object.org_owner != ""`,
|
||||
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
|
||||
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708"}`,
|
||||
`"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
},
|
||||
ExpectedSQL: p(
|
||||
p("connection_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("connection_logs.organization_id IS NOT NULL") + " OR " +
|
||||
p("connection_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("connection_logs.organization_id = ANY(ARRAY ['8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
|
||||
"(false)"),
|
||||
VariableConverter: regosql.ConnectionLogConverter(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -53,7 +53,7 @@ func WorkspaceConverter() *sqltypes.VariableConverter {
|
||||
func AuditLogConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
resourceIDMatcher(),
|
||||
sqltypes.StringVarMatcher("COALESCE(audit_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
|
||||
sqltypes.UUIDVarMatcher("audit_logs.organization_id", []string{"input", "object", "org_owner"}),
|
||||
// Audit logs have no user owner, only owner by an organization.
|
||||
sqltypes.AlwaysFalse(userOwnerMatcher()),
|
||||
)
|
||||
@@ -67,7 +67,7 @@ func AuditLogConverter() *sqltypes.VariableConverter {
|
||||
func ConnectionLogConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
resourceIDMatcher(),
|
||||
sqltypes.StringVarMatcher("COALESCE(connection_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
|
||||
sqltypes.UUIDVarMatcher("connection_logs.organization_id", []string{"input", "object", "org_owner"}),
|
||||
// Connection logs have no user owner, only owner by an organization.
|
||||
sqltypes.AlwaysFalse(userOwnerMatcher()),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
package sqltypes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
var (
|
||||
_ VariableMatcher = astUUIDVar{}
|
||||
_ Node = astUUIDVar{}
|
||||
_ SupportsEquality = astUUIDVar{}
|
||||
)
|
||||
|
||||
// astUUIDVar is a variable that represents a UUID column. Unlike
|
||||
// astStringVar it emits native UUID comparisons (column = 'val'::uuid)
|
||||
// instead of text-based ones (COALESCE(column::text, ”) = 'val').
|
||||
// This allows PostgreSQL to use indexes on UUID columns.
|
||||
type astUUIDVar struct {
|
||||
Source RegoSource
|
||||
FieldPath []string
|
||||
ColumnString string
|
||||
}
|
||||
|
||||
func UUIDVarMatcher(sqlColumn string, regoPath []string) VariableMatcher {
|
||||
return astUUIDVar{FieldPath: regoPath, ColumnString: sqlColumn}
|
||||
}
|
||||
|
||||
func (astUUIDVar) UseAs() Node { return astUUIDVar{} }
|
||||
|
||||
func (u astUUIDVar) ConvertVariable(rego ast.Ref) (Node, bool) {
|
||||
left, err := RegoVarPath(u.FieldPath, rego)
|
||||
if err == nil && len(left) == 0 {
|
||||
return astUUIDVar{
|
||||
Source: RegoSource(rego.String()),
|
||||
FieldPath: u.FieldPath,
|
||||
ColumnString: u.ColumnString,
|
||||
}, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (u astUUIDVar) SQLString(_ *SQLGenerator) string {
|
||||
return u.ColumnString
|
||||
}
|
||||
|
||||
// EqualsSQLString handles equality comparisons for UUID columns.
|
||||
// Rego always produces string literals, so we accept AstString and
|
||||
// cast the literal to ::uuid in the output SQL. This lets PG use
|
||||
// native UUID indexes instead of falling back to text comparisons.
|
||||
// nolint:revive
|
||||
func (u astUUIDVar) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||
switch other.UseAs().(type) {
|
||||
case AstString:
|
||||
// The other side is a rego string literal like
|
||||
// "8c0b9bdc-a013-4b14-a49b-5747bc335708". Emit a comparison
|
||||
// that casts the literal to uuid so PG can use indexes:
|
||||
// column = 'val'::uuid
|
||||
// instead of the text-based:
|
||||
// 'val' = COALESCE(column::text, '')
|
||||
s, ok := other.(AstString)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("expected AstString, got %T", other)
|
||||
}
|
||||
if s.Value == "" {
|
||||
// Empty string in rego means "no value". Compare the
|
||||
// column against NULL since UUID columns represent
|
||||
// absent values as NULL, not empty strings.
|
||||
op := "IS NULL"
|
||||
if not {
|
||||
op = "IS NOT NULL"
|
||||
}
|
||||
return fmt.Sprintf("%s %s", u.ColumnString, op), nil
|
||||
}
|
||||
return fmt.Sprintf("%s %s '%s'::uuid",
|
||||
u.ColumnString, equalsOp(not), s.Value), nil
|
||||
case astUUIDVar:
|
||||
return basicSQLEquality(cfg, not, u, other), nil
|
||||
default:
|
||||
return "", xerrors.Errorf("unsupported equality: %T %s %T",
|
||||
u, equalsOp(not), other)
|
||||
}
|
||||
}
|
||||
|
||||
// ContainedInSQL implements SupportsContainedIn so that a UUID column
|
||||
// can appear in membership checks like `col = ANY(ARRAY[...])`. The
|
||||
// array elements are rego strings, so we cast each to ::uuid.
|
||||
func (u astUUIDVar) ContainedInSQL(_ *SQLGenerator, haystack Node) (string, error) {
|
||||
arr, ok := haystack.(ASTArray)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("unsupported containedIn: %T in %T", u, haystack)
|
||||
}
|
||||
|
||||
if len(arr.Value) == 0 {
|
||||
return "false", nil
|
||||
}
|
||||
|
||||
// Build ARRAY['uuid1'::uuid, 'uuid2'::uuid, ...]
|
||||
values := make([]string, 0, len(arr.Value))
|
||||
for _, v := range arr.Value {
|
||||
s, ok := v.(AstString)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("expected AstString array element, got %T", v)
|
||||
}
|
||||
values = append(values, fmt.Sprintf("'%s'::uuid", s.Value))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s = ANY(ARRAY [%s])",
|
||||
u.ColumnString,
|
||||
strings.Join(values, ",")), nil
|
||||
}
|
||||
@@ -66,7 +66,7 @@ func AuditLogs(ctx context.Context, db database.Store, query string) (database.G
|
||||
}
|
||||
|
||||
// Prepare the count filter, which uses the same parameters as the GetAuditLogsOffsetParams.
|
||||
// nolint:exhaustruct // UserID is not obtained from the query parameters.
|
||||
// nolint:exhaustruct // UserID and CountCap are not obtained from the query parameters.
|
||||
countFilter := database.CountAuditLogsParams{
|
||||
RequestID: filter.RequestID,
|
||||
ResourceID: filter.ResourceID,
|
||||
@@ -123,6 +123,7 @@ func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey
|
||||
}
|
||||
|
||||
// This MUST be kept in sync with the above
|
||||
// nolint:exhaustruct // CountCap is not obtained from the query parameters.
|
||||
countFilter := database.CountConnectionLogsParams{
|
||||
OrganizationID: filter.OrganizationID,
|
||||
WorkspaceOwner: filter.WorkspaceOwner,
|
||||
|
||||
+37
-19
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -389,6 +390,7 @@ type MultiAgentController struct {
|
||||
// connections to the destination
|
||||
tickets map[uuid.UUID]map[uuid.UUID]struct{}
|
||||
coordination *tailnet.BasicCoordination
|
||||
sendGroup singleflight.Group
|
||||
|
||||
cancel context.CancelFunc
|
||||
expireOldAgentsDone chan struct{}
|
||||
@@ -418,28 +420,44 @@ func (m *MultiAgentController) New(client tailnet.CoordinatorClient) tailnet.Clo
|
||||
|
||||
func (m *MultiAgentController) ensureAgent(agentID uuid.UUID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
_, ok := m.connectionTimes[agentID]
|
||||
// If we don't have the agent, subscribe.
|
||||
if !ok {
|
||||
m.logger.Debug(context.Background(),
|
||||
"subscribing to agent", slog.F("agent_id", agentID))
|
||||
if m.coordination != nil {
|
||||
err := m.coordination.Client.Send(&proto.CoordinateRequest{
|
||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
||||
})
|
||||
if err != nil {
|
||||
err = xerrors.Errorf("subscribe agent: %w", err)
|
||||
m.coordination.SendErr(err)
|
||||
_ = m.coordination.Client.Close()
|
||||
m.coordination = nil
|
||||
return err
|
||||
}
|
||||
}
|
||||
m.tickets[agentID] = map[uuid.UUID]struct{}{}
|
||||
if ok {
|
||||
m.connectionTimes[agentID] = time.Now()
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
m.logger.Debug(context.Background(),
|
||||
"subscribing to agent", slog.F("agent_id", agentID))
|
||||
|
||||
_, err, _ := m.sendGroup.Do(agentID.String(), func() (interface{}, error) {
|
||||
m.mu.Lock()
|
||||
coord := m.coordination
|
||||
m.mu.Unlock()
|
||||
if coord == nil {
|
||||
return nil, xerrors.New("no active coordination")
|
||||
}
|
||||
err := coord.Client.Send(&proto.CoordinateRequest{
|
||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.tickets[agentID] = map[uuid.UUID]struct{}{}
|
||||
m.mu.Unlock()
|
||||
return nil, nil
|
||||
})
|
||||
if err != nil {
|
||||
m.logger.Error(context.Background(), "ensureAgent send failed",
|
||||
slog.F("agent_id", agentID), slog.Error(err))
|
||||
return xerrors.Errorf("send AddTunnel: %w", err)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.connectionTimes[agentID] = time.Now()
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1502,6 +1502,7 @@ type Snapshot struct {
|
||||
PrebuiltWorkspaces []PrebuiltWorkspace `json:"prebuilt_workspaces"`
|
||||
AIBridgeInterceptionsSummaries []AIBridgeInterceptionsSummary `json:"aibridge_interceptions_summaries"`
|
||||
BoundaryUsageSummary *BoundaryUsageSummary `json:"boundary_usage_summary"`
|
||||
FirstUserOnboarding *FirstUserOnboarding `json:"first_user_onboarding"`
|
||||
}
|
||||
|
||||
// Deployment contains information about the host running Coder.
|
||||
@@ -1551,6 +1552,14 @@ type User struct {
|
||||
LoginType string `json:"login_type,omitempty"`
|
||||
}
|
||||
|
||||
// FirstUserOnboarding contains optional newsletter preference data
|
||||
// collected during first user setup. This is sent once when the first
|
||||
// user is created.
|
||||
type FirstUserOnboarding struct {
|
||||
NewsletterMarketing bool `json:"newsletter_marketing"`
|
||||
NewsletterReleases bool `json:"newsletter_releases"`
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -223,10 +223,12 @@ func TestTelemetry(t *testing.T) {
|
||||
StartedAt: previousAIBridgeInterceptionPeriod.Add(-30 * time.Minute),
|
||||
}, nil)
|
||||
_ = dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{
|
||||
InterceptionID: aiBridgeInterception1.ID,
|
||||
InputTokens: 100,
|
||||
OutputTokens: 200,
|
||||
Metadata: json.RawMessage(`{"cache_read_input":300,"cache_creation_input":400}`),
|
||||
InterceptionID: aiBridgeInterception1.ID,
|
||||
InputTokens: 100,
|
||||
OutputTokens: 200,
|
||||
CacheReadInputTokens: 300,
|
||||
CacheWriteInputTokens: 400,
|
||||
Metadata: json.RawMessage(`{"cache_read_input":300,"cache_creation_input":400}`),
|
||||
})
|
||||
_ = dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{
|
||||
InterceptionID: aiBridgeInterception1.ID,
|
||||
@@ -248,10 +250,12 @@ func TestTelemetry(t *testing.T) {
|
||||
StartedAt: aiBridgeInterception1.StartedAt,
|
||||
}, nil)
|
||||
_ = dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{
|
||||
InterceptionID: aiBridgeInterception2.ID,
|
||||
InputTokens: 100,
|
||||
OutputTokens: 200,
|
||||
Metadata: json.RawMessage(`{"cache_read_input":300,"cache_creation_input":400}`),
|
||||
InterceptionID: aiBridgeInterception2.ID,
|
||||
InputTokens: 100,
|
||||
OutputTokens: 200,
|
||||
CacheReadInputTokens: 300,
|
||||
CacheWriteInputTokens: 400,
|
||||
Metadata: json.RawMessage(`{"cache_read_input":300,"cache_creation_input":400}`),
|
||||
})
|
||||
_ = dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{
|
||||
InterceptionID: aiBridgeInterception2.ID,
|
||||
|
||||
+12
-1
@@ -281,8 +281,19 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) {
|
||||
telemetryUser := telemetry.ConvertUser(user)
|
||||
// Send the initial users email address!
|
||||
telemetryUser.Email = &user.Email
|
||||
// Only populate onboarding data when the client actually sent it. A nil
|
||||
// OnboardingInfo means the request came from an older client, the CLI, or
|
||||
// the OIDC flow — not from a user who answered "no" to every question.
|
||||
var onboarding *telemetry.FirstUserOnboarding
|
||||
if createUser.OnboardingInfo != nil {
|
||||
onboarding = &telemetry.FirstUserOnboarding{
|
||||
NewsletterMarketing: createUser.OnboardingInfo.NewsletterMarketing,
|
||||
NewsletterReleases: createUser.OnboardingInfo.NewsletterReleases,
|
||||
}
|
||||
}
|
||||
api.Telemetry.Report(&telemetry.Snapshot{
|
||||
Users: []telemetry.User{telemetryUser},
|
||||
Users: []telemetry.User{telemetryUser},
|
||||
FirstUserOnboarding: onboarding,
|
||||
})
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, codersdk.CreateFirstUserResponse{
|
||||
|
||||
@@ -116,6 +116,77 @@ func TestFirstUser(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestFirstUser_OnboardingTelemetry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("OnboardingInfoFlowsToSnapshot", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
fTelemetry := newFakeTelemetryReporter(ctx, t, 10)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
TelemetryReporter: fTelemetry,
|
||||
})
|
||||
|
||||
_, err := client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{
|
||||
Email: "admin@coder.com",
|
||||
Username: "admin",
|
||||
Password: "SomeSecurePassword!",
|
||||
OnboardingInfo: &codersdk.CreateFirstUserOnboardingInfo{
|
||||
NewsletterMarketing: false,
|
||||
NewsletterReleases: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
snapshot := testutil.TryReceive(ctx, t, fTelemetry.snapshots)
|
||||
require.NotNil(t, snapshot.FirstUserOnboarding)
|
||||
require.False(t, snapshot.FirstUserOnboarding.NewsletterMarketing)
|
||||
require.True(t, snapshot.FirstUserOnboarding.NewsletterReleases)
|
||||
})
|
||||
|
||||
t.Run("NilWhenOnboardingInfoOmitted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
fTelemetry := newFakeTelemetryReporter(ctx, t, 10)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
TelemetryReporter: fTelemetry,
|
||||
})
|
||||
|
||||
_, err := client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{
|
||||
Email: "admin@coder.com",
|
||||
Username: "admin",
|
||||
Password: "SomeSecurePassword!",
|
||||
// No OnboardingInfo — simulates old CLI or OIDC flow.
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
snapshot := testutil.TryReceive(ctx, t, fTelemetry.snapshots)
|
||||
require.Nil(t, snapshot.FirstUserOnboarding)
|
||||
})
|
||||
|
||||
t.Run("EmptyOnboardingInfoIsNonNilWithZeroFields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
fTelemetry := newFakeTelemetryReporter(ctx, t, 10)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
TelemetryReporter: fTelemetry,
|
||||
})
|
||||
_, err := client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{
|
||||
Email: "admin@coder.com", Username: "admin",
|
||||
Password: "SomeSecurePassword!",
|
||||
OnboardingInfo: &codersdk.CreateFirstUserOnboardingInfo{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
snapshot := testutil.TryReceive(ctx, t, fTelemetry.snapshots)
|
||||
require.NotNil(t, snapshot.FirstUserOnboarding,
|
||||
"non-nil OnboardingInfo must produce non-nil telemetry")
|
||||
require.False(t, snapshot.FirstUserOnboarding.NewsletterMarketing)
|
||||
require.False(t, snapshot.FirstUserOnboarding.NewsletterReleases)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostLogin(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("InvalidUser", func(t *testing.T) {
|
||||
|
||||
+100
-3
@@ -1465,7 +1465,9 @@ func (api *API) workspaceAgentPostLogSource(rw http.ResponseWriter, r *http.Requ
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Tags Agents
|
||||
// @Param wait query bool false "Opt in to durable reinit checks"
|
||||
// @Success 200 {object} agentsdk.ReinitializationEvent
|
||||
// @Failure 409 {object} codersdk.Response
|
||||
// @Router /workspaceagents/me/reinit [get]
|
||||
func (api *API) workspaceAgentReinit(rw http.ResponseWriter, r *http.Request) {
|
||||
// Allow us to interrupt watch via cancel.
|
||||
@@ -1482,18 +1484,113 @@ func (api *API) workspaceAgentReinit(rw http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
log.Error(ctx, "failed to retrieve workspace from agent token", slog.Error(err))
|
||||
httpapi.InternalServerError(rw, xerrors.New("failed to determine workspace from agent token"))
|
||||
return
|
||||
}
|
||||
log = log.With(slog.F("workspace_id", workspace.ID))
|
||||
|
||||
log.Info(ctx, "agent waiting for reinit instruction")
|
||||
|
||||
reinitEvents := make(chan agentsdk.ReinitializationEvent)
|
||||
cancel, err = prebuilds.NewPubsubWorkspaceClaimListener(api.Pubsub, log).ListenForWorkspaceClaims(ctx, workspace.ID, reinitEvents)
|
||||
// Subscribe to claim events BEFORE any durable checks to avoid a
|
||||
// TOCTOU race: without this, a claim could fire between the
|
||||
// IsPrebuild() check and the subscribe call, and we'd miss the
|
||||
// pubsub event entirely. By subscribing first, any event that
|
||||
// fires during the checks below is buffered in the channel.
|
||||
pubsubCh, cancelSub, err := prebuilds.NewPubsubWorkspaceClaimListener(api.Pubsub, log).ListenForWorkspaceClaims(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
log.Error(ctx, "subscribe to prebuild claimed channel", slog.Error(err))
|
||||
httpapi.InternalServerError(rw, xerrors.New("failed to subscribe to prebuild claimed channel"))
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
defer cancelSub()
|
||||
|
||||
reinitEvents := pubsubCh
|
||||
|
||||
// Only perform the durable claim check when the agent opts in via
|
||||
// the "wait" query parameter. Older agents don't send the
|
||||
// "wait" query parameter and lack the duplicate-reinit guard, so
|
||||
// they would enter an infinite reinit loop if we pre-seeded the
|
||||
// channel on every connection.
|
||||
waitParam, _ := strconv.ParseBool(r.URL.Query().Get("wait"))
|
||||
if waitParam && !workspace.IsPrebuild() {
|
||||
firstBuild, err := api.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx,
|
||||
database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
|
||||
WorkspaceID: workspace.ID,
|
||||
BuildNumber: 1,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(ctx, "failed to get first workspace build", slog.Error(err))
|
||||
httpapi.InternalServerError(rw, xerrors.New("failed to get first workspace build"))
|
||||
return
|
||||
}
|
||||
if firstBuild.InitiatorID != database.PrebuildsSystemUserID {
|
||||
// Not a claimed prebuild — this is a regular workspace.
|
||||
// Return 409 so the agent stops reconnecting to this
|
||||
// endpoint.
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Workspace is not a prebuilt workspace waiting to be claimed.",
|
||||
Detail: "This endpoint is only for agents running in prebuilt workspaces.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// This workspace was a prebuild that got claimed. Check if
|
||||
// the claim build completed successfully before sending
|
||||
// reinit. We assume the latest build is the claim build
|
||||
// (build 2). If a third build (e.g. a restart) starts
|
||||
// between the claim and the agent's reconnection, this
|
||||
// would check that build instead. The window is extremely
|
||||
// small in practice, and a restart would trigger its own
|
||||
// reinit path.
|
||||
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
log.Error(ctx, "failed to get latest workspace build", slog.Error(err))
|
||||
httpapi.InternalServerError(rw, xerrors.New("failed to get latest workspace build"))
|
||||
return
|
||||
}
|
||||
job, err := api.Database.GetProvisionerJobByID(ctx, latestBuild.JobID)
|
||||
if err != nil {
|
||||
log.Error(ctx, "failed to get provisioner job", slog.Error(err))
|
||||
httpapi.InternalServerError(rw, xerrors.New("failed to get provisioner job"))
|
||||
return
|
||||
}
|
||||
|
||||
if job.CompletedAt.Valid && !job.Error.Valid {
|
||||
// Claim build succeeded — cancel the pubsub
|
||||
// subscription (no longer needed) and swap in a
|
||||
// pre-seeded channel so the transmitter delivers
|
||||
// exactly one reinit event.
|
||||
cancelSub()
|
||||
seeded := make(chan agentsdk.ReinitializationEvent, 1)
|
||||
seeded <- agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: workspace.ID,
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
OwnerID: workspace.OwnerID,
|
||||
}
|
||||
reinitEvents = seeded
|
||||
} else if job.CompletedAt.Valid && job.Error.Valid {
|
||||
// Claim build failed permanently. Return 409 so the
|
||||
// agent treats this as terminal and stops retrying
|
||||
// (WaitForReinitLoop exits on any 409).
|
||||
cancelSub()
|
||||
log.Warn(ctx, "claim build failed",
|
||||
slog.F("job_id", job.ID),
|
||||
slog.F("error", job.Error.String))
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Claim build failed permanently.",
|
||||
Detail: job.Error.String,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Claim build still in progress — fall through to the
|
||||
// transmitter. The pubsub subscription (set up above)
|
||||
// will deliver the event when the build completes
|
||||
// successfully. Note: FailJob does not publish a claim
|
||||
// event, so a failed in-progress build will leave the
|
||||
// agent blocking here until it disconnects and
|
||||
// reconnects (at which point the durable check above
|
||||
// handles it).
|
||||
}
|
||||
|
||||
transmitter := agentsdk.NewSSEAgentReinitTransmitter(log, rw, r)
|
||||
|
||||
|
||||
+190
-35
@@ -2,6 +2,7 @@ package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -3278,51 +3279,205 @@ func TestAgentConnectionInfo(t *testing.T) {
|
||||
func TestReinit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
pubsubSpy := pubsubReinitSpy{
|
||||
Pubsub: ps,
|
||||
triedToSubscribe: make(chan string),
|
||||
// Helper to create the prebuilds system user's workspace (an
|
||||
// unclaimed prebuild) and return the build result. The first
|
||||
// build's InitiatorID defaults to PrebuildsSystemUserID via
|
||||
// dbfake.
|
||||
setupPrebuildWorkspace := func(t *testing.T, db database.Store, orgID uuid.UUID) dbfake.WorkspaceResponse {
|
||||
t.Helper()
|
||||
return dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: orgID,
|
||||
OwnerID: database.PrebuildsSystemUserID,
|
||||
}).WithAgent().Do()
|
||||
}
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: &pubsubSpy,
|
||||
|
||||
// Helper to simulate claiming a prebuild: change the workspace
|
||||
// owner to the real user and create a second (claim) build.
|
||||
claimPrebuild := func(t *testing.T, db database.Store, sqlDB *sql.DB, ws database.WorkspaceTable, claimerID uuid.UUID, templateVersionID uuid.UUID, complete bool) dbfake.WorkspaceResponse {
|
||||
t.Helper()
|
||||
// Change the workspace owner to the claiming user.
|
||||
_, err := sqlDB.Exec("UPDATE workspaces SET owner_id = $1 WHERE id = $2", claimerID, ws.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update the in-memory workspace to reflect the new owner
|
||||
// so that dbfake uses it for the second build.
|
||||
ws.OwnerID = claimerID
|
||||
|
||||
builder := dbfake.WorkspaceBuild(t, db, ws).
|
||||
Seed(database.WorkspaceBuild{
|
||||
TemplateVersionID: templateVersionID,
|
||||
BuildNumber: 2,
|
||||
InitiatorID: claimerID,
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
}).
|
||||
WithAgent()
|
||||
if !complete {
|
||||
builder = builder.Starting()
|
||||
}
|
||||
return builder.Do()
|
||||
}
|
||||
|
||||
t.Run("unclaimed prebuild receives reinit via pubsub", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
pubsubSpy := pubsubReinitSpy{
|
||||
Pubsub: ps,
|
||||
triedToSubscribe: make(chan string),
|
||||
}
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: &pubsubSpy,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
r := setupPrebuildWorkspace(t, db, user.OrganizationID)
|
||||
|
||||
pubsubSpy.Lock()
|
||||
pubsubSpy.expectedEvent = agentsdk.PrebuildClaimedChannel(r.Workspace.ID)
|
||||
pubsubSpy.Unlock()
|
||||
|
||||
agentCtx := testutil.Context(t, testutil.WaitShort)
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
|
||||
|
||||
agentReinitializedCh := make(chan *agentsdk.ReinitializationEvent)
|
||||
go func() {
|
||||
reinitEvent, err := agentClient.WaitForReinit(agentCtx)
|
||||
assert.NoError(t, err)
|
||||
agentReinitializedCh <- reinitEvent
|
||||
}()
|
||||
|
||||
// We need to subscribe before we publish, lest we miss the
|
||||
// event.
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
testutil.TryReceive(ctx, t, pubsubSpy.triedToSubscribe)
|
||||
|
||||
// Now that we're subscribed, publish the event.
|
||||
err := prebuilds.NewPubsubWorkspaceClaimPublisher(ps).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: r.Workspace.ID,
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
reinitEvent := testutil.TryReceive(ctx, t, agentReinitializedCh)
|
||||
require.NotNil(t, reinitEvent)
|
||||
require.Equal(t, r.Workspace.ID, reinitEvent.WorkspaceID)
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent().Do()
|
||||
// Verifies the durable claim check: when an agent reconnects
|
||||
// after missing the pubsub event, the handler detects that the
|
||||
// workspace was originally a prebuild (first build initiated by
|
||||
// PrebuildsSystemUserID), is now claimed (owner changed), and
|
||||
// the claim build completed, so it sends a one-shot reinit
|
||||
// event immediately.
|
||||
t.Run("claimed prebuild receives one-shot reinit on reconnect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pubsubSpy.Lock()
|
||||
pubsubSpy.expectedEvent = agentsdk.PrebuildClaimedChannel(r.Workspace.ID)
|
||||
pubsubSpy.Unlock()
|
||||
db, ps, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: ps,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
agentCtx := testutil.Context(t, testutil.WaitShort)
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
|
||||
// Create an unclaimed prebuild (build 1, completed).
|
||||
r := setupPrebuildWorkspace(t, db, user.OrganizationID)
|
||||
|
||||
agentReinitializedCh := make(chan *agentsdk.ReinitializationEvent)
|
||||
go func() {
|
||||
reinitEvent, err := agentClient.WaitForReinit(agentCtx)
|
||||
assert.NoError(t, err)
|
||||
agentReinitializedCh <- reinitEvent
|
||||
}()
|
||||
// Claim it: change owner + create build 2 (completed).
|
||||
claimR := claimPrebuild(t, db, sqlDB, r.Workspace, user.UserID, r.TemplateVersion.ID, true)
|
||||
|
||||
// We need to subscribe before we publish, lest we miss the event
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
testutil.TryReceive(ctx, t, pubsubSpy.triedToSubscribe)
|
||||
agentCtx := testutil.Context(t, testutil.WaitShort)
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(claimR.AgentToken))
|
||||
|
||||
// Now that we're subscribed, publish the event
|
||||
err := prebuilds.NewPubsubWorkspaceClaimPublisher(ps).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: r.Workspace.ID,
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
agentReinitializedCh := make(chan *agentsdk.ReinitializationEvent)
|
||||
go func() {
|
||||
reinitEvent, err := agentClient.WaitForReinit(agentCtx)
|
||||
assert.NoError(t, err)
|
||||
agentReinitializedCh <- reinitEvent
|
||||
}()
|
||||
|
||||
// The agent should receive a reinit event immediately from
|
||||
// the durable claim check — no pubsub publish needed.
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
reinitEvent := testutil.TryReceive(ctx, t, agentReinitializedCh)
|
||||
require.NotNil(t, reinitEvent)
|
||||
require.Equal(t, r.Workspace.ID, reinitEvent.WorkspaceID)
|
||||
require.Equal(t, agentsdk.ReinitializeReasonPrebuildClaimed, reinitEvent.Reason)
|
||||
require.Equal(t, user.UserID, reinitEvent.OwnerID)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
reinitEvent := testutil.TryReceive(ctx, t, agentReinitializedCh)
|
||||
require.NotNil(t, reinitEvent)
|
||||
require.Equal(t, r.Workspace.ID, reinitEvent.WorkspaceID)
|
||||
// Verifies that when the claim build completed with an error,
|
||||
// the handler returns 409 so the agent treats it as terminal
|
||||
// and stops retrying (WaitForReinitLoop exits on any 409).
|
||||
t.Run("failed claim build returns terminal 409", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: ps,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create an unclaimed prebuild (build 1, completed).
|
||||
r := setupPrebuildWorkspace(t, db, user.OrganizationID)
|
||||
|
||||
// Claim it: create build 2 as completed (so agent rows
|
||||
// exist and the token is valid for auth).
|
||||
claimR := claimPrebuild(t, db, sqlDB, r.Workspace, user.UserID, r.TemplateVersion.ID, true)
|
||||
|
||||
// Simulate a claim build failure: set an error on the
|
||||
// provisioner job. This models the case where terraform
|
||||
// apply partially succeeded (creating resources/agents)
|
||||
// but ultimately errored.
|
||||
_, err := sqlDB.Exec(
|
||||
"UPDATE provisioner_jobs SET error = 'simulated claim failure' WHERE id = $1",
|
||||
claimR.Build.JobID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
agentCtx := testutil.Context(t, testutil.WaitShort)
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(claimR.AgentToken))
|
||||
|
||||
_, err = agentClient.WaitForReinit(agentCtx)
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
// Verifies that a regular workspace (never a prebuild) gets a
|
||||
// 409 Conflict response, causing the agent's reinit loop to
|
||||
// close the channel gracefully.
|
||||
t.Run("regular workspace gets 409", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: ps,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create a regular workspace (not a prebuild). The first
|
||||
// build's initiator will be the user, not the prebuilds
|
||||
// system user.
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent().Do()
|
||||
|
||||
agentCtx := testutil.Context(t, testutil.WaitShort)
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
|
||||
|
||||
// WaitForReinit should return an error wrapping a 409.
|
||||
_, err := agentClient.WaitForReinit(agentCtx)
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
type pubsubReinitSpy struct {
|
||||
|
||||
@@ -535,7 +535,7 @@ func (p *DBTokenProvider) connLogInitRequest(w http.ResponseWriter, r *http.Requ
|
||||
Int32: statusCode,
|
||||
Valid: true,
|
||||
},
|
||||
Ip: database.ParseIP(ip),
|
||||
IP: database.ParseIP(ip),
|
||||
UserAgent: sql.NullString{Valid: userAgent != "", String: userAgent},
|
||||
UserID: uuid.NullUUID{
|
||||
UUID: userID,
|
||||
|
||||
@@ -1281,7 +1281,7 @@ func assertConnLogContains(t *testing.T, rr *httptest.ResponseRecorder, r *http.
|
||||
WorkspaceName: workspace.Name,
|
||||
AgentName: agentName,
|
||||
Type: typ,
|
||||
Ip: database.ParseIP(r.RemoteAddr),
|
||||
IP: database.ParseIP(r.RemoteAddr),
|
||||
UserAgent: sql.NullString{Valid: r.UserAgent() != "", String: r.UserAgent()},
|
||||
Code: sql.NullInt32{
|
||||
Int32: int32(resp.StatusCode), // nolint:gosec
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user