Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| aefc75133a | |||
| b9181c3934 | |||
| a90471db53 | |||
| cb71f5e789 | |||
| f50707bc3e |
@@ -4,7 +4,7 @@ description: |
|
||||
inputs:
|
||||
version:
|
||||
description: "The Go version to use."
|
||||
default: "1.25.8"
|
||||
default: "1.25.7"
|
||||
use-preinstalled-go:
|
||||
description: "Whether to use preinstalled Go."
|
||||
default: "false"
|
||||
|
||||
+17
-17
@@ -35,7 +35,7 @@ jobs:
|
||||
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -157,7 +157,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -247,7 +247,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -272,7 +272,7 @@ jobs:
|
||||
if: ${{ !cancelled() }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -329,7 +329,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -381,7 +381,7 @@ jobs:
|
||||
- windows-2022
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -586,7 +586,7 @@ jobs:
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -648,7 +648,7 @@ jobs:
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -720,7 +720,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -747,7 +747,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -780,7 +780,7 @@ jobs:
|
||||
name: ${{ matrix.variant.name }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -860,7 +860,7 @@ jobs:
|
||||
if: needs.changes.outputs.site == 'true' || needs.changes.outputs.ci == 'true'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -941,7 +941,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1013,7 +1013,7 @@ jobs:
|
||||
if: always()
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1128,7 +1128,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1183,7 +1183,7 @@ jobs:
|
||||
IMAGE: ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1580,7 +1580,7 @@ jobs:
|
||||
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
verdict: ${{ steps.check.outputs.verdict }} # DEPLOY or NOOP
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -65,7 +65,7 @@ jobs:
|
||||
packages: write # to retag image as dogfood
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -146,7 +146,7 @@ jobs:
|
||||
needs: deploy
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
if: github.repository_owner == 'coder'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -125,7 +125,7 @@ jobs:
|
||||
id-token: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
- windows-2022
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ jobs:
|
||||
packages: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ jobs:
|
||||
PR_OPEN: ${{ steps.check_pr.outputs.pr_open }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -76,7 +76,7 @@ jobs:
|
||||
runs-on: "ubuntu-latest"
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -184,7 +184,7 @@ jobs:
|
||||
pull-requests: write # needed for commenting on PRs
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -228,7 +228,7 @@ jobs:
|
||||
CODER_IMAGE_TAG: ${{ needs.get_info.outputs.CODER_IMAGE_TAG }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -288,7 +288,7 @@ jobs:
|
||||
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -158,7 +158,7 @@ jobs:
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -796,7 +796,7 @@ jobs:
|
||||
# TODO: skip this if it's not a new release (i.e. a backport). This is
|
||||
# fine right now because it just makes a PR that we can close.
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -872,7 +872,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -965,7 +965,7 @@ jobs:
|
||||
if: ${{ !inputs.dry_run }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -63,3 +63,116 @@ jobs:
|
||||
--data "{\"content\": \"$msg\"}" \
|
||||
"${{ secrets.SLACK_SECURITY_FAILURE_WEBHOOK_URL }}"
|
||||
|
||||
trivy:
|
||||
permissions:
|
||||
security-events: write
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Go
|
||||
uses: ./.github/actions/setup-go
|
||||
|
||||
- name: Setup Node
|
||||
uses: ./.github/actions/setup-node
|
||||
|
||||
- name: Setup sqlc
|
||||
uses: ./.github/actions/setup-sqlc
|
||||
|
||||
- name: Install cosign
|
||||
uses: ./.github/actions/install-cosign
|
||||
|
||||
- name: Install syft
|
||||
uses: ./.github/actions/install-syft
|
||||
|
||||
- name: Install yq
|
||||
run: go run github.com/mikefarah/yq/v4@v4.44.3
|
||||
- name: Install mockgen
|
||||
run: ./.github/scripts/retry.sh -- go install go.uber.org/mock/mockgen@v0.6.0
|
||||
- name: Install protoc-gen-go
|
||||
run: ./.github/scripts/retry.sh -- go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.30
|
||||
- name: Install protoc-gen-go-drpc
|
||||
run: ./.github/scripts/retry.sh -- go install storj.io/drpc/cmd/protoc-gen-go-drpc@v0.0.34
|
||||
- name: Install Protoc
|
||||
run: |
|
||||
# protoc must be in lockstep with our dogfood Dockerfile or the
|
||||
# version in the comments will differ. This is also defined in
|
||||
# ci.yaml.
|
||||
set -euxo pipefail
|
||||
cd dogfood/coder
|
||||
mkdir -p /usr/local/bin
|
||||
mkdir -p /usr/local/include
|
||||
|
||||
DOCKER_BUILDKIT=1 docker build . --target proto -t protoc
|
||||
protoc_path=/usr/local/bin/protoc
|
||||
docker run --rm --entrypoint cat protoc /tmp/bin/protoc > $protoc_path
|
||||
chmod +x $protoc_path
|
||||
protoc --version
|
||||
# Copy the generated files to the include directory.
|
||||
docker run --rm -v /usr/local/include:/target protoc cp -r /tmp/include/google /target/
|
||||
ls -la /usr/local/include/google/protobuf/
|
||||
stat /usr/local/include/google/protobuf/timestamp.proto
|
||||
|
||||
- name: Build Coder linux amd64 Docker image
|
||||
id: build
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
version="$(./scripts/version.sh)"
|
||||
image_job="build/coder_${version}_linux_amd64.tag"
|
||||
|
||||
# This environment variable force make to not build packages and
|
||||
# archives (which the Docker image depends on due to technical reasons
|
||||
# related to concurrent FS writes).
|
||||
export DOCKER_IMAGE_NO_PREREQUISITES=true
|
||||
# This environment variables forces scripts/build_docker.sh to build
|
||||
# the base image tag locally instead of using the cached version from
|
||||
# the registry.
|
||||
CODER_IMAGE_BUILD_BASE_TAG="$(CODER_IMAGE_BASE=coder-base ./scripts/image_tag.sh --version "$version")"
|
||||
export CODER_IMAGE_BUILD_BASE_TAG
|
||||
|
||||
# We would like to use make -j here, but it doesn't work with the some recent additions
|
||||
# to our code generation.
|
||||
make "$image_job"
|
||||
echo "image=$(cat "$image_job")" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@c1824fd6edce30d7ab345a9989de00bbd46ef284 # v0.34.0
|
||||
with:
|
||||
image-ref: ${{ steps.build.outputs.image }}
|
||||
format: sarif
|
||||
output: trivy-results.sarif
|
||||
severity: "CRITICAL,HIGH"
|
||||
|
||||
- name: Upload Trivy scan results to GitHub Security tab
|
||||
uses: github/codeql-action/upload-sarif@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
|
||||
with:
|
||||
sarif_file: trivy-results.sarif
|
||||
category: "Trivy"
|
||||
|
||||
- name: Upload Trivy scan results as an artifact
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: trivy
|
||||
path: trivy-results.sarif
|
||||
retention-days: 7
|
||||
|
||||
- name: Send Slack notification on failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
msg="❌ Trivy Failed\n\nhttps://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
|
||||
curl \
|
||||
-qfsSL \
|
||||
-X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
--data "{\"content\": \"$msg\"}" \
|
||||
"${{ secrets.SLACK_SECURITY_FAILURE_WEBHOOK_URL }}"
|
||||
|
||||
@@ -18,12 +18,12 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: stale
|
||||
uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
|
||||
uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
|
||||
with:
|
||||
stale-issue-label: "stale"
|
||||
stale-pr-label: "stale"
|
||||
@@ -96,7 +96,7 @@ jobs:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -120,7 +120,7 @@ jobs:
|
||||
actions: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ jobs:
|
||||
pull-requests: write # required to post PR review comments by the action
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -3040,62 +3040,6 @@ func TestAgent_Reconnect(t *testing.T) {
|
||||
closer.Close()
|
||||
}
|
||||
|
||||
func TestAgent_ReconnectNoLifecycleReemit(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
fCoordinator := tailnettest.NewFakeCoordinator()
|
||||
agentID := uuid.New()
|
||||
statsCh := make(chan *proto.Stats, 50)
|
||||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||||
|
||||
client := agenttest.NewClient(t,
|
||||
logger,
|
||||
agentID,
|
||||
agentsdk.Manifest{
|
||||
DERPMap: derpMap,
|
||||
Scripts: []codersdk.WorkspaceAgentScript{{
|
||||
Script: "echo hello",
|
||||
Timeout: 30 * time.Second,
|
||||
RunOnStart: true,
|
||||
}},
|
||||
},
|
||||
statsCh,
|
||||
fCoordinator,
|
||||
)
|
||||
defer client.Close()
|
||||
|
||||
closer := agent.New(agent.Options{
|
||||
Client: client,
|
||||
Logger: logger.Named("agent"),
|
||||
})
|
||||
defer closer.Close()
|
||||
|
||||
// Wait for the agent to reach Ready state.
|
||||
require.Eventually(t, func() bool {
|
||||
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
statesBefore := slices.Clone(client.GetLifecycleStates())
|
||||
|
||||
// Disconnect by closing the coordinator response channel.
|
||||
call1 := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
|
||||
close(call1.Resps)
|
||||
|
||||
// Wait for reconnect.
|
||||
testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
|
||||
|
||||
// Wait for a stats report as a deterministic steady-state proof.
|
||||
testutil.RequireReceive(ctx, t, statsCh)
|
||||
|
||||
statesAfter := client.GetLifecycleStates()
|
||||
require.Equal(t, statesBefore, statesAfter,
|
||||
"lifecycle states should not be re-reported after reconnect")
|
||||
|
||||
closer.Close()
|
||||
}
|
||||
|
||||
func TestAgent_WriteVSCodeConfigs(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
@@ -235,10 +235,6 @@ type FakeAgentAPI struct {
|
||||
pushResourcesMonitoringUsageFunc func(*agentproto.PushResourcesMonitoringUsageRequest) (*agentproto.PushResourcesMonitoringUsageResponse, error)
|
||||
}
|
||||
|
||||
func (*FakeAgentAPI) UpdateAppStatus(context.Context, *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (f *FakeAgentAPI) GetManifest(context.Context, *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
|
||||
return f.manifest, nil
|
||||
}
|
||||
|
||||
+330
-544
File diff suppressed because it is too large
Load Diff
+1
-20
@@ -436,7 +436,7 @@ message CreateSubAgentRequest {
|
||||
}
|
||||
|
||||
repeated DisplayApp display_apps = 6;
|
||||
|
||||
|
||||
optional bytes id = 7;
|
||||
}
|
||||
|
||||
@@ -494,24 +494,6 @@ message ReportBoundaryLogsRequest {
|
||||
|
||||
message ReportBoundaryLogsResponse {}
|
||||
|
||||
// UpdateAppStatusRequest updates the given Workspace App's status. c.f. agentsdk.PatchAppStatus
|
||||
message UpdateAppStatusRequest {
|
||||
string slug = 1;
|
||||
|
||||
enum AppStatusState {
|
||||
WORKING = 0;
|
||||
IDLE = 1;
|
||||
COMPLETE = 2;
|
||||
FAILURE = 3;
|
||||
}
|
||||
AppStatusState state = 2;
|
||||
|
||||
string message = 3;
|
||||
string uri = 4;
|
||||
}
|
||||
|
||||
message UpdateAppStatusResponse {}
|
||||
|
||||
service Agent {
|
||||
rpc GetManifest(GetManifestRequest) returns (Manifest);
|
||||
rpc GetServiceBanner(GetServiceBannerRequest) returns (ServiceBanner);
|
||||
@@ -530,5 +512,4 @@ service Agent {
|
||||
rpc DeleteSubAgent(DeleteSubAgentRequest) returns (DeleteSubAgentResponse);
|
||||
rpc ListSubAgents(ListSubAgentsRequest) returns (ListSubAgentsResponse);
|
||||
rpc ReportBoundaryLogs(ReportBoundaryLogsRequest) returns (ReportBoundaryLogsResponse);
|
||||
rpc UpdateAppStatus(UpdateAppStatusRequest) returns (UpdateAppStatusResponse);
|
||||
}
|
||||
|
||||
@@ -56,7 +56,6 @@ type DRPCAgentClient interface {
|
||||
DeleteSubAgent(ctx context.Context, in *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
|
||||
ListSubAgents(ctx context.Context, in *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
|
||||
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
|
||||
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
|
||||
}
|
||||
|
||||
type drpcAgentClient struct {
|
||||
@@ -222,15 +221,6 @@ func (c *drpcAgentClient) ReportBoundaryLogs(ctx context.Context, in *ReportBoun
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentClient) UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
|
||||
out := new(UpdateAppStatusResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
type DRPCAgentServer interface {
|
||||
GetManifest(context.Context, *GetManifestRequest) (*Manifest, error)
|
||||
GetServiceBanner(context.Context, *GetServiceBannerRequest) (*ServiceBanner, error)
|
||||
@@ -249,7 +239,6 @@ type DRPCAgentServer interface {
|
||||
DeleteSubAgent(context.Context, *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
|
||||
ListSubAgents(context.Context, *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
|
||||
ReportBoundaryLogs(context.Context, *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
|
||||
UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
|
||||
}
|
||||
|
||||
type DRPCAgentUnimplementedServer struct{}
|
||||
@@ -322,13 +311,9 @@ func (s *DRPCAgentUnimplementedServer) ReportBoundaryLogs(context.Context, *Repo
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentUnimplementedServer) UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
type DRPCAgentDescription struct{}
|
||||
|
||||
func (DRPCAgentDescription) NumMethods() int { return 18 }
|
||||
func (DRPCAgentDescription) NumMethods() int { return 17 }
|
||||
|
||||
func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
|
||||
switch n {
|
||||
@@ -485,15 +470,6 @@ func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver,
|
||||
in1.(*ReportBoundaryLogsRequest),
|
||||
)
|
||||
}, DRPCAgentServer.ReportBoundaryLogs, true
|
||||
case 17:
|
||||
return "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentServer).
|
||||
UpdateAppStatus(
|
||||
ctx,
|
||||
in1.(*UpdateAppStatusRequest),
|
||||
)
|
||||
}, DRPCAgentServer.UpdateAppStatus, true
|
||||
default:
|
||||
return "", nil, nil, nil, false
|
||||
}
|
||||
@@ -774,19 +750,3 @@ func (x *drpcAgent_ReportBoundaryLogsStream) SendAndClose(m *ReportBoundaryLogsR
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgent_UpdateAppStatusStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*UpdateAppStatusResponse) error
|
||||
}
|
||||
|
||||
type drpcAgent_UpdateAppStatusStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgent_UpdateAppStatusStream) SendAndClose(m *UpdateAppStatusResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_proto_agent_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
@@ -73,13 +73,9 @@ type DRPCAgentClient27 interface {
|
||||
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
|
||||
}
|
||||
|
||||
// DRPCAgentClient28 is the Agent API at v2.8. It adds
|
||||
// - a SubagentId field to the WorkspaceAgentDevcontainer message
|
||||
// - an Id field to the CreateSubAgentRequest message.
|
||||
// - UpdateAppStatus RPC.
|
||||
//
|
||||
// Compatible with Coder v2.31+
|
||||
// DRPCAgentClient28 is the Agent API at v2.8. It adds a SubagentId field to the
|
||||
// WorkspaceAgentDevcontainer message, and a Id field to the CreateSubAgentRequest
|
||||
// message. Compatible with Coder v2.31+
|
||||
type DRPCAgentClient28 interface {
|
||||
DRPCAgentClient27
|
||||
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
|
||||
}
|
||||
|
||||
+45
-50
@@ -10,7 +10,6 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
@@ -24,7 +23,6 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/codersdk/toolsdk"
|
||||
"github.com/coder/retry"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
@@ -541,6 +539,7 @@ func (r *RootCmd) mcpServer() *serpent.Command {
|
||||
defer cancel()
|
||||
defer srv.queue.Close()
|
||||
|
||||
cliui.Infof(inv.Stderr, "Failed to watch screen events")
|
||||
// Start the reporter, watcher, and server. These are all tied to the
|
||||
// lifetime of the MCP server, which is itself tied to the lifetime of the
|
||||
// AI agent.
|
||||
@@ -614,51 +613,48 @@ func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation)
|
||||
}
|
||||
|
||||
func (s *mcpServer) startWatcher(ctx context.Context, inv *serpent.Invocation) {
|
||||
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
for retrier := retry.New(time.Second, 30*time.Second); retrier.Wait(ctx); {
|
||||
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
|
||||
if err == nil {
|
||||
retrier.Reset()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case event := <-eventsCh:
|
||||
switch ev := event.(type) {
|
||||
case agentapi.EventStatusChange:
|
||||
// If the screen is stable, report idle.
|
||||
state := codersdk.WorkspaceAppStatusStateWorking
|
||||
if ev.Status == agentapi.StatusStable {
|
||||
state = codersdk.WorkspaceAppStatusStateIdle
|
||||
}
|
||||
err := s.queue.Push(taskReport{
|
||||
state: state,
|
||||
})
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
|
||||
return
|
||||
case event := <-eventsCh:
|
||||
switch ev := event.(type) {
|
||||
case agentapi.EventStatusChange:
|
||||
state := codersdk.WorkspaceAppStatusStateWorking
|
||||
if ev.Status == agentapi.StatusStable {
|
||||
state = codersdk.WorkspaceAppStatusStateIdle
|
||||
}
|
||||
err := s.queue.Push(taskReport{
|
||||
state: state,
|
||||
})
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
|
||||
return
|
||||
}
|
||||
case agentapi.EventMessageUpdate:
|
||||
if ev.Role == agentapi.RoleUser {
|
||||
err := s.queue.Push(taskReport{
|
||||
messageID: &ev.Id,
|
||||
state: codersdk.WorkspaceAppStatusStateWorking,
|
||||
})
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
case agentapi.EventMessageUpdate:
|
||||
if ev.Role == agentapi.RoleUser {
|
||||
err := s.queue.Push(taskReport{
|
||||
messageID: &ev.Id,
|
||||
state: codersdk.WorkspaceAppStatusStateWorking,
|
||||
})
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
|
||||
return
|
||||
}
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
|
||||
}
|
||||
break loop
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -696,14 +692,13 @@ func (s *mcpServer) startServer(ctx context.Context, inv *serpent.Invocation, in
|
||||
// Add tool dependencies.
|
||||
toolOpts := []func(*toolsdk.Deps){
|
||||
toolsdk.WithTaskReporter(func(args toolsdk.ReportTaskArgs) error {
|
||||
state := codersdk.WorkspaceAppStatusState(args.State)
|
||||
// The agent does not reliably report idle, so when AgentAPI is
|
||||
// enabled we override idle to working and let the screen watcher
|
||||
// detect the real idle via StatusStable. Final states (failure,
|
||||
// complete) are trusted from the agent since the screen watcher
|
||||
// cannot produce them.
|
||||
if s.aiAgentAPIClient != nil && state == codersdk.WorkspaceAppStatusStateIdle {
|
||||
state = codersdk.WorkspaceAppStatusStateWorking
|
||||
// The agent does not reliably report its status correctly. If AgentAPI
|
||||
// is enabled, we will always set the status to "working" when we get an
|
||||
// MCP message, and rely on the screen watcher to eventually catch the
|
||||
// idle state.
|
||||
state := codersdk.WorkspaceAppStatusStateWorking
|
||||
if s.aiAgentAPIClient == nil {
|
||||
state = codersdk.WorkspaceAppStatusState(args.State)
|
||||
}
|
||||
return s.queue.Push(taskReport{
|
||||
link: args.Link,
|
||||
|
||||
+1
-185
@@ -921,7 +921,7 @@ func TestExpMcpReporter(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
// We override idle from the agent to working, but trust final states.
|
||||
// We ignore the state from the agent and assume "working".
|
||||
{
|
||||
name: "IgnoreAgentState",
|
||||
// AI agent reports that it is finished but the summary says it is doing
|
||||
@@ -953,46 +953,6 @@ func TestExpMcpReporter(t *testing.T) {
|
||||
Message: "finished",
|
||||
},
|
||||
},
|
||||
// Agent reports failure; trusted even with AgentAPI enabled.
|
||||
{
|
||||
state: codersdk.WorkspaceAppStatusStateFailure,
|
||||
summary: "something broke",
|
||||
expected: &codersdk.WorkspaceAppStatus{
|
||||
State: codersdk.WorkspaceAppStatusStateFailure,
|
||||
Message: "something broke",
|
||||
},
|
||||
},
|
||||
// After failure, watcher reports stable -> idle.
|
||||
{
|
||||
event: makeStatusEvent(agentapi.StatusStable),
|
||||
expected: &codersdk.WorkspaceAppStatus{
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "something broke",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Final states pass through with AgentAPI enabled.
|
||||
{
|
||||
name: "AllowFinalStates",
|
||||
tests: []test{
|
||||
{
|
||||
state: codersdk.WorkspaceAppStatusStateWorking,
|
||||
summary: "doing work",
|
||||
expected: &codersdk.WorkspaceAppStatus{
|
||||
State: codersdk.WorkspaceAppStatusStateWorking,
|
||||
Message: "doing work",
|
||||
},
|
||||
},
|
||||
// Agent reports complete; not overridden.
|
||||
{
|
||||
state: codersdk.WorkspaceAppStatusStateComplete,
|
||||
summary: "all done",
|
||||
expected: &codersdk.WorkspaceAppStatus{
|
||||
State: codersdk.WorkspaceAppStatusStateComplete,
|
||||
Message: "all done",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// When AgentAPI is not being used, we accept agent state updates as-is.
|
||||
@@ -1150,148 +1110,4 @@ func TestExpMcpReporter(t *testing.T) {
|
||||
<-cmdDone
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Reconnect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a test deployment and workspace.
|
||||
client, db := coderdtest.NewWithDatabase(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
|
||||
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user2.ID,
|
||||
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
|
||||
a[0].Apps = []*proto.App{
|
||||
{
|
||||
Slug: "vscode",
|
||||
},
|
||||
}
|
||||
return a
|
||||
}).Do()
|
||||
|
||||
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitLong))
|
||||
|
||||
// Watch the workspace for changes.
|
||||
watcher, err := client.WatchWorkspace(ctx, r.Workspace.ID)
|
||||
require.NoError(t, err)
|
||||
var lastAppStatus codersdk.WorkspaceAppStatus
|
||||
nextUpdate := func() codersdk.WorkspaceAppStatus {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
require.FailNow(t, "timed out waiting for status update")
|
||||
case w, ok := <-watcher:
|
||||
require.True(t, ok, "watch channel closed")
|
||||
if w.LatestAppStatus != nil && w.LatestAppStatus.ID != lastAppStatus.ID {
|
||||
t.Logf("Got status update: %s > %s", lastAppStatus.State, w.LatestAppStatus.State)
|
||||
lastAppStatus = *w.LatestAppStatus
|
||||
return lastAppStatus
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mock AI AgentAPI server that supports disconnect/reconnect.
|
||||
disconnect := make(chan struct{})
|
||||
listening := make(chan func(sse codersdk.ServerSentEvent) error)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Create a cancelable context so we can stop the SSE sender
|
||||
// goroutine on disconnect without waiting for the HTTP
|
||||
// serve loop to cancel r.Context().
|
||||
sseCtx, sseCancel := context.WithCancel(r.Context())
|
||||
defer sseCancel()
|
||||
r = r.WithContext(sseCtx)
|
||||
|
||||
send, closed, err := httpapi.ServerSentEventSender(w, r)
|
||||
if err != nil {
|
||||
httpapi.Write(sseCtx, w, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error setting up server-sent events.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// Send initial message so the watcher knows the agent is active.
|
||||
send(*makeMessageEvent(0, agentapi.RoleAgent))
|
||||
select {
|
||||
case listening <- send:
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-closed:
|
||||
case <-disconnect:
|
||||
sseCancel()
|
||||
<-closed
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
inv, _ := clitest.New(t,
|
||||
"exp", "mcp", "server",
|
||||
"--agent-url", client.URL.String(),
|
||||
"--agent-token", r.AgentToken,
|
||||
"--app-status-slug", "vscode",
|
||||
"--allowed-tools=coder_report_task",
|
||||
"--ai-agentapi-url", srv.URL,
|
||||
)
|
||||
inv = inv.WithContext(ctx)
|
||||
|
||||
pty := ptytest.New(t)
|
||||
inv.Stdin = pty.Input()
|
||||
inv.Stdout = pty.Output()
|
||||
stderr := ptytest.New(t)
|
||||
inv.Stderr = stderr.Output()
|
||||
|
||||
// Run the MCP server.
|
||||
clitest.Start(t, inv)
|
||||
|
||||
// Initialize.
|
||||
payload := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
|
||||
pty.WriteLine(payload)
|
||||
_ = pty.ReadLine(ctx) // ignore echo
|
||||
_ = pty.ReadLine(ctx) // ignore init response
|
||||
|
||||
// Get first sender from the initial SSE connection.
|
||||
sender := testutil.RequireReceive(ctx, t, listening)
|
||||
|
||||
// Self-report a working status via tool call.
|
||||
toolPayload := `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"doing work","link":""}}}`
|
||||
pty.WriteLine(toolPayload)
|
||||
_ = pty.ReadLine(ctx) // ignore echo
|
||||
_ = pty.ReadLine(ctx) // ignore response
|
||||
got := nextUpdate()
|
||||
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
|
||||
require.Equal(t, "doing work", got.Message)
|
||||
|
||||
// Watcher sends stable, verify idle is reported.
|
||||
err = sender(*makeStatusEvent(agentapi.StatusStable))
|
||||
require.NoError(t, err)
|
||||
got = nextUpdate()
|
||||
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
|
||||
|
||||
// Disconnect the SSE connection by signaling the handler to return.
|
||||
testutil.RequireSend(ctx, t, disconnect, struct{}{})
|
||||
|
||||
// Wait for the watcher to reconnect and get the new sender.
|
||||
sender = testutil.RequireReceive(ctx, t, listening)
|
||||
|
||||
// After reconnect, self-report a working status again.
|
||||
toolPayload = `{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"reconnected","link":""}}}`
|
||||
pty.WriteLine(toolPayload)
|
||||
_ = pty.ReadLine(ctx) // ignore echo
|
||||
_ = pty.ReadLine(ctx) // ignore response
|
||||
got = nextUpdate()
|
||||
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
|
||||
require.Equal(t, "reconnected", got.Message)
|
||||
|
||||
// Verify the watcher still processes events after reconnect.
|
||||
err = sender(*makeStatusEvent(agentapi.StatusStable))
|
||||
require.NoError(t, err)
|
||||
got = nextUpdate()
|
||||
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
|
||||
|
||||
cancel()
|
||||
})
|
||||
}
|
||||
|
||||
+5
-1
@@ -106,7 +106,11 @@ func TestList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
memberClient, member = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
sharedWorkspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
|
||||
@@ -297,7 +297,7 @@ func (pr *ParameterResolver) verifyConstraints(resolved []codersdk.WorkspaceBuil
|
||||
return xerrors.Errorf("ephemeral parameter %q can be used only with --prompt-ephemeral-parameters or --ephemeral-parameter flag", r.Name)
|
||||
}
|
||||
|
||||
if !tvp.Mutable && action != WorkspaceCreate && !pr.isFirstTimeUse(r.Name) {
|
||||
if !tvp.Mutable && action != WorkspaceCreate {
|
||||
return xerrors.Errorf("parameter %q is immutable and cannot be updated", r.Name)
|
||||
}
|
||||
}
|
||||
|
||||
+31
-7
@@ -25,7 +25,11 @@ func TestSharingShare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -64,8 +68,12 @@ func TestSharingShare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -119,7 +127,11 @@ func TestSharingShare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -170,7 +182,11 @@ func TestSharingStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -214,7 +230,11 @@ func TestSharingRemove(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -271,7 +291,11 @@ func TestSharingRemove(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
|
||||
+1
-1
@@ -120,7 +120,7 @@ func (r *RootCmd) start() *serpent.Command {
|
||||
func buildWorkspaceStartRequest(inv *serpent.Invocation, client *codersdk.Client, workspace codersdk.Workspace, parameterFlags workspaceParameterFlags, buildFlags buildFlags, action WorkspaceCLIAction) (codersdk.CreateWorkspaceBuildRequest, error) {
|
||||
version := workspace.LatestBuild.TemplateVersionID
|
||||
|
||||
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || workspace.TemplateRequireActiveVersion || action == WorkspaceUpdate {
|
||||
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || action == WorkspaceUpdate {
|
||||
version = workspace.TemplateActiveVersionID
|
||||
if version != workspace.LatestBuild.TemplateVersionID {
|
||||
action = WorkspaceUpdate
|
||||
|
||||
+4
-4
@@ -33,7 +33,7 @@ func TestStatePull(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taUser.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
|
||||
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
|
||||
Do()
|
||||
statefilePath := filepath.Join(t.TempDir(), "state")
|
||||
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name, statefilePath)
|
||||
@@ -54,7 +54,7 @@ func TestStatePull(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taUser.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
|
||||
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
|
||||
Do()
|
||||
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name)
|
||||
var gotState bytes.Buffer
|
||||
@@ -74,7 +74,7 @@ func TestStatePull(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taUser.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
|
||||
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
|
||||
Do()
|
||||
inv, root := clitest.New(t, "state", "pull", taUser.Username+"/"+r.Workspace.Name,
|
||||
"--build", fmt.Sprintf("%d", r.Build.BuildNumber))
|
||||
@@ -170,7 +170,7 @@ func TestStatePush(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taUser.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{}).ProvisionerState(initialState).
|
||||
Seed(database.WorkspaceBuild{ProvisionerState: initialState}).
|
||||
Do()
|
||||
wantState := []byte("updated state")
|
||||
stateFile, err := os.CreateTemp(t.TempDir(), "")
|
||||
|
||||
+4
-3
@@ -49,9 +49,10 @@ OPTIONS:
|
||||
security purposes if a --wildcard-access-url is configured.
|
||||
|
||||
--disable-workspace-sharing bool, $CODER_DISABLE_WORKSPACE_SHARING
|
||||
Disable workspace sharing. Workspace ACL checking is disabled and only
|
||||
owners can have ssh, apps and terminal access to workspaces. Access
|
||||
based on the 'owner' role is also allowed unless disabled via
|
||||
Disable workspace sharing (requires the "workspace-sharing" experiment
|
||||
to be enabled). Workspace ACL checking is disabled and only owners can
|
||||
have ssh, apps and terminal access to workspaces. Access based on the
|
||||
'owner' role is also allowed unless disabled via
|
||||
--disable-owner-workspace-access.
|
||||
|
||||
--swagger-enable bool, $CODER_SWAGGER_ENABLE
|
||||
|
||||
+4
-4
@@ -523,10 +523,10 @@ disablePathApps: false
|
||||
# workspaces.
|
||||
# (default: <unset>, type: bool)
|
||||
disableOwnerWorkspaceAccess: false
|
||||
# Disable workspace sharing. Workspace ACL checking is disabled and only owners
|
||||
# can have ssh, apps and terminal access to workspaces. Access based on the
|
||||
# 'owner' role is also allowed unless disabled via
|
||||
# --disable-owner-workspace-access.
|
||||
# Disable workspace sharing (requires the "workspace-sharing" experiment to be
|
||||
# enabled). Workspace ACL checking is disabled and only owners can have ssh, apps
|
||||
# and terminal access to workspaces. Access based on the 'owner' role is also
|
||||
# allowed unless disabled via --disable-owner-workspace-access.
|
||||
# (default: <unset>, type: bool)
|
||||
disableWorkspaceSharing: false
|
||||
# These options change the behavior of how clients interact with the Coder.
|
||||
|
||||
+15
-2
@@ -241,13 +241,26 @@ func (r *RootCmd) listTokens() *serpent.Command {
|
||||
}
|
||||
|
||||
tokens, err := client.Tokens(inv.Context(), codersdk.Me, codersdk.TokensFilter{
|
||||
IncludeAll: all,
|
||||
IncludeExpired: includeExpired,
|
||||
IncludeAll: all,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("list tokens: %w", err)
|
||||
}
|
||||
|
||||
// Filter out expired tokens unless --include-expired is set
|
||||
// TODO(Cian): This _could_ get too big for client-side filtering.
|
||||
// If it causes issues, we can filter server-side.
|
||||
if !includeExpired {
|
||||
now := time.Now()
|
||||
filtered := make([]codersdk.APIKeyWithOwner, 0, len(tokens))
|
||||
for _, token := range tokens {
|
||||
if token.ExpiresAt.After(now) {
|
||||
filtered = append(filtered, token)
|
||||
}
|
||||
}
|
||||
tokens = filtered
|
||||
}
|
||||
|
||||
displayTokens = make([]tokenListRow, len(tokens))
|
||||
|
||||
for i, token := range tokens {
|
||||
|
||||
@@ -990,74 +990,4 @@ func TestUpdateValidateRichParameters(t *testing.T) {
|
||||
|
||||
_ = testutil.TryReceive(ctx, t, doneChan)
|
||||
})
|
||||
|
||||
t.Run("NewImmutableParameterViaFlag", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create template and workspace with only a mutable parameter.
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
templateParameters := []*proto.RichParameter{
|
||||
{Name: stringParameterName, Type: "string", Mutable: true, Required: true, Options: []*proto.RichParameterOption{
|
||||
{Name: "First option", Description: "This is first option", Value: "1st"},
|
||||
{Name: "Second option", Description: "This is second option", Value: "2nd"},
|
||||
}},
|
||||
}
|
||||
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses(templateParameters))
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
|
||||
|
||||
inv, root := clitest.New(t, "create", "my-workspace", "--yes", "--template", template.Name, "--parameter", fmt.Sprintf("%s=%s", stringParameterName, "1st"))
|
||||
clitest.SetupConfig(t, member, root)
|
||||
err := inv.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update template: add a new immutable parameter.
|
||||
updatedTemplateParameters := []*proto.RichParameter{
|
||||
templateParameters[0],
|
||||
{Name: immutableParameterName, Type: "string", Mutable: false, Required: true, Options: []*proto.RichParameterOption{
|
||||
{Name: "fir", Description: "First option for immutable parameter", Value: "I"},
|
||||
{Name: "sec", Description: "Second option for immutable parameter", Value: "II"},
|
||||
}},
|
||||
}
|
||||
|
||||
updatedVersion := coderdtest.UpdateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses(updatedTemplateParameters), template.ID)
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, updatedVersion.ID)
|
||||
err = client.UpdateActiveTemplateVersion(context.Background(), template.ID, codersdk.UpdateActiveTemplateVersion{
|
||||
ID: updatedVersion.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update workspace, supplying the new immutable parameter via
|
||||
// the --parameter flag. This should succeed because it's the
|
||||
// first time this parameter is being set.
|
||||
inv, root = clitest.New(t, "update", "my-workspace",
|
||||
"--parameter", fmt.Sprintf("%s=%s", immutableParameterName, "II"))
|
||||
clitest.SetupConfig(t, member, root)
|
||||
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
doneChan := make(chan struct{})
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
err := inv.Run()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
pty.ExpectMatch("Planning workspace")
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
_ = testutil.TryReceive(ctx, t, doneChan)
|
||||
|
||||
// Verify the immutable parameter was set correctly.
|
||||
workspace, err := client.WorkspaceByOwnerAndName(ctx, memberUser.ID.String(), "my-workspace", codersdk.WorkspaceOptions{})
|
||||
require.NoError(t, err)
|
||||
actualParameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, actualParameters, codersdk.WorkspaceBuildParameter{
|
||||
Name: immutableParameterName,
|
||||
Value: "II",
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -179,8 +179,6 @@ func New(opts Options, workspace database.Workspace) *API {
|
||||
Database: opts.Database,
|
||||
Log: opts.Log,
|
||||
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
|
||||
Clock: opts.Clock,
|
||||
NotificationsEnqueuer: opts.NotificationsEnqueuer,
|
||||
}
|
||||
|
||||
api.MetadataAPI = &MetadataAPI{
|
||||
|
||||
@@ -2,10 +2,6 @@ package agentapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
@@ -13,14 +9,7 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
strutil "github.com/coder/coder/v2/coderd/util/strings"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
type AppsAPI struct {
|
||||
@@ -28,8 +17,6 @@ type AppsAPI struct {
|
||||
Database database.Store
|
||||
Log slog.Logger
|
||||
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
|
||||
NotificationsEnqueuer notifications.Enqueuer
|
||||
Clock quartz.Clock
|
||||
}
|
||||
|
||||
func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) {
|
||||
@@ -117,230 +104,3 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
|
||||
}
|
||||
return &agentproto.BatchUpdateAppHealthResponse{}, nil
|
||||
}
|
||||
|
||||
func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
|
||||
if len(req.Message) > 160 {
|
||||
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Message is too long.",
|
||||
Detail: "Message must be less than 160 characters.",
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "message", Detail: "Message must be less than 160 characters."},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
var dbState database.WorkspaceAppStatusState
|
||||
switch req.State {
|
||||
case agentproto.UpdateAppStatusRequest_COMPLETE:
|
||||
dbState = database.WorkspaceAppStatusStateComplete
|
||||
case agentproto.UpdateAppStatusRequest_FAILURE:
|
||||
dbState = database.WorkspaceAppStatusStateFailure
|
||||
case agentproto.UpdateAppStatusRequest_WORKING:
|
||||
dbState = database.WorkspaceAppStatusStateWorking
|
||||
case agentproto.UpdateAppStatusRequest_IDLE:
|
||||
dbState = database.WorkspaceAppStatusStateIdle
|
||||
default:
|
||||
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid state provided.",
|
||||
Detail: fmt.Sprintf("invalid state: %q", req.State),
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "state", Detail: "State must be one of: complete, failure, working, idle."},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
workspaceAgent, err := a.AgentFn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
app, err := a.Database.GetWorkspaceAppByAgentIDAndSlug(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{
|
||||
AgentID: workspaceAgent.ID,
|
||||
Slug: req.Slug,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to get workspace app.",
|
||||
Detail: fmt.Sprintf("No app found with slug %q", req.Slug),
|
||||
})
|
||||
}
|
||||
|
||||
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to get workspace.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Treat the message as untrusted input.
|
||||
cleaned := strutil.UISanitize(req.Message)
|
||||
|
||||
// Get the latest status for the workspace app to detect no-op updates
|
||||
// nolint:gocritic // This is a system restricted operation.
|
||||
latestAppStatus, err := a.Database.GetLatestWorkspaceAppStatusByAppID(dbauthz.AsSystemRestricted(ctx), app.ID)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get latest workspace app status.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
// If no rows found, latestAppStatus will be a zero-value struct (ID == uuid.Nil)
|
||||
|
||||
// nolint:gocritic // This is a system restricted operation.
|
||||
_, err = a.Database.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
|
||||
ID: uuid.New(),
|
||||
CreatedAt: dbtime.Now(),
|
||||
WorkspaceID: workspace.ID,
|
||||
AgentID: workspaceAgent.ID,
|
||||
AppID: app.ID,
|
||||
State: dbState,
|
||||
Message: cleaned,
|
||||
Uri: sql.NullString{
|
||||
String: req.Uri,
|
||||
Valid: req.Uri != "",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to insert workspace app status.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
if a.PublishWorkspaceUpdateFn != nil {
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentAppStatusUpdate)
|
||||
if err != nil {
|
||||
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to publish workspace update.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Notify on state change to Working/Idle for AI tasks
|
||||
a.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, dbState, workspace, workspaceAgent)
|
||||
|
||||
if shouldBump(dbState, latestAppStatus) {
|
||||
// We pass time.Time{} for nextAutostart since we don't have access to
|
||||
// TemplateScheduleStore here. The activity bump logic handles this by
|
||||
// defaulting to the template's activity_bump duration (typically 1 hour).
|
||||
workspacestats.ActivityBumpWorkspace(ctx, a.Log, a.Database, workspace.ID, time.Time{})
|
||||
}
|
||||
// just return a blank response because it doesn't contain any settable fields at present.
|
||||
return new(agentproto.UpdateAppStatusResponse), nil
|
||||
}
|
||||
|
||||
func shouldBump(dbState database.WorkspaceAppStatusState, latestAppStatus database.WorkspaceAppStatus) bool {
|
||||
// Bump deadline when agent reports working or transitions away from working.
|
||||
// This prevents auto-pause during active work and gives users time to interact
|
||||
// after work completes.
|
||||
|
||||
// Bump if reporting working state.
|
||||
if dbState == database.WorkspaceAppStatusStateWorking {
|
||||
return true
|
||||
}
|
||||
|
||||
// Bump if transitioning away from working state.
|
||||
if latestAppStatus.ID != uuid.Nil {
|
||||
prevState := latestAppStatus.State
|
||||
if prevState == database.WorkspaceAppStatusStateWorking {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// enqueueAITaskStateNotification enqueues a notification when an AI task's app
|
||||
// transitions to Working or Idle.
|
||||
// No-op if:
|
||||
// - the workspace agent app isn't configured as an AI task,
|
||||
// - the new state equals the latest persisted state,
|
||||
// - the workspace agent is not ready (still starting up).
|
||||
func (a *AppsAPI) enqueueAITaskStateNotification(
|
||||
ctx context.Context,
|
||||
appID uuid.UUID,
|
||||
latestAppStatus database.WorkspaceAppStatus,
|
||||
newAppStatus database.WorkspaceAppStatusState,
|
||||
workspace database.Workspace,
|
||||
agent database.WorkspaceAgent,
|
||||
) {
|
||||
var notificationTemplate uuid.UUID
|
||||
switch newAppStatus {
|
||||
case database.WorkspaceAppStatusStateWorking:
|
||||
notificationTemplate = notifications.TemplateTaskWorking
|
||||
case database.WorkspaceAppStatusStateIdle:
|
||||
notificationTemplate = notifications.TemplateTaskIdle
|
||||
case database.WorkspaceAppStatusStateComplete:
|
||||
notificationTemplate = notifications.TemplateTaskCompleted
|
||||
case database.WorkspaceAppStatusStateFailure:
|
||||
notificationTemplate = notifications.TemplateTaskFailed
|
||||
default:
|
||||
// Not a notifiable state, do nothing
|
||||
return
|
||||
}
|
||||
|
||||
if !workspace.TaskID.Valid {
|
||||
// Workspace has no task ID, do nothing.
|
||||
return
|
||||
}
|
||||
|
||||
// Only send notifications when the agent is ready. We want to skip
|
||||
// any state transitions that occur whilst the workspace is starting
|
||||
// up as it doesn't make sense to receive them.
|
||||
if agent.LifecycleState != database.WorkspaceAgentLifecycleStateReady {
|
||||
a.Log.Debug(ctx, "skipping AI task notification because agent is not ready",
|
||||
slog.F("agent_id", agent.ID),
|
||||
slog.F("lifecycle_state", agent.LifecycleState),
|
||||
slog.F("new_app_status", newAppStatus),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
task, err := a.Database.GetTaskByID(ctx, workspace.TaskID.UUID)
|
||||
if err != nil {
|
||||
a.Log.Warn(ctx, "failed to get task", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
if !task.WorkspaceAppID.Valid || task.WorkspaceAppID.UUID != appID {
|
||||
// Non-task app, do nothing.
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if the latest persisted state equals the new state (no new transition)
|
||||
// Note: uuid.Nil check is valid here. If no previous status exists,
|
||||
// GetLatestWorkspaceAppStatusByAppID returns sql.ErrNoRows and we get a zero-value struct.
|
||||
if latestAppStatus.ID != uuid.Nil && latestAppStatus.State == newAppStatus {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip the initial "Working" notification when the task first starts.
|
||||
// This is obvious to the user since they just created the task.
|
||||
// We still notify on the first "Idle" status and all subsequent transitions.
|
||||
if latestAppStatus.ID == uuid.Nil && newAppStatus == database.WorkspaceAppStatusStateWorking {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := a.NotificationsEnqueuer.EnqueueWithData(
|
||||
// nolint:gocritic // Need notifier actor to enqueue notifications
|
||||
dbauthz.AsNotifier(ctx),
|
||||
workspace.OwnerID,
|
||||
notificationTemplate,
|
||||
map[string]string{
|
||||
"task": task.Name,
|
||||
"workspace": workspace.Name,
|
||||
},
|
||||
map[string]any{
|
||||
// Use a 1-minute bucketed timestamp to bypass per-day dedupe,
|
||||
// allowing identical content to resend within the same day
|
||||
// (but not more than once every 10s).
|
||||
"dedupe_bypass_ts": a.Clock.Now().UTC().Truncate(time.Minute),
|
||||
},
|
||||
"api-workspace-agent-app-status",
|
||||
// Associate this notification with related entities
|
||||
workspace.ID, workspace.OwnerID, workspace.OrganizationID, appID,
|
||||
); err != nil {
|
||||
a.Log.Warn(ctx, "failed to notify of task state", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
package agentapi
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
)
|
||||
|
||||
func TestShouldBump(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prevState *database.WorkspaceAppStatusState // nil means no previous state
|
||||
newState database.WorkspaceAppStatusState
|
||||
shouldBump bool
|
||||
}{
|
||||
{
|
||||
name: "FirstStatusBumps",
|
||||
prevState: nil,
|
||||
newState: database.WorkspaceAppStatusStateWorking,
|
||||
shouldBump: true,
|
||||
},
|
||||
{
|
||||
name: "WorkingToIdleBumps",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
|
||||
newState: database.WorkspaceAppStatusStateIdle,
|
||||
shouldBump: true,
|
||||
},
|
||||
{
|
||||
name: "WorkingToCompleteBumps",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
|
||||
newState: database.WorkspaceAppStatusStateComplete,
|
||||
shouldBump: true,
|
||||
},
|
||||
{
|
||||
name: "CompleteToIdleNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
|
||||
newState: database.WorkspaceAppStatusStateIdle,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "CompleteToCompleteNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
|
||||
newState: database.WorkspaceAppStatusStateComplete,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "FailureToIdleNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
|
||||
newState: database.WorkspaceAppStatusStateIdle,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "FailureToFailureNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
|
||||
newState: database.WorkspaceAppStatusStateFailure,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "CompleteToWorkingBumps",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
|
||||
newState: database.WorkspaceAppStatusStateWorking,
|
||||
shouldBump: true,
|
||||
},
|
||||
{
|
||||
name: "FailureToCompleteNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
|
||||
newState: database.WorkspaceAppStatusStateComplete,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "WorkingToFailureBumps",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
|
||||
newState: database.WorkspaceAppStatusStateFailure,
|
||||
shouldBump: true,
|
||||
},
|
||||
{
|
||||
name: "IdleToIdleNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateIdle),
|
||||
newState: database.WorkspaceAppStatusStateIdle,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "IdleToWorkingBumps",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateIdle),
|
||||
newState: database.WorkspaceAppStatusStateWorking,
|
||||
shouldBump: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var prevAppStatus database.WorkspaceAppStatus
|
||||
// If there's a previous state, report it first.
|
||||
if tt.prevState != nil {
|
||||
prevAppStatus.ID = uuid.UUID{1}
|
||||
prevAppStatus.State = *tt.prevState
|
||||
}
|
||||
|
||||
didBump := shouldBump(tt.newState, prevAppStatus)
|
||||
if tt.shouldBump {
|
||||
require.True(t, didBump, "wanted deadline to bump but it didn't")
|
||||
} else {
|
||||
require.False(t, didBump, "wanted deadline not to bump but it did")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,13 +2,9 @@ package agentapi_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
@@ -16,12 +12,8 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestBatchUpdateAppHealths(t *testing.T) {
|
||||
@@ -261,183 +253,3 @@ func TestBatchUpdateAppHealths(t *testing.T) {
|
||||
require.Nil(t, resp)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkspaceAgentAppStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
fEnq := ¬ificationstest.FakeEnqueuer{}
|
||||
mClock := quartz.NewMock(t)
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: uuid.UUID{2},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
}
|
||||
workspaceUpdates := make(chan wspubsub.WorkspaceEventKind, 100)
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: mDB,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: func(_ context.Context, agnt *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
assert.Equal(t, *agnt, agent)
|
||||
testutil.AssertSend(ctx, t, workspaceUpdates, kind)
|
||||
return nil
|
||||
},
|
||||
NotificationsEnqueuer: fEnq,
|
||||
Clock: mClock,
|
||||
}
|
||||
|
||||
app := database.WorkspaceApp{
|
||||
ID: uuid.UUID{8},
|
||||
}
|
||||
mDB.EXPECT().GetWorkspaceAppByAgentIDAndSlug(gomock.Any(), database.GetWorkspaceAppByAgentIDAndSlugParams{
|
||||
AgentID: agent.ID,
|
||||
Slug: "vscode",
|
||||
}).Times(1).Return(app, nil)
|
||||
task := database.Task{
|
||||
ID: uuid.UUID{7},
|
||||
WorkspaceAppID: uuid.NullUUID{
|
||||
Valid: true,
|
||||
UUID: app.ID,
|
||||
},
|
||||
}
|
||||
mDB.EXPECT().GetTaskByID(gomock.Any(), task.ID).Times(1).Return(task, nil)
|
||||
workspace := database.Workspace{
|
||||
ID: uuid.UUID{9},
|
||||
TaskID: uuid.NullUUID{
|
||||
Valid: true,
|
||||
UUID: task.ID,
|
||||
},
|
||||
}
|
||||
mDB.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Times(1).Return(workspace, nil)
|
||||
appStatus := database.WorkspaceAppStatus{
|
||||
ID: uuid.UUID{6},
|
||||
}
|
||||
mDB.EXPECT().GetLatestWorkspaceAppStatusByAppID(gomock.Any(), app.ID).Times(1).Return(appStatus, nil)
|
||||
mDB.EXPECT().InsertWorkspaceAppStatus(
|
||||
gomock.Any(),
|
||||
gomock.Cond(func(params database.InsertWorkspaceAppStatusParams) bool {
|
||||
if params.AgentID == agent.ID && params.AppID == app.ID {
|
||||
assert.Equal(t, "testing", params.Message)
|
||||
assert.Equal(t, database.WorkspaceAppStatusStateComplete, params.State)
|
||||
assert.True(t, params.Uri.Valid)
|
||||
assert.Equal(t, "https://example.com", params.Uri.String)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})).Times(1).Return(database.WorkspaceAppStatus{}, nil)
|
||||
|
||||
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
|
||||
Slug: "vscode",
|
||||
Message: "testing",
|
||||
Uri: "https://example.com",
|
||||
State: agentproto.UpdateAppStatusRequest_COMPLETE,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
kind := testutil.RequireReceive(ctx, t, workspaceUpdates)
|
||||
require.Equal(t, wspubsub.WorkspaceEventKindAgentAppStatusUpdate, kind)
|
||||
sent := fEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTaskCompleted))
|
||||
require.Len(t, sent, 1)
|
||||
})
|
||||
|
||||
t.Run("FailUnknownApp", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: uuid.UUID{2},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
}
|
||||
|
||||
mDB.EXPECT().GetWorkspaceAppByAgentIDAndSlug(gomock.Any(), gomock.Any()).
|
||||
Times(1).
|
||||
Return(database.WorkspaceApp{}, sql.ErrNoRows)
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: mDB,
|
||||
Log: testutil.Logger(t),
|
||||
}
|
||||
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
|
||||
Slug: "unknown",
|
||||
Message: "testing",
|
||||
Uri: "https://example.com",
|
||||
State: agentproto.UpdateAppStatusRequest_COMPLETE,
|
||||
})
|
||||
require.ErrorContains(t, err, "No app found with slug")
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("FailUnknownState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: uuid.UUID{2},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
}
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: mDB,
|
||||
Log: testutil.Logger(t),
|
||||
}
|
||||
|
||||
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
|
||||
Slug: "vscode",
|
||||
Message: "testing",
|
||||
Uri: "https://example.com",
|
||||
State: 77,
|
||||
})
|
||||
require.ErrorContains(t, err, "Invalid state")
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("FailTooLong", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: uuid.UUID{2},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
}
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: mDB,
|
||||
Log: testutil.Logger(t),
|
||||
}
|
||||
|
||||
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
|
||||
Slug: "vscode",
|
||||
Message: strings.Repeat("a", 161),
|
||||
Uri: "https://example.com",
|
||||
State: agentproto.UpdateAppStatusRequest_COMPLETE,
|
||||
})
|
||||
require.ErrorContains(t, err, "Message is too long")
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -134,12 +134,9 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
|
||||
case database.WorkspaceAgentLifecycleStateReady,
|
||||
database.WorkspaceAgentLifecycleStateStartTimeout,
|
||||
database.WorkspaceAgentLifecycleStateStartError:
|
||||
// Only emit metrics for the parent agent, this metric is not intended to measure devcontainer durations.
|
||||
if !workspaceAgent.ParentID.Valid {
|
||||
a.emitMetricsOnce.Do(func() {
|
||||
a.emitBuildDurationMetric(ctx, workspaceAgent.ResourceID)
|
||||
})
|
||||
}
|
||||
a.emitMetricsOnce.Do(func() {
|
||||
a.emitBuildDurationMetric(ctx, workspaceAgent.ResourceID)
|
||||
})
|
||||
}
|
||||
|
||||
return req.Lifecycle, nil
|
||||
|
||||
@@ -582,64 +582,6 @@ func TestUpdateLifecycle(t *testing.T) {
|
||||
require.Equal(t, uint64(1), got.GetSampleCount())
|
||||
require.Equal(t, expectedDuration, got.GetSampleSum())
|
||||
})
|
||||
|
||||
t.Run("SubAgentDoesNotEmitMetric", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
parentID := uuid.New()
|
||||
subAgent := database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
ParentID: uuid.NullUUID{UUID: parentID, Valid: true},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateStarting,
|
||||
StartedAt: sql.NullTime{Valid: true, Time: someTime},
|
||||
ReadyAt: sql.NullTime{Valid: false},
|
||||
}
|
||||
lifecycle := &agentproto.Lifecycle{
|
||||
State: agentproto.Lifecycle_READY,
|
||||
ChangedAt: timestamppb.New(now),
|
||||
}
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{
|
||||
ID: subAgent.ID,
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
StartedAt: subAgent.StartedAt,
|
||||
ReadyAt: sql.NullTime{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
}).Return(nil)
|
||||
// GetWorkspaceBuildMetricsByResourceID should NOT be called
|
||||
// because sub-agents should be skipped before querying.
|
||||
reg := prometheus.NewRegistry()
|
||||
metrics := agentapi.NewLifecycleMetrics(reg)
|
||||
api := &agentapi.LifecycleAPI{
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return subAgent, nil
|
||||
},
|
||||
WorkspaceID: workspaceID,
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
Metrics: metrics,
|
||||
PublishWorkspaceUpdateFn: nil,
|
||||
}
|
||||
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
|
||||
Lifecycle: lifecycle,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, lifecycle, resp)
|
||||
|
||||
// We don't expect the metric to be emitted for sub-agents, by default this will fail anyway but it doesn't hurt
|
||||
// to document the test explicitly.
|
||||
dbM.EXPECT().GetWorkspaceBuildMetricsByResourceID(gomock.Any(), gomock.Any()).Times(0)
|
||||
|
||||
// If we were emitting the metric we would have failed by now since it would include a call to the database that we're not expecting.
|
||||
pm, err := reg.Gather()
|
||||
require.NoError(t, err)
|
||||
for _, m := range pm {
|
||||
if m.GetName() == fullMetricName {
|
||||
t.Fatal("metric should not be emitted for sub-agent")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateStartup(t *testing.T) {
|
||||
|
||||
+4
-2
@@ -466,6 +466,7 @@ func (api *API) convertTasks(ctx context.Context, requesterID uuid.UUID, dbTasks
|
||||
|
||||
apiWorkspaces, err := convertWorkspaces(
|
||||
ctx,
|
||||
api.Experiments,
|
||||
api.Logger,
|
||||
requesterID,
|
||||
workspaces,
|
||||
@@ -545,6 +546,7 @@ func (api *API) taskGet(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ws, err := convertWorkspace(
|
||||
ctx,
|
||||
api.Experiments,
|
||||
api.Logger,
|
||||
apiKey.UserID,
|
||||
workspace,
|
||||
@@ -1248,7 +1250,7 @@ func (api *API) postWorkspaceAgentTaskLogSnapshot(rw http.ResponseWriter, r *htt
|
||||
// @Summary Pause task
|
||||
// @ID pause-task
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Accept json
|
||||
// @Tags Tasks
|
||||
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
|
||||
// @Param task path string true "Task ID" format(uuid)
|
||||
@@ -1325,7 +1327,7 @@ func (api *API) pauseTask(rw http.ResponseWriter, r *http.Request) {
|
||||
// @Summary Resume task
|
||||
// @ID resume-task
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Accept json
|
||||
// @Tags Tasks
|
||||
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
|
||||
// @Param task path string true "Task ID" format(uuid)
|
||||
|
||||
Generated
+9
-18
@@ -5894,7 +5894,7 @@ const docTemplate = `{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": [
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
@@ -5936,7 +5936,7 @@ const docTemplate = `{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": [
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
@@ -8238,12 +8238,6 @@ const docTemplate = `{
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "boolean",
|
||||
"description": "Include expired tokens in the list",
|
||||
"name": "include_expired",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
@@ -9551,7 +9545,6 @@ const docTemplate = `{
|
||||
],
|
||||
"summary": "Patch workspace agent app status",
|
||||
"operationId": "patch-workspace-agent-app-status",
|
||||
"deprecated": true,
|
||||
"parameters": [
|
||||
{
|
||||
"description": "app status",
|
||||
@@ -13440,9 +13433,6 @@ const docTemplate = `{
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -13756,9 +13746,6 @@ const docTemplate = `{
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -15110,7 +15097,8 @@ const docTemplate = `{
|
||||
"workspace-usage",
|
||||
"web-push",
|
||||
"oauth2",
|
||||
"mcp-server-http"
|
||||
"mcp-server-http",
|
||||
"workspace-sharing"
|
||||
],
|
||||
"x-enum-comments": {
|
||||
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
|
||||
@@ -15119,6 +15107,7 @@ const docTemplate = `{
|
||||
"ExperimentNotifications": "Sends notifications via SMTP and webhooks following certain events.",
|
||||
"ExperimentOAuth2": "Enables OAuth2 provider functionality.",
|
||||
"ExperimentWebPush": "Enables web push notifications through the browser.",
|
||||
"ExperimentWorkspaceSharing": "Enables updating workspace ACLs for sharing with users and groups.",
|
||||
"ExperimentWorkspaceUsage": "Enables the new workspace usage tracking."
|
||||
},
|
||||
"x-enum-descriptions": [
|
||||
@@ -15128,7 +15117,8 @@ const docTemplate = `{
|
||||
"Enables the new workspace usage tracking.",
|
||||
"Enables web push notifications through the browser.",
|
||||
"Enables OAuth2 provider functionality.",
|
||||
"Enables the MCP HTTP server functionality."
|
||||
"Enables the MCP HTTP server functionality.",
|
||||
"Enables updating workspace ACLs for sharing with users and groups."
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"ExperimentExample",
|
||||
@@ -15137,7 +15127,8 @@ const docTemplate = `{
|
||||
"ExperimentWorkspaceUsage",
|
||||
"ExperimentWebPush",
|
||||
"ExperimentOAuth2",
|
||||
"ExperimentMCPServerHTTP"
|
||||
"ExperimentMCPServerHTTP",
|
||||
"ExperimentWorkspaceSharing"
|
||||
]
|
||||
},
|
||||
"codersdk.ExternalAPIKeyScopes": {
|
||||
|
||||
Generated
+9
-18
@@ -5213,7 +5213,7 @@
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": ["application/json"],
|
||||
"consumes": ["application/json"],
|
||||
"tags": ["Tasks"],
|
||||
"summary": "Pause task",
|
||||
"operationId": "pause-task",
|
||||
@@ -5251,7 +5251,7 @@
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": ["application/json"],
|
||||
"consumes": ["application/json"],
|
||||
"tags": ["Tasks"],
|
||||
"summary": "Resume task",
|
||||
"operationId": "resume-task",
|
||||
@@ -7285,12 +7285,6 @@
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "boolean",
|
||||
"description": "Include expired tokens in the list",
|
||||
"name": "include_expired",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
@@ -8450,7 +8444,6 @@
|
||||
"tags": ["Agents"],
|
||||
"summary": "Patch workspace agent app status",
|
||||
"operationId": "patch-workspace-agent-app-status",
|
||||
"deprecated": true,
|
||||
"parameters": [
|
||||
{
|
||||
"description": "app status",
|
||||
@@ -12042,9 +12035,6 @@
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -12337,9 +12327,6 @@
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -13637,7 +13624,8 @@
|
||||
"workspace-usage",
|
||||
"web-push",
|
||||
"oauth2",
|
||||
"mcp-server-http"
|
||||
"mcp-server-http",
|
||||
"workspace-sharing"
|
||||
],
|
||||
"x-enum-comments": {
|
||||
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
|
||||
@@ -13646,6 +13634,7 @@
|
||||
"ExperimentNotifications": "Sends notifications via SMTP and webhooks following certain events.",
|
||||
"ExperimentOAuth2": "Enables OAuth2 provider functionality.",
|
||||
"ExperimentWebPush": "Enables web push notifications through the browser.",
|
||||
"ExperimentWorkspaceSharing": "Enables updating workspace ACLs for sharing with users and groups.",
|
||||
"ExperimentWorkspaceUsage": "Enables the new workspace usage tracking."
|
||||
},
|
||||
"x-enum-descriptions": [
|
||||
@@ -13655,7 +13644,8 @@
|
||||
"Enables the new workspace usage tracking.",
|
||||
"Enables web push notifications through the browser.",
|
||||
"Enables OAuth2 provider functionality.",
|
||||
"Enables the MCP HTTP server functionality."
|
||||
"Enables the MCP HTTP server functionality.",
|
||||
"Enables updating workspace ACLs for sharing with users and groups."
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"ExperimentExample",
|
||||
@@ -13664,7 +13654,8 @@
|
||||
"ExperimentWorkspaceUsage",
|
||||
"ExperimentWebPush",
|
||||
"ExperimentOAuth2",
|
||||
"ExperimentMCPServerHTTP"
|
||||
"ExperimentMCPServerHTTP",
|
||||
"ExperimentWorkspaceSharing"
|
||||
]
|
||||
},
|
||||
"codersdk.ExternalAPIKeyScopes": {
|
||||
|
||||
+8
-14
@@ -307,26 +307,20 @@ func (api *API) apiKeyByName(rw http.ResponseWriter, r *http.Request) {
|
||||
// @Tags Users
|
||||
// @Param user path string true "User ID, name, or me"
|
||||
// @Success 200 {array} codersdk.APIKey
|
||||
// @Param include_expired query bool false "Include expired tokens in the list"
|
||||
// @Router /users/{user}/keys/tokens [get]
|
||||
func (api *API) tokens(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
user = httpmw.UserParam(r)
|
||||
keys []database.APIKey
|
||||
err error
|
||||
queryStr = r.URL.Query().Get("include_all")
|
||||
includeAll, _ = strconv.ParseBool(queryStr)
|
||||
expiredStr = r.URL.Query().Get("include_expired")
|
||||
includeExpired, _ = strconv.ParseBool(expiredStr)
|
||||
ctx = r.Context()
|
||||
user = httpmw.UserParam(r)
|
||||
keys []database.APIKey
|
||||
err error
|
||||
queryStr = r.URL.Query().Get("include_all")
|
||||
includeAll, _ = strconv.ParseBool(queryStr)
|
||||
)
|
||||
|
||||
if includeAll {
|
||||
// get tokens for all users
|
||||
keys, err = api.Database.GetAPIKeysByLoginType(ctx, database.GetAPIKeysByLoginTypeParams{
|
||||
LoginType: database.LoginTypeToken,
|
||||
IncludeExpired: includeExpired,
|
||||
})
|
||||
keys, err = api.Database.GetAPIKeysByLoginType(ctx, database.LoginTypeToken)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching API keys.",
|
||||
@@ -336,7 +330,7 @@ func (api *API) tokens(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
} else {
|
||||
// get user's tokens only
|
||||
keys, err = api.Database.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: user.ID, IncludeExpired: includeExpired})
|
||||
keys, err = api.Database.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: user.ID})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching API keys.",
|
||||
|
||||
@@ -69,44 +69,6 @@ func TestTokenCRUD(t *testing.T) {
|
||||
require.Equal(t, database.AuditActionDelete, auditor.AuditLogs()[numLogs-1].Action)
|
||||
}
|
||||
|
||||
func TestTokensFilterExpired(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
adminClient := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, adminClient)
|
||||
|
||||
// Create a token.
|
||||
res, err := adminClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
|
||||
Lifetime: time.Hour * 24 * 7,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
keyID := strings.Split(res.Key, "-")[0]
|
||||
|
||||
// List tokens without including expired - should see the token.
|
||||
keys, err := adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
|
||||
// Expire the token.
|
||||
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// List tokens without including expired - should NOT see expired token.
|
||||
keys, err = adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, keys)
|
||||
|
||||
// List tokens WITH including expired - should see expired token.
|
||||
keys, err = adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{
|
||||
IncludeExpired: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, keyID, keys[0].ID)
|
||||
}
|
||||
|
||||
func TestTokenScoped(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
+1
-8
@@ -26,11 +26,6 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// Limit the count query to avoid a slow sequential scan due to joins
|
||||
// on a large table. Set to 0 to disable capping (but also see the note
|
||||
// in the SQL query).
|
||||
const auditLogCountCap = 2000
|
||||
|
||||
// @Summary Get audit logs
|
||||
// @ID get-audit-logs
|
||||
// @Security CoderSessionToken
|
||||
@@ -71,7 +66,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
countFilter.Username = ""
|
||||
}
|
||||
|
||||
countFilter.CountCap = auditLogCountCap
|
||||
// Use the same filters to count the number of audit logs
|
||||
count, err := api.Database.CountAuditLogs(ctx, countFilter)
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
@@ -86,7 +81,6 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
|
||||
AuditLogs: []codersdk.AuditLog{},
|
||||
Count: 0,
|
||||
CountCap: auditLogCountCap,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -104,7 +98,6 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
|
||||
AuditLogs: api.convertAuditLogs(ctx, dblogs),
|
||||
Count: count,
|
||||
CountCap: auditLogCountCap,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -500,7 +500,7 @@ func (e *Executor) runOnce(t time.Time) Stats {
|
||||
"task": task.Name,
|
||||
"task_id": task.ID.String(),
|
||||
"workspace": ws.Name,
|
||||
"pause_reason": "idle timeout",
|
||||
"pause_reason": "inactivity exceeded the dormancy threshold",
|
||||
},
|
||||
"lifecycle_executor",
|
||||
ws.ID, ws.OwnerID, ws.OrganizationID,
|
||||
|
||||
@@ -2082,6 +2082,6 @@ func TestExecutorTaskWorkspace(t *testing.T) {
|
||||
require.Equal(t, task.Name, sent[0].Labels["task"])
|
||||
require.Equal(t, task.ID.String(), sent[0].Labels["task_id"])
|
||||
require.Equal(t, workspace.Name, sent[0].Labels["workspace"])
|
||||
require.Equal(t, "idle timeout", sent[0].Labels["pause_reason"])
|
||||
require.Equal(t, "inactivity exceeded the dormancy threshold", sent[0].Labels["pause_reason"])
|
||||
})
|
||||
}
|
||||
|
||||
+9
-10
@@ -98,7 +98,6 @@ import (
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
"github.com/coder/coder/v2/site"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/derpmetrics"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
@@ -330,10 +329,9 @@ func New(options *Options) *API {
|
||||
panic("developer error: options.PrometheusRegistry is nil and not running a unit test")
|
||||
}
|
||||
|
||||
if options.DeploymentValues.DisableOwnerWorkspaceExec || options.DeploymentValues.DisableWorkspaceSharing {
|
||||
if options.DeploymentValues.DisableOwnerWorkspaceExec {
|
||||
rbac.ReloadBuiltinRoles(&rbac.RoleOptions{
|
||||
NoOwnerWorkspaceExec: bool(options.DeploymentValues.DisableOwnerWorkspaceExec),
|
||||
NoWorkspaceSharing: bool(options.DeploymentValues.DisableWorkspaceSharing),
|
||||
NoOwnerWorkspaceExec: true,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -884,18 +882,17 @@ func New(options *Options) *API {
|
||||
apiRateLimiter := httpmw.RateLimit(options.APIRateLimit, time.Minute)
|
||||
|
||||
// Register DERP on expvar HTTP handler, which we serve below in the router, c.f. expvar.Handler()
|
||||
// These are the metrics the DERP server exposes.
|
||||
// TODO: export via prometheus
|
||||
expDERPOnce.Do(func() {
|
||||
// We need to do this via a global Once because expvar registry is global and panics if we
|
||||
// register multiple times. In production there is only one Coderd and one DERP server per
|
||||
// process, but in testing, we create multiple of both, so the Once protects us from
|
||||
// panicking.
|
||||
if options.DERPServer != nil && expvar.Get("derp") == nil {
|
||||
if options.DERPServer != nil {
|
||||
expvar.Publish("derp", api.DERPServer.ExpVar())
|
||||
}
|
||||
})
|
||||
if options.PrometheusRegistry != nil && options.DERPServer != nil {
|
||||
options.PrometheusRegistry.MustRegister(derpmetrics.NewDERPExpvarCollector(options.DERPServer))
|
||||
}
|
||||
cors := httpmw.Cors(options.DeploymentValues.Dangerous.AllowAllCors.Value())
|
||||
prometheusMW := httpmw.Prometheus(options.PrometheusRegistry)
|
||||
|
||||
@@ -1529,6 +1526,10 @@ func New(options *Options) *API {
|
||||
})
|
||||
r.Get("/timings", api.workspaceTimings)
|
||||
r.Route("/acl", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.RequireExperiment(api.Experiments, codersdk.ExperimentWorkspaceSharing),
|
||||
)
|
||||
|
||||
r.Get("/", api.workspaceACL)
|
||||
r.Patch("/", api.patchWorkspaceACL)
|
||||
r.Delete("/", api.deleteWorkspaceACL)
|
||||
@@ -1737,8 +1738,6 @@ func New(options *Options) *API {
|
||||
r.Patch("/input", api.taskUpdateInput)
|
||||
r.Post("/send", api.taskSend)
|
||||
r.Get("/logs", api.taskLogs)
|
||||
r.Post("/pause", api.pauseTask)
|
||||
r.Post("/resume", api.resumeTask)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
+2
-28
@@ -384,35 +384,9 @@ func TestCSRFExempt(t *testing.T) {
|
||||
data, _ := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
|
||||
// A StatusNotFound means Coderd tried to proxy to the agent and failed because the agent
|
||||
// A StatusBadGateway means Coderd tried to proxy to the agent and failed because the agent
|
||||
// was not there. This means CSRF did not block the app request, which is what we want.
|
||||
require.Equal(t, http.StatusNotFound, resp.StatusCode, "status code 500 is CSRF failure")
|
||||
require.Equal(t, http.StatusBadGateway, resp.StatusCode, "status code 500 is CSRF failure")
|
||||
require.NotContains(t, string(data), "CSRF")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDERPMetrics(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, api := coderdtest.NewWithAPI(t, nil)
|
||||
|
||||
require.NotNil(t, api.Options.DERPServer, "DERP server should be configured")
|
||||
require.NotNil(t, api.Options.PrometheusRegistry, "Prometheus registry should be configured")
|
||||
|
||||
// The registry is created internally by coderd. Gather from it
|
||||
// to verify DERP metrics were registered during startup.
|
||||
metrics, err := api.Options.PrometheusRegistry.Gather()
|
||||
require.NoError(t, err)
|
||||
|
||||
names := make(map[string]struct{})
|
||||
for _, m := range metrics {
|
||||
names[m.GetName()] = struct{}{}
|
||||
}
|
||||
|
||||
assert.Contains(t, names, "coder_derp_server_connections",
|
||||
"expected coder_derp_server_connections to be registered")
|
||||
assert.Contains(t, names, "coder_derp_server_bytes_received_total",
|
||||
"expected coder_derp_server_bytes_received_total to be registered")
|
||||
assert.Contains(t, names, "coder_derp_server_packets_dropped_reason_total",
|
||||
"expected coder_derp_server_packets_dropped_reason_total to be registered")
|
||||
}
|
||||
|
||||
@@ -106,8 +106,6 @@ import (
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const DefaultDERPMeshKey = "test-key"
|
||||
|
||||
const defaultTestDaemonName = "test-daemon"
|
||||
|
||||
type Options struct {
|
||||
@@ -514,18 +512,8 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
|
||||
stunAddresses = options.DeploymentValues.DERP.Server.STUNAddresses.Value()
|
||||
}
|
||||
|
||||
const derpMeshKey = "test-key"
|
||||
// Technically AGPL coderd servers don't set this value, but it doesn't
|
||||
// change any behavior. It's useful for enterprise tests.
|
||||
err = options.Database.InsertDERPMeshKey(dbauthz.AsSystemRestricted(ctx), derpMeshKey) //nolint:gocritic // test
|
||||
if !database.IsUniqueViolation(err, database.UniqueSiteConfigsKeyKey) {
|
||||
require.NoError(t, err, "insert DERP mesh key")
|
||||
}
|
||||
var derpServer *derp.Server
|
||||
if options.DeploymentValues.DERP.Server.Enable.Value() {
|
||||
derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp").Leveled(slog.LevelDebug)))
|
||||
derpServer.SetMeshKey(derpMeshKey)
|
||||
}
|
||||
derpServer := derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp").Leveled(slog.LevelDebug)))
|
||||
derpServer.SetMeshKey("test-key")
|
||||
|
||||
// match default with cli default
|
||||
if options.SSHKeygenAlgorithm == "" {
|
||||
|
||||
@@ -668,31 +668,6 @@ var (
|
||||
}),
|
||||
Scope: rbac.ScopeAll,
|
||||
}.WithCachedASTValue()
|
||||
|
||||
subjectWorkspaceBuilder = rbac.Subject{
|
||||
Type: rbac.SubjectTypeWorkspaceBuilder,
|
||||
FriendlyName: "Workspace Builder",
|
||||
ID: uuid.Nil.String(),
|
||||
Roles: rbac.Roles([]rbac.Role{
|
||||
{
|
||||
Identifier: rbac.RoleIdentifier{Name: "workspace-builder"},
|
||||
DisplayName: "Workspace Builder",
|
||||
Site: rbac.Permissions(map[string][]policy.Action{
|
||||
// Reading provisioner daemons to check eligibility.
|
||||
rbac.ResourceProvisionerDaemon.Type: {policy.ActionRead},
|
||||
// Updating provisioner jobs (e.g. marking prebuild
|
||||
// jobs complete).
|
||||
rbac.ResourceProvisionerJobs.Type: {policy.ActionUpdate},
|
||||
// Reading provisioner state requires template update
|
||||
// permission.
|
||||
rbac.ResourceTemplate.Type: {policy.ActionUpdate},
|
||||
}),
|
||||
User: []rbac.Permission{},
|
||||
ByOrgID: map[string]rbac.OrgPermissions{},
|
||||
},
|
||||
}),
|
||||
Scope: rbac.ScopeAll,
|
||||
}.WithCachedASTValue()
|
||||
)
|
||||
|
||||
// AsProvisionerd returns a context with an actor that has permissions required
|
||||
@@ -799,14 +774,6 @@ func AsBoundaryUsageTracker(ctx context.Context) context.Context {
|
||||
return As(ctx, subjectBoundaryUsageTracker)
|
||||
}
|
||||
|
||||
// AsWorkspaceBuilder returns a context with an actor that has permissions
|
||||
// required for the workspace builder to prepare workspace builds. This
|
||||
// includes reading provisioner daemons, updating provisioner jobs, and
|
||||
// reading provisioner state (which requires template update permission).
|
||||
func AsWorkspaceBuilder(ctx context.Context) context.Context {
|
||||
return As(ctx, subjectWorkspaceBuilder)
|
||||
}
|
||||
|
||||
var AsRemoveActor = rbac.Subject{
|
||||
ID: "remove-actor",
|
||||
}
|
||||
@@ -2194,12 +2161,12 @@ func (q *querier) GetAPIKeyByName(ctx context.Context, arg database.GetAPIKeyByN
|
||||
return fetch(q.log, q.auth, q.db.GetAPIKeyByName)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetAPIKeysByLoginType(ctx context.Context, loginType database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
|
||||
func (q *querier) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByLoginType)(ctx, loginType)
|
||||
}
|
||||
|
||||
func (q *querier) GetAPIKeysByUserID(ctx context.Context, params database.GetAPIKeysByUserIDParams) ([]database.APIKey, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByUserID)(ctx, params)
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByUserID)(ctx, database.GetAPIKeysByUserIDParams{LoginType: params.LoginType, UserID: params.UserID})
|
||||
}
|
||||
|
||||
func (q *querier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]database.APIKey, error) {
|
||||
@@ -2290,7 +2257,7 @@ func (q *querier) GetAuditLogsOffset(ctx context.Context, arg database.GetAuditL
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx context.Context, authToken uuid.UUID) (database.GetAuthenticatedWorkspaceAgentAndBuildByAuthTokenRow, error) {
|
||||
// This is a system function.
|
||||
// This is a system function
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return database.GetAuthenticatedWorkspaceAgentAndBuildByAuthTokenRow{}, err
|
||||
}
|
||||
@@ -3166,13 +3133,6 @@ func (q *querier) GetTelemetryItems(ctx context.Context) ([]database.TelemetryIt
|
||||
return q.db.GetTelemetryItems(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetTelemetryTaskEvents(ctx context.Context, arg database.GetTelemetryTaskEventsParams) ([]database.GetTelemetryTaskEventsRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTask.All()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetTelemetryTaskEvents(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) {
|
||||
if err := q.authorizeTemplateInsights(ctx, arg.TemplateIDs); err != nil {
|
||||
return nil, err
|
||||
@@ -3954,11 +3914,6 @@ func (q *querier) GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, wor
|
||||
return q.db.GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx, workspaceBuildIDs, prep)
|
||||
}
|
||||
|
||||
func (q *querier) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, buildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
|
||||
// Fetching the provisioner state requires Update permission on the template.
|
||||
return fetchWithAction(q.log, q.auth, policy.ActionUpdate, q.db.GetWorkspaceBuildProvisionerStateByID)(ctx, buildID)
|
||||
}
|
||||
|
||||
func (q *querier) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -237,8 +237,8 @@ func (s *MethodTestSuite) TestAPIKey() {
|
||||
s.Run("GetAPIKeysByLoginType", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
a := testutil.Fake(s.T(), faker, database.APIKey{LoginType: database.LoginTypePassword})
|
||||
b := testutil.Fake(s.T(), faker, database.APIKey{LoginType: database.LoginTypePassword})
|
||||
dbm.EXPECT().GetAPIKeysByLoginType(gomock.Any(), database.GetAPIKeysByLoginTypeParams{LoginType: database.LoginTypePassword}).Return([]database.APIKey{a, b}, nil).AnyTimes()
|
||||
check.Args(database.GetAPIKeysByLoginTypeParams{LoginType: database.LoginTypePassword}).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
|
||||
dbm.EXPECT().GetAPIKeysByLoginType(gomock.Any(), database.LoginTypePassword).Return([]database.APIKey{a, b}, nil).AnyTimes()
|
||||
check.Args(database.LoginTypePassword).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
|
||||
}))
|
||||
s.Run("GetAPIKeysByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u1 := testutil.Fake(s.T(), faker, database.User{})
|
||||
@@ -1326,11 +1326,6 @@ func (s *MethodTestSuite) TestTemplate() {
|
||||
dbm.EXPECT().GetTemplateInsightsByTemplate(gomock.Any(), arg).Return([]database.GetTemplateInsightsByTemplateRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceTemplate, policy.ActionViewInsights)
|
||||
}))
|
||||
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetTelemetryTaskEventsParams{}
|
||||
dbm.EXPECT().GetTelemetryTaskEvents(gomock.Any(), arg).Return([]database.GetTelemetryTaskEventsRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceTask.All(), policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetTemplateAppInsights", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetTemplateAppInsightsParams{}
|
||||
dbm.EXPECT().GetTemplateAppInsights(gomock.Any(), arg).Return([]database.GetTemplateAppInsightsRow{}, nil).AnyTimes()
|
||||
@@ -1974,15 +1969,6 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
|
||||
check.Args(build.ID).Asserts(ws, policy.ActionRead).Returns(build)
|
||||
}))
|
||||
s.Run("GetWorkspaceBuildProvisionerStateByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
row := database.GetWorkspaceBuildProvisionerStateByIDRow{
|
||||
ProvisionerState: []byte("state"),
|
||||
TemplateID: uuid.New(),
|
||||
TemplateOrganizationID: uuid.New(),
|
||||
}
|
||||
dbm.EXPECT().GetWorkspaceBuildProvisionerStateByID(gomock.Any(), gomock.Any()).Return(row, nil).AnyTimes()
|
||||
check.Args(uuid.New()).Asserts(row, policy.ActionUpdate).Returns(row)
|
||||
}))
|
||||
s.Run("GetWorkspaceBuildByJobID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ws := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
build := testutil.Fake(s.T(), faker, database.WorkspaceBuild{WorkspaceID: ws.ID})
|
||||
|
||||
@@ -67,8 +67,6 @@ type WorkspaceBuildBuilder struct {
|
||||
|
||||
jobError string // Error message for failed jobs
|
||||
jobErrorCode string // Error code for failed jobs
|
||||
|
||||
provisionerState []byte
|
||||
}
|
||||
|
||||
// BuilderOption is a functional option for customizing job timestamps
|
||||
@@ -140,15 +138,6 @@ func (b WorkspaceBuildBuilder) Seed(seed database.WorkspaceBuild) WorkspaceBuild
|
||||
return b
|
||||
}
|
||||
|
||||
// ProvisionerState sets the provisioner state for the workspace build.
|
||||
// This is stored separately from the seed because ProvisionerState is
|
||||
// not part of the WorkspaceBuild view struct.
|
||||
func (b WorkspaceBuildBuilder) ProvisionerState(state []byte) WorkspaceBuildBuilder {
|
||||
//nolint: revive // returns modified struct
|
||||
b.provisionerState = state
|
||||
return b
|
||||
}
|
||||
|
||||
func (b WorkspaceBuildBuilder) Resource(resource ...*sdkproto.Resource) WorkspaceBuildBuilder {
|
||||
//nolint: revive // returns modified struct
|
||||
b.resources = append(b.resources, resource...)
|
||||
@@ -475,14 +464,6 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
|
||||
}
|
||||
|
||||
resp.Build = dbgen.WorkspaceBuild(b.t, b.db, b.seed)
|
||||
if len(b.provisionerState) > 0 {
|
||||
err = b.db.UpdateWorkspaceBuildProvisionerStateByID(ownerCtx, database.UpdateWorkspaceBuildProvisionerStateByIDParams{
|
||||
ID: resp.Build.ID,
|
||||
UpdatedAt: dbtime.Now(),
|
||||
ProvisionerState: b.provisionerState,
|
||||
})
|
||||
require.NoError(b.t, err, "update provisioner state")
|
||||
}
|
||||
b.logger.Debug(context.Background(), "created workspace build",
|
||||
slog.F("build_id", resp.Build.ID),
|
||||
slog.F("workspace_id", resp.Workspace.ID),
|
||||
|
||||
@@ -504,7 +504,7 @@ func WorkspaceBuild(t testing.TB, db database.Store, orig database.WorkspaceBuil
|
||||
Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart),
|
||||
InitiatorID: takeFirst(orig.InitiatorID, uuid.New()),
|
||||
JobID: jobID,
|
||||
ProvisionerState: []byte{},
|
||||
ProvisionerState: takeFirstSlice(orig.ProvisionerState, []byte{}),
|
||||
Deadline: takeFirst(orig.Deadline, dbtime.Now().Add(time.Hour)),
|
||||
MaxDeadline: takeFirst(orig.MaxDeadline, time.Time{}),
|
||||
Reason: takeFirst(orig.Reason, database.BuildReasonInitiator),
|
||||
@@ -1373,8 +1373,6 @@ func OAuth2ProviderAppCode(t testing.TB, db database.Store, seed database.OAuth2
|
||||
ResourceUri: seed.ResourceUri,
|
||||
CodeChallenge: seed.CodeChallenge,
|
||||
CodeChallengeMethod: seed.CodeChallengeMethod,
|
||||
StateHash: seed.StateHash,
|
||||
RedirectUri: seed.RedirectUri,
|
||||
})
|
||||
require.NoError(t, err, "insert oauth2 app code")
|
||||
return code
|
||||
|
||||
@@ -774,7 +774,7 @@ func (m queryMetricsStore) GetAPIKeyByName(ctx context.Context, arg database.Get
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
|
||||
func (m queryMetricsStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAPIKeysByLoginType(ctx, loginType)
|
||||
m.queryLatencies.WithLabelValues("GetAPIKeysByLoginType").Observe(time.Since(start).Seconds())
|
||||
@@ -1790,14 +1790,6 @@ func (m queryMetricsStore) GetTelemetryItems(ctx context.Context) ([]database.Te
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetTelemetryTaskEvents(ctx context.Context, createdAfter database.GetTelemetryTaskEventsParams) ([]database.GetTelemetryTaskEventsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetTelemetryTaskEvents(ctx, createdAfter)
|
||||
m.queryLatencies.WithLabelValues("GetTelemetryTaskEvents").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTelemetryTaskEvents").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetTemplateAppInsights(ctx, arg)
|
||||
@@ -2438,14 +2430,6 @@ func (m queryMetricsStore) GetWorkspaceBuildParametersByBuildIDs(ctx context.Con
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuildID)
|
||||
m.queryLatencies.WithLabelValues("GetWorkspaceBuildProvisionerStateByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetWorkspaceBuildProvisionerStateByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetWorkspaceBuildStatsByTemplates(ctx, since)
|
||||
|
||||
@@ -1305,18 +1305,18 @@ func (mr *MockStoreMockRecorder) GetAPIKeyByName(ctx, arg any) *gomock.Call {
|
||||
}
|
||||
|
||||
// GetAPIKeysByLoginType mocks base method.
|
||||
func (m *MockStore) GetAPIKeysByLoginType(ctx context.Context, arg database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
|
||||
func (m *MockStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAPIKeysByLoginType", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "GetAPIKeysByLoginType", ctx, loginType)
|
||||
ret0, _ := ret[0].([]database.APIKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAPIKeysByLoginType indicates an expected call of GetAPIKeysByLoginType.
|
||||
func (mr *MockStoreMockRecorder) GetAPIKeysByLoginType(ctx, arg any) *gomock.Call {
|
||||
func (mr *MockStoreMockRecorder) GetAPIKeysByLoginType(ctx, loginType any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeysByLoginType", reflect.TypeOf((*MockStore)(nil).GetAPIKeysByLoginType), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeysByLoginType", reflect.TypeOf((*MockStore)(nil).GetAPIKeysByLoginType), ctx, loginType)
|
||||
}
|
||||
|
||||
// GetAPIKeysByUserID mocks base method.
|
||||
@@ -3314,21 +3314,6 @@ func (mr *MockStoreMockRecorder) GetTelemetryItems(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTelemetryItems", reflect.TypeOf((*MockStore)(nil).GetTelemetryItems), ctx)
|
||||
}
|
||||
|
||||
// GetTelemetryTaskEvents mocks base method.
|
||||
func (m *MockStore) GetTelemetryTaskEvents(ctx context.Context, arg database.GetTelemetryTaskEventsParams) ([]database.GetTelemetryTaskEventsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTelemetryTaskEvents", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetTelemetryTaskEventsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetTelemetryTaskEvents indicates an expected call of GetTelemetryTaskEvents.
|
||||
func (mr *MockStoreMockRecorder) GetTelemetryTaskEvents(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTelemetryTaskEvents", reflect.TypeOf((*MockStore)(nil).GetTelemetryTaskEvents), ctx, arg)
|
||||
}
|
||||
|
||||
// GetTemplateAppInsights mocks base method.
|
||||
func (m *MockStore) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4559,21 +4544,6 @@ func (mr *MockStoreMockRecorder) GetWorkspaceBuildParametersByBuildIDs(ctx, work
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceBuildParametersByBuildIDs", reflect.TypeOf((*MockStore)(nil).GetWorkspaceBuildParametersByBuildIDs), ctx, workspaceBuildIds)
|
||||
}
|
||||
|
||||
// GetWorkspaceBuildProvisionerStateByID mocks base method.
|
||||
func (m *MockStore) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetWorkspaceBuildProvisionerStateByID", ctx, workspaceBuildID)
|
||||
ret0, _ := ret[0].(database.GetWorkspaceBuildProvisionerStateByIDRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetWorkspaceBuildProvisionerStateByID indicates an expected call of GetWorkspaceBuildProvisionerStateByID.
|
||||
func (mr *MockStoreMockRecorder) GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuildID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceBuildProvisionerStateByID", reflect.TypeOf((*MockStore)(nil).GetWorkspaceBuildProvisionerStateByID), ctx, workspaceBuildID)
|
||||
}
|
||||
|
||||
// GetWorkspaceBuildStatsByTemplates mocks base method.
|
||||
func (m *MockStore) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+2
-7
@@ -1471,9 +1471,7 @@ CREATE TABLE oauth2_provider_app_codes (
|
||||
app_id uuid NOT NULL,
|
||||
resource_uri text,
|
||||
code_challenge text,
|
||||
code_challenge_method text,
|
||||
state_hash text,
|
||||
redirect_uri text
|
||||
code_challenge_method text
|
||||
);
|
||||
|
||||
COMMENT ON TABLE oauth2_provider_app_codes IS 'Codes are meant to be exchanged for access tokens.';
|
||||
@@ -1484,10 +1482,6 @@ COMMENT ON COLUMN oauth2_provider_app_codes.code_challenge IS 'PKCE code challen
|
||||
|
||||
COMMENT ON COLUMN oauth2_provider_app_codes.code_challenge_method IS 'PKCE challenge method (S256)';
|
||||
|
||||
COMMENT ON COLUMN oauth2_provider_app_codes.state_hash IS 'SHA-256 hash of the OAuth2 state parameter, stored to prevent state reflection attacks.';
|
||||
|
||||
COMMENT ON COLUMN oauth2_provider_app_codes.redirect_uri IS 'The redirect_uri provided during authorization, to be verified during token exchange (RFC 6749 §4.1.3).';
|
||||
|
||||
CREATE TABLE oauth2_provider_app_secrets (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
@@ -2708,6 +2702,7 @@ CREATE VIEW workspace_build_with_user AS
|
||||
workspace_builds.build_number,
|
||||
workspace_builds.transition,
|
||||
workspace_builds.initiator_id,
|
||||
workspace_builds.provisioner_state,
|
||||
workspace_builds.job_id,
|
||||
workspace_builds.deadline,
|
||||
workspace_builds.reason,
|
||||
|
||||
@@ -51,34 +51,15 @@ func TestViewSubsetTemplateVersion(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestViewSubsetWorkspaceBuild ensures WorkspaceBuildTable is a subset of
|
||||
// WorkspaceBuild, with the exception of ProvisionerState which is
|
||||
// intentionally excluded from the workspace_build_with_user view to avoid
|
||||
// loading the large Terraform state blob on hot paths.
|
||||
// TestViewSubsetWorkspaceBuild ensures WorkspaceBuildTable is a subset of WorkspaceBuild
|
||||
func TestViewSubsetWorkspaceBuild(t *testing.T) {
|
||||
t.Parallel()
|
||||
table := reflect.TypeOf(database.WorkspaceBuildTable{})
|
||||
joined := reflect.TypeOf(database.WorkspaceBuild{})
|
||||
|
||||
tableFields := fieldNames(allFields(table))
|
||||
joinedFields := fieldNames(allFields(joined))
|
||||
|
||||
// ProvisionerState is intentionally excluded from the
|
||||
// workspace_build_with_user view to avoid loading multi-MB Terraform
|
||||
// state blobs on hot paths. Callers that need it use
|
||||
// GetWorkspaceBuildProvisionerStateByID instead.
|
||||
excludedFields := map[string]bool{
|
||||
"ProvisionerState": true,
|
||||
}
|
||||
|
||||
var filtered []string
|
||||
for _, name := range tableFields {
|
||||
if !excludedFields[name] {
|
||||
filtered = append(filtered, name)
|
||||
}
|
||||
}
|
||||
|
||||
if !assert.Subset(t, joinedFields, filtered, "table is not subset") {
|
||||
tableFields := allFields(table)
|
||||
joinedFields := allFields(joined)
|
||||
if !assert.Subset(t, fieldNames(joinedFields), fieldNames(tableFields), "table is not subset") {
|
||||
t.Log("Some fields were added to the WorkspaceBuild Table without updating the 'workspace_build_with_user' view.")
|
||||
t.Log("See migration 000141_join_users_build_version.up.sql to create the view.")
|
||||
}
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
ALTER TABLE oauth2_provider_app_codes
|
||||
DROP COLUMN state_hash,
|
||||
DROP COLUMN redirect_uri;
|
||||
@@ -1,9 +0,0 @@
|
||||
ALTER TABLE oauth2_provider_app_codes
|
||||
ADD COLUMN state_hash text,
|
||||
ADD COLUMN redirect_uri text;
|
||||
|
||||
COMMENT ON COLUMN oauth2_provider_app_codes.state_hash IS
|
||||
'SHA-256 hash of the OAuth2 state parameter, stored to prevent state reflection attacks.';
|
||||
|
||||
COMMENT ON COLUMN oauth2_provider_app_codes.redirect_uri IS
|
||||
'The redirect_uri provided during authorization, to be verified during token exchange (RFC 6749 §4.1.3).';
|
||||
-31
@@ -1,31 +0,0 @@
|
||||
-- Restore provisioner_state to workspace_build_with_user view.
|
||||
DROP VIEW workspace_build_with_user;
|
||||
|
||||
CREATE VIEW workspace_build_with_user AS
|
||||
SELECT
|
||||
workspace_builds.id,
|
||||
workspace_builds.created_at,
|
||||
workspace_builds.updated_at,
|
||||
workspace_builds.workspace_id,
|
||||
workspace_builds.template_version_id,
|
||||
workspace_builds.build_number,
|
||||
workspace_builds.transition,
|
||||
workspace_builds.initiator_id,
|
||||
workspace_builds.provisioner_state,
|
||||
workspace_builds.job_id,
|
||||
workspace_builds.deadline,
|
||||
workspace_builds.reason,
|
||||
workspace_builds.daily_cost,
|
||||
workspace_builds.max_deadline,
|
||||
workspace_builds.template_version_preset_id,
|
||||
workspace_builds.has_ai_task,
|
||||
workspace_builds.has_external_agent,
|
||||
COALESCE(visible_users.avatar_url, ''::text) AS initiator_by_avatar_url,
|
||||
COALESCE(visible_users.username, ''::text) AS initiator_by_username,
|
||||
COALESCE(visible_users.name, ''::text) AS initiator_by_name
|
||||
FROM
|
||||
workspace_builds
|
||||
LEFT JOIN
|
||||
visible_users ON workspace_builds.initiator_id = visible_users.id;
|
||||
|
||||
COMMENT ON VIEW workspace_build_with_user IS 'Joins in the username + avatar url of the initiated by user.';
|
||||
@@ -1,33 +0,0 @@
|
||||
-- Drop and recreate workspace_build_with_user to exclude provisioner_state.
|
||||
-- This avoids loading the large Terraform state blob (1-5 MB per workspace)
|
||||
-- on every query that uses this view. The callers that need provisioner_state
|
||||
-- now fetch it separately via GetWorkspaceBuildProvisionerStateByID.
|
||||
DROP VIEW workspace_build_with_user;
|
||||
|
||||
CREATE VIEW workspace_build_with_user AS
|
||||
SELECT
|
||||
workspace_builds.id,
|
||||
workspace_builds.created_at,
|
||||
workspace_builds.updated_at,
|
||||
workspace_builds.workspace_id,
|
||||
workspace_builds.template_version_id,
|
||||
workspace_builds.build_number,
|
||||
workspace_builds.transition,
|
||||
workspace_builds.initiator_id,
|
||||
workspace_builds.job_id,
|
||||
workspace_builds.deadline,
|
||||
workspace_builds.reason,
|
||||
workspace_builds.daily_cost,
|
||||
workspace_builds.max_deadline,
|
||||
workspace_builds.template_version_preset_id,
|
||||
workspace_builds.has_ai_task,
|
||||
workspace_builds.has_external_agent,
|
||||
COALESCE(visible_users.avatar_url, ''::text) AS initiator_by_avatar_url,
|
||||
COALESCE(visible_users.username, ''::text) AS initiator_by_username,
|
||||
COALESCE(visible_users.name, ''::text) AS initiator_by_name
|
||||
FROM
|
||||
workspace_builds
|
||||
LEFT JOIN
|
||||
visible_users ON workspace_builds.initiator_id = visible_users.id;
|
||||
|
||||
COMMENT ON VIEW workspace_build_with_user IS 'Joins in the username + avatar url of the initiated by user.';
|
||||
@@ -316,14 +316,6 @@ func (t GetFileTemplatesRow) RBACObject() rbac.Object {
|
||||
WithGroupACL(t.GroupACL)
|
||||
}
|
||||
|
||||
// RBACObject for a workspace build's provisioner state requires Update access of the template.
|
||||
func (t GetWorkspaceBuildProvisionerStateByIDRow) RBACObject() rbac.Object {
|
||||
return rbac.ResourceTemplate.WithID(t.TemplateID).
|
||||
InOrg(t.TemplateOrganizationID).
|
||||
WithACLUserList(t.UserACL).
|
||||
WithGroupACL(t.GroupACL)
|
||||
}
|
||||
|
||||
func (t Template) DeepCopy() Template {
|
||||
cpy := t
|
||||
cpy.UserACL = maps.Clone(t.UserACL)
|
||||
|
||||
@@ -610,7 +610,6 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi
|
||||
arg.DateTo,
|
||||
arg.BuildReason,
|
||||
arg.RequestID,
|
||||
arg.CountCap,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -747,7 +746,6 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
|
||||
arg.WorkspaceID,
|
||||
arg.ConnectionID,
|
||||
arg.Status,
|
||||
arg.CountCap,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
||||
@@ -145,13 +145,5 @@ func extractWhereClause(query string) string {
|
||||
// Remove SQL comments
|
||||
whereClause = regexp.MustCompile(`(?m)--.*$`).ReplaceAllString(whereClause, "")
|
||||
|
||||
// Normalize indentation so subquery wrapping doesn't cause
|
||||
// mismatches.
|
||||
lines := strings.Split(whereClause, "\n")
|
||||
for i, line := range lines {
|
||||
lines[i] = strings.TrimLeft(line, " \t")
|
||||
}
|
||||
whereClause = strings.Join(lines, "\n")
|
||||
|
||||
return strings.TrimSpace(whereClause)
|
||||
}
|
||||
|
||||
@@ -4026,10 +4026,6 @@ type OAuth2ProviderAppCode struct {
|
||||
CodeChallenge sql.NullString `db:"code_challenge" json:"code_challenge"`
|
||||
// PKCE challenge method (S256)
|
||||
CodeChallengeMethod sql.NullString `db:"code_challenge_method" json:"code_challenge_method"`
|
||||
// SHA-256 hash of the OAuth2 state parameter, stored to prevent state reflection attacks.
|
||||
StateHash sql.NullString `db:"state_hash" json:"state_hash"`
|
||||
// The redirect_uri provided during authorization, to be verified during token exchange (RFC 6749 §4.1.3).
|
||||
RedirectUri sql.NullString `db:"redirect_uri" json:"redirect_uri"`
|
||||
}
|
||||
|
||||
type OAuth2ProviderAppSecret struct {
|
||||
@@ -4987,6 +4983,7 @@ type WorkspaceBuild struct {
|
||||
BuildNumber int32 `db:"build_number" json:"build_number"`
|
||||
Transition WorkspaceTransition `db:"transition" json:"transition"`
|
||||
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
|
||||
ProvisionerState []byte `db:"provisioner_state" json:"provisioner_state"`
|
||||
JobID uuid.UUID `db:"job_id" json:"job_id"`
|
||||
Deadline time.Time `db:"deadline" json:"deadline"`
|
||||
Reason BuildReason `db:"reason" json:"reason"`
|
||||
|
||||
@@ -169,7 +169,7 @@ type sqlcQuerier interface {
|
||||
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
|
||||
// there is no unique constraint on empty token names
|
||||
GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNameParams) (APIKey, error)
|
||||
GetAPIKeysByLoginType(ctx context.Context, arg GetAPIKeysByLoginTypeParams) ([]APIKey, error)
|
||||
GetAPIKeysByLoginType(ctx context.Context, loginType LoginType) ([]APIKey, error)
|
||||
GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error)
|
||||
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
|
||||
GetActivePresetPrebuildSchedules(ctx context.Context) ([]TemplateVersionPresetPrebuildSchedule, error)
|
||||
@@ -361,23 +361,6 @@ type sqlcQuerier interface {
|
||||
GetTaskSnapshot(ctx context.Context, taskID uuid.UUID) (TaskSnapshot, error)
|
||||
GetTelemetryItem(ctx context.Context, key string) (TelemetryItem, error)
|
||||
GetTelemetryItems(ctx context.Context) ([]TelemetryItem, error)
|
||||
// Returns all data needed to build task lifecycle events for telemetry
|
||||
// in a single round-trip. For each task whose workspace is in the
|
||||
// given set, fetches:
|
||||
// - the latest workspace app binding (task_workspace_apps)
|
||||
// - the most recent stop and start builds (workspace_builds)
|
||||
// - the last "working" app status (workspace_app_statuses)
|
||||
// - the first app status after resume, for active workspaces
|
||||
//
|
||||
// Assumptions:
|
||||
// - 1:1 relationship between tasks and workspaces. All builds on the
|
||||
// workspace are considered task-related.
|
||||
// - Idle duration approximation: If the agent reports "working", does
|
||||
// work, then reports "done", we miss that working time.
|
||||
// - lws and active_dur join across all historical app IDs for the task,
|
||||
// because each resume cycle provisions a new app ID. This ensures
|
||||
// pre-pause statuses contribute to idle duration and active duration.
|
||||
GetTelemetryTaskEvents(ctx context.Context, arg GetTelemetryTaskEventsParams) ([]GetTelemetryTaskEventsRow, error)
|
||||
// GetTemplateAppInsights returns the aggregate usage of each app in a given
|
||||
// timeframe. The result can be filtered on template_ids, meaning only user data
|
||||
// from workspaces based on those templates will be included.
|
||||
@@ -523,11 +506,6 @@ type sqlcQuerier interface {
|
||||
GetWorkspaceBuildMetricsByResourceID(ctx context.Context, id uuid.UUID) (GetWorkspaceBuildMetricsByResourceIDRow, error)
|
||||
GetWorkspaceBuildParameters(ctx context.Context, workspaceBuildID uuid.UUID) ([]WorkspaceBuildParameter, error)
|
||||
GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIds []uuid.UUID) ([]WorkspaceBuildParameter, error)
|
||||
// Fetches the provisioner state of a workspace build, joined through to the
|
||||
// template so that dbauthz can enforce policy.ActionUpdate on the template.
|
||||
// Provisioner state contains sensitive Terraform state and should only be
|
||||
// accessible to template administrators.
|
||||
GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (GetWorkspaceBuildProvisionerStateByIDRow, error)
|
||||
GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]GetWorkspaceBuildStatsByTemplatesRow, error)
|
||||
GetWorkspaceBuildsByWorkspaceID(ctx context.Context, arg GetWorkspaceBuildsByWorkspaceIDParams) ([]WorkspaceBuild, error)
|
||||
GetWorkspaceBuildsCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceBuild, error)
|
||||
|
||||
@@ -8195,9 +8195,8 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
|
||||
|
||||
// All keys are present before deletion
|
||||
keys, err := db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
IncludeExpired: true,
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, len(expiredTimes)+len(unexpiredTimes))
|
||||
@@ -8213,9 +8212,8 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
|
||||
|
||||
// Ensure it was deleted
|
||||
remaining, err := db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
IncludeExpired: true,
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, remaining, len(expiredTimes)+len(unexpiredTimes)-1)
|
||||
@@ -8230,9 +8228,8 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
|
||||
|
||||
// Ensure only unexpired keys remain
|
||||
remaining, err = db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
IncludeExpired: true,
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, remaining, len(unexpiredTimes))
|
||||
@@ -8742,123 +8739,3 @@ func TestInsertWorkspaceAgentDevcontainers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetWorkspaceBuildMetricsByResourceID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
tmpl := dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
})
|
||||
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
||||
OrganizationID: org.ID,
|
||||
TemplateID: uuid.NullUUID{UUID: tmpl.ID, Valid: true},
|
||||
CreatedBy: user.ID,
|
||||
})
|
||||
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
||||
OrganizationID: org.ID,
|
||||
TemplateID: tmpl.ID,
|
||||
OwnerID: user.ID,
|
||||
AutomaticUpdates: database.AutomaticUpdatesNever,
|
||||
})
|
||||
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
||||
OrganizationID: org.ID,
|
||||
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
||||
})
|
||||
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
WorkspaceID: ws.ID,
|
||||
TemplateVersionID: tv.ID,
|
||||
JobID: job.ID,
|
||||
InitiatorID: user.ID,
|
||||
})
|
||||
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
||||
JobID: job.ID,
|
||||
})
|
||||
|
||||
parentReadyAt := dbtime.Now()
|
||||
parentStartedAt := parentReadyAt.Add(-time.Second)
|
||||
_ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
||||
ResourceID: resource.ID,
|
||||
StartedAt: sql.NullTime{Time: parentStartedAt, Valid: true},
|
||||
ReadyAt: sql.NullTime{Time: parentReadyAt, Valid: true},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
})
|
||||
|
||||
row, err := db.GetWorkspaceBuildMetricsByResourceID(ctx, resource.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, row.AllAgentsReady)
|
||||
require.True(t, parentReadyAt.Equal(row.LastAgentReadyAt))
|
||||
require.Equal(t, "success", row.WorstStatus)
|
||||
})
|
||||
|
||||
t.Run("SubAgentExcluded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
tmpl := dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
})
|
||||
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
||||
OrganizationID: org.ID,
|
||||
TemplateID: uuid.NullUUID{UUID: tmpl.ID, Valid: true},
|
||||
CreatedBy: user.ID,
|
||||
})
|
||||
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
||||
OrganizationID: org.ID,
|
||||
TemplateID: tmpl.ID,
|
||||
OwnerID: user.ID,
|
||||
AutomaticUpdates: database.AutomaticUpdatesNever,
|
||||
})
|
||||
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
||||
OrganizationID: org.ID,
|
||||
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
||||
})
|
||||
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
WorkspaceID: ws.ID,
|
||||
TemplateVersionID: tv.ID,
|
||||
JobID: job.ID,
|
||||
InitiatorID: user.ID,
|
||||
})
|
||||
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
||||
JobID: job.ID,
|
||||
})
|
||||
|
||||
parentReadyAt := dbtime.Now()
|
||||
parentStartedAt := parentReadyAt.Add(-time.Second)
|
||||
parentAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
||||
ResourceID: resource.ID,
|
||||
StartedAt: sql.NullTime{Time: parentStartedAt, Valid: true},
|
||||
ReadyAt: sql.NullTime{Time: parentReadyAt, Valid: true},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
})
|
||||
|
||||
// Sub-agent with ready_at 1 hour later should be excluded.
|
||||
subAgentReadyAt := parentReadyAt.Add(time.Hour)
|
||||
subAgentStartedAt := subAgentReadyAt.Add(-time.Second)
|
||||
_ = dbgen.WorkspaceSubAgent(t, db, parentAgent, database.WorkspaceAgent{
|
||||
StartedAt: sql.NullTime{Time: subAgentStartedAt, Valid: true},
|
||||
ReadyAt: sql.NullTime{Time: subAgentReadyAt, Valid: true},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
})
|
||||
|
||||
row, err := db.GetWorkspaceBuildMetricsByResourceID(ctx, resource.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, row.AllAgentsReady)
|
||||
// LastAgentReadyAt should be the parent's, not the sub-agent's.
|
||||
require.True(t, parentReadyAt.Equal(row.LastAgentReadyAt))
|
||||
require.Equal(t, "success", row.WorstStatus)
|
||||
})
|
||||
}
|
||||
|
||||
+220
-491
@@ -1270,16 +1270,10 @@ func (q *sqlQuerier) GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNamePar
|
||||
|
||||
const getAPIKeysByLoginType = `-- name: GetAPIKeysByLoginType :many
|
||||
SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address, token_name, scopes, allow_list FROM api_keys WHERE login_type = $1
|
||||
AND ($2::bool OR expires_at > now())
|
||||
`
|
||||
|
||||
type GetAPIKeysByLoginTypeParams struct {
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
IncludeExpired bool `db:"include_expired" json:"include_expired"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAPIKeysByLoginType(ctx context.Context, arg GetAPIKeysByLoginTypeParams) ([]APIKey, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getAPIKeysByLoginType, arg.LoginType, arg.IncludeExpired)
|
||||
func (q *sqlQuerier) GetAPIKeysByLoginType(ctx context.Context, loginType LoginType) ([]APIKey, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getAPIKeysByLoginType, loginType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1317,17 +1311,15 @@ func (q *sqlQuerier) GetAPIKeysByLoginType(ctx context.Context, arg GetAPIKeysBy
|
||||
|
||||
const getAPIKeysByUserID = `-- name: GetAPIKeysByUserID :many
|
||||
SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address, token_name, scopes, allow_list FROM api_keys WHERE login_type = $1 AND user_id = $2
|
||||
AND ($3::bool OR expires_at > now())
|
||||
`
|
||||
|
||||
type GetAPIKeysByUserIDParams struct {
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
IncludeExpired bool `db:"include_expired" json:"include_expired"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getAPIKeysByUserID, arg.LoginType, arg.UserID, arg.IncludeExpired)
|
||||
rows, err := q.db.QueryContext(ctx, getAPIKeysByUserID, arg.LoginType, arg.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1511,105 +1503,93 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP
|
||||
}
|
||||
|
||||
const countAuditLogs = `-- name: CountAuditLogs :one
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT 1
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN $1::text != '' THEN resource_type = $1::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN $4::text != '' THEN resource_target = $4
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN action = $5::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower($7)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8::text != '' THEN users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
-- Avoid a slow scan on a large table with joins. The caller
|
||||
-- passes the count cap and we add 1 so the frontend can detect
|
||||
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
|
||||
-- -> NULL + 1 = NULL).
|
||||
-- NOTE: Parameterizing this so that we can easily change from,
|
||||
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
|
||||
-- here if disabling the capping on a large table permanently.
|
||||
-- This way the PG planner can plan parallel execution for
|
||||
-- potential large wins.
|
||||
LIMIT NULLIF($13::int, 0) + 1
|
||||
) AS limited_count
|
||||
SELECT COUNT(*)
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN $1::text != '' THEN resource_type = $1::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN $4::text != '' THEN resource_target = $4
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN action = $5::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower($7)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8::text != '' THEN users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
`
|
||||
|
||||
type CountAuditLogsParams struct {
|
||||
@@ -1625,7 +1605,6 @@ type CountAuditLogsParams struct {
|
||||
DateTo time.Time `db:"date_to" json:"date_to"`
|
||||
BuildReason string `db:"build_reason" json:"build_reason"`
|
||||
RequestID uuid.UUID `db:"request_id" json:"request_id"`
|
||||
CountCap int32 `db:"count_cap" json:"count_cap"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error) {
|
||||
@@ -1642,7 +1621,6 @@ func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParam
|
||||
arg.DateTo,
|
||||
arg.BuildReason,
|
||||
arg.RequestID,
|
||||
arg.CountCap,
|
||||
)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
@@ -2121,113 +2099,110 @@ func (q *sqlQuerier) UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBou
|
||||
}
|
||||
|
||||
const countConnectionLogs = `-- name: CountConnectionLogs :one
|
||||
SELECT COUNT(*) AS count FROM (
|
||||
SELECT 1
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = $1
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN $2 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($2) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN $4 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = $4 AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN $5 :: text != '' THEN
|
||||
type = $5 :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7 :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($7) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8 :: text != '' THEN
|
||||
users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN $13 :: text != '' THEN
|
||||
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
-- NOTE: See the CountAuditLogs LIMIT note.
|
||||
LIMIT NULLIF($14::int, 0) + 1
|
||||
) AS limited_count
|
||||
SELECT
|
||||
COUNT(*) AS count
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = $1
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN $2 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($2) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN $4 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = $4 AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN $5 :: text != '' THEN
|
||||
type = $5 :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7 :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($7) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8 :: text != '' THEN
|
||||
users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN $13 :: text != '' THEN
|
||||
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
`
|
||||
|
||||
type CountConnectionLogsParams struct {
|
||||
@@ -2244,7 +2219,6 @@ type CountConnectionLogsParams struct {
|
||||
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
|
||||
ConnectionID uuid.UUID `db:"connection_id" json:"connection_id"`
|
||||
Status string `db:"status" json:"status"`
|
||||
CountCap int32 `db:"count_cap" json:"count_cap"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) {
|
||||
@@ -2262,7 +2236,6 @@ func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectio
|
||||
arg.WorkspaceID,
|
||||
arg.ConnectionID,
|
||||
arg.Status,
|
||||
arg.CountCap,
|
||||
)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
@@ -6795,7 +6768,7 @@ func (q *sqlQuerier) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context
|
||||
}
|
||||
|
||||
const getOAuth2ProviderAppCodeByID = `-- name: GetOAuth2ProviderAppCodeByID :one
|
||||
SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method, state_hash, redirect_uri FROM oauth2_provider_app_codes WHERE id = $1
|
||||
SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method FROM oauth2_provider_app_codes WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) (OAuth2ProviderAppCode, error) {
|
||||
@@ -6812,14 +6785,12 @@ func (q *sqlQuerier) GetOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.U
|
||||
&i.ResourceUri,
|
||||
&i.CodeChallenge,
|
||||
&i.CodeChallengeMethod,
|
||||
&i.StateHash,
|
||||
&i.RedirectUri,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getOAuth2ProviderAppCodeByPrefix = `-- name: GetOAuth2ProviderAppCodeByPrefix :one
|
||||
SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method, state_hash, redirect_uri FROM oauth2_provider_app_codes WHERE secret_prefix = $1
|
||||
SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method FROM oauth2_provider_app_codes WHERE secret_prefix = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetOAuth2ProviderAppCodeByPrefix(ctx context.Context, secretPrefix []byte) (OAuth2ProviderAppCode, error) {
|
||||
@@ -6836,8 +6807,6 @@ func (q *sqlQuerier) GetOAuth2ProviderAppCodeByPrefix(ctx context.Context, secre
|
||||
&i.ResourceUri,
|
||||
&i.CodeChallenge,
|
||||
&i.CodeChallengeMethod,
|
||||
&i.StateHash,
|
||||
&i.RedirectUri,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -7241,9 +7210,7 @@ INSERT INTO oauth2_provider_app_codes (
|
||||
user_id,
|
||||
resource_uri,
|
||||
code_challenge,
|
||||
code_challenge_method,
|
||||
state_hash,
|
||||
redirect_uri
|
||||
code_challenge_method
|
||||
) VALUES(
|
||||
$1,
|
||||
$2,
|
||||
@@ -7254,10 +7221,8 @@ INSERT INTO oauth2_provider_app_codes (
|
||||
$7,
|
||||
$8,
|
||||
$9,
|
||||
$10,
|
||||
$11,
|
||||
$12
|
||||
) RETURNING id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method, state_hash, redirect_uri
|
||||
$10
|
||||
) RETURNING id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method
|
||||
`
|
||||
|
||||
type InsertOAuth2ProviderAppCodeParams struct {
|
||||
@@ -7271,8 +7236,6 @@ type InsertOAuth2ProviderAppCodeParams struct {
|
||||
ResourceUri sql.NullString `db:"resource_uri" json:"resource_uri"`
|
||||
CodeChallenge sql.NullString `db:"code_challenge" json:"code_challenge"`
|
||||
CodeChallengeMethod sql.NullString `db:"code_challenge_method" json:"code_challenge_method"`
|
||||
StateHash sql.NullString `db:"state_hash" json:"state_hash"`
|
||||
RedirectUri sql.NullString `db:"redirect_uri" json:"redirect_uri"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertOAuth2ProviderAppCode(ctx context.Context, arg InsertOAuth2ProviderAppCodeParams) (OAuth2ProviderAppCode, error) {
|
||||
@@ -7287,8 +7250,6 @@ func (q *sqlQuerier) InsertOAuth2ProviderAppCode(ctx context.Context, arg Insert
|
||||
arg.ResourceUri,
|
||||
arg.CodeChallenge,
|
||||
arg.CodeChallengeMethod,
|
||||
arg.StateHash,
|
||||
arg.RedirectUri,
|
||||
)
|
||||
var i OAuth2ProviderAppCode
|
||||
err := row.Scan(
|
||||
@@ -7302,8 +7263,6 @@ func (q *sqlQuerier) InsertOAuth2ProviderAppCode(ctx context.Context, arg Insert
|
||||
&i.ResourceUri,
|
||||
&i.CodeChallenge,
|
||||
&i.CodeChallengeMethod,
|
||||
&i.StateHash,
|
||||
&i.RedirectUri,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -13353,203 +13312,6 @@ func (q *sqlQuerier) GetTaskSnapshot(ctx context.Context, taskID uuid.UUID) (Tas
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTelemetryTaskEvents = `-- name: GetTelemetryTaskEvents :many
|
||||
WITH task_app_ids AS (
|
||||
SELECT task_id, workspace_app_id
|
||||
FROM task_workspace_apps
|
||||
),
|
||||
task_status_timeline AS (
|
||||
-- All app statuses across every historical app for each task,
|
||||
-- plus synthetic "boundary" rows at each stop/start build transition.
|
||||
-- This allows us to correctly take gaps due to pause/resume into account.
|
||||
SELECT tai.task_id, was.created_at, was.state::text AS state
|
||||
FROM workspace_app_statuses was
|
||||
JOIN task_app_ids tai ON tai.workspace_app_id = was.app_id
|
||||
UNION ALL
|
||||
SELECT t.id AS task_id, wb.created_at, '_boundary' AS state
|
||||
FROM tasks t
|
||||
JOIN workspace_builds wb ON wb.workspace_id = t.workspace_id
|
||||
WHERE t.deleted_at IS NULL
|
||||
AND t.workspace_id IS NOT NULL
|
||||
AND wb.build_number > 1
|
||||
),
|
||||
task_event_data AS (
|
||||
SELECT
|
||||
t.id AS task_id,
|
||||
t.workspace_id,
|
||||
twa.workspace_app_id,
|
||||
-- Latest stop build.
|
||||
stop_build.created_at AS stop_build_created_at,
|
||||
stop_build.reason AS stop_build_reason,
|
||||
-- Latest start build (task_resume only).
|
||||
start_build.created_at AS start_build_created_at,
|
||||
start_build.reason AS start_build_reason,
|
||||
start_build.build_number AS start_build_number,
|
||||
-- Last "working" app status (for idle duration).
|
||||
lws.created_at AS last_working_status_at,
|
||||
-- First app status after resume (for resume-to-status duration).
|
||||
-- Only populated for workspaces in an active phase (started more
|
||||
-- recently than stopped).
|
||||
fsar.created_at AS first_status_after_resume_at,
|
||||
-- Cumulative time spent in "working" state.
|
||||
active_dur.total_working_ms AS active_duration_ms
|
||||
FROM tasks t
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT task_app.workspace_app_id
|
||||
FROM task_workspace_apps task_app
|
||||
WHERE task_app.task_id = t.id
|
||||
ORDER BY task_app.workspace_build_number DESC
|
||||
LIMIT 1
|
||||
) twa ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT wb.created_at, wb.reason, wb.build_number
|
||||
FROM workspace_builds wb
|
||||
WHERE wb.workspace_id = t.workspace_id
|
||||
AND wb.transition = 'stop'
|
||||
ORDER BY wb.build_number DESC
|
||||
LIMIT 1
|
||||
) stop_build ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT wb.created_at, wb.reason, wb.build_number
|
||||
FROM workspace_builds wb
|
||||
WHERE wb.workspace_id = t.workspace_id
|
||||
AND wb.transition = 'start'
|
||||
ORDER BY wb.build_number DESC
|
||||
LIMIT 1
|
||||
) start_build ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT tst.created_at
|
||||
FROM task_status_timeline tst
|
||||
WHERE tst.task_id = t.id
|
||||
AND tst.state = 'working'
|
||||
-- Only consider status before the latest pause so that
|
||||
-- post-resume statuses don't mask pre-pause idle time.
|
||||
AND (stop_build.created_at IS NULL
|
||||
OR tst.created_at <= stop_build.created_at)
|
||||
ORDER BY tst.created_at DESC
|
||||
LIMIT 1
|
||||
) lws ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT was.created_at
|
||||
FROM workspace_app_statuses was
|
||||
WHERE was.app_id = twa.workspace_app_id
|
||||
AND was.created_at > start_build.created_at
|
||||
ORDER BY was.created_at ASC
|
||||
LIMIT 1
|
||||
) fsar ON twa.workspace_app_id IS NOT NULL
|
||||
AND start_build.created_at IS NOT NULL
|
||||
AND (stop_build.created_at IS NULL
|
||||
OR start_build.created_at > stop_build.created_at)
|
||||
-- Active duration: cumulative time spent in "working" state across all
|
||||
-- historical app IDs for this task. Uses LEAD() to convert point-in-time
|
||||
-- statuses into intervals, then sums intervals where state='working'. For
|
||||
-- the last status, falls back to stop_build time (if paused) or @now (if
|
||||
-- still running).
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT COALESCE(
|
||||
SUM(EXTRACT(EPOCH FROM (interval_end - interval_start)) * 1000)::bigint,
|
||||
0
|
||||
)::bigint AS total_working_ms
|
||||
FROM (
|
||||
SELECT
|
||||
tst.created_at AS interval_start,
|
||||
COALESCE(
|
||||
LEAD(tst.created_at) OVER (ORDER BY tst.created_at ASC, CASE WHEN tst.state = '_boundary' THEN 1 ELSE 0 END ASC),
|
||||
CASE WHEN stop_build.created_at IS NOT NULL
|
||||
AND (start_build.created_at IS NULL
|
||||
OR stop_build.created_at > start_build.created_at)
|
||||
THEN stop_build.created_at
|
||||
ELSE $1::timestamptz
|
||||
END
|
||||
) AS interval_end,
|
||||
tst.state
|
||||
FROM task_status_timeline tst
|
||||
WHERE tst.task_id = t.id
|
||||
) intervals
|
||||
WHERE intervals.state = 'working'
|
||||
) active_dur ON TRUE
|
||||
WHERE t.deleted_at IS NULL
|
||||
AND t.workspace_id IS NOT NULL
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM workspace_builds wb
|
||||
WHERE wb.workspace_id = t.workspace_id
|
||||
AND wb.created_at > $2
|
||||
)
|
||||
)
|
||||
SELECT task_id, workspace_id, workspace_app_id, stop_build_created_at, stop_build_reason, start_build_created_at, start_build_reason, start_build_number, last_working_status_at, first_status_after_resume_at, active_duration_ms FROM task_event_data
|
||||
ORDER BY task_id
|
||||
`
|
||||
|
||||
type GetTelemetryTaskEventsParams struct {
|
||||
Now time.Time `db:"now" json:"now"`
|
||||
CreatedAfter time.Time `db:"created_after" json:"created_after"`
|
||||
}
|
||||
|
||||
type GetTelemetryTaskEventsRow struct {
|
||||
TaskID uuid.UUID `db:"task_id" json:"task_id"`
|
||||
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
|
||||
WorkspaceAppID uuid.NullUUID `db:"workspace_app_id" json:"workspace_app_id"`
|
||||
StopBuildCreatedAt sql.NullTime `db:"stop_build_created_at" json:"stop_build_created_at"`
|
||||
StopBuildReason NullBuildReason `db:"stop_build_reason" json:"stop_build_reason"`
|
||||
StartBuildCreatedAt sql.NullTime `db:"start_build_created_at" json:"start_build_created_at"`
|
||||
StartBuildReason NullBuildReason `db:"start_build_reason" json:"start_build_reason"`
|
||||
StartBuildNumber sql.NullInt32 `db:"start_build_number" json:"start_build_number"`
|
||||
LastWorkingStatusAt sql.NullTime `db:"last_working_status_at" json:"last_working_status_at"`
|
||||
FirstStatusAfterResumeAt sql.NullTime `db:"first_status_after_resume_at" json:"first_status_after_resume_at"`
|
||||
ActiveDurationMs int64 `db:"active_duration_ms" json:"active_duration_ms"`
|
||||
}
|
||||
|
||||
// Returns all data needed to build task lifecycle events for telemetry
|
||||
// in a single round-trip. For each task whose workspace is in the
|
||||
// given set, fetches:
|
||||
// - the latest workspace app binding (task_workspace_apps)
|
||||
// - the most recent stop and start builds (workspace_builds)
|
||||
// - the last "working" app status (workspace_app_statuses)
|
||||
// - the first app status after resume, for active workspaces
|
||||
//
|
||||
// Assumptions:
|
||||
// - 1:1 relationship between tasks and workspaces. All builds on the
|
||||
// workspace are considered task-related.
|
||||
// - Idle duration approximation: If the agent reports "working", does
|
||||
// work, then reports "done", we miss that working time.
|
||||
// - lws and active_dur join across all historical app IDs for the task,
|
||||
// because each resume cycle provisions a new app ID. This ensures
|
||||
// pre-pause statuses contribute to idle duration and active duration.
|
||||
func (q *sqlQuerier) GetTelemetryTaskEvents(ctx context.Context, arg GetTelemetryTaskEventsParams) ([]GetTelemetryTaskEventsRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getTelemetryTaskEvents, arg.Now, arg.CreatedAfter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetTelemetryTaskEventsRow
|
||||
for rows.Next() {
|
||||
var i GetTelemetryTaskEventsRow
|
||||
if err := rows.Scan(
|
||||
&i.TaskID,
|
||||
&i.WorkspaceID,
|
||||
&i.WorkspaceAppID,
|
||||
&i.StopBuildCreatedAt,
|
||||
&i.StopBuildReason,
|
||||
&i.StartBuildCreatedAt,
|
||||
&i.StartBuildReason,
|
||||
&i.StartBuildNumber,
|
||||
&i.LastWorkingStatusAt,
|
||||
&i.FirstStatusAfterResumeAt,
|
||||
&i.ActiveDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertTask = `-- name: InsertTask :one
|
||||
INSERT INTO tasks
|
||||
(id, organization_id, owner_id, name, display_name, workspace_id, template_version_id, template_parameters, prompt, created_at)
|
||||
@@ -18179,7 +17941,7 @@ const getAuthenticatedWorkspaceAgentAndBuildByAuthToken = `-- name: GetAuthentic
|
||||
SELECT
|
||||
workspaces.id, workspaces.created_at, workspaces.updated_at, workspaces.owner_id, workspaces.organization_id, workspaces.template_id, workspaces.deleted, workspaces.name, workspaces.autostart_schedule, workspaces.ttl, workspaces.last_used_at, workspaces.dormant_at, workspaces.deleting_at, workspaces.automatic_updates, workspaces.favorite, workspaces.next_start_at, workspaces.group_acl, workspaces.user_acl,
|
||||
workspace_agents.id, workspace_agents.created_at, workspace_agents.updated_at, workspace_agents.name, workspace_agents.first_connected_at, workspace_agents.last_connected_at, workspace_agents.disconnected_at, workspace_agents.resource_id, workspace_agents.auth_token, workspace_agents.auth_instance_id, workspace_agents.architecture, workspace_agents.environment_variables, workspace_agents.operating_system, workspace_agents.instance_metadata, workspace_agents.resource_metadata, workspace_agents.directory, workspace_agents.version, workspace_agents.last_connected_replica_id, workspace_agents.connection_timeout_seconds, workspace_agents.troubleshooting_url, workspace_agents.motd_file, workspace_agents.lifecycle_state, workspace_agents.expanded_directory, workspace_agents.logs_length, workspace_agents.logs_overflowed, workspace_agents.started_at, workspace_agents.ready_at, workspace_agents.subsystems, workspace_agents.display_apps, workspace_agents.api_version, workspace_agents.display_order, workspace_agents.parent_id, workspace_agents.api_key_scope, workspace_agents.deleted,
|
||||
workspace_build_with_user.id, workspace_build_with_user.created_at, workspace_build_with_user.updated_at, workspace_build_with_user.workspace_id, workspace_build_with_user.template_version_id, workspace_build_with_user.build_number, workspace_build_with_user.transition, workspace_build_with_user.initiator_id, workspace_build_with_user.job_id, workspace_build_with_user.deadline, workspace_build_with_user.reason, workspace_build_with_user.daily_cost, workspace_build_with_user.max_deadline, workspace_build_with_user.template_version_preset_id, workspace_build_with_user.has_ai_task, workspace_build_with_user.has_external_agent, workspace_build_with_user.initiator_by_avatar_url, workspace_build_with_user.initiator_by_username, workspace_build_with_user.initiator_by_name,
|
||||
workspace_build_with_user.id, workspace_build_with_user.created_at, workspace_build_with_user.updated_at, workspace_build_with_user.workspace_id, workspace_build_with_user.template_version_id, workspace_build_with_user.build_number, workspace_build_with_user.transition, workspace_build_with_user.initiator_id, workspace_build_with_user.provisioner_state, workspace_build_with_user.job_id, workspace_build_with_user.deadline, workspace_build_with_user.reason, workspace_build_with_user.daily_cost, workspace_build_with_user.max_deadline, workspace_build_with_user.template_version_preset_id, workspace_build_with_user.has_ai_task, workspace_build_with_user.has_external_agent, workspace_build_with_user.initiator_by_avatar_url, workspace_build_with_user.initiator_by_username, workspace_build_with_user.initiator_by_name,
|
||||
tasks.id AS task_id
|
||||
FROM
|
||||
workspace_agents
|
||||
@@ -18317,6 +18079,7 @@ func (q *sqlQuerier) GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx conte
|
||||
&i.WorkspaceBuild.BuildNumber,
|
||||
&i.WorkspaceBuild.Transition,
|
||||
&i.WorkspaceBuild.InitiatorID,
|
||||
&i.WorkspaceBuild.ProvisionerState,
|
||||
&i.WorkspaceBuild.JobID,
|
||||
&i.WorkspaceBuild.Deadline,
|
||||
&i.WorkspaceBuild.Reason,
|
||||
@@ -21230,7 +20993,7 @@ func (q *sqlQuerier) InsertWorkspaceBuildParameters(ctx context.Context, arg Ins
|
||||
}
|
||||
|
||||
const getActiveWorkspaceBuildsByTemplateID = `-- name: GetActiveWorkspaceBuildsByTemplateID :many
|
||||
SELECT wb.id, wb.created_at, wb.updated_at, wb.workspace_id, wb.template_version_id, wb.build_number, wb.transition, wb.initiator_id, wb.job_id, wb.deadline, wb.reason, wb.daily_cost, wb.max_deadline, wb.template_version_preset_id, wb.has_ai_task, wb.has_external_agent, wb.initiator_by_avatar_url, wb.initiator_by_username, wb.initiator_by_name
|
||||
SELECT wb.id, wb.created_at, wb.updated_at, wb.workspace_id, wb.template_version_id, wb.build_number, wb.transition, wb.initiator_id, wb.provisioner_state, wb.job_id, wb.deadline, wb.reason, wb.daily_cost, wb.max_deadline, wb.template_version_preset_id, wb.has_ai_task, wb.has_external_agent, wb.initiator_by_avatar_url, wb.initiator_by_username, wb.initiator_by_name
|
||||
FROM (
|
||||
SELECT
|
||||
workspace_id, MAX(build_number) as max_build_number
|
||||
@@ -21278,6 +21041,7 @@ func (q *sqlQuerier) GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, t
|
||||
&i.BuildNumber,
|
||||
&i.Transition,
|
||||
&i.InitiatorID,
|
||||
&i.ProvisionerState,
|
||||
&i.JobID,
|
||||
&i.Deadline,
|
||||
&i.Reason,
|
||||
@@ -21385,7 +21149,7 @@ func (q *sqlQuerier) GetFailedWorkspaceBuildsByTemplateID(ctx context.Context, a
|
||||
|
||||
const getLatestWorkspaceBuildByWorkspaceID = `-- name: GetLatestWorkspaceBuildByWorkspaceID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
FROM
|
||||
workspace_build_with_user AS workspace_builds
|
||||
WHERE
|
||||
@@ -21408,6 +21172,7 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, w
|
||||
&i.BuildNumber,
|
||||
&i.Transition,
|
||||
&i.InitiatorID,
|
||||
&i.ProvisionerState,
|
||||
&i.JobID,
|
||||
&i.Deadline,
|
||||
&i.Reason,
|
||||
@@ -21426,7 +21191,7 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, w
|
||||
const getLatestWorkspaceBuildsByWorkspaceIDs = `-- name: GetLatestWorkspaceBuildsByWorkspaceIDs :many
|
||||
SELECT
|
||||
DISTINCT ON (workspace_id)
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
FROM
|
||||
workspace_build_with_user AS workspace_builds
|
||||
WHERE
|
||||
@@ -21453,6 +21218,7 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context,
|
||||
&i.BuildNumber,
|
||||
&i.Transition,
|
||||
&i.InitiatorID,
|
||||
&i.ProvisionerState,
|
||||
&i.JobID,
|
||||
&i.Deadline,
|
||||
&i.Reason,
|
||||
@@ -21480,7 +21246,7 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context,
|
||||
|
||||
const getWorkspaceBuildByID = `-- name: GetWorkspaceBuildByID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
FROM
|
||||
workspace_build_with_user AS workspace_builds
|
||||
WHERE
|
||||
@@ -21501,6 +21267,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (W
|
||||
&i.BuildNumber,
|
||||
&i.Transition,
|
||||
&i.InitiatorID,
|
||||
&i.ProvisionerState,
|
||||
&i.JobID,
|
||||
&i.Deadline,
|
||||
&i.Reason,
|
||||
@@ -21518,7 +21285,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (W
|
||||
|
||||
const getWorkspaceBuildByJobID = `-- name: GetWorkspaceBuildByJobID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
FROM
|
||||
workspace_build_with_user AS workspace_builds
|
||||
WHERE
|
||||
@@ -21539,6 +21306,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByJobID(ctx context.Context, jobID uuid.UU
|
||||
&i.BuildNumber,
|
||||
&i.Transition,
|
||||
&i.InitiatorID,
|
||||
&i.ProvisionerState,
|
||||
&i.JobID,
|
||||
&i.Deadline,
|
||||
&i.Reason,
|
||||
@@ -21556,7 +21324,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByJobID(ctx context.Context, jobID uuid.UU
|
||||
|
||||
const getWorkspaceBuildByWorkspaceIDAndBuildNumber = `-- name: GetWorkspaceBuildByWorkspaceIDAndBuildNumber :one
|
||||
SELECT
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
FROM
|
||||
workspace_build_with_user AS workspace_builds
|
||||
WHERE
|
||||
@@ -21581,6 +21349,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx context.Co
|
||||
&i.BuildNumber,
|
||||
&i.Transition,
|
||||
&i.InitiatorID,
|
||||
&i.ProvisionerState,
|
||||
&i.JobID,
|
||||
&i.Deadline,
|
||||
&i.Reason,
|
||||
@@ -21618,7 +21387,7 @@ JOIN workspaces w ON wb.workspace_id = w.id
|
||||
JOIN templates t ON w.template_id = t.id
|
||||
JOIN organizations o ON t.organization_id = o.id
|
||||
JOIN workspace_resources wr ON wr.job_id = wb.job_id
|
||||
JOIN workspace_agents wa ON wa.resource_id = wr.id AND wa.parent_id IS NULL
|
||||
JOIN workspace_agents wa ON wa.resource_id = wr.id
|
||||
WHERE wb.job_id = (SELECT job_id FROM workspace_resources WHERE workspace_resources.id = $1)
|
||||
GROUP BY wb.created_at, wb.transition, t.name, o.name, w.owner_id
|
||||
`
|
||||
@@ -21652,48 +21421,6 @@ func (q *sqlQuerier) GetWorkspaceBuildMetricsByResourceID(ctx context.Context, i
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getWorkspaceBuildProvisionerStateByID = `-- name: GetWorkspaceBuildProvisionerStateByID :one
|
||||
SELECT
|
||||
workspace_builds.provisioner_state,
|
||||
templates.id AS template_id,
|
||||
templates.organization_id AS template_organization_id,
|
||||
templates.user_acl,
|
||||
templates.group_acl
|
||||
FROM
|
||||
workspace_builds
|
||||
INNER JOIN
|
||||
workspaces ON workspaces.id = workspace_builds.workspace_id
|
||||
INNER JOIN
|
||||
templates ON templates.id = workspaces.template_id
|
||||
WHERE
|
||||
workspace_builds.id = $1
|
||||
`
|
||||
|
||||
type GetWorkspaceBuildProvisionerStateByIDRow struct {
|
||||
ProvisionerState []byte `db:"provisioner_state" json:"provisioner_state"`
|
||||
TemplateID uuid.UUID `db:"template_id" json:"template_id"`
|
||||
TemplateOrganizationID uuid.UUID `db:"template_organization_id" json:"template_organization_id"`
|
||||
UserACL TemplateACL `db:"user_acl" json:"user_acl"`
|
||||
GroupACL TemplateACL `db:"group_acl" json:"group_acl"`
|
||||
}
|
||||
|
||||
// Fetches the provisioner state of a workspace build, joined through to the
|
||||
// template so that dbauthz can enforce policy.ActionUpdate on the template.
|
||||
// Provisioner state contains sensitive Terraform state and should only be
|
||||
// accessible to template administrators.
|
||||
func (q *sqlQuerier) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (GetWorkspaceBuildProvisionerStateByIDRow, error) {
|
||||
row := q.db.QueryRowContext(ctx, getWorkspaceBuildProvisionerStateByID, workspaceBuildID)
|
||||
var i GetWorkspaceBuildProvisionerStateByIDRow
|
||||
err := row.Scan(
|
||||
&i.ProvisionerState,
|
||||
&i.TemplateID,
|
||||
&i.TemplateOrganizationID,
|
||||
&i.UserACL,
|
||||
&i.GroupACL,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getWorkspaceBuildStatsByTemplates = `-- name: GetWorkspaceBuildStatsByTemplates :many
|
||||
SELECT
|
||||
w.template_id,
|
||||
@@ -21763,7 +21490,7 @@ func (q *sqlQuerier) GetWorkspaceBuildStatsByTemplates(ctx context.Context, sinc
|
||||
|
||||
const getWorkspaceBuildsByWorkspaceID = `-- name: GetWorkspaceBuildsByWorkspaceID :many
|
||||
SELECT
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
FROM
|
||||
workspace_build_with_user AS workspace_builds
|
||||
WHERE
|
||||
@@ -21827,6 +21554,7 @@ func (q *sqlQuerier) GetWorkspaceBuildsByWorkspaceID(ctx context.Context, arg Ge
|
||||
&i.BuildNumber,
|
||||
&i.Transition,
|
||||
&i.InitiatorID,
|
||||
&i.ProvisionerState,
|
||||
&i.JobID,
|
||||
&i.Deadline,
|
||||
&i.Reason,
|
||||
@@ -21853,7 +21581,7 @@ func (q *sqlQuerier) GetWorkspaceBuildsByWorkspaceID(ctx context.Context, arg Ge
|
||||
}
|
||||
|
||||
const getWorkspaceBuildsCreatedAfter = `-- name: GetWorkspaceBuildsCreatedAfter :many
|
||||
SELECT id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name FROM workspace_build_with_user WHERE created_at > $1
|
||||
SELECT id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name FROM workspace_build_with_user WHERE created_at > $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetWorkspaceBuildsCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceBuild, error) {
|
||||
@@ -21874,6 +21602,7 @@ func (q *sqlQuerier) GetWorkspaceBuildsCreatedAfter(ctx context.Context, created
|
||||
&i.BuildNumber,
|
||||
&i.Transition,
|
||||
&i.InitiatorID,
|
||||
&i.ProvisionerState,
|
||||
&i.JobID,
|
||||
&i.Deadline,
|
||||
&i.Reason,
|
||||
|
||||
@@ -25,12 +25,10 @@ LIMIT
|
||||
SELECT * FROM api_keys WHERE last_used > $1;
|
||||
|
||||
-- name: GetAPIKeysByLoginType :many
|
||||
SELECT * FROM api_keys WHERE login_type = $1
|
||||
AND (@include_expired::bool OR expires_at > now());
|
||||
SELECT * FROM api_keys WHERE login_type = $1;
|
||||
|
||||
-- name: GetAPIKeysByUserID :many
|
||||
SELECT * FROM api_keys WHERE login_type = $1 AND user_id = $2
|
||||
AND (@include_expired::bool OR expires_at > now());
|
||||
SELECT * FROM api_keys WHERE login_type = $1 AND user_id = $2;
|
||||
|
||||
-- name: InsertAPIKey :one
|
||||
INSERT INTO
|
||||
|
||||
@@ -149,105 +149,94 @@ VALUES (
|
||||
RETURNING *;
|
||||
|
||||
-- name: CountAuditLogs :one
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT 1
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN @resource_target::text != '' THEN resource_target = @resource_target
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN @action::text != '' THEN action = @action::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower(@username)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @email::text != '' THEN users.email = @email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
-- Avoid a slow scan on a large table with joins. The caller
|
||||
-- passes the count cap and we add 1 so the frontend can detect
|
||||
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
|
||||
-- -> NULL + 1 = NULL).
|
||||
-- NOTE: Parameterizing this so that we can easily change from,
|
||||
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
|
||||
-- here if disabling the capping on a large table permanently.
|
||||
-- This way the PG planner can plan parallel execution for
|
||||
-- potential large wins.
|
||||
LIMIT NULLIF(@count_cap::int, 0) + 1
|
||||
) AS limited_count;
|
||||
SELECT COUNT(*)
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN @resource_target::text != '' THEN resource_target = @resource_target
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN @action::text != '' THEN action = @action::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower(@username)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @email::text != '' THEN users.email = @email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
;
|
||||
|
||||
-- name: DeleteOldAuditLogConnectionEvents :exec
|
||||
DELETE FROM audit_logs
|
||||
|
||||
@@ -133,113 +133,111 @@ OFFSET
|
||||
@offset_opt;
|
||||
|
||||
-- name: CountConnectionLogs :one
|
||||
SELECT COUNT(*) AS count FROM (
|
||||
SELECT 1
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN @workspace_owner :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = @workspace_owner_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN @workspace_owner_email :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = @workspace_owner_email AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN @type :: text != '' THEN
|
||||
type = @type :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@username) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @user_email :: text != '' THEN
|
||||
users.email = @user_email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= @connected_after
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= @connected_before
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = @workspace_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = @connection_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN @status :: text != '' THEN
|
||||
((@status = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
-- NOTE: See the CountAuditLogs LIMIT note.
|
||||
LIMIT NULLIF(@count_cap::int, 0) + 1
|
||||
) AS limited_count;
|
||||
SELECT
|
||||
COUNT(*) AS count
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN @workspace_owner :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = @workspace_owner_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN @workspace_owner_email :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = @workspace_owner_email AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN @type :: text != '' THEN
|
||||
type = @type :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@username) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @user_email :: text != '' THEN
|
||||
users.email = @user_email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= @connected_after
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= @connected_before
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = @workspace_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = @connection_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN @status :: text != '' THEN
|
||||
((@status = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
;
|
||||
|
||||
-- name: DeleteOldConnectionLogs :execrows
|
||||
WITH old_logs AS (
|
||||
|
||||
@@ -140,9 +140,7 @@ INSERT INTO oauth2_provider_app_codes (
|
||||
user_id,
|
||||
resource_uri,
|
||||
code_challenge,
|
||||
code_challenge_method,
|
||||
state_hash,
|
||||
redirect_uri
|
||||
code_challenge_method
|
||||
) VALUES(
|
||||
$1,
|
||||
$2,
|
||||
@@ -153,9 +151,7 @@ INSERT INTO oauth2_provider_app_codes (
|
||||
$7,
|
||||
$8,
|
||||
$9,
|
||||
$10,
|
||||
$11,
|
||||
$12
|
||||
$10
|
||||
) RETURNING *;
|
||||
|
||||
-- name: DeleteOAuth2ProviderAppCodeByID :exec
|
||||
|
||||
@@ -100,146 +100,3 @@ FROM
|
||||
task_snapshots
|
||||
WHERE
|
||||
task_id = $1;
|
||||
|
||||
-- name: GetTelemetryTaskEvents :many
|
||||
-- Returns all data needed to build task lifecycle events for telemetry
|
||||
-- in a single round-trip. For each task whose workspace is in the
|
||||
-- given set, fetches:
|
||||
-- - the latest workspace app binding (task_workspace_apps)
|
||||
-- - the most recent stop and start builds (workspace_builds)
|
||||
-- - the last "working" app status (workspace_app_statuses)
|
||||
-- - the first app status after resume, for active workspaces
|
||||
--
|
||||
-- Assumptions:
|
||||
-- - 1:1 relationship between tasks and workspaces. All builds on the
|
||||
-- workspace are considered task-related.
|
||||
-- - Idle duration approximation: If the agent reports "working", does
|
||||
-- work, then reports "done", we miss that working time.
|
||||
-- - lws and active_dur join across all historical app IDs for the task,
|
||||
-- because each resume cycle provisions a new app ID. This ensures
|
||||
-- pre-pause statuses contribute to idle duration and active duration.
|
||||
WITH task_app_ids AS (
|
||||
SELECT task_id, workspace_app_id
|
||||
FROM task_workspace_apps
|
||||
),
|
||||
task_status_timeline AS (
|
||||
-- All app statuses across every historical app for each task,
|
||||
-- plus synthetic "boundary" rows at each stop/start build transition.
|
||||
-- This allows us to correctly take gaps due to pause/resume into account.
|
||||
SELECT tai.task_id, was.created_at, was.state::text AS state
|
||||
FROM workspace_app_statuses was
|
||||
JOIN task_app_ids tai ON tai.workspace_app_id = was.app_id
|
||||
UNION ALL
|
||||
SELECT t.id AS task_id, wb.created_at, '_boundary' AS state
|
||||
FROM tasks t
|
||||
JOIN workspace_builds wb ON wb.workspace_id = t.workspace_id
|
||||
WHERE t.deleted_at IS NULL
|
||||
AND t.workspace_id IS NOT NULL
|
||||
AND wb.build_number > 1
|
||||
),
|
||||
task_event_data AS (
|
||||
SELECT
|
||||
t.id AS task_id,
|
||||
t.workspace_id,
|
||||
twa.workspace_app_id,
|
||||
-- Latest stop build.
|
||||
stop_build.created_at AS stop_build_created_at,
|
||||
stop_build.reason AS stop_build_reason,
|
||||
-- Latest start build (task_resume only).
|
||||
start_build.created_at AS start_build_created_at,
|
||||
start_build.reason AS start_build_reason,
|
||||
start_build.build_number AS start_build_number,
|
||||
-- Last "working" app status (for idle duration).
|
||||
lws.created_at AS last_working_status_at,
|
||||
-- First app status after resume (for resume-to-status duration).
|
||||
-- Only populated for workspaces in an active phase (started more
|
||||
-- recently than stopped).
|
||||
fsar.created_at AS first_status_after_resume_at,
|
||||
-- Cumulative time spent in "working" state.
|
||||
active_dur.total_working_ms AS active_duration_ms
|
||||
FROM tasks t
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT task_app.workspace_app_id
|
||||
FROM task_workspace_apps task_app
|
||||
WHERE task_app.task_id = t.id
|
||||
ORDER BY task_app.workspace_build_number DESC
|
||||
LIMIT 1
|
||||
) twa ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT wb.created_at, wb.reason, wb.build_number
|
||||
FROM workspace_builds wb
|
||||
WHERE wb.workspace_id = t.workspace_id
|
||||
AND wb.transition = 'stop'
|
||||
ORDER BY wb.build_number DESC
|
||||
LIMIT 1
|
||||
) stop_build ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT wb.created_at, wb.reason, wb.build_number
|
||||
FROM workspace_builds wb
|
||||
WHERE wb.workspace_id = t.workspace_id
|
||||
AND wb.transition = 'start'
|
||||
ORDER BY wb.build_number DESC
|
||||
LIMIT 1
|
||||
) start_build ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT tst.created_at
|
||||
FROM task_status_timeline tst
|
||||
WHERE tst.task_id = t.id
|
||||
AND tst.state = 'working'
|
||||
-- Only consider status before the latest pause so that
|
||||
-- post-resume statuses don't mask pre-pause idle time.
|
||||
AND (stop_build.created_at IS NULL
|
||||
OR tst.created_at <= stop_build.created_at)
|
||||
ORDER BY tst.created_at DESC
|
||||
LIMIT 1
|
||||
) lws ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT was.created_at
|
||||
FROM workspace_app_statuses was
|
||||
WHERE was.app_id = twa.workspace_app_id
|
||||
AND was.created_at > start_build.created_at
|
||||
ORDER BY was.created_at ASC
|
||||
LIMIT 1
|
||||
) fsar ON twa.workspace_app_id IS NOT NULL
|
||||
AND start_build.created_at IS NOT NULL
|
||||
AND (stop_build.created_at IS NULL
|
||||
OR start_build.created_at > stop_build.created_at)
|
||||
-- Active duration: cumulative time spent in "working" state across all
|
||||
-- historical app IDs for this task. Uses LEAD() to convert point-in-time
|
||||
-- statuses into intervals, then sums intervals where state='working'. For
|
||||
-- the last status, falls back to stop_build time (if paused) or @now (if
|
||||
-- still running).
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT COALESCE(
|
||||
SUM(EXTRACT(EPOCH FROM (interval_end - interval_start)) * 1000)::bigint,
|
||||
0
|
||||
)::bigint AS total_working_ms
|
||||
FROM (
|
||||
SELECT
|
||||
tst.created_at AS interval_start,
|
||||
COALESCE(
|
||||
LEAD(tst.created_at) OVER (ORDER BY tst.created_at ASC, CASE WHEN tst.state = '_boundary' THEN 1 ELSE 0 END ASC),
|
||||
CASE WHEN stop_build.created_at IS NOT NULL
|
||||
AND (start_build.created_at IS NULL
|
||||
OR stop_build.created_at > start_build.created_at)
|
||||
THEN stop_build.created_at
|
||||
ELSE @now::timestamptz
|
||||
END
|
||||
) AS interval_end,
|
||||
tst.state
|
||||
FROM task_status_timeline tst
|
||||
WHERE tst.task_id = t.id
|
||||
) intervals
|
||||
WHERE intervals.state = 'working'
|
||||
) active_dur ON TRUE
|
||||
WHERE t.deleted_at IS NULL
|
||||
AND t.workspace_id IS NOT NULL
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM workspace_builds wb
|
||||
WHERE wb.workspace_id = t.workspace_id
|
||||
AND wb.created_at > @created_after
|
||||
)
|
||||
)
|
||||
SELECT * FROM task_event_data
|
||||
ORDER BY task_id;
|
||||
|
||||
|
||||
@@ -87,4 +87,3 @@ SELECT DISTINCT ON (workspace_id)
|
||||
FROM workspace_app_statuses
|
||||
WHERE workspace_id = ANY(@ids :: uuid[])
|
||||
ORDER BY workspace_id, created_at DESC;
|
||||
|
||||
|
||||
@@ -268,26 +268,6 @@ JOIN workspaces w ON wb.workspace_id = w.id
|
||||
JOIN templates t ON w.template_id = t.id
|
||||
JOIN organizations o ON t.organization_id = o.id
|
||||
JOIN workspace_resources wr ON wr.job_id = wb.job_id
|
||||
JOIN workspace_agents wa ON wa.resource_id = wr.id AND wa.parent_id IS NULL
|
||||
JOIN workspace_agents wa ON wa.resource_id = wr.id
|
||||
WHERE wb.job_id = (SELECT job_id FROM workspace_resources WHERE workspace_resources.id = $1)
|
||||
GROUP BY wb.created_at, wb.transition, t.name, o.name, w.owner_id;
|
||||
|
||||
-- name: GetWorkspaceBuildProvisionerStateByID :one
|
||||
-- Fetches the provisioner state of a workspace build, joined through to the
|
||||
-- template so that dbauthz can enforce policy.ActionUpdate on the template.
|
||||
-- Provisioner state contains sensitive Terraform state and should only be
|
||||
-- accessible to template administrators.
|
||||
SELECT
|
||||
workspace_builds.provisioner_state,
|
||||
templates.id AS template_id,
|
||||
templates.organization_id AS template_organization_id,
|
||||
templates.user_acl,
|
||||
templates.group_acl
|
||||
FROM
|
||||
workspace_builds
|
||||
INNER JOIN
|
||||
workspaces ON workspaces.id = workspace_builds.workspace_id
|
||||
INNER JOIN
|
||||
templates ON templates.id = workspaces.template_id
|
||||
WHERE
|
||||
workspace_builds.id = @workspace_build_id;
|
||||
|
||||
@@ -124,24 +124,6 @@ sql:
|
||||
- column: "tasks_with_status.workspace_app_health"
|
||||
go_type:
|
||||
type: "NullWorkspaceAppHealth"
|
||||
# Workaround for sqlc not interpreting the left join correctly
|
||||
# in the combined telemetry query.
|
||||
- column: "task_event_data.start_build_number"
|
||||
go_type: "database/sql.NullInt32"
|
||||
- column: "task_event_data.stop_build_created_at"
|
||||
go_type: "database/sql.NullTime"
|
||||
- column: "task_event_data.stop_build_reason"
|
||||
go_type:
|
||||
type: "NullBuildReason"
|
||||
- column: "task_event_data.start_build_created_at"
|
||||
go_type: "database/sql.NullTime"
|
||||
- column: "task_event_data.start_build_reason"
|
||||
go_type:
|
||||
type: "NullBuildReason"
|
||||
- column: "task_event_data.last_working_status_at"
|
||||
go_type: "database/sql.NullTime"
|
||||
- column: "task_event_data.first_status_after_resume_at"
|
||||
go_type: "database/sql.NullTime"
|
||||
rename:
|
||||
group_member: GroupMemberTable
|
||||
group_members_expanded: GroupMember
|
||||
|
||||
@@ -228,11 +228,12 @@ func (p *QueryParamParser) RedirectURL(vals url.Values, base *url.URL, queryPara
|
||||
})
|
||||
}
|
||||
|
||||
// OAuth 2.1 requires exact redirect URI matching.
|
||||
if v.String() != base.String() {
|
||||
// It can be a sub-directory but not a sub-domain, as we have apps on
|
||||
// sub-domains and that seems too dangerous.
|
||||
if v.Host != base.Host || !strings.HasPrefix(v.Path, base.Path) {
|
||||
p.Errors = append(p.Errors, codersdk.ValidationError{
|
||||
Field: queryParam,
|
||||
Detail: fmt.Sprintf("Query param %q must exactly match %s", queryParam, base),
|
||||
Detail: fmt.Sprintf("Query param %q must be a subset of %s", queryParam, base),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -62,12 +62,8 @@ func CSRF(cookieCfg codersdk.HTTPCookieConfig) func(next http.Handler) http.Hand
|
||||
mw.ExemptRegexp(regexp.MustCompile("/organizations/[^/]+/provisionerdaemons/*"))
|
||||
|
||||
mw.ExemptFunc(func(r *http.Request) bool {
|
||||
// Enforce CSRF on API routes and the OAuth2 authorize
|
||||
// endpoint. The authorize endpoint serves a browser consent
|
||||
// form whose POST must be CSRF-protected to prevent
|
||||
// cross-site authorization code theft (coder/security#121).
|
||||
if !strings.HasPrefix(r.URL.Path, "/api") &&
|
||||
!strings.HasPrefix(r.URL.Path, "/oauth2/authorize") {
|
||||
// Only enforce CSRF on API routes.
|
||||
if !strings.HasPrefix(r.URL.Path, "/api") {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -51,26 +51,6 @@ func TestCSRFExemptList(t *testing.T) {
|
||||
URL: "https://coder.com/api/v2/me",
|
||||
Exempt: false,
|
||||
},
|
||||
{
|
||||
Name: "OAuth2Authorize",
|
||||
URL: "https://coder.com/oauth2/authorize",
|
||||
Exempt: false,
|
||||
},
|
||||
{
|
||||
Name: "OAuth2AuthorizeQuery",
|
||||
URL: "https://coder.com/oauth2/authorize?client_id=test",
|
||||
Exempt: false,
|
||||
},
|
||||
{
|
||||
Name: "OAuth2Tokens",
|
||||
URL: "https://coder.com/oauth2/tokens",
|
||||
Exempt: true,
|
||||
},
|
||||
{
|
||||
Name: "OAuth2Register",
|
||||
URL: "https://coder.com/oauth2/register",
|
||||
Exempt: true,
|
||||
},
|
||||
}
|
||||
|
||||
mw := httpmw.CSRF(codersdk.HTTPCookieConfig{})
|
||||
|
||||
@@ -348,12 +348,8 @@ func reapJob(ctx context.Context, log slog.Logger, db database.Store, pub pubsub
|
||||
|
||||
// Only copy the provisioner state if there's no state in
|
||||
// the current build.
|
||||
currentStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get workspace build provisioner state: %w", err)
|
||||
}
|
||||
if len(currentStateRow.ProvisionerState) == 0 {
|
||||
// Get the previous build's state if it exists.
|
||||
if len(build.ProvisionerState) == 0 {
|
||||
// Get the previous build if it exists.
|
||||
prevBuild, err := db.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
|
||||
WorkspaceID: build.WorkspaceID,
|
||||
BuildNumber: build.BuildNumber - 1,
|
||||
@@ -362,14 +358,10 @@ func reapJob(ctx context.Context, log slog.Logger, db database.Store, pub pubsub
|
||||
return xerrors.Errorf("get previous workspace build: %w", err)
|
||||
}
|
||||
if err == nil {
|
||||
prevStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, prevBuild.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get previous workspace build provisioner state: %w", err)
|
||||
}
|
||||
err = db.UpdateWorkspaceBuildProvisionerStateByID(ctx, database.UpdateWorkspaceBuildProvisionerStateByIDParams{
|
||||
ID: build.ID,
|
||||
UpdatedAt: dbtime.Now(),
|
||||
ProvisionerState: prevStateRow.ProvisionerState,
|
||||
ProvisionerState: prevBuild.ProvisionerState,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update workspace build by id: %w", err)
|
||||
|
||||
@@ -126,9 +126,9 @@ func TestDetectorHungWorkspaceBuild(t *testing.T) {
|
||||
previousBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
|
||||
ProvisionerState(expectedWorkspaceBuildState).
|
||||
Succeeded(dbfake.WithJobCompletedAt(twentyMinAgo)).
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
|
||||
ProvisionerState: expectedWorkspaceBuildState,
|
||||
}).Succeeded(dbfake.WithJobCompletedAt(twentyMinAgo)).
|
||||
Do()
|
||||
|
||||
// Current build (hung - running job with UpdatedAt > 5 min ago).
|
||||
@@ -163,9 +163,7 @@ func TestDetectorHungWorkspaceBuild(t *testing.T) {
|
||||
// Check that the provisioner state was copied.
|
||||
build, err := db.GetWorkspaceBuildByID(ctx, currentBuild.Build.ID)
|
||||
require.NoError(t, err)
|
||||
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedWorkspaceBuildState, provisionerStateRow.ProvisionerState)
|
||||
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
|
||||
|
||||
detector.Close()
|
||||
detector.Wait()
|
||||
@@ -196,9 +194,9 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
|
||||
previousBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
|
||||
ProvisionerState([]byte(`{"dean":"NOT cool","colin":"also NOT cool"}`)).
|
||||
Succeeded(dbfake.WithJobCompletedAt(twentyMinAgo)).
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
|
||||
ProvisionerState: []byte(`{"dean":"NOT cool","colin":"also NOT cool"}`),
|
||||
}).Succeeded(dbfake.WithJobCompletedAt(twentyMinAgo)).
|
||||
Do()
|
||||
|
||||
// Current build (hung - running job with UpdatedAt > 5 min ago).
|
||||
@@ -206,8 +204,9 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
|
||||
currentBuild := dbfake.WorkspaceBuild(t, db, previousBuild.Workspace).
|
||||
Pubsub(pubsub).
|
||||
Seed(database.WorkspaceBuild{
|
||||
BuildNumber: 2,
|
||||
}).ProvisionerState(expectedWorkspaceBuildState).
|
||||
BuildNumber: 2,
|
||||
ProvisionerState: expectedWorkspaceBuildState,
|
||||
}).
|
||||
Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
|
||||
Do()
|
||||
|
||||
@@ -236,9 +235,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
|
||||
// Check that the provisioner state was NOT copied.
|
||||
build, err := db.GetWorkspaceBuildByID(ctx, currentBuild.Build.ID)
|
||||
require.NoError(t, err)
|
||||
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedWorkspaceBuildState, provisionerStateRow.ProvisionerState)
|
||||
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
|
||||
|
||||
detector.Close()
|
||||
detector.Wait()
|
||||
@@ -269,9 +266,9 @@ func TestDetectorHungWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testing.T
|
||||
currentBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
|
||||
ProvisionerState(expectedWorkspaceBuildState).
|
||||
Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
|
||||
ProvisionerState: expectedWorkspaceBuildState,
|
||||
}).Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
|
||||
Do()
|
||||
|
||||
t.Log("current job ID: ", currentBuild.Build.JobID)
|
||||
@@ -298,9 +295,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testing.T
|
||||
// Check that the provisioner state was NOT updated.
|
||||
build, err := db.GetWorkspaceBuildByID(ctx, currentBuild.Build.ID)
|
||||
require.NoError(t, err)
|
||||
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedWorkspaceBuildState, provisionerStateRow.ProvisionerState)
|
||||
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
|
||||
|
||||
detector.Close()
|
||||
detector.Wait()
|
||||
@@ -330,9 +325,9 @@ func TestDetectorPendingWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testin
|
||||
currentBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
|
||||
ProvisionerState(expectedWorkspaceBuildState).
|
||||
Pending(dbfake.WithJobCreatedAt(thirtyFiveMinAgo), dbfake.WithJobUpdatedAt(thirtyFiveMinAgo)).
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
|
||||
ProvisionerState: expectedWorkspaceBuildState,
|
||||
}).Pending(dbfake.WithJobCreatedAt(thirtyFiveMinAgo), dbfake.WithJobUpdatedAt(thirtyFiveMinAgo)).
|
||||
Do()
|
||||
|
||||
t.Log("current job ID: ", currentBuild.Build.JobID)
|
||||
@@ -361,9 +356,7 @@ func TestDetectorPendingWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testin
|
||||
// Check that the provisioner state was NOT updated.
|
||||
build, err := db.GetWorkspaceBuildByID(ctx, currentBuild.Build.ID)
|
||||
require.NoError(t, err)
|
||||
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedWorkspaceBuildState, provisionerStateRow.ProvisionerState)
|
||||
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
|
||||
|
||||
detector.Close()
|
||||
detector.Wait()
|
||||
@@ -405,9 +398,9 @@ func TestDetectorWorkspaceBuildForDormantWorkspace(t *testing.T) {
|
||||
Time: now.Add(-time.Hour),
|
||||
Valid: true,
|
||||
},
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
|
||||
ProvisionerState(expectedWorkspaceBuildState).
|
||||
Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
|
||||
ProvisionerState: expectedWorkspaceBuildState,
|
||||
}).Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
|
||||
Do()
|
||||
|
||||
t.Log("current job ID: ", currentBuild.Build.JobID)
|
||||
|
||||
+22
-40
@@ -2,8 +2,6 @@ package mcp_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -12,7 +10,6 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
mcpclient "github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
@@ -27,15 +24,6 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// mcpGeneratePKCE creates a PKCE verifier and S256 challenge for MCP
|
||||
// e2e tests.
|
||||
func mcpGeneratePKCE() (verifier, challenge string) {
|
||||
verifier = uuid.NewString() + uuid.NewString()
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||
return verifier, challenge
|
||||
}
|
||||
|
||||
func TestMCPHTTP_E2E_ClientIntegration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -565,32 +553,31 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
|
||||
// In a real flow, this would be done through the browser consent page
|
||||
// For testing, we'll create the code directly using the internal API
|
||||
|
||||
// First, we need to authorize the app (simulating user consent).
|
||||
staticVerifier, staticChallenge := mcpGeneratePKCE()
|
||||
authURL := fmt.Sprintf("%s/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&state=test_state&code_challenge=%s&code_challenge_method=S256",
|
||||
api.AccessURL.String(), app.ID, "http://localhost:3000/callback", staticChallenge)
|
||||
// First, we need to authorize the app (simulating user consent)
|
||||
authURL := fmt.Sprintf("%s/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&state=test_state",
|
||||
api.AccessURL.String(), app.ID, "http://localhost:3000/callback")
|
||||
|
||||
// Create an HTTP client that follows redirects but captures the final redirect.
|
||||
// Create an HTTP client that follows redirects but captures the final redirect
|
||||
client := &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse // Stop following redirects
|
||||
},
|
||||
}
|
||||
|
||||
// Make the authorization request (this would normally be done in a browser).
|
||||
// Make the authorization request (this would normally be done in a browser)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", authURL, nil)
|
||||
require.NoError(t, err)
|
||||
// Use RFC 6750 Bearer token for authentication.
|
||||
// Use RFC 6750 Bearer token for authentication
|
||||
req.Header.Set("Authorization", "Bearer "+coderClient.SessionToken())
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// The response should be a redirect to the consent page or directly to callback.
|
||||
// For testing purposes, let's simulate the POST consent approval.
|
||||
// The response should be a redirect to the consent page or directly to callback
|
||||
// For testing purposes, let's simulate the POST consent approval
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
// This means we got the consent page, now we need to POST consent.
|
||||
// This means we got the consent page, now we need to POST consent
|
||||
consentReq, err := http.NewRequestWithContext(ctx, "POST", authURL, nil)
|
||||
require.NoError(t, err)
|
||||
consentReq.Header.Set("Authorization", "Bearer "+coderClient.SessionToken())
|
||||
@@ -601,7 +588,7 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
// Extract authorization code from redirect URL.
|
||||
// Extract authorization code from redirect URL
|
||||
require.True(t, resp.StatusCode >= 300 && resp.StatusCode < 400, "Expected redirect response")
|
||||
location := resp.Header.Get("Location")
|
||||
require.NotEmpty(t, location, "Expected Location header in redirect")
|
||||
@@ -613,14 +600,13 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
|
||||
|
||||
t.Logf("Successfully obtained authorization code: %s", authCode[:10]+"...")
|
||||
|
||||
// Step 2: Exchange authorization code for access token and refresh token.
|
||||
// Step 2: Exchange authorization code for access token and refresh token
|
||||
tokenRequestBody := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {app.ID.String()},
|
||||
"client_secret": {secret.ClientSecretFull},
|
||||
"code": {authCode},
|
||||
"redirect_uri": {"http://localhost:3000/callback"},
|
||||
"code_verifier": {staticVerifier},
|
||||
}
|
||||
|
||||
tokenReq, err := http.NewRequestWithContext(ctx, "POST", api.AccessURL.String()+"/oauth2/tokens",
|
||||
@@ -882,44 +868,41 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
|
||||
|
||||
t.Logf("Successfully registered dynamic client: %s", clientID)
|
||||
|
||||
// Step 3: Perform OAuth2 authorization code flow with dynamically registered client.
|
||||
dynamicVerifier, dynamicChallenge := mcpGeneratePKCE()
|
||||
authURL := fmt.Sprintf("%s/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&state=dynamic_state&code_challenge=%s&code_challenge_method=S256",
|
||||
api.AccessURL.String(), clientID, "http://localhost:3000/callback", dynamicChallenge)
|
||||
// Step 3: Perform OAuth2 authorization code flow with dynamically registered client
|
||||
authURL := fmt.Sprintf("%s/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&state=dynamic_state",
|
||||
api.AccessURL.String(), clientID, "http://localhost:3000/callback")
|
||||
|
||||
// Create an HTTP client that captures redirects.
|
||||
// Create an HTTP client that captures redirects
|
||||
authClient := &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse // Stop following redirects
|
||||
},
|
||||
}
|
||||
|
||||
// Make the authorization request with authentication.
|
||||
// Make the authorization request with authentication
|
||||
authReq, err := http.NewRequestWithContext(ctx, "GET", authURL, nil)
|
||||
require.NoError(t, err)
|
||||
authReq.Header.Set("Cookie", fmt.Sprintf("coder_session_token=%s", coderClient.SessionToken()))
|
||||
authReq.Header.Set("Authorization", "Bearer "+coderClient.SessionToken())
|
||||
|
||||
authResp, err := authClient.Do(authReq)
|
||||
require.NoError(t, err)
|
||||
defer authResp.Body.Close()
|
||||
|
||||
// Handle the response - check for error first.
|
||||
// Handle the response - check for error first
|
||||
if authResp.StatusCode == http.StatusBadRequest {
|
||||
// Read error response for debugging.
|
||||
// Read error response for debugging
|
||||
bodyBytes, err := io.ReadAll(authResp.Body)
|
||||
require.NoError(t, err)
|
||||
t.Logf("OAuth2 authorization error: %s", string(bodyBytes))
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// Handle consent flow if needed.
|
||||
// Handle consent flow if needed
|
||||
if authResp.StatusCode == http.StatusOK {
|
||||
// This means we got the consent page, now we need to POST consent.
|
||||
// This means we got the consent page, now we need to POST consent
|
||||
consentReq, err := http.NewRequestWithContext(ctx, "POST", authURL, nil)
|
||||
require.NoError(t, err)
|
||||
consentReq.Header.Set("Cookie", fmt.Sprintf("coder_session_token=%s", coderClient.SessionToken()))
|
||||
consentReq.Header.Set("Authorization", "Bearer "+coderClient.SessionToken())
|
||||
consentReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
authResp, err = authClient.Do(consentReq)
|
||||
@@ -927,7 +910,7 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
|
||||
defer authResp.Body.Close()
|
||||
}
|
||||
|
||||
// Extract authorization code from redirect.
|
||||
// Extract authorization code from redirect
|
||||
require.True(t, authResp.StatusCode >= 300 && authResp.StatusCode < 400,
|
||||
"Expected redirect response, got %d", authResp.StatusCode)
|
||||
location := authResp.Header.Get("Location")
|
||||
@@ -940,14 +923,13 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
|
||||
|
||||
t.Logf("Successfully obtained authorization code: %s", authCode[:10]+"...")
|
||||
|
||||
// Step 4: Exchange authorization code for access token.
|
||||
// Step 4: Exchange authorization code for access token
|
||||
tokenRequestBody := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {clientID},
|
||||
"client_secret": {clientSecret},
|
||||
"code": {authCode},
|
||||
"redirect_uri": {"http://localhost:3000/callback"},
|
||||
"code_verifier": {dynamicVerifier},
|
||||
}
|
||||
|
||||
tokenReq, err := http.NewRequestWithContext(ctx, "POST", api.AccessURL.String()+"/oauth2/tokens",
|
||||
|
||||
@@ -287,18 +287,9 @@ func (api *API) paginatedMembers(rw http.ResponseWriter, r *http.Request) {
|
||||
memberRows = append(memberRows, row)
|
||||
}
|
||||
|
||||
if len(paginatedMemberRows) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.PaginatedMembersResponse{
|
||||
Members: []codersdk.OrganizationMemberWithUserData{},
|
||||
Count: 0,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
members, err := convertOrganizationMembersWithUserData(ctx, api.Database, memberRows)
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := codersdk.PaginatedMembersResponse{
|
||||
|
||||
+41
-92
@@ -2,8 +2,6 @@ package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -291,6 +289,7 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
|
||||
authError: "Invalid query params:",
|
||||
},
|
||||
{
|
||||
// TODO: This is valid for now, but should it be?
|
||||
name: "DifferentProtocol",
|
||||
app: apps.Default,
|
||||
preAuth: func(valid *oauth2.Config) {
|
||||
@@ -298,7 +297,6 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
|
||||
newURL.Scheme = "https"
|
||||
valid.RedirectURL = newURL.String()
|
||||
},
|
||||
authError: "Invalid query params:",
|
||||
},
|
||||
{
|
||||
name: "NestedPath",
|
||||
@@ -308,7 +306,6 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
|
||||
newURL.Path = path.Join(newURL.Path, "nested")
|
||||
valid.RedirectURL = newURL.String()
|
||||
},
|
||||
authError: "Invalid query params:",
|
||||
},
|
||||
{
|
||||
// Some oauth implementations allow this, but our users can host
|
||||
@@ -484,12 +481,11 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
|
||||
}
|
||||
|
||||
var code string
|
||||
var verifier string
|
||||
if test.defaultCode != nil {
|
||||
code = *test.defaultCode
|
||||
} else {
|
||||
var err error
|
||||
code, verifier, err = authorizationFlow(ctx, userClient, valid)
|
||||
code, err = authorizationFlow(ctx, userClient, valid)
|
||||
if test.authError != "" {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, test.authError)
|
||||
@@ -504,12 +500,8 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
|
||||
test.preToken(valid)
|
||||
}
|
||||
|
||||
// Do the actual exchange. Include PKCE code_verifier when
|
||||
// we obtained a code through the authorization flow.
|
||||
exchangeOpts := append([]oauth2.AuthCodeOption{
|
||||
oauth2.SetAuthURLParam("code_verifier", verifier),
|
||||
}, test.exchangeMutate...)
|
||||
token, err := valid.Exchange(ctx, code, exchangeOpts...)
|
||||
// Do the actual exchange.
|
||||
token, err := valid.Exchange(ctx, code, test.exchangeMutate...)
|
||||
if test.tokenError != "" {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, test.tokenError)
|
||||
@@ -691,11 +683,10 @@ func TestOAuth2ProviderTokenRefresh(t *testing.T) {
|
||||
}
|
||||
|
||||
type exchangeSetup struct {
|
||||
cfg *oauth2.Config
|
||||
app codersdk.OAuth2ProviderApp
|
||||
secret codersdk.OAuth2ProviderAppSecretFull
|
||||
code string
|
||||
verifier string
|
||||
cfg *oauth2.Config
|
||||
app codersdk.OAuth2ProviderApp
|
||||
secret codersdk.OAuth2ProviderAppSecretFull
|
||||
code string
|
||||
}
|
||||
|
||||
func TestOAuth2ProviderRevoke(t *testing.T) {
|
||||
@@ -739,13 +730,11 @@ func TestOAuth2ProviderRevoke(t *testing.T) {
|
||||
name: "OverrideCodeAndToken",
|
||||
fn: func(ctx context.Context, client *codersdk.Client, s exchangeSetup) {
|
||||
// Generating a new code should wipe out the old code.
|
||||
code, verifier, err := authorizationFlow(ctx, client, s.cfg)
|
||||
code, err := authorizationFlow(ctx, client, s.cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generating a new token should wipe out the old token.
|
||||
_, err = s.cfg.Exchange(ctx, code,
|
||||
oauth2.SetAuthURLParam("code_verifier", verifier),
|
||||
)
|
||||
_, err = s.cfg.Exchange(ctx, code)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
replacesToken: true,
|
||||
@@ -781,15 +770,14 @@ func TestOAuth2ProviderRevoke(t *testing.T) {
|
||||
}
|
||||
|
||||
// Go through the auth flow to get a code.
|
||||
code, verifier, err := authorizationFlow(ctx, testClient, cfg)
|
||||
code, err := authorizationFlow(ctx, testClient, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
return exchangeSetup{
|
||||
cfg: cfg,
|
||||
app: app,
|
||||
secret: secret,
|
||||
code: code,
|
||||
verifier: verifier,
|
||||
cfg: cfg,
|
||||
app: app,
|
||||
secret: secret,
|
||||
code: code,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -806,16 +794,12 @@ func TestOAuth2ProviderRevoke(t *testing.T) {
|
||||
test.fn(ctx, testClient, testEntities)
|
||||
|
||||
// Exchange should fail because the code should be gone.
|
||||
_, err := testEntities.cfg.Exchange(ctx, testEntities.code,
|
||||
oauth2.SetAuthURLParam("code_verifier", testEntities.verifier),
|
||||
)
|
||||
_, err := testEntities.cfg.Exchange(ctx, testEntities.code)
|
||||
require.Error(t, err)
|
||||
|
||||
// Try again, this time letting the exchange complete first.
|
||||
testEntities = setup(ctx, testClient, test.name+"-2")
|
||||
token, err := testEntities.cfg.Exchange(ctx, testEntities.code,
|
||||
oauth2.SetAuthURLParam("code_verifier", testEntities.verifier),
|
||||
)
|
||||
token, err := testEntities.cfg.Exchange(ctx, testEntities.code)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate the returned access token and that the app is listed.
|
||||
@@ -888,38 +872,25 @@ func generateApps(ctx context.Context, t *testing.T, client *codersdk.Client, su
|
||||
}
|
||||
}
|
||||
|
||||
// generatePKCE creates a PKCE verifier and S256 challenge for testing.
|
||||
func generatePKCE() (verifier, challenge string) {
|
||||
verifier = uuid.NewString() + uuid.NewString()
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||
return verifier, challenge
|
||||
}
|
||||
|
||||
func authorizationFlow(ctx context.Context, client *codersdk.Client, cfg *oauth2.Config) (code, codeVerifier string, err error) {
|
||||
func authorizationFlow(ctx context.Context, client *codersdk.Client, cfg *oauth2.Config) (string, error) {
|
||||
state := uuid.NewString()
|
||||
codeVerifier, challenge := generatePKCE()
|
||||
authURL := cfg.AuthCodeURL(state,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
)
|
||||
authURL := cfg.AuthCodeURL(state)
|
||||
|
||||
// Make a POST request to simulate clicking "Allow" on the authorization page.
|
||||
// This bypasses the HTML consent page and directly processes the authorization.
|
||||
code, err = oidctest.OAuth2GetCode(
|
||||
// Make a POST request to simulate clicking "Allow" on the authorization page
|
||||
// This bypasses the HTML consent page and directly processes the authorization
|
||||
return oidctest.OAuth2GetCode(
|
||||
authURL,
|
||||
func(req *http.Request) (*http.Response, error) {
|
||||
// Change to POST to simulate the form submission.
|
||||
// Change to POST to simulate the form submission
|
||||
req.Method = http.MethodPost
|
||||
|
||||
// Prevent automatic redirect following.
|
||||
// Prevent automatic redirect following
|
||||
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
return client.Request(ctx, req.Method, req.URL.String(), nil)
|
||||
},
|
||||
)
|
||||
return code, codeVerifier, err
|
||||
}
|
||||
|
||||
func must[T any](value T, err error) T {
|
||||
@@ -1026,15 +997,11 @@ func TestOAuth2ProviderResourceIndicators(t *testing.T) {
|
||||
Scopes: []string{},
|
||||
}
|
||||
|
||||
// Step 1: Authorization with resource parameter and PKCE.
|
||||
// Step 1: Authorization with resource parameter
|
||||
state := uuid.NewString()
|
||||
verifier, challenge := generatePKCE()
|
||||
authURL := cfg.AuthCodeURL(state,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
)
|
||||
authURL := cfg.AuthCodeURL(state)
|
||||
if test.authResource != "" {
|
||||
// Add resource parameter to auth URL.
|
||||
// Add resource parameter to auth URL
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
require.NoError(t, err)
|
||||
query := parsedURL.Query()
|
||||
@@ -1063,7 +1030,7 @@ func TestOAuth2ProviderResourceIndicators(t *testing.T) {
|
||||
|
||||
// Step 2: Token exchange with resource parameter
|
||||
// Use custom token exchange since golang.org/x/oauth2 doesn't support resource parameter in token requests
|
||||
token, err := customTokenExchange(ctx, ownerClient.URL.String(), apps.Default.ID.String(), secret.ClientSecretFull, code, apps.Default.CallbackURL, test.tokenResource, verifier)
|
||||
token, err := customTokenExchange(ctx, ownerClient.URL.String(), apps.Default.ID.String(), secret.ClientSecretFull, code, apps.Default.CallbackURL, test.tokenResource)
|
||||
if test.expectTokenError {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid_target")
|
||||
@@ -1160,13 +1127,9 @@ func TestOAuth2ProviderCrossResourceAudienceValidation(t *testing.T) {
|
||||
Scopes: []string{},
|
||||
}
|
||||
|
||||
// Authorization with resource parameter for server1 and PKCE.
|
||||
// Authorization with resource parameter for server1
|
||||
state := uuid.NewString()
|
||||
verifier, challenge := generatePKCE()
|
||||
authURL := cfg.AuthCodeURL(state,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
)
|
||||
authURL := cfg.AuthCodeURL(state)
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
require.NoError(t, err)
|
||||
query := parsedURL.Query()
|
||||
@@ -1186,11 +1149,8 @@ func TestOAuth2ProviderCrossResourceAudienceValidation(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exchange code for token with resource parameter and PKCE verifier.
|
||||
token, err := cfg.Exchange(ctx, code,
|
||||
oauth2.SetAuthURLParam("resource", resource1),
|
||||
oauth2.SetAuthURLParam("code_verifier", verifier),
|
||||
)
|
||||
// Exchange code for token with resource parameter
|
||||
token, err := cfg.Exchange(ctx, code, oauth2.SetAuthURLParam("resource", resource1))
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token.AccessToken)
|
||||
|
||||
@@ -1266,11 +1226,9 @@ func TestOAuth2RefreshExpiryOutlivesAccess(t *testing.T) {
|
||||
}
|
||||
|
||||
// Authorization and token exchange
|
||||
code, verifier, err := authorizationFlow(ctx, ownerClient, cfg)
|
||||
code, err := authorizationFlow(ctx, ownerClient, cfg)
|
||||
require.NoError(t, err)
|
||||
tok, err := cfg.Exchange(ctx, code,
|
||||
oauth2.SetAuthURLParam("code_verifier", verifier),
|
||||
)
|
||||
tok, err := cfg.Exchange(ctx, code)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, tok.AccessToken)
|
||||
require.NotEmpty(t, tok.RefreshToken)
|
||||
@@ -1295,7 +1253,7 @@ func TestOAuth2RefreshExpiryOutlivesAccess(t *testing.T) {
|
||||
|
||||
// customTokenExchange performs a custom OAuth2 token exchange with support for resource parameter
|
||||
// This is needed because golang.org/x/oauth2 doesn't support custom parameters in token requests
|
||||
func customTokenExchange(ctx context.Context, baseURL, clientID, clientSecret, code, redirectURI, resource, codeVerifier string) (*oauth2.Token, error) {
|
||||
func customTokenExchange(ctx context.Context, baseURL, clientID, clientSecret, code, redirectURI, resource string) (*oauth2.Token, error) {
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
@@ -1305,9 +1263,6 @@ func customTokenExchange(ctx context.Context, baseURL, clientID, clientSecret, c
|
||||
if resource != "" {
|
||||
data.Set("resource", resource)
|
||||
}
|
||||
if codeVerifier != "" {
|
||||
data.Set("code_verifier", codeVerifier)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/oauth2/tokens", strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
@@ -1682,21 +1637,17 @@ func TestOAuth2CoderClient(t *testing.T) {
|
||||
// Make a new user
|
||||
client, user := coderdtest.CreateAnotherUser(t, owner, first.OrganizationID)
|
||||
|
||||
// Do an OAuth2 token exchange and get a new client with an oauth token.
|
||||
// Do an OAuth2 token exchange and get a new client with an oauth token
|
||||
state := uuid.NewString()
|
||||
verifier, challenge := generatePKCE()
|
||||
|
||||
// Get an OAuth2 code for a token exchange.
|
||||
// Get an OAuth2 code for a token exchange
|
||||
code, err := oidctest.OAuth2GetCode(
|
||||
cfg.AuthCodeURL(state,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
),
|
||||
cfg.AuthCodeURL(state),
|
||||
func(req *http.Request) (*http.Response, error) {
|
||||
// Change to POST to simulate the form submission.
|
||||
// Change to POST to simulate the form submission
|
||||
req.Method = http.MethodPost
|
||||
|
||||
// Prevent automatic redirect following.
|
||||
// Prevent automatic redirect following
|
||||
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
@@ -1705,9 +1656,7 @@ func TestOAuth2CoderClient(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := cfg.Exchange(ctx, code,
|
||||
oauth2.SetAuthURLParam("code_verifier", verifier),
|
||||
)
|
||||
token, err := cfg.Exchange(ctx, code)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Use the oauth client's authentication
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package oauth2provider
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -11,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/justinas/nosurf"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -25,7 +22,6 @@ import (
|
||||
type authorizeParams struct {
|
||||
clientID string
|
||||
redirectURL *url.URL
|
||||
redirectURIProvided bool
|
||||
responseType codersdk.OAuth2ProviderResponseType
|
||||
scope []string
|
||||
state string
|
||||
@@ -38,13 +34,11 @@ func extractAuthorizeParams(r *http.Request, callbackURL *url.URL) (authorizePar
|
||||
p := httpapi.NewQueryParamParser()
|
||||
vals := r.URL.Query()
|
||||
|
||||
// response_type and client_id are always required.
|
||||
p.RequiredNotEmpty("response_type", "client_id")
|
||||
|
||||
params := authorizeParams{
|
||||
clientID: p.String(vals, "", "client_id"),
|
||||
redirectURL: p.RedirectURL(vals, callbackURL, "redirect_uri"),
|
||||
redirectURIProvided: vals.Get("redirect_uri") != "",
|
||||
responseType: httpapi.ParseCustom(p, vals, "", "response_type", httpapi.ParseEnum[codersdk.OAuth2ProviderResponseType]),
|
||||
scope: strings.Fields(strings.TrimSpace(p.String(vals, "", "scope"))),
|
||||
state: p.String(vals, "", "state"),
|
||||
@@ -52,15 +46,6 @@ func extractAuthorizeParams(r *http.Request, callbackURL *url.URL) (authorizePar
|
||||
codeChallenge: p.String(vals, "", "code_challenge"),
|
||||
codeChallengeMethod: p.String(vals, "", "code_challenge_method"),
|
||||
}
|
||||
|
||||
// PKCE is required for authorization code flow requests.
|
||||
if params.responseType == codersdk.OAuth2ProviderResponseTypeCode && params.codeChallenge == "" {
|
||||
p.Errors = append(p.Errors, codersdk.ValidationError{
|
||||
Field: "code_challenge",
|
||||
Detail: `Query param "code_challenge" is required and cannot be empty`,
|
||||
})
|
||||
}
|
||||
|
||||
// Validate resource indicator syntax (RFC 8707): must be absolute URI without fragment
|
||||
if err := validateResourceParameter(params.resource); err != nil {
|
||||
p.Errors = append(p.Errors, codersdk.ValidationError{
|
||||
@@ -127,22 +112,6 @@ func ShowAuthorizePage(accessURL *url.URL) http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
if params.responseType != codersdk.OAuth2ProviderResponseTypeCode {
|
||||
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
|
||||
Status: http.StatusBadRequest,
|
||||
HideStatus: false,
|
||||
Title: "Unsupported Response Type",
|
||||
Description: "Only response_type=code is supported.",
|
||||
Actions: []site.Action{
|
||||
{
|
||||
URL: accessURL.String(),
|
||||
Text: "Back to site",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
cancel := params.redirectURL
|
||||
cancelQuery := params.redirectURL.Query()
|
||||
cancelQuery.Add("error", "access_denied")
|
||||
@@ -153,7 +122,6 @@ func ShowAuthorizePage(accessURL *url.URL) http.HandlerFunc {
|
||||
AppName: app.Name,
|
||||
CancelURI: cancel.String(),
|
||||
RedirectURI: r.URL.String(),
|
||||
CSRFToken: nosurf.Token(r),
|
||||
Username: ua.FriendlyName,
|
||||
})
|
||||
}
|
||||
@@ -179,23 +147,16 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// OAuth 2.1 removes the implicit grant. Only
|
||||
// authorization code flow is supported.
|
||||
if params.responseType != codersdk.OAuth2ProviderResponseTypeCode {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest,
|
||||
codersdk.OAuth2ErrorCodeUnsupportedResponseType,
|
||||
"Only response_type=code is supported")
|
||||
return
|
||||
}
|
||||
|
||||
// code_challenge is required (enforced by RequiredNotEmpty above),
|
||||
// but default the method to S256 if omitted.
|
||||
if params.codeChallengeMethod == "" {
|
||||
params.codeChallengeMethod = string(codersdk.OAuth2PKCECodeChallengeMethodS256)
|
||||
}
|
||||
if err := codersdk.ValidatePKCECodeChallengeMethod(params.codeChallengeMethod); err != nil {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, err.Error())
|
||||
return
|
||||
// Validate PKCE for public clients (MCP requirement)
|
||||
if params.codeChallenge != "" {
|
||||
// If code_challenge is provided but method is not, default to S256
|
||||
if params.codeChallengeMethod == "" {
|
||||
params.codeChallengeMethod = string(codersdk.OAuth2PKCECodeChallengeMethodS256)
|
||||
}
|
||||
if err := codersdk.ValidatePKCECodeChallengeMethod(params.codeChallengeMethod); err != nil {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Ignoring scope for now, but should look into implementing.
|
||||
@@ -233,8 +194,6 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
|
||||
ResourceUri: sql.NullString{String: params.resource, Valid: params.resource != ""},
|
||||
CodeChallenge: sql.NullString{String: params.codeChallenge, Valid: params.codeChallenge != ""},
|
||||
CodeChallengeMethod: sql.NullString{String: params.codeChallengeMethod, Valid: params.codeChallengeMethod != ""},
|
||||
StateHash: hashOAuth2State(params.state),
|
||||
RedirectUri: sql.NullString{String: params.redirectURL.String(), Valid: params.redirectURIProvided},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert oauth2 authorization code: %w", err)
|
||||
@@ -259,16 +218,3 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
|
||||
http.Redirect(rw, r, params.redirectURL.String(), http.StatusFound)
|
||||
}
|
||||
}
|
||||
|
||||
// hashOAuth2State returns a SHA-256 hash of the OAuth2 state parameter. If
|
||||
// the state is empty, it returns a null string.
|
||||
func hashOAuth2State(state string) sql.NullString {
|
||||
if state == "" {
|
||||
return sql.NullString{}
|
||||
}
|
||||
hash := sha256.Sum256([]byte(state))
|
||||
return sql.NullString{
|
||||
String: hex.EncodeToString(hash[:]),
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
//nolint:testpackage // Internal test for unexported hashOAuth2State helper.
|
||||
package oauth2provider
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHashOAuth2State(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("EmptyState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := hashOAuth2State("")
|
||||
assert.False(t, result.Valid, "empty state should return invalid NullString")
|
||||
assert.Empty(t, result.String, "empty state should return empty string")
|
||||
})
|
||||
|
||||
t.Run("NonEmptyState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
state := "test-state-value"
|
||||
result := hashOAuth2State(state)
|
||||
require.True(t, result.Valid, "non-empty state should return valid NullString")
|
||||
|
||||
// Verify it's a proper SHA-256 hash.
|
||||
expected := sha256.Sum256([]byte(state))
|
||||
assert.Equal(t, hex.EncodeToString(expected[:]), result.String,
|
||||
"state hash should be SHA-256 hex digest")
|
||||
})
|
||||
|
||||
t.Run("DifferentStatesProduceDifferentHashes", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
hash1 := hashOAuth2State("state-a")
|
||||
hash2 := hashOAuth2State("state-b")
|
||||
require.True(t, hash1.Valid)
|
||||
require.True(t, hash2.Valid)
|
||||
assert.NotEqual(t, hash1.String, hash2.String,
|
||||
"different states should produce different hashes")
|
||||
})
|
||||
|
||||
t.Run("SameStateProducesSameHash", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
hash1 := hashOAuth2State("deterministic")
|
||||
hash2 := hashOAuth2State("deterministic")
|
||||
require.True(t, hash1.Valid)
|
||||
assert.Equal(t, hash1.String, hash2.String,
|
||||
"same state should produce identical hash")
|
||||
})
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package oauth2provider_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/site"
|
||||
)
|
||||
|
||||
func TestOAuthConsentFormIncludesCSRFToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const csrfFieldValue = "csrf-field-value"
|
||||
req := httptest.NewRequest(http.MethodGet, "https://coder.com/oauth2/authorize", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
site.RenderOAuthAllowPage(rec, req, site.RenderOAuthAllowData{
|
||||
AppName: "Test OAuth App",
|
||||
CancelURI: "https://coder.com/cancel",
|
||||
RedirectURI: "https://coder.com/oauth2/authorize?client_id=test",
|
||||
CSRFToken: csrfFieldValue,
|
||||
Username: "test-user",
|
||||
})
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Result().StatusCode)
|
||||
assert.Contains(t, rec.Body.String(), `name="csrf_token"`)
|
||||
assert.Contains(t, rec.Body.String(), `value="`+csrfFieldValue+`"`)
|
||||
}
|
||||
@@ -158,9 +158,7 @@ func TestOAuth2InvalidPKCE(t *testing.T) {
|
||||
)
|
||||
}
|
||||
|
||||
// TestOAuth2WithoutPKCEIsRejected verifies that authorization requests without
|
||||
// a code_challenge are rejected now that PKCE is mandatory.
|
||||
func TestOAuth2WithoutPKCEIsRejected(t *testing.T) {
|
||||
func TestOAuth2WithoutPKCE(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
@@ -168,15 +166,15 @@ func TestOAuth2WithoutPKCEIsRejected(t *testing.T) {
|
||||
})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create OAuth2 app.
|
||||
app, _ := oauth2providertest.CreateTestOAuth2App(t, client)
|
||||
// Create OAuth2 app
|
||||
app, clientSecret := oauth2providertest.CreateTestOAuth2App(t, client)
|
||||
t.Cleanup(func() {
|
||||
oauth2providertest.CleanupOAuth2App(t, client, app.ID)
|
||||
})
|
||||
|
||||
state := oauth2providertest.GenerateState(t)
|
||||
|
||||
// Authorization without code_challenge should be rejected.
|
||||
// Perform authorization without PKCE
|
||||
authParams := oauth2providertest.AuthorizeParams{
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
@@ -184,9 +182,21 @@ func TestOAuth2WithoutPKCEIsRejected(t *testing.T) {
|
||||
State: state,
|
||||
}
|
||||
|
||||
oauth2providertest.AuthorizeOAuth2AppExpectingError(
|
||||
t, client, client.URL.String(), authParams, http.StatusBadRequest,
|
||||
)
|
||||
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
|
||||
require.NotEmpty(t, code, "should receive authorization code")
|
||||
|
||||
// Exchange code for token without PKCE
|
||||
tokenParams := oauth2providertest.TokenExchangeParams{
|
||||
GrantType: "authorization_code",
|
||||
Code: code,
|
||||
ClientID: app.ID.String(),
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
}
|
||||
|
||||
token := oauth2providertest.ExchangeCodeForToken(t, client.URL.String(), tokenParams)
|
||||
require.NotEmpty(t, token.AccessToken, "should receive access token")
|
||||
require.NotEmpty(t, token.RefreshToken, "should receive refresh token")
|
||||
}
|
||||
|
||||
func TestOAuth2TokenExchangeClientSecretBasic(t *testing.T) {
|
||||
@@ -202,16 +212,13 @@ func TestOAuth2TokenExchangeClientSecretBasic(t *testing.T) {
|
||||
oauth2providertest.CleanupOAuth2App(t, client, app.ID)
|
||||
})
|
||||
|
||||
codeVerifier, codeChallenge := oauth2providertest.GeneratePKCE(t)
|
||||
state := oauth2providertest.GenerateState(t)
|
||||
|
||||
authParams := oauth2providertest.AuthorizeParams{
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: "S256",
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
}
|
||||
|
||||
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
|
||||
@@ -222,7 +229,6 @@ func TestOAuth2TokenExchangeClientSecretBasic(t *testing.T) {
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", oauth2providertest.TestRedirectURI)
|
||||
data.Set("code_verifier", codeVerifier)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", client.URL.String()+"/oauth2/tokens", strings.NewReader(data.Encode()))
|
||||
require.NoError(t, err, "failed to create token request")
|
||||
@@ -259,16 +265,13 @@ func TestOAuth2TokenExchangeClientSecretBasicInvalidSecret(t *testing.T) {
|
||||
oauth2providertest.CleanupOAuth2App(t, client, app.ID)
|
||||
})
|
||||
|
||||
codeVerifier, codeChallenge := oauth2providertest.GeneratePKCE(t)
|
||||
state := oauth2providertest.GenerateState(t)
|
||||
|
||||
authParams := oauth2providertest.AuthorizeParams{
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: "S256",
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
}
|
||||
|
||||
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
|
||||
@@ -279,7 +282,6 @@ func TestOAuth2TokenExchangeClientSecretBasicInvalidSecret(t *testing.T) {
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", oauth2providertest.TestRedirectURI)
|
||||
data.Set("code_verifier", codeVerifier)
|
||||
|
||||
wrongSecret := clientSecret + "x"
|
||||
|
||||
@@ -347,30 +349,26 @@ func TestOAuth2ResourceParameter(t *testing.T) {
|
||||
})
|
||||
|
||||
state := oauth2providertest.GenerateState(t)
|
||||
codeVerifier, codeChallenge := oauth2providertest.GeneratePKCE(t)
|
||||
|
||||
// Perform authorization with resource parameter.
|
||||
// Perform authorization with resource parameter
|
||||
authParams := oauth2providertest.AuthorizeParams{
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: "S256",
|
||||
Resource: oauth2providertest.TestResourceURI,
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
Resource: oauth2providertest.TestResourceURI,
|
||||
}
|
||||
|
||||
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
|
||||
require.NotEmpty(t, code, "should receive authorization code")
|
||||
|
||||
// Exchange code for token with resource parameter.
|
||||
// Exchange code for token with resource parameter
|
||||
tokenParams := oauth2providertest.TokenExchangeParams{
|
||||
GrantType: "authorization_code",
|
||||
Code: code,
|
||||
ClientID: app.ID.String(),
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
CodeVerifier: codeVerifier,
|
||||
Resource: oauth2providertest.TestResourceURI,
|
||||
}
|
||||
|
||||
@@ -394,16 +392,13 @@ func TestOAuth2TokenRefresh(t *testing.T) {
|
||||
})
|
||||
|
||||
state := oauth2providertest.GenerateState(t)
|
||||
codeVerifier, codeChallenge := oauth2providertest.GeneratePKCE(t)
|
||||
|
||||
// Get initial token.
|
||||
// Get initial token
|
||||
authParams := oauth2providertest.AuthorizeParams{
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: "S256",
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
}
|
||||
|
||||
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
|
||||
@@ -414,7 +409,6 @@ func TestOAuth2TokenRefresh(t *testing.T) {
|
||||
ClientID: app.ID.String(),
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
CodeVerifier: codeVerifier,
|
||||
}
|
||||
|
||||
initialToken := oauth2providertest.ExchangeCodeForToken(t, client.URL.String(), tokenParams)
|
||||
|
||||
@@ -254,27 +254,14 @@ func authorizationCodeGrant(ctx context.Context, db database.Store, app database
|
||||
return codersdk.OAuth2TokenResponse{}, errBadCode
|
||||
}
|
||||
|
||||
// Verify redirect_uri matches the one used during authorization
|
||||
// (RFC 6749 §4.1.3).
|
||||
if dbCode.RedirectUri.Valid && dbCode.RedirectUri.String != "" {
|
||||
if req.RedirectURI != dbCode.RedirectUri.String {
|
||||
return codersdk.OAuth2TokenResponse{}, errBadCode
|
||||
// Verify PKCE challenge if present
|
||||
if dbCode.CodeChallenge.Valid && dbCode.CodeChallenge.String != "" {
|
||||
if req.CodeVerifier == "" {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
|
||||
}
|
||||
if !VerifyPKCE(dbCode.CodeChallenge.String, req.CodeVerifier) {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
|
||||
}
|
||||
}
|
||||
|
||||
// PKCE is mandatory for all authorization code flows
|
||||
// (OAuth 2.1). Verify the code verifier against the stored
|
||||
// challenge.
|
||||
if req.CodeVerifier == "" {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
|
||||
}
|
||||
if !dbCode.CodeChallenge.Valid || dbCode.CodeChallenge.String == "" {
|
||||
// Code was issued without a challenge — should not happen
|
||||
// with authorize endpoint enforcement, but defend in depth.
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
|
||||
}
|
||||
if !VerifyPKCE(dbCode.CodeChallenge.String, req.CodeVerifier) {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
|
||||
}
|
||||
|
||||
// Verify resource parameter consistency (RFC 8707)
|
||||
|
||||
@@ -318,7 +318,6 @@ func TestExtractAuthorizeParams_Scopes(t *testing.T) {
|
||||
query.Set("response_type", "code")
|
||||
query.Set("client_id", "test-client")
|
||||
query.Set("redirect_uri", "http://localhost:3000/callback")
|
||||
query.Set("code_challenge", "test-challenge")
|
||||
if tc.scopeParam != "" {
|
||||
query.Set("scope", tc.scopeParam)
|
||||
}
|
||||
@@ -342,34 +341,6 @@ func TestExtractAuthorizeParams_Scopes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractAuthorizeParams_TokenResponseTypeDoesNotRequirePKCE ensures
|
||||
// response_type=token is parsed without requiring PKCE fields so callers can
|
||||
// return unsupported_response_type instead of invalid_request.
|
||||
func TestExtractAuthorizeParams_TokenResponseTypeDoesNotRequirePKCE(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
callbackURL, err := url.Parse("http://localhost:3000/callback")
|
||||
require.NoError(t, err)
|
||||
|
||||
query := url.Values{}
|
||||
query.Set("response_type", string(codersdk.OAuth2ProviderResponseTypeToken))
|
||||
query.Set("client_id", "test-client")
|
||||
query.Set("redirect_uri", "http://localhost:3000/callback")
|
||||
|
||||
reqURL, err := url.Parse("http://localhost:8080/oauth2/authorize?" + query.Encode())
|
||||
require.NoError(t, err)
|
||||
|
||||
req := &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: reqURL,
|
||||
}
|
||||
|
||||
params, validationErrs, err := extractAuthorizeParams(req, callbackURL)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, validationErrs)
|
||||
require.Equal(t, codersdk.OAuth2ProviderResponseTypeToken, params.responseType)
|
||||
}
|
||||
|
||||
// TestRefreshTokenGrant_Scopes tests that scopes can be requested during refresh
|
||||
func TestRefreshTokenGrant_Scopes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -19,9 +19,9 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
templatesActiveUsersDesc = prometheus.NewDesc("coderd_insights_templates_active_users", "The number of active users of the template.", []string{"template_name", "organization_name"}, nil)
|
||||
applicationsUsageSecondsDesc = prometheus.NewDesc("coderd_insights_applications_usage_seconds", "The application usage per template.", []string{"template_name", "application_name", "slug", "organization_name"}, nil)
|
||||
parametersDesc = prometheus.NewDesc("coderd_insights_parameters", "The parameter usage per template.", []string{"template_name", "parameter_name", "parameter_type", "parameter_value", "organization_name"}, nil)
|
||||
templatesActiveUsersDesc = prometheus.NewDesc("coderd_insights_templates_active_users", "The number of active users of the template.", []string{"template_name"}, nil)
|
||||
applicationsUsageSecondsDesc = prometheus.NewDesc("coderd_insights_applications_usage_seconds", "The application usage per template.", []string{"template_name", "application_name", "slug"}, nil)
|
||||
parametersDesc = prometheus.NewDesc("coderd_insights_parameters", "The parameter usage per template.", []string{"template_name", "parameter_name", "parameter_type", "parameter_value"}, nil)
|
||||
)
|
||||
|
||||
type MetricsCollector struct {
|
||||
@@ -38,8 +38,7 @@ type insightsData struct {
|
||||
apps []database.GetTemplateAppInsightsByTemplateRow
|
||||
params []parameterRow
|
||||
|
||||
templateNames map[uuid.UUID]string
|
||||
organizationNames map[uuid.UUID]string // template ID → org name
|
||||
templateNames map[uuid.UUID]string
|
||||
}
|
||||
|
||||
type parameterRow struct {
|
||||
@@ -138,7 +137,6 @@ func (mc *MetricsCollector) Run(ctx context.Context) (func(), error) {
|
||||
templateIDs := uniqueTemplateIDs(templateInsights, appInsights, paramInsights)
|
||||
|
||||
templateNames := make(map[uuid.UUID]string, len(templateIDs))
|
||||
organizationNames := make(map[uuid.UUID]string, len(templateIDs))
|
||||
if len(templateIDs) > 0 {
|
||||
templates, err := mc.database.GetTemplatesWithFilter(ctx, database.GetTemplatesWithFilterParams{
|
||||
IDs: templateIDs,
|
||||
@@ -148,31 +146,6 @@ func (mc *MetricsCollector) Run(ctx context.Context) (func(), error) {
|
||||
return
|
||||
}
|
||||
templateNames = onlyTemplateNames(templates)
|
||||
|
||||
// Build org name lookup so that metrics can
|
||||
// distinguish templates with the same name across
|
||||
// different organizations.
|
||||
orgIDs := make([]uuid.UUID, 0, len(templates))
|
||||
for _, t := range templates {
|
||||
orgIDs = append(orgIDs, t.OrganizationID)
|
||||
}
|
||||
orgIDs = slice.Unique(orgIDs)
|
||||
|
||||
orgs, err := mc.database.GetOrganizations(ctx, database.GetOrganizationsParams{
|
||||
IDs: orgIDs,
|
||||
})
|
||||
if err != nil {
|
||||
mc.logger.Error(ctx, "unable to fetch organizations from database", slog.Error(err))
|
||||
return
|
||||
}
|
||||
orgNameByID := make(map[uuid.UUID]string, len(orgs))
|
||||
for _, o := range orgs {
|
||||
orgNameByID[o.ID] = o.Name
|
||||
}
|
||||
organizationNames = make(map[uuid.UUID]string, len(templates))
|
||||
for _, t := range templates {
|
||||
organizationNames[t.ID] = orgNameByID[t.OrganizationID]
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh the collector state
|
||||
@@ -181,8 +154,7 @@ func (mc *MetricsCollector) Run(ctx context.Context) (func(), error) {
|
||||
apps: appInsights,
|
||||
params: paramInsights,
|
||||
|
||||
templateNames: templateNames,
|
||||
organizationNames: organizationNames,
|
||||
templateNames: templateNames,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -222,46 +194,44 @@ func (mc *MetricsCollector) Collect(metricsCh chan<- prometheus.Metric) {
|
||||
// Custom apps
|
||||
for _, appRow := range data.apps {
|
||||
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue, float64(appRow.UsageSeconds), data.templateNames[appRow.TemplateID],
|
||||
appRow.DisplayName, appRow.SlugOrPort, data.organizationNames[appRow.TemplateID])
|
||||
appRow.DisplayName, appRow.SlugOrPort)
|
||||
}
|
||||
|
||||
// Built-in apps
|
||||
for _, templateRow := range data.templates {
|
||||
orgName := data.organizationNames[templateRow.TemplateID]
|
||||
|
||||
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue,
|
||||
float64(templateRow.UsageVscodeSeconds),
|
||||
data.templateNames[templateRow.TemplateID],
|
||||
codersdk.TemplateBuiltinAppDisplayNameVSCode,
|
||||
"", orgName)
|
||||
"")
|
||||
|
||||
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue,
|
||||
float64(templateRow.UsageJetbrainsSeconds),
|
||||
data.templateNames[templateRow.TemplateID],
|
||||
codersdk.TemplateBuiltinAppDisplayNameJetBrains,
|
||||
"", orgName)
|
||||
"")
|
||||
|
||||
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue,
|
||||
float64(templateRow.UsageReconnectingPtySeconds),
|
||||
data.templateNames[templateRow.TemplateID],
|
||||
codersdk.TemplateBuiltinAppDisplayNameWebTerminal,
|
||||
"", orgName)
|
||||
"")
|
||||
|
||||
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue,
|
||||
float64(templateRow.UsageSshSeconds),
|
||||
data.templateNames[templateRow.TemplateID],
|
||||
codersdk.TemplateBuiltinAppDisplayNameSSH,
|
||||
"", orgName)
|
||||
"")
|
||||
}
|
||||
|
||||
// Templates
|
||||
for _, templateRow := range data.templates {
|
||||
metricsCh <- prometheus.MustNewConstMetric(templatesActiveUsersDesc, prometheus.GaugeValue, float64(templateRow.ActiveUsers), data.templateNames[templateRow.TemplateID], data.organizationNames[templateRow.TemplateID])
|
||||
metricsCh <- prometheus.MustNewConstMetric(templatesActiveUsersDesc, prometheus.GaugeValue, float64(templateRow.ActiveUsers), data.templateNames[templateRow.TemplateID])
|
||||
}
|
||||
|
||||
// Parameters
|
||||
for _, parameterRow := range data.params {
|
||||
metricsCh <- prometheus.MustNewConstMetric(parametersDesc, prometheus.GaugeValue, float64(parameterRow.count), data.templateNames[parameterRow.templateID], parameterRow.name, parameterRow.aType, parameterRow.value, data.organizationNames[parameterRow.templateID])
|
||||
metricsCh <- prometheus.MustNewConstMetric(parametersDesc, prometheus.GaugeValue, float64(parameterRow.count), data.templateNames[parameterRow.templateID], parameterRow.name, parameterRow.aType, parameterRow.value)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
{
|
||||
"coderd_insights_applications_usage_seconds[application_name=JetBrains,organization_name=coder,slug=,template_name=golden-template]": 60,
|
||||
"coderd_insights_applications_usage_seconds[application_name=Visual Studio Code,organization_name=coder,slug=,template_name=golden-template]": 60,
|
||||
"coderd_insights_applications_usage_seconds[application_name=Web Terminal,organization_name=coder,slug=,template_name=golden-template]": 0,
|
||||
"coderd_insights_applications_usage_seconds[application_name=SSH,organization_name=coder,slug=,template_name=golden-template]": 60,
|
||||
"coderd_insights_applications_usage_seconds[application_name=Golden Slug,organization_name=coder,slug=golden-slug,template_name=golden-template]": 180,
|
||||
"coderd_insights_parameters[organization_name=coder,parameter_name=first_parameter,parameter_type=string,parameter_value=Foobar,template_name=golden-template]": 1,
|
||||
"coderd_insights_parameters[organization_name=coder,parameter_name=first_parameter,parameter_type=string,parameter_value=Baz,template_name=golden-template]": 1,
|
||||
"coderd_insights_parameters[organization_name=coder,parameter_name=second_parameter,parameter_type=bool,parameter_value=true,template_name=golden-template]": 2,
|
||||
"coderd_insights_parameters[organization_name=coder,parameter_name=third_parameter,parameter_type=number,parameter_value=789,template_name=golden-template]": 1,
|
||||
"coderd_insights_parameters[organization_name=coder,parameter_name=third_parameter,parameter_type=number,parameter_value=999,template_name=golden-template]": 1,
|
||||
"coderd_insights_templates_active_users[organization_name=coder,template_name=golden-template]": 1
|
||||
"coderd_insights_applications_usage_seconds[application_name=JetBrains,slug=,template_name=golden-template]": 60,
|
||||
"coderd_insights_applications_usage_seconds[application_name=Visual Studio Code,slug=,template_name=golden-template]": 60,
|
||||
"coderd_insights_applications_usage_seconds[application_name=Web Terminal,slug=,template_name=golden-template]": 0,
|
||||
"coderd_insights_applications_usage_seconds[application_name=SSH,slug=,template_name=golden-template]": 60,
|
||||
"coderd_insights_applications_usage_seconds[application_name=Golden Slug,slug=golden-slug,template_name=golden-template]": 180,
|
||||
"coderd_insights_parameters[parameter_name=first_parameter,parameter_type=string,parameter_value=Foobar,template_name=golden-template]": 1,
|
||||
"coderd_insights_parameters[parameter_name=first_parameter,parameter_type=string,parameter_value=Baz,template_name=golden-template]": 1,
|
||||
"coderd_insights_parameters[parameter_name=second_parameter,parameter_type=bool,parameter_value=true,template_name=golden-template]": 2,
|
||||
"coderd_insights_parameters[parameter_name=third_parameter,parameter_type=number,parameter_value=789,template_name=golden-template]": 1,
|
||||
"coderd_insights_parameters[parameter_name=third_parameter,parameter_type=number,parameter_value=999,template_name=golden-template]": 1,
|
||||
"coderd_insights_templates_active_users[template_name=golden-template]": 1
|
||||
}
|
||||
|
||||
@@ -564,7 +564,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
|
||||
// The check `s.OIDCConfig != nil` is not as strict, since it can be an interface
|
||||
// pointing to a typed nil.
|
||||
if !reflect.ValueOf(s.OIDCConfig).IsNil() {
|
||||
workspaceOwnerOIDCAccessToken, err = ObtainOIDCAccessToken(ctx, s.Logger, s.Database, s.OIDCConfig, owner.ID)
|
||||
workspaceOwnerOIDCAccessToken, err = obtainOIDCAccessToken(ctx, s.Database, s.OIDCConfig, owner.ID)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("obtain OIDC access token: %s", err))
|
||||
}
|
||||
@@ -725,16 +725,11 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
|
||||
}
|
||||
}
|
||||
|
||||
provisionerStateRow, err := s.Database.GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuild.ID)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("get workspace build provisioner state: %s", err))
|
||||
}
|
||||
|
||||
protoJob.Type = &proto.AcquiredJob_WorkspaceBuild_{
|
||||
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
|
||||
WorkspaceBuildId: workspaceBuild.ID.String(),
|
||||
WorkspaceName: workspace.Name,
|
||||
State: provisionerStateRow.ProvisionerState,
|
||||
State: workspaceBuild.ProvisionerState,
|
||||
RichParameterValues: convertRichParameterValues(workspaceBuildParameters),
|
||||
PreviousParameterValues: convertRichParameterValues(lastWorkspaceBuildParameters),
|
||||
VariableValues: asVariableValues(templateVariables),
|
||||
@@ -845,11 +840,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
|
||||
|
||||
// Record the time the job spent waiting in the queue.
|
||||
if s.metrics != nil && job.StartedAt.Valid && job.Provisioner.Valid() {
|
||||
// These timestamps lose their monotonic clock component after a Postgres
|
||||
// round-trip, so the subtraction is based purely on wall-clock time. Floor at
|
||||
// 1ms as a defensive measure against clock adjustments producing a negative
|
||||
// delta while acknowledging there's a non-zero queue time.
|
||||
queueWaitSeconds := max(job.StartedAt.Time.Sub(job.CreatedAt).Seconds(), 0.001)
|
||||
queueWaitSeconds := job.StartedAt.Time.Sub(job.CreatedAt).Seconds()
|
||||
s.metrics.ObserveJobQueueWait(string(job.Provisioner), string(job.Type), jobTransition, jobBuildReason, queueWaitSeconds)
|
||||
}
|
||||
|
||||
@@ -3075,37 +3066,9 @@ func deleteSessionTokenForUserAndWorkspace(ctx context.Context, db database.Stor
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldRefreshOIDCToken(link database.UserLink) (bool, time.Time) {
|
||||
if link.OAuthRefreshToken == "" {
|
||||
// We cannot refresh even if we wanted to
|
||||
return false, link.OAuthExpiry
|
||||
}
|
||||
|
||||
if link.OAuthExpiry.IsZero() {
|
||||
// 0 expire means the token never expires, so we shouldn't refresh
|
||||
return false, link.OAuthExpiry
|
||||
}
|
||||
|
||||
// This handles an edge case where the token is about to expire. A workspace
|
||||
// build takes a non-trivial amount of time. If the token is to expire during the
|
||||
// build, then the build risks failure. To mitigate this, refresh the token
|
||||
// prematurely.
|
||||
//
|
||||
// If an OIDC provider issues short-lived tokens less than our defined period,
|
||||
// the token will always be refreshed on every workspace build.
|
||||
//
|
||||
// By setting the expiration backwards, we are effectively shortening the
|
||||
// time a token can be alive for by 10 minutes.
|
||||
// Note: This is how it is done in the oauth2 package's own token refreshing logic.
|
||||
expiresAt := link.OAuthExpiry.Add(-time.Minute * 10)
|
||||
|
||||
// Return if the token is assumed to be expired.
|
||||
return expiresAt.Before(dbtime.Now()), expiresAt
|
||||
}
|
||||
|
||||
// ObtainOIDCAccessToken returns a valid OpenID Connect access token
|
||||
// obtainOIDCAccessToken returns a valid OpenID Connect access token
|
||||
// for the user if it's able to obtain one, otherwise it returns an empty string.
|
||||
func ObtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.Store, oidcConfig promoauth.OAuth2Config, userID uuid.UUID) (string, error) {
|
||||
func obtainOIDCAccessToken(ctx context.Context, db database.Store, oidcConfig promoauth.OAuth2Config, userID uuid.UUID) (string, error) {
|
||||
link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: userID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
@@ -3117,13 +3080,11 @@ func ObtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.
|
||||
return "", xerrors.Errorf("get owner oidc link: %w", err)
|
||||
}
|
||||
|
||||
if shouldRefresh, expiresAt := shouldRefreshOIDCToken(link); shouldRefresh {
|
||||
if link.OAuthExpiry.Before(dbtime.Now()) && !link.OAuthExpiry.IsZero() && link.OAuthRefreshToken != "" {
|
||||
token, err := oidcConfig.TokenSource(ctx, &oauth2.Token{
|
||||
AccessToken: link.OAuthAccessToken,
|
||||
RefreshToken: link.OAuthRefreshToken,
|
||||
// Use the expiresAt returned by shouldRefreshOIDCToken.
|
||||
// It will force a refresh with an expired time.
|
||||
Expiry: expiresAt,
|
||||
Expiry: link.OAuthExpiry,
|
||||
}).Token()
|
||||
if err != nil {
|
||||
// If OIDC fails to refresh, we return an empty string and don't fail.
|
||||
@@ -3148,7 +3109,6 @@ func ObtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("update user link: %w", err)
|
||||
}
|
||||
logger.Info(ctx, "refreshed expired OIDC token for user during workspace build", slog.F("user_id", userID))
|
||||
}
|
||||
|
||||
return link.OAuthAccessToken, nil
|
||||
|
||||
@@ -16,109 +16,13 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestShouldRefreshOIDCToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := dbtime.Now()
|
||||
testCases := []struct {
|
||||
name string
|
||||
link database.UserLink
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "NoRefreshToken",
|
||||
link: database.UserLink{OAuthExpiry: now.Add(-time.Hour)},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "ZeroExpiry",
|
||||
link: database.UserLink{OAuthRefreshToken: "refresh"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "LongExpired",
|
||||
link: database.UserLink{
|
||||
OAuthRefreshToken: "refresh",
|
||||
OAuthExpiry: now.Add(-1 * time.Hour),
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
// Edge being "+/- 10 minutes"
|
||||
name: "EdgeExpired",
|
||||
link: database.UserLink{
|
||||
OAuthRefreshToken: "refresh",
|
||||
OAuthExpiry: now.Add(-1 * time.Minute * 10),
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Expired",
|
||||
link: database.UserLink{
|
||||
OAuthRefreshToken: "refresh",
|
||||
OAuthExpiry: now.Add(-1 * time.Minute),
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "SoonToBeExpired",
|
||||
link: database.UserLink{
|
||||
OAuthRefreshToken: "refresh",
|
||||
OAuthExpiry: now.Add(5 * time.Minute),
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "SoonToBeExpiredEdge",
|
||||
link: database.UserLink{
|
||||
OAuthRefreshToken: "refresh",
|
||||
OAuthExpiry: now.Add(9 * time.Minute),
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "AfterEdge",
|
||||
link: database.UserLink{
|
||||
OAuthRefreshToken: "refresh",
|
||||
OAuthExpiry: now.Add(11 * time.Minute),
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "NotExpired",
|
||||
link: database.UserLink{
|
||||
OAuthRefreshToken: "refresh",
|
||||
OAuthExpiry: now.Add(time.Hour),
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "NotEvenCloseExpired",
|
||||
link: database.UserLink{
|
||||
OAuthRefreshToken: "refresh",
|
||||
OAuthExpiry: now.Add(time.Hour * 24),
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
shouldRefresh, _ := shouldRefreshOIDCToken(tc.link)
|
||||
require.Equal(t, tc.want, shouldRefresh)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestObtainOIDCAccessToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
t.Run("NoToken", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
_, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, nil, uuid.Nil)
|
||||
_, err := obtainOIDCAccessToken(ctx, db, nil, uuid.Nil)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("InvalidConfig", func(t *testing.T) {
|
||||
@@ -131,7 +35,7 @@ func TestObtainOIDCAccessToken(t *testing.T) {
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
OAuthExpiry: dbtime.Now().Add(-time.Hour),
|
||||
})
|
||||
_, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, &oauth2.Config{}, user.ID)
|
||||
_, err := obtainOIDCAccessToken(ctx, db, &oauth2.Config{}, user.ID)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("MissingLink", func(t *testing.T) {
|
||||
@@ -140,7 +44,7 @@ func TestObtainOIDCAccessToken(t *testing.T) {
|
||||
user := dbgen.User(t, db, database.User{
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
})
|
||||
tok, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, &oauth2.Config{}, user.ID)
|
||||
tok, err := obtainOIDCAccessToken(ctx, db, &oauth2.Config{}, user.ID)
|
||||
require.Empty(t, tok)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
@@ -153,7 +57,7 @@ func TestObtainOIDCAccessToken(t *testing.T) {
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
OAuthExpiry: dbtime.Now().Add(-time.Hour),
|
||||
})
|
||||
_, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, &testutil.OAuth2Config{
|
||||
_, err := obtainOIDCAccessToken(ctx, db, &testutil.OAuth2Config{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "token",
|
||||
},
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -31,7 +30,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
@@ -60,175 +58,6 @@ import (
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
// TestTokenIsRefreshedEarly creates a fake OIDC IDP that sets expiration times
|
||||
// of the token to values that are "near expiration". Expiration being 10minutes
|
||||
// earlier than it needs to be. The `ObtainOIDCAccessToken` should refresh these
|
||||
// tokens early.
|
||||
func TestTokenIsRefreshedEarly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("WithCoderd", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tokenRefreshCount := 0
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithServing(),
|
||||
oidctest.WithDefaultExpire(time.Minute*8),
|
||||
oidctest.WithRefresh(func(email string) error {
|
||||
tokenRefreshCount++
|
||||
return nil
|
||||
}),
|
||||
)
|
||||
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
})
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
owner := coderdtest.New(t, &coderdtest.Options{
|
||||
OIDCConfig: cfg,
|
||||
IncludeProvisionerDaemon: true,
|
||||
Database: db,
|
||||
Pubsub: ps,
|
||||
})
|
||||
first := coderdtest.CreateFirstUser(t, owner)
|
||||
version := coderdtest.CreateTemplateVersion(t, owner, first.OrganizationID, nil)
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, owner, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, owner, first.OrganizationID, version.ID)
|
||||
|
||||
// Setup an OIDC user.
|
||||
client, _ := fake.Login(t, owner, jwt.MapClaims{
|
||||
"email": "user@unauthorized.com",
|
||||
"email_verified": true,
|
||||
"sub": uuid.NewString(),
|
||||
})
|
||||
|
||||
// Creating a workspace should refresh the oidc early.
|
||||
tokenRefreshCount = 0
|
||||
wrk := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, wrk.LatestBuild.ID)
|
||||
require.Equal(t, 1, tokenRefreshCount)
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest // Sub tests need to run sequentially.
|
||||
func TestTokenIsRefreshedEarlyWithoutCoderd(t *testing.T) {
|
||||
t.Parallel()
|
||||
tokenRefreshCount := 0
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithServing(),
|
||||
oidctest.WithDefaultExpire(time.Minute*8),
|
||||
oidctest.WithRefresh(func(email string) error {
|
||||
tokenRefreshCount++
|
||||
return nil
|
||||
}),
|
||||
)
|
||||
cfg := fake.OIDCConfig(t, nil)
|
||||
|
||||
// Fetch a valid token from the fake OIDC provider
|
||||
token, err := fake.GenerateAuthenticatedToken(jwt.MapClaims{
|
||||
"email": "user@unauthorized.com",
|
||||
"email_verified": true,
|
||||
"sub": uuid.NewString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
dbgen.UserLink(t, db, database.UserLink{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
LinkedID: "foo",
|
||||
OAuthAccessToken: token.AccessToken,
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
// The oauth expiry does not really matter, since each test will manually control
|
||||
// this value.
|
||||
OAuthExpiry: dbtime.Now().Add(time.Hour),
|
||||
})
|
||||
|
||||
setLinkExpiration := func(t *testing.T, exp time.Time) database.UserLink {
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
links, err := db.GetUserLinksByUserID(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, links, 1)
|
||||
link := links[0]
|
||||
|
||||
newLink, err := db.UpdateUserLink(ctx, database.UpdateUserLinkParams{
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthAccessTokenKeyID: link.OAuthAccessTokenKeyID,
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthRefreshTokenKeyID: link.OAuthRefreshTokenKeyID,
|
||||
OAuthExpiry: exp,
|
||||
Claims: link.Claims,
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return newLink
|
||||
}
|
||||
|
||||
for _, c := range []struct {
|
||||
name string
|
||||
// expires is a function to return a more up to date "now".
|
||||
// Because the oauth library is calling `time.Now()`, we cannot use
|
||||
// mocked clocks.
|
||||
expires func() time.Time
|
||||
refreshExpected bool
|
||||
}{
|
||||
{
|
||||
name: "ZeroExpiry",
|
||||
expires: func() time.Time { return time.Time{} },
|
||||
refreshExpected: false,
|
||||
},
|
||||
{
|
||||
name: "LongExpired",
|
||||
expires: func() time.Time { return dbtime.Now().Add(-time.Hour) },
|
||||
refreshExpected: true,
|
||||
},
|
||||
{
|
||||
name: "EdgeExpired",
|
||||
expires: func() time.Time { return dbtime.Now().Add(-time.Minute * 10) },
|
||||
refreshExpected: true,
|
||||
},
|
||||
{
|
||||
name: "RecentExpired",
|
||||
expires: func() time.Time { return dbtime.Now().Add(-time.Second * -1) },
|
||||
refreshExpected: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "Future",
|
||||
expires: func() time.Time { return dbtime.Now().Add(time.Hour) },
|
||||
refreshExpected: false,
|
||||
},
|
||||
{
|
||||
name: "FutureWithinRefreshWindow",
|
||||
expires: func() time.Time { return dbtime.Now().Add(time.Minute * 8) },
|
||||
refreshExpected: true,
|
||||
},
|
||||
} {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
oldLink := setLinkExpiration(t, c.expires())
|
||||
tokenRefreshCount = 0
|
||||
_, err := provisionerdserver.ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, cfg, user.ID)
|
||||
require.NoError(t, err)
|
||||
links, err := db.GetUserLinksByUserID(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, links, 1)
|
||||
newLink := links[0]
|
||||
|
||||
if c.refreshExpected {
|
||||
require.Equal(t, 1, tokenRefreshCount)
|
||||
|
||||
require.NotEqual(t, oldLink.OAuthAccessToken, newLink.OAuthAccessToken)
|
||||
require.NotEqual(t, oldLink.OAuthRefreshToken, newLink.OAuthRefreshToken)
|
||||
} else {
|
||||
require.Equal(t, 0, tokenRefreshCount)
|
||||
require.Equal(t, oldLink.OAuthAccessToken, newLink.OAuthAccessToken)
|
||||
require.Equal(t, oldLink.OAuthRefreshToken, newLink.OAuthRefreshToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testTemplateScheduleStore() *atomic.Pointer[schedule.TemplateScheduleStore] {
|
||||
poitr := &atomic.Pointer[schedule.TemplateScheduleStore]{}
|
||||
store := schedule.NewAGPLTemplateScheduleStore()
|
||||
@@ -1492,9 +1321,7 @@ func TestFailJob(t *testing.T) {
|
||||
<-publishedLogs
|
||||
build, err := db.GetWorkspaceBuildByID(ctx, buildID)
|
||||
require.NoError(t, err)
|
||||
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "some state", string(provisionerStateRow.ProvisionerState))
|
||||
require.Equal(t, "some state", string(build.ProvisionerState))
|
||||
require.Len(t, auditor.AuditLogs(), 1)
|
||||
|
||||
// Assert that the workspace_id field get populated
|
||||
|
||||
@@ -81,7 +81,6 @@ const (
|
||||
SubjectAibridged SubjectType = "aibridged"
|
||||
SubjectTypeDBPurge SubjectType = "dbpurge"
|
||||
SubjectTypeBoundaryUsageTracker SubjectType = "boundary_usage_tracker"
|
||||
SubjectTypeWorkspaceBuilder SubjectType = "workspace_builder"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -282,40 +282,6 @@ neq(input.object.owner, "");
|
||||
p("'10d03e62-7703-4df5-a358-4f76577d4e2f' = id :: text") + " AND " + p("id :: text != ''") + " AND " + p("'' = ''"),
|
||||
),
|
||||
},
|
||||
{
|
||||
Name: "AuditLogUUID",
|
||||
Queries: []string{
|
||||
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
|
||||
`input.object.org_owner != ""`,
|
||||
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
|
||||
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708", "05f58202-4bfc-43ce-9ba4-5ff6e0174a71"}`,
|
||||
`"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
},
|
||||
ExpectedSQL: p(
|
||||
p("audit_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("audit_logs.organization_id IS NOT NULL") + " OR " +
|
||||
p("audit_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("audit_logs.organization_id = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'::uuid,'8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
|
||||
"(false)"),
|
||||
VariableConverter: regosql.AuditLogConverter(),
|
||||
},
|
||||
{
|
||||
Name: "ConnectionLogUUID",
|
||||
Queries: []string{
|
||||
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
|
||||
`input.object.org_owner != ""`,
|
||||
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
|
||||
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708"}`,
|
||||
`"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
},
|
||||
ExpectedSQL: p(
|
||||
p("connection_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("connection_logs.organization_id IS NOT NULL") + " OR " +
|
||||
p("connection_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("connection_logs.organization_id = ANY(ARRAY ['8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
|
||||
"(false)"),
|
||||
VariableConverter: regosql.ConnectionLogConverter(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -53,7 +53,7 @@ func WorkspaceConverter() *sqltypes.VariableConverter {
|
||||
func AuditLogConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
resourceIDMatcher(),
|
||||
sqltypes.UUIDVarMatcher("audit_logs.organization_id", []string{"input", "object", "org_owner"}),
|
||||
sqltypes.StringVarMatcher("COALESCE(audit_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
|
||||
// Audit logs have no user owner, only owner by an organization.
|
||||
sqltypes.AlwaysFalse(userOwnerMatcher()),
|
||||
)
|
||||
@@ -67,7 +67,7 @@ func AuditLogConverter() *sqltypes.VariableConverter {
|
||||
func ConnectionLogConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
resourceIDMatcher(),
|
||||
sqltypes.UUIDVarMatcher("connection_logs.organization_id", []string{"input", "object", "org_owner"}),
|
||||
sqltypes.StringVarMatcher("COALESCE(connection_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
|
||||
// Connection logs have no user owner, only owner by an organization.
|
||||
sqltypes.AlwaysFalse(userOwnerMatcher()),
|
||||
)
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
package sqltypes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
var (
|
||||
_ VariableMatcher = astUUIDVar{}
|
||||
_ Node = astUUIDVar{}
|
||||
_ SupportsEquality = astUUIDVar{}
|
||||
)
|
||||
|
||||
// astUUIDVar is a variable that represents a UUID column. Unlike
|
||||
// astStringVar it emits native UUID comparisons (column = 'val'::uuid)
|
||||
// instead of text-based ones (COALESCE(column::text, ”) = 'val').
|
||||
// This allows PostgreSQL to use indexes on UUID columns.
|
||||
type astUUIDVar struct {
|
||||
Source RegoSource
|
||||
FieldPath []string
|
||||
ColumnString string
|
||||
}
|
||||
|
||||
func UUIDVarMatcher(sqlColumn string, regoPath []string) VariableMatcher {
|
||||
return astUUIDVar{FieldPath: regoPath, ColumnString: sqlColumn}
|
||||
}
|
||||
|
||||
func (astUUIDVar) UseAs() Node { return astUUIDVar{} }
|
||||
|
||||
func (u astUUIDVar) ConvertVariable(rego ast.Ref) (Node, bool) {
|
||||
left, err := RegoVarPath(u.FieldPath, rego)
|
||||
if err == nil && len(left) == 0 {
|
||||
return astUUIDVar{
|
||||
Source: RegoSource(rego.String()),
|
||||
FieldPath: u.FieldPath,
|
||||
ColumnString: u.ColumnString,
|
||||
}, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (u astUUIDVar) SQLString(_ *SQLGenerator) string {
|
||||
return u.ColumnString
|
||||
}
|
||||
|
||||
// EqualsSQLString handles equality comparisons for UUID columns.
|
||||
// Rego always produces string literals, so we accept AstString and
|
||||
// cast the literal to ::uuid in the output SQL. This lets PG use
|
||||
// native UUID indexes instead of falling back to text comparisons.
|
||||
// nolint:revive
|
||||
func (u astUUIDVar) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||
switch other.UseAs().(type) {
|
||||
case AstString:
|
||||
// The other side is a rego string literal like
|
||||
// "8c0b9bdc-a013-4b14-a49b-5747bc335708". Emit a comparison
|
||||
// that casts the literal to uuid so PG can use indexes:
|
||||
// column = 'val'::uuid
|
||||
// instead of the text-based:
|
||||
// 'val' = COALESCE(column::text, '')
|
||||
s, ok := other.(AstString)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("expected AstString, got %T", other)
|
||||
}
|
||||
if s.Value == "" {
|
||||
// Empty string in rego means "no value". Compare the
|
||||
// column against NULL since UUID columns represent
|
||||
// absent values as NULL, not empty strings.
|
||||
op := "IS NULL"
|
||||
if not {
|
||||
op = "IS NOT NULL"
|
||||
}
|
||||
return fmt.Sprintf("%s %s", u.ColumnString, op), nil
|
||||
}
|
||||
return fmt.Sprintf("%s %s '%s'::uuid",
|
||||
u.ColumnString, equalsOp(not), s.Value), nil
|
||||
case astUUIDVar:
|
||||
return basicSQLEquality(cfg, not, u, other), nil
|
||||
default:
|
||||
return "", xerrors.Errorf("unsupported equality: %T %s %T",
|
||||
u, equalsOp(not), other)
|
||||
}
|
||||
}
|
||||
|
||||
// ContainedInSQL implements SupportsContainedIn so that a UUID column
|
||||
// can appear in membership checks like `col = ANY(ARRAY[...])`. The
|
||||
// array elements are rego strings, so we cast each to ::uuid.
|
||||
func (u astUUIDVar) ContainedInSQL(_ *SQLGenerator, haystack Node) (string, error) {
|
||||
arr, ok := haystack.(ASTArray)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("unsupported containedIn: %T in %T", u, haystack)
|
||||
}
|
||||
|
||||
if len(arr.Value) == 0 {
|
||||
return "false", nil
|
||||
}
|
||||
|
||||
// Build ARRAY['uuid1'::uuid, 'uuid2'::uuid, ...]
|
||||
values := make([]string, 0, len(arr.Value))
|
||||
for _, v := range arr.Value {
|
||||
s, ok := v.(AstString)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("expected AstString array element, got %T", v)
|
||||
}
|
||||
values = append(values, fmt.Sprintf("'%s'::uuid", s.Value))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s = ANY(ARRAY [%s])",
|
||||
u.ColumnString,
|
||||
strings.Join(values, ",")), nil
|
||||
}
|
||||
+18
-38
@@ -244,7 +244,6 @@ func SystemRoleName(name string) bool {
|
||||
|
||||
type RoleOptions struct {
|
||||
NoOwnerWorkspaceExec bool
|
||||
NoWorkspaceSharing bool
|
||||
}
|
||||
|
||||
// ReservedRoleName exists because the database should only allow unique role
|
||||
@@ -268,23 +267,12 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
opts = &RoleOptions{}
|
||||
}
|
||||
|
||||
denyPermissions := []Permission{}
|
||||
if opts.NoWorkspaceSharing {
|
||||
denyPermissions = append(denyPermissions, Permission{
|
||||
Negate: true,
|
||||
ResourceType: ResourceWorkspace.Type,
|
||||
Action: policy.ActionShare,
|
||||
})
|
||||
}
|
||||
|
||||
ownerWorkspaceActions := ResourceWorkspace.AvailableActions()
|
||||
if opts.NoOwnerWorkspaceExec {
|
||||
// Remove ssh and application connect from the owner role. This
|
||||
// prevents owners from have exec access to all workspaces.
|
||||
ownerWorkspaceActions = slice.Omit(
|
||||
ownerWorkspaceActions,
|
||||
policy.ActionApplicationConnect, policy.ActionSSH,
|
||||
)
|
||||
ownerWorkspaceActions = slice.Omit(ownerWorkspaceActions,
|
||||
policy.ActionApplicationConnect, policy.ActionSSH)
|
||||
}
|
||||
|
||||
// Static roles that never change should be allocated in a closure.
|
||||
@@ -307,8 +295,7 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
// Explicitly setting PrebuiltWorkspace permissions for clarity.
|
||||
// Note: even without PrebuiltWorkspace permissions, access is still granted via Workspace permissions.
|
||||
ResourcePrebuiltWorkspace.Type: {policy.ActionUpdate, policy.ActionDelete},
|
||||
})...,
|
||||
),
|
||||
})...),
|
||||
User: []Permission{},
|
||||
ByOrgID: map[string]OrgPermissions{},
|
||||
}.withCachedRegoValue()
|
||||
@@ -316,17 +303,13 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
memberRole := Role{
|
||||
Identifier: RoleMember(),
|
||||
DisplayName: "Member",
|
||||
Site: append(
|
||||
Permissions(map[string][]policy.Action{
|
||||
ResourceAssignRole.Type: {policy.ActionRead},
|
||||
// All users can see OAuth2 provider applications.
|
||||
ResourceOauth2App.Type: {policy.ActionRead},
|
||||
ResourceWorkspaceProxy.Type: {policy.ActionRead},
|
||||
}),
|
||||
denyPermissions...,
|
||||
),
|
||||
User: append(
|
||||
allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUser, ResourceOrganizationMember, ResourceOrganizationMember, ResourceBoundaryUsage),
|
||||
Site: Permissions(map[string][]policy.Action{
|
||||
ResourceAssignRole.Type: {policy.ActionRead},
|
||||
// All users can see OAuth2 provider applications.
|
||||
ResourceOauth2App.Type: {policy.ActionRead},
|
||||
ResourceWorkspaceProxy.Type: {policy.ActionRead},
|
||||
}),
|
||||
User: append(allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUser, ResourceOrganizationMember, ResourceOrganizationMember, ResourceBoundaryUsage),
|
||||
Permissions(map[string][]policy.Action{
|
||||
// Users cannot do create/update/delete on themselves, but they
|
||||
// can read their own details.
|
||||
@@ -450,17 +433,14 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
ByOrgID: map[string]OrgPermissions{
|
||||
// Org admins should not have workspace exec perms.
|
||||
organizationID.String(): {
|
||||
Org: append(
|
||||
allPermsExcept(ResourceWorkspace, ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceAssignRole, ResourceUserSecret, ResourceBoundaryUsage),
|
||||
Permissions(map[string][]policy.Action{
|
||||
ResourceWorkspace.Type: slice.Omit(ResourceWorkspace.AvailableActions(), policy.ActionApplicationConnect, policy.ActionSSH),
|
||||
ResourceWorkspaceDormant.Type: {policy.ActionRead, policy.ActionDelete, policy.ActionCreate, policy.ActionUpdate, policy.ActionWorkspaceStop, policy.ActionCreateAgent, policy.ActionDeleteAgent, policy.ActionUpdateAgent},
|
||||
// PrebuiltWorkspaces are a subset of Workspaces.
|
||||
// Explicitly setting PrebuiltWorkspace permissions for clarity.
|
||||
// Note: even without PrebuiltWorkspace permissions, access is still granted via Workspace permissions.
|
||||
ResourcePrebuiltWorkspace.Type: {policy.ActionUpdate, policy.ActionDelete},
|
||||
})...,
|
||||
),
|
||||
Org: append(allPermsExcept(ResourceWorkspace, ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceAssignRole, ResourceUserSecret, ResourceBoundaryUsage), Permissions(map[string][]policy.Action{
|
||||
ResourceWorkspaceDormant.Type: {policy.ActionRead, policy.ActionDelete, policy.ActionCreate, policy.ActionUpdate, policy.ActionWorkspaceStop, policy.ActionCreateAgent, policy.ActionDeleteAgent, policy.ActionUpdateAgent},
|
||||
ResourceWorkspace.Type: slice.Omit(ResourceWorkspace.AvailableActions(), policy.ActionApplicationConnect, policy.ActionSSH),
|
||||
// PrebuiltWorkspaces are a subset of Workspaces.
|
||||
// Explicitly setting PrebuiltWorkspace permissions for clarity.
|
||||
// Note: even without PrebuiltWorkspace permissions, access is still granted via Workspace permissions.
|
||||
ResourcePrebuiltWorkspace.Type: {policy.ActionUpdate, policy.ActionDelete},
|
||||
})...),
|
||||
Member: []Permission{},
|
||||
},
|
||||
},
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user