Compare commits

..

5 Commits

Author SHA1 Message Date
Jon Ayers aefc75133a fix: use separate http.Transports for wsproxy tests 2026-02-24 23:02:37 +00:00
Jon Ayers b9181c3934 feat(wsproxy): add /debug/expvar endpoint for DERP server stats 2026-02-21 00:19:41 +00:00
Jon Ayers a90471db53 feat(monitoring): add wsproxy DERP section to Grafana dashboard
Adds a new 'Workspace Proxy - DERP' row with 6 panels:
- DERP Connections (current connections and home connections)
- DERP Client Breakdown (local, remote, total)
- DERP Throughput (bytes received/sent rate)
- DERP Packets (received/sent/forwarded rate)
- DERP Packet Drops (by reason label)
- DERP Queue Duration (average queue duration)
2026-02-20 23:44:24 +00:00
Jon Ayers cb71f5e789 feat(wsproxy): add DERP websocket throughput metrics
Add Prometheus metrics tracking active DERP websocket connections and
bytes relayed through the wsproxy:

- coder_wsproxy_derp_websocket_active_connections (gauge)
- coder_wsproxy_derp_websocket_bytes_total (counter, direction=read|write)

Implementation adds a DERPWebsocketMetrics hook struct and countingConn
wrapper in tailnet/, and a new WithWebsocketSupportAndMetrics function
that instruments the websocket connection lifecycle. The existing
WithWebsocketSupport function delegates to the new one with nil metrics.
2026-02-20 23:44:21 +00:00
Jon Ayers f50707bc3e feat(wsproxy): add Prometheus collector for DERP server expvar metrics
Create a prometheus.Collector that bridges the tailscale derp.Server's
expvar-based stats to Prometheus metrics with namespace coder, subsystem
wsproxy_derp. Handles counters, gauges, labeled metrics (nested
metrics.Set for drop reasons, packet types, etc.), and the average
queue duration (converted from ms to seconds).

Register the collector in the wsproxy server after derpServer creation.
2026-02-20 23:40:03 +00:00
247 changed files with 4247 additions and 8444 deletions
+1 -1
View File
@@ -4,7 +4,7 @@ description: |
inputs:
version:
description: "The Go version to use."
default: "1.25.8"
default: "1.25.7"
use-preinstalled-go:
description: "Whether to use preinstalled Go."
default: "false"
+17 -17
View File
@@ -35,7 +35,7 @@ jobs:
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -157,7 +157,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -247,7 +247,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -272,7 +272,7 @@ jobs:
if: ${{ !cancelled() }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -329,7 +329,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -381,7 +381,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -586,7 +586,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -648,7 +648,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -720,7 +720,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -747,7 +747,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -780,7 +780,7 @@ jobs:
name: ${{ matrix.variant.name }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -860,7 +860,7 @@ jobs:
if: needs.changes.outputs.site == 'true' || needs.changes.outputs.ci == 'true'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -941,7 +941,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1013,7 +1013,7 @@ jobs:
if: always()
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1128,7 +1128,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1183,7 +1183,7 @@ jobs:
IMAGE: ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1580,7 +1580,7 @@ jobs:
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+3 -3
View File
@@ -36,7 +36,7 @@ jobs:
verdict: ${{ steps.check.outputs.verdict }} # DEPLOY or NOOP
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -65,7 +65,7 @@ jobs:
packages: write # to retag image as dogfood
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -146,7 +146,7 @@ jobs:
needs: deploy
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+1 -1
View File
@@ -38,7 +38,7 @@ jobs:
if: github.repository_owner == 'coder'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+2 -2
View File
@@ -26,7 +26,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -125,7 +125,7 @@ jobs:
id-token: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+1 -1
View File
@@ -28,7 +28,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+1 -1
View File
@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+1 -1
View File
@@ -19,7 +19,7 @@ jobs:
packages: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+5 -5
View File
@@ -39,7 +39,7 @@ jobs:
PR_OPEN: ${{ steps.check_pr.outputs.pr_open }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -76,7 +76,7 @@ jobs:
runs-on: "ubuntu-latest"
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -184,7 +184,7 @@ jobs:
pull-requests: write # needed for commenting on PRs
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -228,7 +228,7 @@ jobs:
CODER_IMAGE_TAG: ${{ needs.get_info.outputs.CODER_IMAGE_TAG }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -288,7 +288,7 @@ jobs:
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+1 -1
View File
@@ -14,7 +14,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+4 -4
View File
@@ -158,7 +158,7 @@ jobs:
version: ${{ steps.version.outputs.version }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -796,7 +796,7 @@ jobs:
# TODO: skip this if it's not a new release (i.e. a backport). This is
# fine right now because it just makes a PR that we can close.
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -872,7 +872,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -965,7 +965,7 @@ jobs:
if: ${{ !inputs.dry_run }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+1 -1
View File
@@ -20,7 +20,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+114 -1
View File
@@ -27,7 +27,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -63,3 +63,116 @@ jobs:
--data "{\"content\": \"$msg\"}" \
"${{ secrets.SLACK_SECURITY_FAILURE_WEBHOOK_URL }}"
trivy:
permissions:
security-events: write
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Setup Go
uses: ./.github/actions/setup-go
- name: Setup Node
uses: ./.github/actions/setup-node
- name: Setup sqlc
uses: ./.github/actions/setup-sqlc
- name: Install cosign
uses: ./.github/actions/install-cosign
- name: Install syft
uses: ./.github/actions/install-syft
- name: Install yq
run: go run github.com/mikefarah/yq/v4@v4.44.3
- name: Install mockgen
run: ./.github/scripts/retry.sh -- go install go.uber.org/mock/mockgen@v0.6.0
- name: Install protoc-gen-go
run: ./.github/scripts/retry.sh -- go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.30
- name: Install protoc-gen-go-drpc
run: ./.github/scripts/retry.sh -- go install storj.io/drpc/cmd/protoc-gen-go-drpc@v0.0.34
- name: Install Protoc
run: |
# protoc must be in lockstep with our dogfood Dockerfile or the
# version in the comments will differ. This is also defined in
# ci.yaml.
set -euxo pipefail
cd dogfood/coder
mkdir -p /usr/local/bin
mkdir -p /usr/local/include
DOCKER_BUILDKIT=1 docker build . --target proto -t protoc
protoc_path=/usr/local/bin/protoc
docker run --rm --entrypoint cat protoc /tmp/bin/protoc > $protoc_path
chmod +x $protoc_path
protoc --version
# Copy the generated files to the include directory.
docker run --rm -v /usr/local/include:/target protoc cp -r /tmp/include/google /target/
ls -la /usr/local/include/google/protobuf/
stat /usr/local/include/google/protobuf/timestamp.proto
- name: Build Coder linux amd64 Docker image
id: build
run: |
set -euo pipefail
version="$(./scripts/version.sh)"
image_job="build/coder_${version}_linux_amd64.tag"
# This environment variable force make to not build packages and
# archives (which the Docker image depends on due to technical reasons
# related to concurrent FS writes).
export DOCKER_IMAGE_NO_PREREQUISITES=true
# This environment variables forces scripts/build_docker.sh to build
# the base image tag locally instead of using the cached version from
# the registry.
CODER_IMAGE_BUILD_BASE_TAG="$(CODER_IMAGE_BASE=coder-base ./scripts/image_tag.sh --version "$version")"
export CODER_IMAGE_BUILD_BASE_TAG
# We would like to use make -j here, but it doesn't work with the some recent additions
# to our code generation.
make "$image_job"
echo "image=$(cat "$image_job")" >> "$GITHUB_OUTPUT"
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@c1824fd6edce30d7ab345a9989de00bbd46ef284 # v0.34.0
with:
image-ref: ${{ steps.build.outputs.image }}
format: sarif
output: trivy-results.sarif
severity: "CRITICAL,HIGH"
- name: Upload Trivy scan results to GitHub Security tab
uses: github/codeql-action/upload-sarif@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
with:
sarif_file: trivy-results.sarif
category: "Trivy"
- name: Upload Trivy scan results as an artifact
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: trivy
path: trivy-results.sarif
retention-days: 7
- name: Send Slack notification on failure
if: ${{ failure() }}
run: |
msg="❌ Trivy Failed\n\nhttps://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
curl \
-qfsSL \
-X POST \
-H "Content-Type: application/json" \
--data "{\"content\": \"$msg\"}" \
"${{ secrets.SLACK_SECURITY_FAILURE_WEBHOOK_URL }}"
+4 -4
View File
@@ -18,12 +18,12 @@ jobs:
pull-requests: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
- name: stale
uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
with:
stale-issue-label: "stale"
stale-pr-label: "stale"
@@ -96,7 +96,7 @@ jobs:
contents: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -120,7 +120,7 @@ jobs:
actions: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+1 -1
View File
@@ -21,7 +21,7 @@ jobs:
pull-requests: write # required to post PR review comments by the action
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
-56
View File
@@ -3040,62 +3040,6 @@ func TestAgent_Reconnect(t *testing.T) {
closer.Close()
}
func TestAgent_ReconnectNoLifecycleReemit(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := testutil.Logger(t)
fCoordinator := tailnettest.NewFakeCoordinator()
agentID := uuid.New()
statsCh := make(chan *proto.Stats, 50)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
client := agenttest.NewClient(t,
logger,
agentID,
agentsdk.Manifest{
DERPMap: derpMap,
Scripts: []codersdk.WorkspaceAgentScript{{
Script: "echo hello",
Timeout: 30 * time.Second,
RunOnStart: true,
}},
},
statsCh,
fCoordinator,
)
defer client.Close()
closer := agent.New(agent.Options{
Client: client,
Logger: logger.Named("agent"),
})
defer closer.Close()
// Wait for the agent to reach Ready state.
require.Eventually(t, func() bool {
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
}, testutil.WaitShort, testutil.IntervalFast)
statesBefore := slices.Clone(client.GetLifecycleStates())
// Disconnect by closing the coordinator response channel.
call1 := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
close(call1.Resps)
// Wait for reconnect.
testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
// Wait for a stats report as a deterministic steady-state proof.
testutil.RequireReceive(ctx, t, statsCh)
statesAfter := client.GetLifecycleStates()
require.Equal(t, statesBefore, statesAfter,
"lifecycle states should not be re-reported after reconnect")
closer.Close()
}
func TestAgent_WriteVSCodeConfigs(t *testing.T) {
t.Parallel()
logger := testutil.Logger(t)
-4
View File
@@ -235,10 +235,6 @@ type FakeAgentAPI struct {
pushResourcesMonitoringUsageFunc func(*agentproto.PushResourcesMonitoringUsageRequest) (*agentproto.PushResourcesMonitoringUsageResponse, error)
}
func (*FakeAgentAPI) UpdateAppStatus(context.Context, *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
panic("unimplemented")
}
func (f *FakeAgentAPI) GetManifest(context.Context, *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
return f.manifest, nil
}
+330 -544
View File
File diff suppressed because it is too large Load Diff
+1 -20
View File
@@ -436,7 +436,7 @@ message CreateSubAgentRequest {
}
repeated DisplayApp display_apps = 6;
optional bytes id = 7;
}
@@ -494,24 +494,6 @@ message ReportBoundaryLogsRequest {
message ReportBoundaryLogsResponse {}
// UpdateAppStatusRequest updates the given Workspace App's status. c.f. agentsdk.PatchAppStatus
message UpdateAppStatusRequest {
string slug = 1;
enum AppStatusState {
WORKING = 0;
IDLE = 1;
COMPLETE = 2;
FAILURE = 3;
}
AppStatusState state = 2;
string message = 3;
string uri = 4;
}
message UpdateAppStatusResponse {}
service Agent {
rpc GetManifest(GetManifestRequest) returns (Manifest);
rpc GetServiceBanner(GetServiceBannerRequest) returns (ServiceBanner);
@@ -530,5 +512,4 @@ service Agent {
rpc DeleteSubAgent(DeleteSubAgentRequest) returns (DeleteSubAgentResponse);
rpc ListSubAgents(ListSubAgentsRequest) returns (ListSubAgentsResponse);
rpc ReportBoundaryLogs(ReportBoundaryLogsRequest) returns (ReportBoundaryLogsResponse);
rpc UpdateAppStatus(UpdateAppStatusRequest) returns (UpdateAppStatusResponse);
}
+1 -41
View File
@@ -56,7 +56,6 @@ type DRPCAgentClient interface {
DeleteSubAgent(ctx context.Context, in *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
ListSubAgents(ctx context.Context, in *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
type drpcAgentClient struct {
@@ -222,15 +221,6 @@ func (c *drpcAgentClient) ReportBoundaryLogs(ctx context.Context, in *ReportBoun
return out, nil
}
func (c *drpcAgentClient) UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
out := new(UpdateAppStatusResponse)
err := c.cc.Invoke(ctx, "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
type DRPCAgentServer interface {
GetManifest(context.Context, *GetManifestRequest) (*Manifest, error)
GetServiceBanner(context.Context, *GetServiceBannerRequest) (*ServiceBanner, error)
@@ -249,7 +239,6 @@ type DRPCAgentServer interface {
DeleteSubAgent(context.Context, *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
ListSubAgents(context.Context, *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
ReportBoundaryLogs(context.Context, *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
type DRPCAgentUnimplementedServer struct{}
@@ -322,13 +311,9 @@ func (s *DRPCAgentUnimplementedServer) ReportBoundaryLogs(context.Context, *Repo
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentUnimplementedServer) UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
type DRPCAgentDescription struct{}
func (DRPCAgentDescription) NumMethods() int { return 18 }
func (DRPCAgentDescription) NumMethods() int { return 17 }
func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
switch n {
@@ -485,15 +470,6 @@ func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver,
in1.(*ReportBoundaryLogsRequest),
)
}, DRPCAgentServer.ReportBoundaryLogs, true
case 17:
return "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentServer).
UpdateAppStatus(
ctx,
in1.(*UpdateAppStatusRequest),
)
}, DRPCAgentServer.UpdateAppStatus, true
default:
return "", nil, nil, nil, false
}
@@ -774,19 +750,3 @@ func (x *drpcAgent_ReportBoundaryLogsStream) SendAndClose(m *ReportBoundaryLogsR
}
return x.CloseSend()
}
type DRPCAgent_UpdateAppStatusStream interface {
drpc.Stream
SendAndClose(*UpdateAppStatusResponse) error
}
type drpcAgent_UpdateAppStatusStream struct {
drpc.Stream
}
func (x *drpcAgent_UpdateAppStatusStream) SendAndClose(m *UpdateAppStatusResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_proto_agent_proto{}); err != nil {
return err
}
return x.CloseSend()
}
+3 -7
View File
@@ -73,13 +73,9 @@ type DRPCAgentClient27 interface {
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
}
// DRPCAgentClient28 is the Agent API at v2.8. It adds
// - a SubagentId field to the WorkspaceAgentDevcontainer message
// - an Id field to the CreateSubAgentRequest message.
// - UpdateAppStatus RPC.
//
// Compatible with Coder v2.31+
// DRPCAgentClient28 is the Agent API at v2.8. It adds a SubagentId field to the
// WorkspaceAgentDevcontainer message, and a Id field to the CreateSubAgentRequest
// message. Compatible with Coder v2.31+
type DRPCAgentClient28 interface {
DRPCAgentClient27
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
+45 -50
View File
@@ -10,7 +10,6 @@ import (
"path/filepath"
"slices"
"strings"
"time"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
@@ -24,7 +23,6 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/toolsdk"
"github.com/coder/retry"
"github.com/coder/serpent"
)
@@ -541,6 +539,7 @@ func (r *RootCmd) mcpServer() *serpent.Command {
defer cancel()
defer srv.queue.Close()
cliui.Infof(inv.Stderr, "Failed to watch screen events")
// Start the reporter, watcher, and server. These are all tied to the
// lifetime of the MCP server, which is itself tied to the lifetime of the
// AI agent.
@@ -614,51 +613,48 @@ func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation)
}
func (s *mcpServer) startWatcher(ctx context.Context, inv *serpent.Invocation) {
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
return
}
go func() {
for retrier := retry.New(time.Second, 30*time.Second); retrier.Wait(ctx); {
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
if err == nil {
retrier.Reset()
loop:
for {
select {
case <-ctx.Done():
for {
select {
case <-ctx.Done():
return
case event := <-eventsCh:
switch ev := event.(type) {
case agentapi.EventStatusChange:
// If the screen is stable, report idle.
state := codersdk.WorkspaceAppStatusStateWorking
if ev.Status == agentapi.StatusStable {
state = codersdk.WorkspaceAppStatusStateIdle
}
err := s.queue.Push(taskReport{
state: state,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
case event := <-eventsCh:
switch ev := event.(type) {
case agentapi.EventStatusChange:
state := codersdk.WorkspaceAppStatusStateWorking
if ev.Status == agentapi.StatusStable {
state = codersdk.WorkspaceAppStatusStateIdle
}
err := s.queue.Push(taskReport{
state: state,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
}
case agentapi.EventMessageUpdate:
if ev.Role == agentapi.RoleUser {
err := s.queue.Push(taskReport{
messageID: &ev.Id,
state: codersdk.WorkspaceAppStatusStateWorking,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
}
}
}
case agentapi.EventMessageUpdate:
if ev.Role == agentapi.RoleUser {
err := s.queue.Push(taskReport{
messageID: &ev.Id,
state: codersdk.WorkspaceAppStatusStateWorking,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
}
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
}
break loop
}
}
} else {
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
}
return
}
}
}()
@@ -696,14 +692,13 @@ func (s *mcpServer) startServer(ctx context.Context, inv *serpent.Invocation, in
// Add tool dependencies.
toolOpts := []func(*toolsdk.Deps){
toolsdk.WithTaskReporter(func(args toolsdk.ReportTaskArgs) error {
state := codersdk.WorkspaceAppStatusState(args.State)
// The agent does not reliably report idle, so when AgentAPI is
// enabled we override idle to working and let the screen watcher
// detect the real idle via StatusStable. Final states (failure,
// complete) are trusted from the agent since the screen watcher
// cannot produce them.
if s.aiAgentAPIClient != nil && state == codersdk.WorkspaceAppStatusStateIdle {
state = codersdk.WorkspaceAppStatusStateWorking
// The agent does not reliably report its status correctly. If AgentAPI
// is enabled, we will always set the status to "working" when we get an
// MCP message, and rely on the screen watcher to eventually catch the
// idle state.
state := codersdk.WorkspaceAppStatusStateWorking
if s.aiAgentAPIClient == nil {
state = codersdk.WorkspaceAppStatusState(args.State)
}
return s.queue.Push(taskReport{
link: args.Link,
+1 -185
View File
@@ -921,7 +921,7 @@ func TestExpMcpReporter(t *testing.T) {
},
},
},
// We override idle from the agent to working, but trust final states.
// We ignore the state from the agent and assume "working".
{
name: "IgnoreAgentState",
// AI agent reports that it is finished but the summary says it is doing
@@ -953,46 +953,6 @@ func TestExpMcpReporter(t *testing.T) {
Message: "finished",
},
},
// Agent reports failure; trusted even with AgentAPI enabled.
{
state: codersdk.WorkspaceAppStatusStateFailure,
summary: "something broke",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateFailure,
Message: "something broke",
},
},
// After failure, watcher reports stable -> idle.
{
event: makeStatusEvent(agentapi.StatusStable),
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateIdle,
Message: "something broke",
},
},
},
},
// Final states pass through with AgentAPI enabled.
{
name: "AllowFinalStates",
tests: []test{
{
state: codersdk.WorkspaceAppStatusStateWorking,
summary: "doing work",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateWorking,
Message: "doing work",
},
},
// Agent reports complete; not overridden.
{
state: codersdk.WorkspaceAppStatusStateComplete,
summary: "all done",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateComplete,
Message: "all done",
},
},
},
},
// When AgentAPI is not being used, we accept agent state updates as-is.
@@ -1150,148 +1110,4 @@ func TestExpMcpReporter(t *testing.T) {
<-cmdDone
})
}
t.Run("Reconnect", func(t *testing.T) {
t.Parallel()
// Create a test deployment and workspace.
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user2.ID,
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
a[0].Apps = []*proto.App{
{
Slug: "vscode",
},
}
return a
}).Do()
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitLong))
// Watch the workspace for changes.
watcher, err := client.WatchWorkspace(ctx, r.Workspace.ID)
require.NoError(t, err)
var lastAppStatus codersdk.WorkspaceAppStatus
nextUpdate := func() codersdk.WorkspaceAppStatus {
for {
select {
case <-ctx.Done():
require.FailNow(t, "timed out waiting for status update")
case w, ok := <-watcher:
require.True(t, ok, "watch channel closed")
if w.LatestAppStatus != nil && w.LatestAppStatus.ID != lastAppStatus.ID {
t.Logf("Got status update: %s > %s", lastAppStatus.State, w.LatestAppStatus.State)
lastAppStatus = *w.LatestAppStatus
return lastAppStatus
}
}
}
}
// Mock AI AgentAPI server that supports disconnect/reconnect.
disconnect := make(chan struct{})
listening := make(chan func(sse codersdk.ServerSentEvent) error)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Create a cancelable context so we can stop the SSE sender
// goroutine on disconnect without waiting for the HTTP
// serve loop to cancel r.Context().
sseCtx, sseCancel := context.WithCancel(r.Context())
defer sseCancel()
r = r.WithContext(sseCtx)
send, closed, err := httpapi.ServerSentEventSender(w, r)
if err != nil {
httpapi.Write(sseCtx, w, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error setting up server-sent events.",
Detail: err.Error(),
})
return
}
// Send initial message so the watcher knows the agent is active.
send(*makeMessageEvent(0, agentapi.RoleAgent))
select {
case listening <- send:
case <-r.Context().Done():
return
}
select {
case <-closed:
case <-disconnect:
sseCancel()
<-closed
}
}))
t.Cleanup(srv.Close)
inv, _ := clitest.New(t,
"exp", "mcp", "server",
"--agent-url", client.URL.String(),
"--agent-token", r.AgentToken,
"--app-status-slug", "vscode",
"--allowed-tools=coder_report_task",
"--ai-agentapi-url", srv.URL,
)
inv = inv.WithContext(ctx)
pty := ptytest.New(t)
inv.Stdin = pty.Input()
inv.Stdout = pty.Output()
stderr := ptytest.New(t)
inv.Stderr = stderr.Output()
// Run the MCP server.
clitest.Start(t, inv)
// Initialize.
payload := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
pty.WriteLine(payload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore init response
// Get first sender from the initial SSE connection.
sender := testutil.RequireReceive(ctx, t, listening)
// Self-report a working status via tool call.
toolPayload := `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"doing work","link":""}}}`
pty.WriteLine(toolPayload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore response
got := nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
require.Equal(t, "doing work", got.Message)
// Watcher sends stable, verify idle is reported.
err = sender(*makeStatusEvent(agentapi.StatusStable))
require.NoError(t, err)
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
// Disconnect the SSE connection by signaling the handler to return.
testutil.RequireSend(ctx, t, disconnect, struct{}{})
// Wait for the watcher to reconnect and get the new sender.
sender = testutil.RequireReceive(ctx, t, listening)
// After reconnect, self-report a working status again.
toolPayload = `{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"reconnected","link":""}}}`
pty.WriteLine(toolPayload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore response
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
require.Equal(t, "reconnected", got.Message)
// Verify the watcher still processes events after reconnect.
err = sender(*makeStatusEvent(agentapi.StatusStable))
require.NoError(t, err)
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
cancel()
})
}
+5 -1
View File
@@ -106,7 +106,11 @@ func TestList(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
memberClient, member = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
sharedWorkspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
+1 -1
View File
@@ -297,7 +297,7 @@ func (pr *ParameterResolver) verifyConstraints(resolved []codersdk.WorkspaceBuil
return xerrors.Errorf("ephemeral parameter %q can be used only with --prompt-ephemeral-parameters or --ephemeral-parameter flag", r.Name)
}
if !tvp.Mutable && action != WorkspaceCreate && !pr.isFirstTimeUse(r.Name) {
if !tvp.Mutable && action != WorkspaceCreate {
return xerrors.Errorf("parameter %q is immutable and cannot be updated", r.Name)
}
}
+31 -7
View File
@@ -25,7 +25,11 @@ func TestSharingShare(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -64,8 +68,12 @@ func TestSharingShare(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
orgOwner = coderdtest.CreateFirstUser(t, client)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -119,7 +127,11 @@ func TestSharingShare(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -170,7 +182,11 @@ func TestSharingStatus(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -214,7 +230,11 @@ func TestSharingRemove(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -271,7 +291,11 @@ func TestSharingRemove(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
+1 -1
View File
@@ -120,7 +120,7 @@ func (r *RootCmd) start() *serpent.Command {
func buildWorkspaceStartRequest(inv *serpent.Invocation, client *codersdk.Client, workspace codersdk.Workspace, parameterFlags workspaceParameterFlags, buildFlags buildFlags, action WorkspaceCLIAction) (codersdk.CreateWorkspaceBuildRequest, error) {
version := workspace.LatestBuild.TemplateVersionID
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || workspace.TemplateRequireActiveVersion || action == WorkspaceUpdate {
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || action == WorkspaceUpdate {
version = workspace.TemplateActiveVersionID
if version != workspace.LatestBuild.TemplateVersionID {
action = WorkspaceUpdate
+4 -4
View File
@@ -33,7 +33,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Do()
statefilePath := filepath.Join(t.TempDir(), "state")
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name, statefilePath)
@@ -54,7 +54,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Do()
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name)
var gotState bytes.Buffer
@@ -74,7 +74,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Do()
inv, root := clitest.New(t, "state", "pull", taUser.Username+"/"+r.Workspace.Name,
"--build", fmt.Sprintf("%d", r.Build.BuildNumber))
@@ -170,7 +170,7 @@ func TestStatePush(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(initialState).
Seed(database.WorkspaceBuild{ProvisionerState: initialState}).
Do()
wantState := []byte("updated state")
stateFile, err := os.CreateTemp(t.TempDir(), "")
+4 -3
View File
@@ -49,9 +49,10 @@ OPTIONS:
security purposes if a --wildcard-access-url is configured.
--disable-workspace-sharing bool, $CODER_DISABLE_WORKSPACE_SHARING
Disable workspace sharing. Workspace ACL checking is disabled and only
owners can have ssh, apps and terminal access to workspaces. Access
based on the 'owner' role is also allowed unless disabled via
Disable workspace sharing (requires the "workspace-sharing" experiment
to be enabled). Workspace ACL checking is disabled and only owners can
have ssh, apps and terminal access to workspaces. Access based on the
'owner' role is also allowed unless disabled via
--disable-owner-workspace-access.
--swagger-enable bool, $CODER_SWAGGER_ENABLE
+4 -4
View File
@@ -523,10 +523,10 @@ disablePathApps: false
# workspaces.
# (default: <unset>, type: bool)
disableOwnerWorkspaceAccess: false
# Disable workspace sharing. Workspace ACL checking is disabled and only owners
# can have ssh, apps and terminal access to workspaces. Access based on the
# 'owner' role is also allowed unless disabled via
# --disable-owner-workspace-access.
# Disable workspace sharing (requires the "workspace-sharing" experiment to be
# enabled). Workspace ACL checking is disabled and only owners can have ssh, apps
# and terminal access to workspaces. Access based on the 'owner' role is also
# allowed unless disabled via --disable-owner-workspace-access.
# (default: <unset>, type: bool)
disableWorkspaceSharing: false
# These options change the behavior of how clients interact with the Coder.
+15 -2
View File
@@ -241,13 +241,26 @@ func (r *RootCmd) listTokens() *serpent.Command {
}
tokens, err := client.Tokens(inv.Context(), codersdk.Me, codersdk.TokensFilter{
IncludeAll: all,
IncludeExpired: includeExpired,
IncludeAll: all,
})
if err != nil {
return xerrors.Errorf("list tokens: %w", err)
}
// Filter out expired tokens unless --include-expired is set
// TODO(Cian): This _could_ get too big for client-side filtering.
// If it causes issues, we can filter server-side.
if !includeExpired {
now := time.Now()
filtered := make([]codersdk.APIKeyWithOwner, 0, len(tokens))
for _, token := range tokens {
if token.ExpiresAt.After(now) {
filtered = append(filtered, token)
}
}
tokens = filtered
}
displayTokens = make([]tokenListRow, len(tokens))
for i, token := range tokens {
-70
View File
@@ -990,74 +990,4 @@ func TestUpdateValidateRichParameters(t *testing.T) {
_ = testutil.TryReceive(ctx, t, doneChan)
})
t.Run("NewImmutableParameterViaFlag", func(t *testing.T) {
t.Parallel()
// Create template and workspace with only a mutable parameter.
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
templateParameters := []*proto.RichParameter{
{Name: stringParameterName, Type: "string", Mutable: true, Required: true, Options: []*proto.RichParameterOption{
{Name: "First option", Description: "This is first option", Value: "1st"},
{Name: "Second option", Description: "This is second option", Value: "2nd"},
}},
}
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses(templateParameters))
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
inv, root := clitest.New(t, "create", "my-workspace", "--yes", "--template", template.Name, "--parameter", fmt.Sprintf("%s=%s", stringParameterName, "1st"))
clitest.SetupConfig(t, member, root)
err := inv.Run()
require.NoError(t, err)
// Update template: add a new immutable parameter.
updatedTemplateParameters := []*proto.RichParameter{
templateParameters[0],
{Name: immutableParameterName, Type: "string", Mutable: false, Required: true, Options: []*proto.RichParameterOption{
{Name: "fir", Description: "First option for immutable parameter", Value: "I"},
{Name: "sec", Description: "Second option for immutable parameter", Value: "II"},
}},
}
updatedVersion := coderdtest.UpdateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses(updatedTemplateParameters), template.ID)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, updatedVersion.ID)
err = client.UpdateActiveTemplateVersion(context.Background(), template.ID, codersdk.UpdateActiveTemplateVersion{
ID: updatedVersion.ID,
})
require.NoError(t, err)
// Update workspace, supplying the new immutable parameter via
// the --parameter flag. This should succeed because it's the
// first time this parameter is being set.
inv, root = clitest.New(t, "update", "my-workspace",
"--parameter", fmt.Sprintf("%s=%s", immutableParameterName, "II"))
clitest.SetupConfig(t, member, root)
pty := ptytest.New(t).Attach(inv)
doneChan := make(chan struct{})
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
pty.ExpectMatch("Planning workspace")
ctx := testutil.Context(t, testutil.WaitLong)
_ = testutil.TryReceive(ctx, t, doneChan)
// Verify the immutable parameter was set correctly.
workspace, err := client.WorkspaceByOwnerAndName(ctx, memberUser.ID.String(), "my-workspace", codersdk.WorkspaceOptions{})
require.NoError(t, err)
actualParameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID)
require.NoError(t, err)
require.Contains(t, actualParameters, codersdk.WorkspaceBuildParameter{
Name: immutableParameterName,
Value: "II",
})
})
}
-2
View File
@@ -179,8 +179,6 @@ func New(opts Options, workspace database.Workspace) *API {
Database: opts.Database,
Log: opts.Log,
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
Clock: opts.Clock,
NotificationsEnqueuer: opts.NotificationsEnqueuer,
}
api.MetadataAPI = &MetadataAPI{
-240
View File
@@ -2,10 +2,6 @@ package agentapi
import (
"context"
"database/sql"
"fmt"
"net/http"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
@@ -13,14 +9,7 @@ import (
"cdr.dev/slog/v3"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/notifications"
strutil "github.com/coder/coder/v2/coderd/util/strings"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
type AppsAPI struct {
@@ -28,8 +17,6 @@ type AppsAPI struct {
Database database.Store
Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
NotificationsEnqueuer notifications.Enqueuer
Clock quartz.Clock
}
func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) {
@@ -117,230 +104,3 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
}
return &agentproto.BatchUpdateAppHealthResponse{}, nil
}
func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
if len(req.Message) > 160 {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Message is too long.",
Detail: "Message must be less than 160 characters.",
Validations: []codersdk.ValidationError{
{Field: "message", Detail: "Message must be less than 160 characters."},
},
})
}
var dbState database.WorkspaceAppStatusState
switch req.State {
case agentproto.UpdateAppStatusRequest_COMPLETE:
dbState = database.WorkspaceAppStatusStateComplete
case agentproto.UpdateAppStatusRequest_FAILURE:
dbState = database.WorkspaceAppStatusStateFailure
case agentproto.UpdateAppStatusRequest_WORKING:
dbState = database.WorkspaceAppStatusStateWorking
case agentproto.UpdateAppStatusRequest_IDLE:
dbState = database.WorkspaceAppStatusStateIdle
default:
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Invalid state provided.",
Detail: fmt.Sprintf("invalid state: %q", req.State),
Validations: []codersdk.ValidationError{
{Field: "state", Detail: "State must be one of: complete, failure, working, idle."},
},
})
}
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, err
}
app, err := a.Database.GetWorkspaceAppByAgentIDAndSlug(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{
AgentID: workspaceAgent.ID,
Slug: req.Slug,
})
if err != nil {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace app.",
Detail: fmt.Sprintf("No app found with slug %q", req.Slug),
})
}
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
if err != nil {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace.",
Detail: err.Error(),
})
}
// Treat the message as untrusted input.
cleaned := strutil.UISanitize(req.Message)
// Get the latest status for the workspace app to detect no-op updates
// nolint:gocritic // This is a system restricted operation.
latestAppStatus, err := a.Database.GetLatestWorkspaceAppStatusByAppID(dbauthz.AsSystemRestricted(ctx), app.ID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get latest workspace app status.",
Detail: err.Error(),
})
}
// If no rows found, latestAppStatus will be a zero-value struct (ID == uuid.Nil)
// nolint:gocritic // This is a system restricted operation.
_, err = a.Database.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
ID: uuid.New(),
CreatedAt: dbtime.Now(),
WorkspaceID: workspace.ID,
AgentID: workspaceAgent.ID,
AppID: app.ID,
State: dbState,
Message: cleaned,
Uri: sql.NullString{
String: req.Uri,
Valid: req.Uri != "",
},
})
if err != nil {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to insert workspace app status.",
Detail: err.Error(),
})
}
if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentAppStatusUpdate)
if err != nil {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to publish workspace update.",
Detail: err.Error(),
})
}
}
// Notify on state change to Working/Idle for AI tasks
a.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, dbState, workspace, workspaceAgent)
if shouldBump(dbState, latestAppStatus) {
// We pass time.Time{} for nextAutostart since we don't have access to
// TemplateScheduleStore here. The activity bump logic handles this by
// defaulting to the template's activity_bump duration (typically 1 hour).
workspacestats.ActivityBumpWorkspace(ctx, a.Log, a.Database, workspace.ID, time.Time{})
}
// just return a blank response because it doesn't contain any settable fields at present.
return new(agentproto.UpdateAppStatusResponse), nil
}
func shouldBump(dbState database.WorkspaceAppStatusState, latestAppStatus database.WorkspaceAppStatus) bool {
// Bump deadline when agent reports working or transitions away from working.
// This prevents auto-pause during active work and gives users time to interact
// after work completes.
// Bump if reporting working state.
if dbState == database.WorkspaceAppStatusStateWorking {
return true
}
// Bump if transitioning away from working state.
if latestAppStatus.ID != uuid.Nil {
prevState := latestAppStatus.State
if prevState == database.WorkspaceAppStatusStateWorking {
return true
}
}
return false
}
// enqueueAITaskStateNotification enqueues a notification when an AI task's app
// transitions to Working or Idle.
// No-op if:
// - the workspace agent app isn't configured as an AI task,
// - the new state equals the latest persisted state,
// - the workspace agent is not ready (still starting up).
func (a *AppsAPI) enqueueAITaskStateNotification(
ctx context.Context,
appID uuid.UUID,
latestAppStatus database.WorkspaceAppStatus,
newAppStatus database.WorkspaceAppStatusState,
workspace database.Workspace,
agent database.WorkspaceAgent,
) {
var notificationTemplate uuid.UUID
switch newAppStatus {
case database.WorkspaceAppStatusStateWorking:
notificationTemplate = notifications.TemplateTaskWorking
case database.WorkspaceAppStatusStateIdle:
notificationTemplate = notifications.TemplateTaskIdle
case database.WorkspaceAppStatusStateComplete:
notificationTemplate = notifications.TemplateTaskCompleted
case database.WorkspaceAppStatusStateFailure:
notificationTemplate = notifications.TemplateTaskFailed
default:
// Not a notifiable state, do nothing
return
}
if !workspace.TaskID.Valid {
// Workspace has no task ID, do nothing.
return
}
// Only send notifications when the agent is ready. We want to skip
// any state transitions that occur whilst the workspace is starting
// up as it doesn't make sense to receive them.
if agent.LifecycleState != database.WorkspaceAgentLifecycleStateReady {
a.Log.Debug(ctx, "skipping AI task notification because agent is not ready",
slog.F("agent_id", agent.ID),
slog.F("lifecycle_state", agent.LifecycleState),
slog.F("new_app_status", newAppStatus),
)
return
}
task, err := a.Database.GetTaskByID(ctx, workspace.TaskID.UUID)
if err != nil {
a.Log.Warn(ctx, "failed to get task", slog.Error(err))
return
}
if !task.WorkspaceAppID.Valid || task.WorkspaceAppID.UUID != appID {
// Non-task app, do nothing.
return
}
// Skip if the latest persisted state equals the new state (no new transition)
// Note: uuid.Nil check is valid here. If no previous status exists,
// GetLatestWorkspaceAppStatusByAppID returns sql.ErrNoRows and we get a zero-value struct.
if latestAppStatus.ID != uuid.Nil && latestAppStatus.State == newAppStatus {
return
}
// Skip the initial "Working" notification when the task first starts.
// This is obvious to the user since they just created the task.
// We still notify on the first "Idle" status and all subsequent transitions.
if latestAppStatus.ID == uuid.Nil && newAppStatus == database.WorkspaceAppStatusStateWorking {
return
}
if _, err := a.NotificationsEnqueuer.EnqueueWithData(
// nolint:gocritic // Need notifier actor to enqueue notifications
dbauthz.AsNotifier(ctx),
workspace.OwnerID,
notificationTemplate,
map[string]string{
"task": task.Name,
"workspace": workspace.Name,
},
map[string]any{
// Use a 1-minute bucketed timestamp to bypass per-day dedupe,
// allowing identical content to resend within the same day
// (but not more than once every 10s).
"dedupe_bypass_ts": a.Clock.Now().UTC().Truncate(time.Minute),
},
"api-workspace-agent-app-status",
// Associate this notification with related entities
workspace.ID, workspace.OwnerID, workspace.OrganizationID, appID,
); err != nil {
a.Log.Warn(ctx, "failed to notify of task state", slog.Error(err))
return
}
}
-115
View File
@@ -1,115 +0,0 @@
package agentapi
import (
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/util/ptr"
)
func TestShouldBump(t *testing.T) {
t.Parallel()
tests := []struct {
name string
prevState *database.WorkspaceAppStatusState // nil means no previous state
newState database.WorkspaceAppStatusState
shouldBump bool
}{
{
name: "FirstStatusBumps",
prevState: nil,
newState: database.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
{
name: "WorkingToIdleBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: true,
},
{
name: "WorkingToCompleteBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
newState: database.WorkspaceAppStatusStateComplete,
shouldBump: true,
},
{
name: "CompleteToIdleNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "CompleteToCompleteNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
newState: database.WorkspaceAppStatusStateComplete,
shouldBump: false,
},
{
name: "FailureToIdleNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "FailureToFailureNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
newState: database.WorkspaceAppStatusStateFailure,
shouldBump: false,
},
{
name: "CompleteToWorkingBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
newState: database.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
{
name: "FailureToCompleteNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
newState: database.WorkspaceAppStatusStateComplete,
shouldBump: false,
},
{
name: "WorkingToFailureBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
newState: database.WorkspaceAppStatusStateFailure,
shouldBump: true,
},
{
name: "IdleToIdleNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateIdle),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "IdleToWorkingBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateIdle),
newState: database.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var prevAppStatus database.WorkspaceAppStatus
// If there's a previous state, report it first.
if tt.prevState != nil {
prevAppStatus.ID = uuid.UUID{1}
prevAppStatus.State = *tt.prevState
}
didBump := shouldBump(tt.newState, prevAppStatus)
if tt.shouldBump {
require.True(t, didBump, "wanted deadline to bump but it didn't")
} else {
require.False(t, didBump, "wanted deadline not to bump but it did")
}
})
}
}
-188
View File
@@ -2,13 +2,9 @@ package agentapi_test
import (
"context"
"database/sql"
"net/http"
"strings"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
@@ -16,12 +12,8 @@ import (
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestBatchUpdateAppHealths(t *testing.T) {
@@ -261,183 +253,3 @@ func TestBatchUpdateAppHealths(t *testing.T) {
require.Nil(t, resp)
})
}
func TestWorkspaceAgentAppStatus(t *testing.T) {
t.Parallel()
t.Run("Success", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
fEnq := &notificationstest.FakeEnqueuer{}
mClock := quartz.NewMock(t)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
workspaceUpdates := make(chan wspubsub.WorkspaceEventKind, 100)
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(_ context.Context, agnt *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
assert.Equal(t, *agnt, agent)
testutil.AssertSend(ctx, t, workspaceUpdates, kind)
return nil
},
NotificationsEnqueuer: fEnq,
Clock: mClock,
}
app := database.WorkspaceApp{
ID: uuid.UUID{8},
}
mDB.EXPECT().GetWorkspaceAppByAgentIDAndSlug(gomock.Any(), database.GetWorkspaceAppByAgentIDAndSlugParams{
AgentID: agent.ID,
Slug: "vscode",
}).Times(1).Return(app, nil)
task := database.Task{
ID: uuid.UUID{7},
WorkspaceAppID: uuid.NullUUID{
Valid: true,
UUID: app.ID,
},
}
mDB.EXPECT().GetTaskByID(gomock.Any(), task.ID).Times(1).Return(task, nil)
workspace := database.Workspace{
ID: uuid.UUID{9},
TaskID: uuid.NullUUID{
Valid: true,
UUID: task.ID,
},
}
mDB.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Times(1).Return(workspace, nil)
appStatus := database.WorkspaceAppStatus{
ID: uuid.UUID{6},
}
mDB.EXPECT().GetLatestWorkspaceAppStatusByAppID(gomock.Any(), app.ID).Times(1).Return(appStatus, nil)
mDB.EXPECT().InsertWorkspaceAppStatus(
gomock.Any(),
gomock.Cond(func(params database.InsertWorkspaceAppStatusParams) bool {
if params.AgentID == agent.ID && params.AppID == app.ID {
assert.Equal(t, "testing", params.Message)
assert.Equal(t, database.WorkspaceAppStatusStateComplete, params.State)
assert.True(t, params.Uri.Valid)
assert.Equal(t, "https://example.com", params.Uri.String)
return true
}
return false
})).Times(1).Return(database.WorkspaceAppStatus{}, nil)
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "vscode",
Message: "testing",
Uri: "https://example.com",
State: agentproto.UpdateAppStatusRequest_COMPLETE,
})
require.NoError(t, err)
kind := testutil.RequireReceive(ctx, t, workspaceUpdates)
require.Equal(t, wspubsub.WorkspaceEventKindAgentAppStatusUpdate, kind)
sent := fEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTaskCompleted))
require.Len(t, sent, 1)
})
t.Run("FailUnknownApp", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
mDB.EXPECT().GetWorkspaceAppByAgentIDAndSlug(gomock.Any(), gomock.Any()).
Times(1).
Return(database.WorkspaceApp{}, sql.ErrNoRows)
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "unknown",
Message: "testing",
Uri: "https://example.com",
State: agentproto.UpdateAppStatusRequest_COMPLETE,
})
require.ErrorContains(t, err, "No app found with slug")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("FailUnknownState", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "vscode",
Message: "testing",
Uri: "https://example.com",
State: 77,
})
require.ErrorContains(t, err, "Invalid state")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("FailTooLong", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "vscode",
Message: strings.Repeat("a", 161),
Uri: "https://example.com",
State: agentproto.UpdateAppStatusRequest_COMPLETE,
})
require.ErrorContains(t, err, "Message is too long")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
}
+3 -6
View File
@@ -134,12 +134,9 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
case database.WorkspaceAgentLifecycleStateReady,
database.WorkspaceAgentLifecycleStateStartTimeout,
database.WorkspaceAgentLifecycleStateStartError:
// Only emit metrics for the parent agent, this metric is not intended to measure devcontainer durations.
if !workspaceAgent.ParentID.Valid {
a.emitMetricsOnce.Do(func() {
a.emitBuildDurationMetric(ctx, workspaceAgent.ResourceID)
})
}
a.emitMetricsOnce.Do(func() {
a.emitBuildDurationMetric(ctx, workspaceAgent.ResourceID)
})
}
return req.Lifecycle, nil
-58
View File
@@ -582,64 +582,6 @@ func TestUpdateLifecycle(t *testing.T) {
require.Equal(t, uint64(1), got.GetSampleCount())
require.Equal(t, expectedDuration, got.GetSampleSum())
})
t.Run("SubAgentDoesNotEmitMetric", func(t *testing.T) {
t.Parallel()
parentID := uuid.New()
subAgent := database.WorkspaceAgent{
ID: uuid.New(),
ParentID: uuid.NullUUID{UUID: parentID, Valid: true},
LifecycleState: database.WorkspaceAgentLifecycleStateStarting,
StartedAt: sql.NullTime{Valid: true, Time: someTime},
ReadyAt: sql.NullTime{Valid: false},
}
lifecycle := &agentproto.Lifecycle{
State: agentproto.Lifecycle_READY,
ChangedAt: timestamppb.New(now),
}
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: subAgent.ID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
StartedAt: subAgent.StartedAt,
ReadyAt: sql.NullTime{
Time: now,
Valid: true,
},
}).Return(nil)
// GetWorkspaceBuildMetricsByResourceID should NOT be called
// because sub-agents should be skipped before querying.
reg := prometheus.NewRegistry()
metrics := agentapi.NewLifecycleMetrics(reg)
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return subAgent, nil
},
WorkspaceID: workspaceID,
Database: dbM,
Log: testutil.Logger(t),
Metrics: metrics,
PublishWorkspaceUpdateFn: nil,
}
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
Lifecycle: lifecycle,
})
require.NoError(t, err)
require.Equal(t, lifecycle, resp)
// We don't expect the metric to be emitted for sub-agents, by default this will fail anyway but it doesn't hurt
// to document the test explicitly.
dbM.EXPECT().GetWorkspaceBuildMetricsByResourceID(gomock.Any(), gomock.Any()).Times(0)
// If we were emitting the metric we would have failed by now since it would include a call to the database that we're not expecting.
pm, err := reg.Gather()
require.NoError(t, err)
for _, m := range pm {
if m.GetName() == fullMetricName {
t.Fatal("metric should not be emitted for sub-agent")
}
}
})
}
func TestUpdateStartup(t *testing.T) {
+4 -2
View File
@@ -466,6 +466,7 @@ func (api *API) convertTasks(ctx context.Context, requesterID uuid.UUID, dbTasks
apiWorkspaces, err := convertWorkspaces(
ctx,
api.Experiments,
api.Logger,
requesterID,
workspaces,
@@ -545,6 +546,7 @@ func (api *API) taskGet(rw http.ResponseWriter, r *http.Request) {
ws, err := convertWorkspace(
ctx,
api.Experiments,
api.Logger,
apiKey.UserID,
workspace,
@@ -1248,7 +1250,7 @@ func (api *API) postWorkspaceAgentTaskLogSnapshot(rw http.ResponseWriter, r *htt
// @Summary Pause task
// @ID pause-task
// @Security CoderSessionToken
// @Produce json
// @Accept json
// @Tags Tasks
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
// @Param task path string true "Task ID" format(uuid)
@@ -1325,7 +1327,7 @@ func (api *API) pauseTask(rw http.ResponseWriter, r *http.Request) {
// @Summary Resume task
// @ID resume-task
// @Security CoderSessionToken
// @Produce json
// @Accept json
// @Tags Tasks
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
// @Param task path string true "Task ID" format(uuid)
+9 -18
View File
@@ -5894,7 +5894,7 @@ const docTemplate = `{
"CoderSessionToken": []
}
],
"produces": [
"consumes": [
"application/json"
],
"tags": [
@@ -5936,7 +5936,7 @@ const docTemplate = `{
"CoderSessionToken": []
}
],
"produces": [
"consumes": [
"application/json"
],
"tags": [
@@ -8238,12 +8238,6 @@ const docTemplate = `{
"name": "user",
"in": "path",
"required": true
},
{
"type": "boolean",
"description": "Include expired tokens in the list",
"name": "include_expired",
"in": "query"
}
],
"responses": {
@@ -9551,7 +9545,6 @@ const docTemplate = `{
],
"summary": "Patch workspace agent app status",
"operationId": "patch-workspace-agent-app-status",
"deprecated": true,
"parameters": [
{
"description": "app status",
@@ -13440,9 +13433,6 @@ const docTemplate = `{
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -13756,9 +13746,6 @@ const docTemplate = `{
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -15110,7 +15097,8 @@ const docTemplate = `{
"workspace-usage",
"web-push",
"oauth2",
"mcp-server-http"
"mcp-server-http",
"workspace-sharing"
],
"x-enum-comments": {
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
@@ -15119,6 +15107,7 @@ const docTemplate = `{
"ExperimentNotifications": "Sends notifications via SMTP and webhooks following certain events.",
"ExperimentOAuth2": "Enables OAuth2 provider functionality.",
"ExperimentWebPush": "Enables web push notifications through the browser.",
"ExperimentWorkspaceSharing": "Enables updating workspace ACLs for sharing with users and groups.",
"ExperimentWorkspaceUsage": "Enables the new workspace usage tracking."
},
"x-enum-descriptions": [
@@ -15128,7 +15117,8 @@ const docTemplate = `{
"Enables the new workspace usage tracking.",
"Enables web push notifications through the browser.",
"Enables OAuth2 provider functionality.",
"Enables the MCP HTTP server functionality."
"Enables the MCP HTTP server functionality.",
"Enables updating workspace ACLs for sharing with users and groups."
],
"x-enum-varnames": [
"ExperimentExample",
@@ -15137,7 +15127,8 @@ const docTemplate = `{
"ExperimentWorkspaceUsage",
"ExperimentWebPush",
"ExperimentOAuth2",
"ExperimentMCPServerHTTP"
"ExperimentMCPServerHTTP",
"ExperimentWorkspaceSharing"
]
},
"codersdk.ExternalAPIKeyScopes": {
+9 -18
View File
@@ -5213,7 +5213,7 @@
"CoderSessionToken": []
}
],
"produces": ["application/json"],
"consumes": ["application/json"],
"tags": ["Tasks"],
"summary": "Pause task",
"operationId": "pause-task",
@@ -5251,7 +5251,7 @@
"CoderSessionToken": []
}
],
"produces": ["application/json"],
"consumes": ["application/json"],
"tags": ["Tasks"],
"summary": "Resume task",
"operationId": "resume-task",
@@ -7285,12 +7285,6 @@
"name": "user",
"in": "path",
"required": true
},
{
"type": "boolean",
"description": "Include expired tokens in the list",
"name": "include_expired",
"in": "query"
}
],
"responses": {
@@ -8450,7 +8444,6 @@
"tags": ["Agents"],
"summary": "Patch workspace agent app status",
"operationId": "patch-workspace-agent-app-status",
"deprecated": true,
"parameters": [
{
"description": "app status",
@@ -12042,9 +12035,6 @@
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -12337,9 +12327,6 @@
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -13637,7 +13624,8 @@
"workspace-usage",
"web-push",
"oauth2",
"mcp-server-http"
"mcp-server-http",
"workspace-sharing"
],
"x-enum-comments": {
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
@@ -13646,6 +13634,7 @@
"ExperimentNotifications": "Sends notifications via SMTP and webhooks following certain events.",
"ExperimentOAuth2": "Enables OAuth2 provider functionality.",
"ExperimentWebPush": "Enables web push notifications through the browser.",
"ExperimentWorkspaceSharing": "Enables updating workspace ACLs for sharing with users and groups.",
"ExperimentWorkspaceUsage": "Enables the new workspace usage tracking."
},
"x-enum-descriptions": [
@@ -13655,7 +13644,8 @@
"Enables the new workspace usage tracking.",
"Enables web push notifications through the browser.",
"Enables OAuth2 provider functionality.",
"Enables the MCP HTTP server functionality."
"Enables the MCP HTTP server functionality.",
"Enables updating workspace ACLs for sharing with users and groups."
],
"x-enum-varnames": [
"ExperimentExample",
@@ -13664,7 +13654,8 @@
"ExperimentWorkspaceUsage",
"ExperimentWebPush",
"ExperimentOAuth2",
"ExperimentMCPServerHTTP"
"ExperimentMCPServerHTTP",
"ExperimentWorkspaceSharing"
]
},
"codersdk.ExternalAPIKeyScopes": {
+8 -14
View File
@@ -307,26 +307,20 @@ func (api *API) apiKeyByName(rw http.ResponseWriter, r *http.Request) {
// @Tags Users
// @Param user path string true "User ID, name, or me"
// @Success 200 {array} codersdk.APIKey
// @Param include_expired query bool false "Include expired tokens in the list"
// @Router /users/{user}/keys/tokens [get]
func (api *API) tokens(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
user = httpmw.UserParam(r)
keys []database.APIKey
err error
queryStr = r.URL.Query().Get("include_all")
includeAll, _ = strconv.ParseBool(queryStr)
expiredStr = r.URL.Query().Get("include_expired")
includeExpired, _ = strconv.ParseBool(expiredStr)
ctx = r.Context()
user = httpmw.UserParam(r)
keys []database.APIKey
err error
queryStr = r.URL.Query().Get("include_all")
includeAll, _ = strconv.ParseBool(queryStr)
)
if includeAll {
// get tokens for all users
keys, err = api.Database.GetAPIKeysByLoginType(ctx, database.GetAPIKeysByLoginTypeParams{
LoginType: database.LoginTypeToken,
IncludeExpired: includeExpired,
})
keys, err = api.Database.GetAPIKeysByLoginType(ctx, database.LoginTypeToken)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching API keys.",
@@ -336,7 +330,7 @@ func (api *API) tokens(rw http.ResponseWriter, r *http.Request) {
}
} else {
// get user's tokens only
keys, err = api.Database.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: user.ID, IncludeExpired: includeExpired})
keys, err = api.Database.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: user.ID})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching API keys.",
-38
View File
@@ -69,44 +69,6 @@ func TestTokenCRUD(t *testing.T) {
require.Equal(t, database.AuditActionDelete, auditor.AuditLogs()[numLogs-1].Action)
}
func TestTokensFilterExpired(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
adminClient := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, adminClient)
// Create a token.
res, err := adminClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
Lifetime: time.Hour * 24 * 7,
})
require.NoError(t, err)
keyID := strings.Split(res.Key, "-")[0]
// List tokens without including expired - should see the token.
keys, err := adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
require.NoError(t, err)
require.Len(t, keys, 1)
// Expire the token.
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
require.NoError(t, err)
// List tokens without including expired - should NOT see expired token.
keys, err = adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
require.NoError(t, err)
require.Empty(t, keys)
// List tokens WITH including expired - should see expired token.
keys, err = adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{
IncludeExpired: true,
})
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, keyID, keys[0].ID)
}
func TestTokenScoped(t *testing.T) {
t.Parallel()
+1 -8
View File
@@ -26,11 +26,6 @@ import (
"github.com/coder/coder/v2/codersdk"
)
// Limit the count query to avoid a slow sequential scan due to joins
// on a large table. Set to 0 to disable capping (but also see the note
// in the SQL query).
const auditLogCountCap = 2000
// @Summary Get audit logs
// @ID get-audit-logs
// @Security CoderSessionToken
@@ -71,7 +66,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
countFilter.Username = ""
}
countFilter.CountCap = auditLogCountCap
// Use the same filters to count the number of audit logs
count, err := api.Database.CountAuditLogs(ctx, countFilter)
if dbauthz.IsNotAuthorizedError(err) {
httpapi.Forbidden(rw)
@@ -86,7 +81,6 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
AuditLogs: []codersdk.AuditLog{},
Count: 0,
CountCap: auditLogCountCap,
})
return
}
@@ -104,7 +98,6 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
AuditLogs: api.convertAuditLogs(ctx, dblogs),
Count: count,
CountCap: auditLogCountCap,
})
}
+1 -1
View File
@@ -500,7 +500,7 @@ func (e *Executor) runOnce(t time.Time) Stats {
"task": task.Name,
"task_id": task.ID.String(),
"workspace": ws.Name,
"pause_reason": "idle timeout",
"pause_reason": "inactivity exceeded the dormancy threshold",
},
"lifecycle_executor",
ws.ID, ws.OwnerID, ws.OrganizationID,
+1 -1
View File
@@ -2082,6 +2082,6 @@ func TestExecutorTaskWorkspace(t *testing.T) {
require.Equal(t, task.Name, sent[0].Labels["task"])
require.Equal(t, task.ID.String(), sent[0].Labels["task_id"])
require.Equal(t, workspace.Name, sent[0].Labels["workspace"])
require.Equal(t, "idle timeout", sent[0].Labels["pause_reason"])
require.Equal(t, "inactivity exceeded the dormancy threshold", sent[0].Labels["pause_reason"])
})
}
+9 -10
View File
@@ -98,7 +98,6 @@ import (
"github.com/coder/coder/v2/provisionersdk"
"github.com/coder/coder/v2/site"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/derpmetrics"
"github.com/coder/quartz"
"github.com/coder/serpent"
)
@@ -330,10 +329,9 @@ func New(options *Options) *API {
panic("developer error: options.PrometheusRegistry is nil and not running a unit test")
}
if options.DeploymentValues.DisableOwnerWorkspaceExec || options.DeploymentValues.DisableWorkspaceSharing {
if options.DeploymentValues.DisableOwnerWorkspaceExec {
rbac.ReloadBuiltinRoles(&rbac.RoleOptions{
NoOwnerWorkspaceExec: bool(options.DeploymentValues.DisableOwnerWorkspaceExec),
NoWorkspaceSharing: bool(options.DeploymentValues.DisableWorkspaceSharing),
NoOwnerWorkspaceExec: true,
})
}
@@ -884,18 +882,17 @@ func New(options *Options) *API {
apiRateLimiter := httpmw.RateLimit(options.APIRateLimit, time.Minute)
// Register DERP on expvar HTTP handler, which we serve below in the router, c.f. expvar.Handler()
// These are the metrics the DERP server exposes.
// TODO: export via prometheus
expDERPOnce.Do(func() {
// We need to do this via a global Once because expvar registry is global and panics if we
// register multiple times. In production there is only one Coderd and one DERP server per
// process, but in testing, we create multiple of both, so the Once protects us from
// panicking.
if options.DERPServer != nil && expvar.Get("derp") == nil {
if options.DERPServer != nil {
expvar.Publish("derp", api.DERPServer.ExpVar())
}
})
if options.PrometheusRegistry != nil && options.DERPServer != nil {
options.PrometheusRegistry.MustRegister(derpmetrics.NewDERPExpvarCollector(options.DERPServer))
}
cors := httpmw.Cors(options.DeploymentValues.Dangerous.AllowAllCors.Value())
prometheusMW := httpmw.Prometheus(options.PrometheusRegistry)
@@ -1529,6 +1526,10 @@ func New(options *Options) *API {
})
r.Get("/timings", api.workspaceTimings)
r.Route("/acl", func(r chi.Router) {
r.Use(
httpmw.RequireExperiment(api.Experiments, codersdk.ExperimentWorkspaceSharing),
)
r.Get("/", api.workspaceACL)
r.Patch("/", api.patchWorkspaceACL)
r.Delete("/", api.deleteWorkspaceACL)
@@ -1737,8 +1738,6 @@ func New(options *Options) *API {
r.Patch("/input", api.taskUpdateInput)
r.Post("/send", api.taskSend)
r.Get("/logs", api.taskLogs)
r.Post("/pause", api.pauseTask)
r.Post("/resume", api.resumeTask)
})
})
})
+2 -28
View File
@@ -384,35 +384,9 @@ func TestCSRFExempt(t *testing.T) {
data, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
// A StatusNotFound means Coderd tried to proxy to the agent and failed because the agent
// A StatusBadGateway means Coderd tried to proxy to the agent and failed because the agent
// was not there. This means CSRF did not block the app request, which is what we want.
require.Equal(t, http.StatusNotFound, resp.StatusCode, "status code 500 is CSRF failure")
require.Equal(t, http.StatusBadGateway, resp.StatusCode, "status code 500 is CSRF failure")
require.NotContains(t, string(data), "CSRF")
})
}
func TestDERPMetrics(t *testing.T) {
t.Parallel()
_, _, api := coderdtest.NewWithAPI(t, nil)
require.NotNil(t, api.Options.DERPServer, "DERP server should be configured")
require.NotNil(t, api.Options.PrometheusRegistry, "Prometheus registry should be configured")
// The registry is created internally by coderd. Gather from it
// to verify DERP metrics were registered during startup.
metrics, err := api.Options.PrometheusRegistry.Gather()
require.NoError(t, err)
names := make(map[string]struct{})
for _, m := range metrics {
names[m.GetName()] = struct{}{}
}
assert.Contains(t, names, "coder_derp_server_connections",
"expected coder_derp_server_connections to be registered")
assert.Contains(t, names, "coder_derp_server_bytes_received_total",
"expected coder_derp_server_bytes_received_total to be registered")
assert.Contains(t, names, "coder_derp_server_packets_dropped_reason_total",
"expected coder_derp_server_packets_dropped_reason_total to be registered")
}
+2 -14
View File
@@ -106,8 +106,6 @@ import (
"github.com/coder/quartz"
)
const DefaultDERPMeshKey = "test-key"
const defaultTestDaemonName = "test-daemon"
type Options struct {
@@ -514,18 +512,8 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
stunAddresses = options.DeploymentValues.DERP.Server.STUNAddresses.Value()
}
const derpMeshKey = "test-key"
// Technically AGPL coderd servers don't set this value, but it doesn't
// change any behavior. It's useful for enterprise tests.
err = options.Database.InsertDERPMeshKey(dbauthz.AsSystemRestricted(ctx), derpMeshKey) //nolint:gocritic // test
if !database.IsUniqueViolation(err, database.UniqueSiteConfigsKeyKey) {
require.NoError(t, err, "insert DERP mesh key")
}
var derpServer *derp.Server
if options.DeploymentValues.DERP.Server.Enable.Value() {
derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp").Leveled(slog.LevelDebug)))
derpServer.SetMeshKey(derpMeshKey)
}
derpServer := derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp").Leveled(slog.LevelDebug)))
derpServer.SetMeshKey("test-key")
// match default with cli default
if options.SSHKeygenAlgorithm == "" {
+3 -48
View File
@@ -668,31 +668,6 @@ var (
}),
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
subjectWorkspaceBuilder = rbac.Subject{
Type: rbac.SubjectTypeWorkspaceBuilder,
FriendlyName: "Workspace Builder",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
{
Identifier: rbac.RoleIdentifier{Name: "workspace-builder"},
DisplayName: "Workspace Builder",
Site: rbac.Permissions(map[string][]policy.Action{
// Reading provisioner daemons to check eligibility.
rbac.ResourceProvisionerDaemon.Type: {policy.ActionRead},
// Updating provisioner jobs (e.g. marking prebuild
// jobs complete).
rbac.ResourceProvisionerJobs.Type: {policy.ActionUpdate},
// Reading provisioner state requires template update
// permission.
rbac.ResourceTemplate.Type: {policy.ActionUpdate},
}),
User: []rbac.Permission{},
ByOrgID: map[string]rbac.OrgPermissions{},
},
}),
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
)
// AsProvisionerd returns a context with an actor that has permissions required
@@ -799,14 +774,6 @@ func AsBoundaryUsageTracker(ctx context.Context) context.Context {
return As(ctx, subjectBoundaryUsageTracker)
}
// AsWorkspaceBuilder returns a context with an actor that has permissions
// required for the workspace builder to prepare workspace builds. This
// includes reading provisioner daemons, updating provisioner jobs, and
// reading provisioner state (which requires template update permission).
func AsWorkspaceBuilder(ctx context.Context) context.Context {
return As(ctx, subjectWorkspaceBuilder)
}
var AsRemoveActor = rbac.Subject{
ID: "remove-actor",
}
@@ -2194,12 +2161,12 @@ func (q *querier) GetAPIKeyByName(ctx context.Context, arg database.GetAPIKeyByN
return fetch(q.log, q.auth, q.db.GetAPIKeyByName)(ctx, arg)
}
func (q *querier) GetAPIKeysByLoginType(ctx context.Context, loginType database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
func (q *querier) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByLoginType)(ctx, loginType)
}
func (q *querier) GetAPIKeysByUserID(ctx context.Context, params database.GetAPIKeysByUserIDParams) ([]database.APIKey, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByUserID)(ctx, params)
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByUserID)(ctx, database.GetAPIKeysByUserIDParams{LoginType: params.LoginType, UserID: params.UserID})
}
func (q *querier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]database.APIKey, error) {
@@ -2290,7 +2257,7 @@ func (q *querier) GetAuditLogsOffset(ctx context.Context, arg database.GetAuditL
}
func (q *querier) GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx context.Context, authToken uuid.UUID) (database.GetAuthenticatedWorkspaceAgentAndBuildByAuthTokenRow, error) {
// This is a system function.
// This is a system function
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return database.GetAuthenticatedWorkspaceAgentAndBuildByAuthTokenRow{}, err
}
@@ -3166,13 +3133,6 @@ func (q *querier) GetTelemetryItems(ctx context.Context) ([]database.TelemetryIt
return q.db.GetTelemetryItems(ctx)
}
func (q *querier) GetTelemetryTaskEvents(ctx context.Context, arg database.GetTelemetryTaskEventsParams) ([]database.GetTelemetryTaskEventsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTask.All()); err != nil {
return nil, err
}
return q.db.GetTelemetryTaskEvents(ctx, arg)
}
func (q *querier) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) {
if err := q.authorizeTemplateInsights(ctx, arg.TemplateIDs); err != nil {
return nil, err
@@ -3954,11 +3914,6 @@ func (q *querier) GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, wor
return q.db.GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx, workspaceBuildIDs, prep)
}
func (q *querier) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, buildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
// Fetching the provisioner state requires Update permission on the template.
return fetchWithAction(q.log, q.auth, policy.ActionUpdate, q.db.GetWorkspaceBuildProvisionerStateByID)(ctx, buildID)
}
func (q *querier) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
+2 -16
View File
@@ -237,8 +237,8 @@ func (s *MethodTestSuite) TestAPIKey() {
s.Run("GetAPIKeysByLoginType", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
a := testutil.Fake(s.T(), faker, database.APIKey{LoginType: database.LoginTypePassword})
b := testutil.Fake(s.T(), faker, database.APIKey{LoginType: database.LoginTypePassword})
dbm.EXPECT().GetAPIKeysByLoginType(gomock.Any(), database.GetAPIKeysByLoginTypeParams{LoginType: database.LoginTypePassword}).Return([]database.APIKey{a, b}, nil).AnyTimes()
check.Args(database.GetAPIKeysByLoginTypeParams{LoginType: database.LoginTypePassword}).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
dbm.EXPECT().GetAPIKeysByLoginType(gomock.Any(), database.LoginTypePassword).Return([]database.APIKey{a, b}, nil).AnyTimes()
check.Args(database.LoginTypePassword).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
}))
s.Run("GetAPIKeysByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u1 := testutil.Fake(s.T(), faker, database.User{})
@@ -1326,11 +1326,6 @@ func (s *MethodTestSuite) TestTemplate() {
dbm.EXPECT().GetTemplateInsightsByTemplate(gomock.Any(), arg).Return([]database.GetTemplateInsightsByTemplateRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceTemplate, policy.ActionViewInsights)
}))
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetTelemetryTaskEventsParams{}
dbm.EXPECT().GetTelemetryTaskEvents(gomock.Any(), arg).Return([]database.GetTelemetryTaskEventsRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceTask.All(), policy.ActionRead)
}))
s.Run("GetTemplateAppInsights", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetTemplateAppInsightsParams{}
dbm.EXPECT().GetTemplateAppInsights(gomock.Any(), arg).Return([]database.GetTemplateAppInsightsRow{}, nil).AnyTimes()
@@ -1974,15 +1969,6 @@ func (s *MethodTestSuite) TestWorkspace() {
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
check.Args(build.ID).Asserts(ws, policy.ActionRead).Returns(build)
}))
s.Run("GetWorkspaceBuildProvisionerStateByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
row := database.GetWorkspaceBuildProvisionerStateByIDRow{
ProvisionerState: []byte("state"),
TemplateID: uuid.New(),
TemplateOrganizationID: uuid.New(),
}
dbm.EXPECT().GetWorkspaceBuildProvisionerStateByID(gomock.Any(), gomock.Any()).Return(row, nil).AnyTimes()
check.Args(uuid.New()).Asserts(row, policy.ActionUpdate).Returns(row)
}))
s.Run("GetWorkspaceBuildByJobID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ws := testutil.Fake(s.T(), faker, database.Workspace{})
build := testutil.Fake(s.T(), faker, database.WorkspaceBuild{WorkspaceID: ws.ID})
-19
View File
@@ -67,8 +67,6 @@ type WorkspaceBuildBuilder struct {
jobError string // Error message for failed jobs
jobErrorCode string // Error code for failed jobs
provisionerState []byte
}
// BuilderOption is a functional option for customizing job timestamps
@@ -140,15 +138,6 @@ func (b WorkspaceBuildBuilder) Seed(seed database.WorkspaceBuild) WorkspaceBuild
return b
}
// ProvisionerState sets the provisioner state for the workspace build.
// This is stored separately from the seed because ProvisionerState is
// not part of the WorkspaceBuild view struct.
func (b WorkspaceBuildBuilder) ProvisionerState(state []byte) WorkspaceBuildBuilder {
//nolint: revive // returns modified struct
b.provisionerState = state
return b
}
func (b WorkspaceBuildBuilder) Resource(resource ...*sdkproto.Resource) WorkspaceBuildBuilder {
//nolint: revive // returns modified struct
b.resources = append(b.resources, resource...)
@@ -475,14 +464,6 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
}
resp.Build = dbgen.WorkspaceBuild(b.t, b.db, b.seed)
if len(b.provisionerState) > 0 {
err = b.db.UpdateWorkspaceBuildProvisionerStateByID(ownerCtx, database.UpdateWorkspaceBuildProvisionerStateByIDParams{
ID: resp.Build.ID,
UpdatedAt: dbtime.Now(),
ProvisionerState: b.provisionerState,
})
require.NoError(b.t, err, "update provisioner state")
}
b.logger.Debug(context.Background(), "created workspace build",
slog.F("build_id", resp.Build.ID),
slog.F("workspace_id", resp.Workspace.ID),
+1 -3
View File
@@ -504,7 +504,7 @@ func WorkspaceBuild(t testing.TB, db database.Store, orig database.WorkspaceBuil
Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart),
InitiatorID: takeFirst(orig.InitiatorID, uuid.New()),
JobID: jobID,
ProvisionerState: []byte{},
ProvisionerState: takeFirstSlice(orig.ProvisionerState, []byte{}),
Deadline: takeFirst(orig.Deadline, dbtime.Now().Add(time.Hour)),
MaxDeadline: takeFirst(orig.MaxDeadline, time.Time{}),
Reason: takeFirst(orig.Reason, database.BuildReasonInitiator),
@@ -1373,8 +1373,6 @@ func OAuth2ProviderAppCode(t testing.TB, db database.Store, seed database.OAuth2
ResourceUri: seed.ResourceUri,
CodeChallenge: seed.CodeChallenge,
CodeChallengeMethod: seed.CodeChallengeMethod,
StateHash: seed.StateHash,
RedirectUri: seed.RedirectUri,
})
require.NoError(t, err, "insert oauth2 app code")
return code
+1 -17
View File
@@ -774,7 +774,7 @@ func (m queryMetricsStore) GetAPIKeyByName(ctx context.Context, arg database.Get
return r0, r1
}
func (m queryMetricsStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
func (m queryMetricsStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
start := time.Now()
r0, r1 := m.s.GetAPIKeysByLoginType(ctx, loginType)
m.queryLatencies.WithLabelValues("GetAPIKeysByLoginType").Observe(time.Since(start).Seconds())
@@ -1790,14 +1790,6 @@ func (m queryMetricsStore) GetTelemetryItems(ctx context.Context) ([]database.Te
return r0, r1
}
func (m queryMetricsStore) GetTelemetryTaskEvents(ctx context.Context, createdAfter database.GetTelemetryTaskEventsParams) ([]database.GetTelemetryTaskEventsRow, error) {
start := time.Now()
r0, r1 := m.s.GetTelemetryTaskEvents(ctx, createdAfter)
m.queryLatencies.WithLabelValues("GetTelemetryTaskEvents").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTelemetryTaskEvents").Inc()
return r0, r1
}
func (m queryMetricsStore) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) {
start := time.Now()
r0, r1 := m.s.GetTemplateAppInsights(ctx, arg)
@@ -2438,14 +2430,6 @@ func (m queryMetricsStore) GetWorkspaceBuildParametersByBuildIDs(ctx context.Con
return r0, r1
}
func (m queryMetricsStore) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
start := time.Now()
r0, r1 := m.s.GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuildID)
m.queryLatencies.WithLabelValues("GetWorkspaceBuildProvisionerStateByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetWorkspaceBuildProvisionerStateByID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
start := time.Now()
r0, r1 := m.s.GetWorkspaceBuildStatsByTemplates(ctx, since)
+4 -34
View File
@@ -1305,18 +1305,18 @@ func (mr *MockStoreMockRecorder) GetAPIKeyByName(ctx, arg any) *gomock.Call {
}
// GetAPIKeysByLoginType mocks base method.
func (m *MockStore) GetAPIKeysByLoginType(ctx context.Context, arg database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
func (m *MockStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAPIKeysByLoginType", ctx, arg)
ret := m.ctrl.Call(m, "GetAPIKeysByLoginType", ctx, loginType)
ret0, _ := ret[0].([]database.APIKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAPIKeysByLoginType indicates an expected call of GetAPIKeysByLoginType.
func (mr *MockStoreMockRecorder) GetAPIKeysByLoginType(ctx, arg any) *gomock.Call {
func (mr *MockStoreMockRecorder) GetAPIKeysByLoginType(ctx, loginType any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeysByLoginType", reflect.TypeOf((*MockStore)(nil).GetAPIKeysByLoginType), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeysByLoginType", reflect.TypeOf((*MockStore)(nil).GetAPIKeysByLoginType), ctx, loginType)
}
// GetAPIKeysByUserID mocks base method.
@@ -3314,21 +3314,6 @@ func (mr *MockStoreMockRecorder) GetTelemetryItems(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTelemetryItems", reflect.TypeOf((*MockStore)(nil).GetTelemetryItems), ctx)
}
// GetTelemetryTaskEvents mocks base method.
func (m *MockStore) GetTelemetryTaskEvents(ctx context.Context, arg database.GetTelemetryTaskEventsParams) ([]database.GetTelemetryTaskEventsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTelemetryTaskEvents", ctx, arg)
ret0, _ := ret[0].([]database.GetTelemetryTaskEventsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTelemetryTaskEvents indicates an expected call of GetTelemetryTaskEvents.
func (mr *MockStoreMockRecorder) GetTelemetryTaskEvents(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTelemetryTaskEvents", reflect.TypeOf((*MockStore)(nil).GetTelemetryTaskEvents), ctx, arg)
}
// GetTemplateAppInsights mocks base method.
func (m *MockStore) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) {
m.ctrl.T.Helper()
@@ -4559,21 +4544,6 @@ func (mr *MockStoreMockRecorder) GetWorkspaceBuildParametersByBuildIDs(ctx, work
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceBuildParametersByBuildIDs", reflect.TypeOf((*MockStore)(nil).GetWorkspaceBuildParametersByBuildIDs), ctx, workspaceBuildIds)
}
// GetWorkspaceBuildProvisionerStateByID mocks base method.
func (m *MockStore) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetWorkspaceBuildProvisionerStateByID", ctx, workspaceBuildID)
ret0, _ := ret[0].(database.GetWorkspaceBuildProvisionerStateByIDRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetWorkspaceBuildProvisionerStateByID indicates an expected call of GetWorkspaceBuildProvisionerStateByID.
func (mr *MockStoreMockRecorder) GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuildID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceBuildProvisionerStateByID", reflect.TypeOf((*MockStore)(nil).GetWorkspaceBuildProvisionerStateByID), ctx, workspaceBuildID)
}
// GetWorkspaceBuildStatsByTemplates mocks base method.
func (m *MockStore) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
m.ctrl.T.Helper()
+2 -7
View File
@@ -1471,9 +1471,7 @@ CREATE TABLE oauth2_provider_app_codes (
app_id uuid NOT NULL,
resource_uri text,
code_challenge text,
code_challenge_method text,
state_hash text,
redirect_uri text
code_challenge_method text
);
COMMENT ON TABLE oauth2_provider_app_codes IS 'Codes are meant to be exchanged for access tokens.';
@@ -1484,10 +1482,6 @@ COMMENT ON COLUMN oauth2_provider_app_codes.code_challenge IS 'PKCE code challen
COMMENT ON COLUMN oauth2_provider_app_codes.code_challenge_method IS 'PKCE challenge method (S256)';
COMMENT ON COLUMN oauth2_provider_app_codes.state_hash IS 'SHA-256 hash of the OAuth2 state parameter, stored to prevent state reflection attacks.';
COMMENT ON COLUMN oauth2_provider_app_codes.redirect_uri IS 'The redirect_uri provided during authorization, to be verified during token exchange (RFC 6749 §4.1.3).';
CREATE TABLE oauth2_provider_app_secrets (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
@@ -2708,6 +2702,7 @@ CREATE VIEW workspace_build_with_user AS
workspace_builds.build_number,
workspace_builds.transition,
workspace_builds.initiator_id,
workspace_builds.provisioner_state,
workspace_builds.job_id,
workspace_builds.deadline,
workspace_builds.reason,
+4 -23
View File
@@ -51,34 +51,15 @@ func TestViewSubsetTemplateVersion(t *testing.T) {
}
}
// TestViewSubsetWorkspaceBuild ensures WorkspaceBuildTable is a subset of
// WorkspaceBuild, with the exception of ProvisionerState which is
// intentionally excluded from the workspace_build_with_user view to avoid
// loading the large Terraform state blob on hot paths.
// TestViewSubsetWorkspaceBuild ensures WorkspaceBuildTable is a subset of WorkspaceBuild
func TestViewSubsetWorkspaceBuild(t *testing.T) {
t.Parallel()
table := reflect.TypeOf(database.WorkspaceBuildTable{})
joined := reflect.TypeOf(database.WorkspaceBuild{})
tableFields := fieldNames(allFields(table))
joinedFields := fieldNames(allFields(joined))
// ProvisionerState is intentionally excluded from the
// workspace_build_with_user view to avoid loading multi-MB Terraform
// state blobs on hot paths. Callers that need it use
// GetWorkspaceBuildProvisionerStateByID instead.
excludedFields := map[string]bool{
"ProvisionerState": true,
}
var filtered []string
for _, name := range tableFields {
if !excludedFields[name] {
filtered = append(filtered, name)
}
}
if !assert.Subset(t, joinedFields, filtered, "table is not subset") {
tableFields := allFields(table)
joinedFields := allFields(joined)
if !assert.Subset(t, fieldNames(joinedFields), fieldNames(tableFields), "table is not subset") {
t.Log("Some fields were added to the WorkspaceBuild Table without updating the 'workspace_build_with_user' view.")
t.Log("See migration 000141_join_users_build_version.up.sql to create the view.")
}
@@ -1,3 +0,0 @@
ALTER TABLE oauth2_provider_app_codes
DROP COLUMN state_hash,
DROP COLUMN redirect_uri;
@@ -1,9 +0,0 @@
ALTER TABLE oauth2_provider_app_codes
ADD COLUMN state_hash text,
ADD COLUMN redirect_uri text;
COMMENT ON COLUMN oauth2_provider_app_codes.state_hash IS
'SHA-256 hash of the OAuth2 state parameter, stored to prevent state reflection attacks.';
COMMENT ON COLUMN oauth2_provider_app_codes.redirect_uri IS
'The redirect_uri provided during authorization, to be verified during token exchange (RFC 6749 §4.1.3).';
@@ -1,31 +0,0 @@
-- Restore provisioner_state to workspace_build_with_user view.
DROP VIEW workspace_build_with_user;
CREATE VIEW workspace_build_with_user AS
SELECT
workspace_builds.id,
workspace_builds.created_at,
workspace_builds.updated_at,
workspace_builds.workspace_id,
workspace_builds.template_version_id,
workspace_builds.build_number,
workspace_builds.transition,
workspace_builds.initiator_id,
workspace_builds.provisioner_state,
workspace_builds.job_id,
workspace_builds.deadline,
workspace_builds.reason,
workspace_builds.daily_cost,
workspace_builds.max_deadline,
workspace_builds.template_version_preset_id,
workspace_builds.has_ai_task,
workspace_builds.has_external_agent,
COALESCE(visible_users.avatar_url, ''::text) AS initiator_by_avatar_url,
COALESCE(visible_users.username, ''::text) AS initiator_by_username,
COALESCE(visible_users.name, ''::text) AS initiator_by_name
FROM
workspace_builds
LEFT JOIN
visible_users ON workspace_builds.initiator_id = visible_users.id;
COMMENT ON VIEW workspace_build_with_user IS 'Joins in the username + avatar url of the initiated by user.';
@@ -1,33 +0,0 @@
-- Drop and recreate workspace_build_with_user to exclude provisioner_state.
-- This avoids loading the large Terraform state blob (1-5 MB per workspace)
-- on every query that uses this view. The callers that need provisioner_state
-- now fetch it separately via GetWorkspaceBuildProvisionerStateByID.
DROP VIEW workspace_build_with_user;
CREATE VIEW workspace_build_with_user AS
SELECT
workspace_builds.id,
workspace_builds.created_at,
workspace_builds.updated_at,
workspace_builds.workspace_id,
workspace_builds.template_version_id,
workspace_builds.build_number,
workspace_builds.transition,
workspace_builds.initiator_id,
workspace_builds.job_id,
workspace_builds.deadline,
workspace_builds.reason,
workspace_builds.daily_cost,
workspace_builds.max_deadline,
workspace_builds.template_version_preset_id,
workspace_builds.has_ai_task,
workspace_builds.has_external_agent,
COALESCE(visible_users.avatar_url, ''::text) AS initiator_by_avatar_url,
COALESCE(visible_users.username, ''::text) AS initiator_by_username,
COALESCE(visible_users.name, ''::text) AS initiator_by_name
FROM
workspace_builds
LEFT JOIN
visible_users ON workspace_builds.initiator_id = visible_users.id;
COMMENT ON VIEW workspace_build_with_user IS 'Joins in the username + avatar url of the initiated by user.';
-8
View File
@@ -316,14 +316,6 @@ func (t GetFileTemplatesRow) RBACObject() rbac.Object {
WithGroupACL(t.GroupACL)
}
// RBACObject for a workspace build's provisioner state requires Update access of the template.
func (t GetWorkspaceBuildProvisionerStateByIDRow) RBACObject() rbac.Object {
return rbac.ResourceTemplate.WithID(t.TemplateID).
InOrg(t.TemplateOrganizationID).
WithACLUserList(t.UserACL).
WithGroupACL(t.GroupACL)
}
func (t Template) DeepCopy() Template {
cpy := t
cpy.UserACL = maps.Clone(t.UserACL)
-2
View File
@@ -610,7 +610,6 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi
arg.DateTo,
arg.BuildReason,
arg.RequestID,
arg.CountCap,
)
if err != nil {
return 0, err
@@ -747,7 +746,6 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
arg.WorkspaceID,
arg.ConnectionID,
arg.Status,
arg.CountCap,
)
if err != nil {
return 0, err
@@ -145,13 +145,5 @@ func extractWhereClause(query string) string {
// Remove SQL comments
whereClause = regexp.MustCompile(`(?m)--.*$`).ReplaceAllString(whereClause, "")
// Normalize indentation so subquery wrapping doesn't cause
// mismatches.
lines := strings.Split(whereClause, "\n")
for i, line := range lines {
lines[i] = strings.TrimLeft(line, " \t")
}
whereClause = strings.Join(lines, "\n")
return strings.TrimSpace(whereClause)
}
+1 -4
View File
@@ -4026,10 +4026,6 @@ type OAuth2ProviderAppCode struct {
CodeChallenge sql.NullString `db:"code_challenge" json:"code_challenge"`
// PKCE challenge method (S256)
CodeChallengeMethod sql.NullString `db:"code_challenge_method" json:"code_challenge_method"`
// SHA-256 hash of the OAuth2 state parameter, stored to prevent state reflection attacks.
StateHash sql.NullString `db:"state_hash" json:"state_hash"`
// The redirect_uri provided during authorization, to be verified during token exchange (RFC 6749 §4.1.3).
RedirectUri sql.NullString `db:"redirect_uri" json:"redirect_uri"`
}
type OAuth2ProviderAppSecret struct {
@@ -4987,6 +4983,7 @@ type WorkspaceBuild struct {
BuildNumber int32 `db:"build_number" json:"build_number"`
Transition WorkspaceTransition `db:"transition" json:"transition"`
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
ProvisionerState []byte `db:"provisioner_state" json:"provisioner_state"`
JobID uuid.UUID `db:"job_id" json:"job_id"`
Deadline time.Time `db:"deadline" json:"deadline"`
Reason BuildReason `db:"reason" json:"reason"`
+1 -23
View File
@@ -169,7 +169,7 @@ type sqlcQuerier interface {
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
// there is no unique constraint on empty token names
GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNameParams) (APIKey, error)
GetAPIKeysByLoginType(ctx context.Context, arg GetAPIKeysByLoginTypeParams) ([]APIKey, error)
GetAPIKeysByLoginType(ctx context.Context, loginType LoginType) ([]APIKey, error)
GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error)
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
GetActivePresetPrebuildSchedules(ctx context.Context) ([]TemplateVersionPresetPrebuildSchedule, error)
@@ -361,23 +361,6 @@ type sqlcQuerier interface {
GetTaskSnapshot(ctx context.Context, taskID uuid.UUID) (TaskSnapshot, error)
GetTelemetryItem(ctx context.Context, key string) (TelemetryItem, error)
GetTelemetryItems(ctx context.Context) ([]TelemetryItem, error)
// Returns all data needed to build task lifecycle events for telemetry
// in a single round-trip. For each task whose workspace is in the
// given set, fetches:
// - the latest workspace app binding (task_workspace_apps)
// - the most recent stop and start builds (workspace_builds)
// - the last "working" app status (workspace_app_statuses)
// - the first app status after resume, for active workspaces
//
// Assumptions:
// - 1:1 relationship between tasks and workspaces. All builds on the
// workspace are considered task-related.
// - Idle duration approximation: If the agent reports "working", does
// work, then reports "done", we miss that working time.
// - lws and active_dur join across all historical app IDs for the task,
// because each resume cycle provisions a new app ID. This ensures
// pre-pause statuses contribute to idle duration and active duration.
GetTelemetryTaskEvents(ctx context.Context, arg GetTelemetryTaskEventsParams) ([]GetTelemetryTaskEventsRow, error)
// GetTemplateAppInsights returns the aggregate usage of each app in a given
// timeframe. The result can be filtered on template_ids, meaning only user data
// from workspaces based on those templates will be included.
@@ -523,11 +506,6 @@ type sqlcQuerier interface {
GetWorkspaceBuildMetricsByResourceID(ctx context.Context, id uuid.UUID) (GetWorkspaceBuildMetricsByResourceIDRow, error)
GetWorkspaceBuildParameters(ctx context.Context, workspaceBuildID uuid.UUID) ([]WorkspaceBuildParameter, error)
GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIds []uuid.UUID) ([]WorkspaceBuildParameter, error)
// Fetches the provisioner state of a workspace build, joined through to the
// template so that dbauthz can enforce policy.ActionUpdate on the template.
// Provisioner state contains sensitive Terraform state and should only be
// accessible to template administrators.
GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (GetWorkspaceBuildProvisionerStateByIDRow, error)
GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]GetWorkspaceBuildStatsByTemplatesRow, error)
GetWorkspaceBuildsByWorkspaceID(ctx context.Context, arg GetWorkspaceBuildsByWorkspaceIDParams) ([]WorkspaceBuild, error)
GetWorkspaceBuildsCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceBuild, error)
+6 -129
View File
@@ -8195,9 +8195,8 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
// All keys are present before deletion
keys, err := db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
LoginType: user.LoginType,
UserID: user.ID,
IncludeExpired: true,
LoginType: user.LoginType,
UserID: user.ID,
})
require.NoError(t, err)
require.Len(t, keys, len(expiredTimes)+len(unexpiredTimes))
@@ -8213,9 +8212,8 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
// Ensure it was deleted
remaining, err := db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
LoginType: user.LoginType,
UserID: user.ID,
IncludeExpired: true,
LoginType: user.LoginType,
UserID: user.ID,
})
require.NoError(t, err)
require.Len(t, remaining, len(expiredTimes)+len(unexpiredTimes)-1)
@@ -8230,9 +8228,8 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
// Ensure only unexpired keys remain
remaining, err = db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
LoginType: user.LoginType,
UserID: user.ID,
IncludeExpired: true,
LoginType: user.LoginType,
UserID: user.ID,
})
require.NoError(t, err)
require.Len(t, remaining, len(unexpiredTimes))
@@ -8742,123 +8739,3 @@ func TestInsertWorkspaceAgentDevcontainers(t *testing.T) {
})
}
}
func TestGetWorkspaceBuildMetricsByResourceID(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
org := dbgen.Organization(t, db, database.Organization{})
user := dbgen.User(t, db, database.User{})
tmpl := dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
})
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
TemplateID: uuid.NullUUID{UUID: tmpl.ID, Valid: true},
CreatedBy: user.ID,
})
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
OrganizationID: org.ID,
TemplateID: tmpl.ID,
OwnerID: user.ID,
AutomaticUpdates: database.AutomaticUpdatesNever,
})
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
OrganizationID: org.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
})
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: ws.ID,
TemplateVersionID: tv.ID,
JobID: job.ID,
InitiatorID: user.ID,
})
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
JobID: job.ID,
})
parentReadyAt := dbtime.Now()
parentStartedAt := parentReadyAt.Add(-time.Second)
_ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ResourceID: resource.ID,
StartedAt: sql.NullTime{Time: parentStartedAt, Valid: true},
ReadyAt: sql.NullTime{Time: parentReadyAt, Valid: true},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
})
row, err := db.GetWorkspaceBuildMetricsByResourceID(ctx, resource.ID)
require.NoError(t, err)
require.True(t, row.AllAgentsReady)
require.True(t, parentReadyAt.Equal(row.LastAgentReadyAt))
require.Equal(t, "success", row.WorstStatus)
})
t.Run("SubAgentExcluded", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
org := dbgen.Organization(t, db, database.Organization{})
user := dbgen.User(t, db, database.User{})
tmpl := dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
})
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
TemplateID: uuid.NullUUID{UUID: tmpl.ID, Valid: true},
CreatedBy: user.ID,
})
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
OrganizationID: org.ID,
TemplateID: tmpl.ID,
OwnerID: user.ID,
AutomaticUpdates: database.AutomaticUpdatesNever,
})
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
OrganizationID: org.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
})
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: ws.ID,
TemplateVersionID: tv.ID,
JobID: job.ID,
InitiatorID: user.ID,
})
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
JobID: job.ID,
})
parentReadyAt := dbtime.Now()
parentStartedAt := parentReadyAt.Add(-time.Second)
parentAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ResourceID: resource.ID,
StartedAt: sql.NullTime{Time: parentStartedAt, Valid: true},
ReadyAt: sql.NullTime{Time: parentReadyAt, Valid: true},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
})
// Sub-agent with ready_at 1 hour later should be excluded.
subAgentReadyAt := parentReadyAt.Add(time.Hour)
subAgentStartedAt := subAgentReadyAt.Add(-time.Second)
_ = dbgen.WorkspaceSubAgent(t, db, parentAgent, database.WorkspaceAgent{
StartedAt: sql.NullTime{Time: subAgentStartedAt, Valid: true},
ReadyAt: sql.NullTime{Time: subAgentReadyAt, Valid: true},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
})
row, err := db.GetWorkspaceBuildMetricsByResourceID(ctx, resource.ID)
require.NoError(t, err)
require.True(t, row.AllAgentsReady)
// LastAgentReadyAt should be the parent's, not the sub-agent's.
require.True(t, parentReadyAt.Equal(row.LastAgentReadyAt))
require.Equal(t, "success", row.WorstStatus)
})
}
+220 -491
View File
@@ -1270,16 +1270,10 @@ func (q *sqlQuerier) GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNamePar
const getAPIKeysByLoginType = `-- name: GetAPIKeysByLoginType :many
SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address, token_name, scopes, allow_list FROM api_keys WHERE login_type = $1
AND ($2::bool OR expires_at > now())
`
type GetAPIKeysByLoginTypeParams struct {
LoginType LoginType `db:"login_type" json:"login_type"`
IncludeExpired bool `db:"include_expired" json:"include_expired"`
}
func (q *sqlQuerier) GetAPIKeysByLoginType(ctx context.Context, arg GetAPIKeysByLoginTypeParams) ([]APIKey, error) {
rows, err := q.db.QueryContext(ctx, getAPIKeysByLoginType, arg.LoginType, arg.IncludeExpired)
func (q *sqlQuerier) GetAPIKeysByLoginType(ctx context.Context, loginType LoginType) ([]APIKey, error) {
rows, err := q.db.QueryContext(ctx, getAPIKeysByLoginType, loginType)
if err != nil {
return nil, err
}
@@ -1317,17 +1311,15 @@ func (q *sqlQuerier) GetAPIKeysByLoginType(ctx context.Context, arg GetAPIKeysBy
const getAPIKeysByUserID = `-- name: GetAPIKeysByUserID :many
SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address, token_name, scopes, allow_list FROM api_keys WHERE login_type = $1 AND user_id = $2
AND ($3::bool OR expires_at > now())
`
type GetAPIKeysByUserIDParams struct {
LoginType LoginType `db:"login_type" json:"login_type"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
IncludeExpired bool `db:"include_expired" json:"include_expired"`
LoginType LoginType `db:"login_type" json:"login_type"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
}
func (q *sqlQuerier) GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error) {
rows, err := q.db.QueryContext(ctx, getAPIKeysByUserID, arg.LoginType, arg.UserID, arg.IncludeExpired)
rows, err := q.db.QueryContext(ctx, getAPIKeysByUserID, arg.LoginType, arg.UserID)
if err != nil {
return nil, err
}
@@ -1511,105 +1503,93 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP
}
const countAuditLogs = `-- name: CountAuditLogs :one
SELECT COUNT(*) FROM (
SELECT 1
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN $1::text != '' THEN resource_type = $1::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
ELSE true
END
-- Filter organization_id
AND CASE
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN $4::text != '' THEN resource_target = $4
ELSE true
END
-- Filter action
AND CASE
WHEN $5::text != '' THEN action = $5::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
ELSE true
END
-- Filter by username
AND CASE
WHEN $7::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower($7)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN $8::text != '' THEN users.email = $8
ELSE true
END
-- Filter by date_from
AND CASE
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
ELSE true
END
-- Filter by date_to
AND CASE
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
ELSE true
END
-- Filter request_id
AND CASE
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
-- Avoid a slow scan on a large table with joins. The caller
-- passes the count cap and we add 1 so the frontend can detect
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
-- -> NULL + 1 = NULL).
-- NOTE: Parameterizing this so that we can easily change from,
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
-- here if disabling the capping on a large table permanently.
-- This way the PG planner can plan parallel execution for
-- potential large wins.
LIMIT NULLIF($13::int, 0) + 1
) AS limited_count
SELECT COUNT(*)
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN $1::text != '' THEN resource_type = $1::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
ELSE true
END
-- Filter organization_id
AND CASE
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN $4::text != '' THEN resource_target = $4
ELSE true
END
-- Filter action
AND CASE
WHEN $5::text != '' THEN action = $5::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
ELSE true
END
-- Filter by username
AND CASE
WHEN $7::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower($7)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN $8::text != '' THEN users.email = $8
ELSE true
END
-- Filter by date_from
AND CASE
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
ELSE true
END
-- Filter by date_to
AND CASE
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
ELSE true
END
-- Filter request_id
AND CASE
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
`
type CountAuditLogsParams struct {
@@ -1625,7 +1605,6 @@ type CountAuditLogsParams struct {
DateTo time.Time `db:"date_to" json:"date_to"`
BuildReason string `db:"build_reason" json:"build_reason"`
RequestID uuid.UUID `db:"request_id" json:"request_id"`
CountCap int32 `db:"count_cap" json:"count_cap"`
}
func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error) {
@@ -1642,7 +1621,6 @@ func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParam
arg.DateTo,
arg.BuildReason,
arg.RequestID,
arg.CountCap,
)
var count int64
err := row.Scan(&count)
@@ -2121,113 +2099,110 @@ func (q *sqlQuerier) UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBou
}
const countConnectionLogs = `-- name: CountConnectionLogs :one
SELECT COUNT(*) AS count FROM (
SELECT 1
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = $1
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN $2 :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower($2) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = $3
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN $4 :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = $4 AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN $5 :: text != '' THEN
type = $5 :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = $6
ELSE true
END
-- Filter by username
AND CASE
WHEN $7 :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower($7) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN $8 :: text != '' THEN
users.email = $8
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= $9
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= $10
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = $11
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = $12
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN $13 :: text != '' THEN
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
-- NOTE: See the CountAuditLogs LIMIT note.
LIMIT NULLIF($14::int, 0) + 1
) AS limited_count
SELECT
COUNT(*) AS count
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = $1
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN $2 :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower($2) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = $3
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN $4 :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = $4 AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN $5 :: text != '' THEN
type = $5 :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = $6
ELSE true
END
-- Filter by username
AND CASE
WHEN $7 :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower($7) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN $8 :: text != '' THEN
users.email = $8
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= $9
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= $10
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = $11
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = $12
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN $13 :: text != '' THEN
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
`
type CountConnectionLogsParams struct {
@@ -2244,7 +2219,6 @@ type CountConnectionLogsParams struct {
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
ConnectionID uuid.UUID `db:"connection_id" json:"connection_id"`
Status string `db:"status" json:"status"`
CountCap int32 `db:"count_cap" json:"count_cap"`
}
func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) {
@@ -2262,7 +2236,6 @@ func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectio
arg.WorkspaceID,
arg.ConnectionID,
arg.Status,
arg.CountCap,
)
var count int64
err := row.Scan(&count)
@@ -6795,7 +6768,7 @@ func (q *sqlQuerier) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context
}
const getOAuth2ProviderAppCodeByID = `-- name: GetOAuth2ProviderAppCodeByID :one
SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method, state_hash, redirect_uri FROM oauth2_provider_app_codes WHERE id = $1
SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method FROM oauth2_provider_app_codes WHERE id = $1
`
func (q *sqlQuerier) GetOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) (OAuth2ProviderAppCode, error) {
@@ -6812,14 +6785,12 @@ func (q *sqlQuerier) GetOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.U
&i.ResourceUri,
&i.CodeChallenge,
&i.CodeChallengeMethod,
&i.StateHash,
&i.RedirectUri,
)
return i, err
}
const getOAuth2ProviderAppCodeByPrefix = `-- name: GetOAuth2ProviderAppCodeByPrefix :one
SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method, state_hash, redirect_uri FROM oauth2_provider_app_codes WHERE secret_prefix = $1
SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method FROM oauth2_provider_app_codes WHERE secret_prefix = $1
`
func (q *sqlQuerier) GetOAuth2ProviderAppCodeByPrefix(ctx context.Context, secretPrefix []byte) (OAuth2ProviderAppCode, error) {
@@ -6836,8 +6807,6 @@ func (q *sqlQuerier) GetOAuth2ProviderAppCodeByPrefix(ctx context.Context, secre
&i.ResourceUri,
&i.CodeChallenge,
&i.CodeChallengeMethod,
&i.StateHash,
&i.RedirectUri,
)
return i, err
}
@@ -7241,9 +7210,7 @@ INSERT INTO oauth2_provider_app_codes (
user_id,
resource_uri,
code_challenge,
code_challenge_method,
state_hash,
redirect_uri
code_challenge_method
) VALUES(
$1,
$2,
@@ -7254,10 +7221,8 @@ INSERT INTO oauth2_provider_app_codes (
$7,
$8,
$9,
$10,
$11,
$12
) RETURNING id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method, state_hash, redirect_uri
$10
) RETURNING id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method
`
type InsertOAuth2ProviderAppCodeParams struct {
@@ -7271,8 +7236,6 @@ type InsertOAuth2ProviderAppCodeParams struct {
ResourceUri sql.NullString `db:"resource_uri" json:"resource_uri"`
CodeChallenge sql.NullString `db:"code_challenge" json:"code_challenge"`
CodeChallengeMethod sql.NullString `db:"code_challenge_method" json:"code_challenge_method"`
StateHash sql.NullString `db:"state_hash" json:"state_hash"`
RedirectUri sql.NullString `db:"redirect_uri" json:"redirect_uri"`
}
func (q *sqlQuerier) InsertOAuth2ProviderAppCode(ctx context.Context, arg InsertOAuth2ProviderAppCodeParams) (OAuth2ProviderAppCode, error) {
@@ -7287,8 +7250,6 @@ func (q *sqlQuerier) InsertOAuth2ProviderAppCode(ctx context.Context, arg Insert
arg.ResourceUri,
arg.CodeChallenge,
arg.CodeChallengeMethod,
arg.StateHash,
arg.RedirectUri,
)
var i OAuth2ProviderAppCode
err := row.Scan(
@@ -7302,8 +7263,6 @@ func (q *sqlQuerier) InsertOAuth2ProviderAppCode(ctx context.Context, arg Insert
&i.ResourceUri,
&i.CodeChallenge,
&i.CodeChallengeMethod,
&i.StateHash,
&i.RedirectUri,
)
return i, err
}
@@ -13353,203 +13312,6 @@ func (q *sqlQuerier) GetTaskSnapshot(ctx context.Context, taskID uuid.UUID) (Tas
return i, err
}
const getTelemetryTaskEvents = `-- name: GetTelemetryTaskEvents :many
WITH task_app_ids AS (
SELECT task_id, workspace_app_id
FROM task_workspace_apps
),
task_status_timeline AS (
-- All app statuses across every historical app for each task,
-- plus synthetic "boundary" rows at each stop/start build transition.
-- This allows us to correctly take gaps due to pause/resume into account.
SELECT tai.task_id, was.created_at, was.state::text AS state
FROM workspace_app_statuses was
JOIN task_app_ids tai ON tai.workspace_app_id = was.app_id
UNION ALL
SELECT t.id AS task_id, wb.created_at, '_boundary' AS state
FROM tasks t
JOIN workspace_builds wb ON wb.workspace_id = t.workspace_id
WHERE t.deleted_at IS NULL
AND t.workspace_id IS NOT NULL
AND wb.build_number > 1
),
task_event_data AS (
SELECT
t.id AS task_id,
t.workspace_id,
twa.workspace_app_id,
-- Latest stop build.
stop_build.created_at AS stop_build_created_at,
stop_build.reason AS stop_build_reason,
-- Latest start build (task_resume only).
start_build.created_at AS start_build_created_at,
start_build.reason AS start_build_reason,
start_build.build_number AS start_build_number,
-- Last "working" app status (for idle duration).
lws.created_at AS last_working_status_at,
-- First app status after resume (for resume-to-status duration).
-- Only populated for workspaces in an active phase (started more
-- recently than stopped).
fsar.created_at AS first_status_after_resume_at,
-- Cumulative time spent in "working" state.
active_dur.total_working_ms AS active_duration_ms
FROM tasks t
LEFT JOIN LATERAL (
SELECT task_app.workspace_app_id
FROM task_workspace_apps task_app
WHERE task_app.task_id = t.id
ORDER BY task_app.workspace_build_number DESC
LIMIT 1
) twa ON TRUE
LEFT JOIN LATERAL (
SELECT wb.created_at, wb.reason, wb.build_number
FROM workspace_builds wb
WHERE wb.workspace_id = t.workspace_id
AND wb.transition = 'stop'
ORDER BY wb.build_number DESC
LIMIT 1
) stop_build ON TRUE
LEFT JOIN LATERAL (
SELECT wb.created_at, wb.reason, wb.build_number
FROM workspace_builds wb
WHERE wb.workspace_id = t.workspace_id
AND wb.transition = 'start'
ORDER BY wb.build_number DESC
LIMIT 1
) start_build ON TRUE
LEFT JOIN LATERAL (
SELECT tst.created_at
FROM task_status_timeline tst
WHERE tst.task_id = t.id
AND tst.state = 'working'
-- Only consider status before the latest pause so that
-- post-resume statuses don't mask pre-pause idle time.
AND (stop_build.created_at IS NULL
OR tst.created_at <= stop_build.created_at)
ORDER BY tst.created_at DESC
LIMIT 1
) lws ON TRUE
LEFT JOIN LATERAL (
SELECT was.created_at
FROM workspace_app_statuses was
WHERE was.app_id = twa.workspace_app_id
AND was.created_at > start_build.created_at
ORDER BY was.created_at ASC
LIMIT 1
) fsar ON twa.workspace_app_id IS NOT NULL
AND start_build.created_at IS NOT NULL
AND (stop_build.created_at IS NULL
OR start_build.created_at > stop_build.created_at)
-- Active duration: cumulative time spent in "working" state across all
-- historical app IDs for this task. Uses LEAD() to convert point-in-time
-- statuses into intervals, then sums intervals where state='working'. For
-- the last status, falls back to stop_build time (if paused) or @now (if
-- still running).
LEFT JOIN LATERAL (
SELECT COALESCE(
SUM(EXTRACT(EPOCH FROM (interval_end - interval_start)) * 1000)::bigint,
0
)::bigint AS total_working_ms
FROM (
SELECT
tst.created_at AS interval_start,
COALESCE(
LEAD(tst.created_at) OVER (ORDER BY tst.created_at ASC, CASE WHEN tst.state = '_boundary' THEN 1 ELSE 0 END ASC),
CASE WHEN stop_build.created_at IS NOT NULL
AND (start_build.created_at IS NULL
OR stop_build.created_at > start_build.created_at)
THEN stop_build.created_at
ELSE $1::timestamptz
END
) AS interval_end,
tst.state
FROM task_status_timeline tst
WHERE tst.task_id = t.id
) intervals
WHERE intervals.state = 'working'
) active_dur ON TRUE
WHERE t.deleted_at IS NULL
AND t.workspace_id IS NOT NULL
AND EXISTS (
SELECT 1 FROM workspace_builds wb
WHERE wb.workspace_id = t.workspace_id
AND wb.created_at > $2
)
)
SELECT task_id, workspace_id, workspace_app_id, stop_build_created_at, stop_build_reason, start_build_created_at, start_build_reason, start_build_number, last_working_status_at, first_status_after_resume_at, active_duration_ms FROM task_event_data
ORDER BY task_id
`
type GetTelemetryTaskEventsParams struct {
Now time.Time `db:"now" json:"now"`
CreatedAfter time.Time `db:"created_after" json:"created_after"`
}
type GetTelemetryTaskEventsRow struct {
TaskID uuid.UUID `db:"task_id" json:"task_id"`
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
WorkspaceAppID uuid.NullUUID `db:"workspace_app_id" json:"workspace_app_id"`
StopBuildCreatedAt sql.NullTime `db:"stop_build_created_at" json:"stop_build_created_at"`
StopBuildReason NullBuildReason `db:"stop_build_reason" json:"stop_build_reason"`
StartBuildCreatedAt sql.NullTime `db:"start_build_created_at" json:"start_build_created_at"`
StartBuildReason NullBuildReason `db:"start_build_reason" json:"start_build_reason"`
StartBuildNumber sql.NullInt32 `db:"start_build_number" json:"start_build_number"`
LastWorkingStatusAt sql.NullTime `db:"last_working_status_at" json:"last_working_status_at"`
FirstStatusAfterResumeAt sql.NullTime `db:"first_status_after_resume_at" json:"first_status_after_resume_at"`
ActiveDurationMs int64 `db:"active_duration_ms" json:"active_duration_ms"`
}
// Returns all data needed to build task lifecycle events for telemetry
// in a single round-trip. For each task whose workspace is in the
// given set, fetches:
// - the latest workspace app binding (task_workspace_apps)
// - the most recent stop and start builds (workspace_builds)
// - the last "working" app status (workspace_app_statuses)
// - the first app status after resume, for active workspaces
//
// Assumptions:
// - 1:1 relationship between tasks and workspaces. All builds on the
// workspace are considered task-related.
// - Idle duration approximation: If the agent reports "working", does
// work, then reports "done", we miss that working time.
// - lws and active_dur join across all historical app IDs for the task,
// because each resume cycle provisions a new app ID. This ensures
// pre-pause statuses contribute to idle duration and active duration.
func (q *sqlQuerier) GetTelemetryTaskEvents(ctx context.Context, arg GetTelemetryTaskEventsParams) ([]GetTelemetryTaskEventsRow, error) {
rows, err := q.db.QueryContext(ctx, getTelemetryTaskEvents, arg.Now, arg.CreatedAfter)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetTelemetryTaskEventsRow
for rows.Next() {
var i GetTelemetryTaskEventsRow
if err := rows.Scan(
&i.TaskID,
&i.WorkspaceID,
&i.WorkspaceAppID,
&i.StopBuildCreatedAt,
&i.StopBuildReason,
&i.StartBuildCreatedAt,
&i.StartBuildReason,
&i.StartBuildNumber,
&i.LastWorkingStatusAt,
&i.FirstStatusAfterResumeAt,
&i.ActiveDurationMs,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertTask = `-- name: InsertTask :one
INSERT INTO tasks
(id, organization_id, owner_id, name, display_name, workspace_id, template_version_id, template_parameters, prompt, created_at)
@@ -18179,7 +17941,7 @@ const getAuthenticatedWorkspaceAgentAndBuildByAuthToken = `-- name: GetAuthentic
SELECT
workspaces.id, workspaces.created_at, workspaces.updated_at, workspaces.owner_id, workspaces.organization_id, workspaces.template_id, workspaces.deleted, workspaces.name, workspaces.autostart_schedule, workspaces.ttl, workspaces.last_used_at, workspaces.dormant_at, workspaces.deleting_at, workspaces.automatic_updates, workspaces.favorite, workspaces.next_start_at, workspaces.group_acl, workspaces.user_acl,
workspace_agents.id, workspace_agents.created_at, workspace_agents.updated_at, workspace_agents.name, workspace_agents.first_connected_at, workspace_agents.last_connected_at, workspace_agents.disconnected_at, workspace_agents.resource_id, workspace_agents.auth_token, workspace_agents.auth_instance_id, workspace_agents.architecture, workspace_agents.environment_variables, workspace_agents.operating_system, workspace_agents.instance_metadata, workspace_agents.resource_metadata, workspace_agents.directory, workspace_agents.version, workspace_agents.last_connected_replica_id, workspace_agents.connection_timeout_seconds, workspace_agents.troubleshooting_url, workspace_agents.motd_file, workspace_agents.lifecycle_state, workspace_agents.expanded_directory, workspace_agents.logs_length, workspace_agents.logs_overflowed, workspace_agents.started_at, workspace_agents.ready_at, workspace_agents.subsystems, workspace_agents.display_apps, workspace_agents.api_version, workspace_agents.display_order, workspace_agents.parent_id, workspace_agents.api_key_scope, workspace_agents.deleted,
workspace_build_with_user.id, workspace_build_with_user.created_at, workspace_build_with_user.updated_at, workspace_build_with_user.workspace_id, workspace_build_with_user.template_version_id, workspace_build_with_user.build_number, workspace_build_with_user.transition, workspace_build_with_user.initiator_id, workspace_build_with_user.job_id, workspace_build_with_user.deadline, workspace_build_with_user.reason, workspace_build_with_user.daily_cost, workspace_build_with_user.max_deadline, workspace_build_with_user.template_version_preset_id, workspace_build_with_user.has_ai_task, workspace_build_with_user.has_external_agent, workspace_build_with_user.initiator_by_avatar_url, workspace_build_with_user.initiator_by_username, workspace_build_with_user.initiator_by_name,
workspace_build_with_user.id, workspace_build_with_user.created_at, workspace_build_with_user.updated_at, workspace_build_with_user.workspace_id, workspace_build_with_user.template_version_id, workspace_build_with_user.build_number, workspace_build_with_user.transition, workspace_build_with_user.initiator_id, workspace_build_with_user.provisioner_state, workspace_build_with_user.job_id, workspace_build_with_user.deadline, workspace_build_with_user.reason, workspace_build_with_user.daily_cost, workspace_build_with_user.max_deadline, workspace_build_with_user.template_version_preset_id, workspace_build_with_user.has_ai_task, workspace_build_with_user.has_external_agent, workspace_build_with_user.initiator_by_avatar_url, workspace_build_with_user.initiator_by_username, workspace_build_with_user.initiator_by_name,
tasks.id AS task_id
FROM
workspace_agents
@@ -18317,6 +18079,7 @@ func (q *sqlQuerier) GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx conte
&i.WorkspaceBuild.BuildNumber,
&i.WorkspaceBuild.Transition,
&i.WorkspaceBuild.InitiatorID,
&i.WorkspaceBuild.ProvisionerState,
&i.WorkspaceBuild.JobID,
&i.WorkspaceBuild.Deadline,
&i.WorkspaceBuild.Reason,
@@ -21230,7 +20993,7 @@ func (q *sqlQuerier) InsertWorkspaceBuildParameters(ctx context.Context, arg Ins
}
const getActiveWorkspaceBuildsByTemplateID = `-- name: GetActiveWorkspaceBuildsByTemplateID :many
SELECT wb.id, wb.created_at, wb.updated_at, wb.workspace_id, wb.template_version_id, wb.build_number, wb.transition, wb.initiator_id, wb.job_id, wb.deadline, wb.reason, wb.daily_cost, wb.max_deadline, wb.template_version_preset_id, wb.has_ai_task, wb.has_external_agent, wb.initiator_by_avatar_url, wb.initiator_by_username, wb.initiator_by_name
SELECT wb.id, wb.created_at, wb.updated_at, wb.workspace_id, wb.template_version_id, wb.build_number, wb.transition, wb.initiator_id, wb.provisioner_state, wb.job_id, wb.deadline, wb.reason, wb.daily_cost, wb.max_deadline, wb.template_version_preset_id, wb.has_ai_task, wb.has_external_agent, wb.initiator_by_avatar_url, wb.initiator_by_username, wb.initiator_by_name
FROM (
SELECT
workspace_id, MAX(build_number) as max_build_number
@@ -21278,6 +21041,7 @@ func (q *sqlQuerier) GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, t
&i.BuildNumber,
&i.Transition,
&i.InitiatorID,
&i.ProvisionerState,
&i.JobID,
&i.Deadline,
&i.Reason,
@@ -21385,7 +21149,7 @@ func (q *sqlQuerier) GetFailedWorkspaceBuildsByTemplateID(ctx context.Context, a
const getLatestWorkspaceBuildByWorkspaceID = `-- name: GetLatestWorkspaceBuildByWorkspaceID :one
SELECT
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
FROM
workspace_build_with_user AS workspace_builds
WHERE
@@ -21408,6 +21172,7 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, w
&i.BuildNumber,
&i.Transition,
&i.InitiatorID,
&i.ProvisionerState,
&i.JobID,
&i.Deadline,
&i.Reason,
@@ -21426,7 +21191,7 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, w
const getLatestWorkspaceBuildsByWorkspaceIDs = `-- name: GetLatestWorkspaceBuildsByWorkspaceIDs :many
SELECT
DISTINCT ON (workspace_id)
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
FROM
workspace_build_with_user AS workspace_builds
WHERE
@@ -21453,6 +21218,7 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context,
&i.BuildNumber,
&i.Transition,
&i.InitiatorID,
&i.ProvisionerState,
&i.JobID,
&i.Deadline,
&i.Reason,
@@ -21480,7 +21246,7 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context,
const getWorkspaceBuildByID = `-- name: GetWorkspaceBuildByID :one
SELECT
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
FROM
workspace_build_with_user AS workspace_builds
WHERE
@@ -21501,6 +21267,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (W
&i.BuildNumber,
&i.Transition,
&i.InitiatorID,
&i.ProvisionerState,
&i.JobID,
&i.Deadline,
&i.Reason,
@@ -21518,7 +21285,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (W
const getWorkspaceBuildByJobID = `-- name: GetWorkspaceBuildByJobID :one
SELECT
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
FROM
workspace_build_with_user AS workspace_builds
WHERE
@@ -21539,6 +21306,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByJobID(ctx context.Context, jobID uuid.UU
&i.BuildNumber,
&i.Transition,
&i.InitiatorID,
&i.ProvisionerState,
&i.JobID,
&i.Deadline,
&i.Reason,
@@ -21556,7 +21324,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByJobID(ctx context.Context, jobID uuid.UU
const getWorkspaceBuildByWorkspaceIDAndBuildNumber = `-- name: GetWorkspaceBuildByWorkspaceIDAndBuildNumber :one
SELECT
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
FROM
workspace_build_with_user AS workspace_builds
WHERE
@@ -21581,6 +21349,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx context.Co
&i.BuildNumber,
&i.Transition,
&i.InitiatorID,
&i.ProvisionerState,
&i.JobID,
&i.Deadline,
&i.Reason,
@@ -21618,7 +21387,7 @@ JOIN workspaces w ON wb.workspace_id = w.id
JOIN templates t ON w.template_id = t.id
JOIN organizations o ON t.organization_id = o.id
JOIN workspace_resources wr ON wr.job_id = wb.job_id
JOIN workspace_agents wa ON wa.resource_id = wr.id AND wa.parent_id IS NULL
JOIN workspace_agents wa ON wa.resource_id = wr.id
WHERE wb.job_id = (SELECT job_id FROM workspace_resources WHERE workspace_resources.id = $1)
GROUP BY wb.created_at, wb.transition, t.name, o.name, w.owner_id
`
@@ -21652,48 +21421,6 @@ func (q *sqlQuerier) GetWorkspaceBuildMetricsByResourceID(ctx context.Context, i
return i, err
}
const getWorkspaceBuildProvisionerStateByID = `-- name: GetWorkspaceBuildProvisionerStateByID :one
SELECT
workspace_builds.provisioner_state,
templates.id AS template_id,
templates.organization_id AS template_organization_id,
templates.user_acl,
templates.group_acl
FROM
workspace_builds
INNER JOIN
workspaces ON workspaces.id = workspace_builds.workspace_id
INNER JOIN
templates ON templates.id = workspaces.template_id
WHERE
workspace_builds.id = $1
`
type GetWorkspaceBuildProvisionerStateByIDRow struct {
ProvisionerState []byte `db:"provisioner_state" json:"provisioner_state"`
TemplateID uuid.UUID `db:"template_id" json:"template_id"`
TemplateOrganizationID uuid.UUID `db:"template_organization_id" json:"template_organization_id"`
UserACL TemplateACL `db:"user_acl" json:"user_acl"`
GroupACL TemplateACL `db:"group_acl" json:"group_acl"`
}
// Fetches the provisioner state of a workspace build, joined through to the
// template so that dbauthz can enforce policy.ActionUpdate on the template.
// Provisioner state contains sensitive Terraform state and should only be
// accessible to template administrators.
func (q *sqlQuerier) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (GetWorkspaceBuildProvisionerStateByIDRow, error) {
row := q.db.QueryRowContext(ctx, getWorkspaceBuildProvisionerStateByID, workspaceBuildID)
var i GetWorkspaceBuildProvisionerStateByIDRow
err := row.Scan(
&i.ProvisionerState,
&i.TemplateID,
&i.TemplateOrganizationID,
&i.UserACL,
&i.GroupACL,
)
return i, err
}
const getWorkspaceBuildStatsByTemplates = `-- name: GetWorkspaceBuildStatsByTemplates :many
SELECT
w.template_id,
@@ -21763,7 +21490,7 @@ func (q *sqlQuerier) GetWorkspaceBuildStatsByTemplates(ctx context.Context, sinc
const getWorkspaceBuildsByWorkspaceID = `-- name: GetWorkspaceBuildsByWorkspaceID :many
SELECT
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
FROM
workspace_build_with_user AS workspace_builds
WHERE
@@ -21827,6 +21554,7 @@ func (q *sqlQuerier) GetWorkspaceBuildsByWorkspaceID(ctx context.Context, arg Ge
&i.BuildNumber,
&i.Transition,
&i.InitiatorID,
&i.ProvisionerState,
&i.JobID,
&i.Deadline,
&i.Reason,
@@ -21853,7 +21581,7 @@ func (q *sqlQuerier) GetWorkspaceBuildsByWorkspaceID(ctx context.Context, arg Ge
}
const getWorkspaceBuildsCreatedAfter = `-- name: GetWorkspaceBuildsCreatedAfter :many
SELECT id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name FROM workspace_build_with_user WHERE created_at > $1
SELECT id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name FROM workspace_build_with_user WHERE created_at > $1
`
func (q *sqlQuerier) GetWorkspaceBuildsCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceBuild, error) {
@@ -21874,6 +21602,7 @@ func (q *sqlQuerier) GetWorkspaceBuildsCreatedAfter(ctx context.Context, created
&i.BuildNumber,
&i.Transition,
&i.InitiatorID,
&i.ProvisionerState,
&i.JobID,
&i.Deadline,
&i.Reason,
+2 -4
View File
@@ -25,12 +25,10 @@ LIMIT
SELECT * FROM api_keys WHERE last_used > $1;
-- name: GetAPIKeysByLoginType :many
SELECT * FROM api_keys WHERE login_type = $1
AND (@include_expired::bool OR expires_at > now());
SELECT * FROM api_keys WHERE login_type = $1;
-- name: GetAPIKeysByUserID :many
SELECT * FROM api_keys WHERE login_type = $1 AND user_id = $2
AND (@include_expired::bool OR expires_at > now());
SELECT * FROM api_keys WHERE login_type = $1 AND user_id = $2;
-- name: InsertAPIKey :one
INSERT INTO
+88 -99
View File
@@ -149,105 +149,94 @@ VALUES (
RETURNING *;
-- name: CountAuditLogs :one
SELECT COUNT(*) FROM (
SELECT 1
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
ELSE true
END
-- Filter organization_id
AND CASE
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN @resource_target::text != '' THEN resource_target = @resource_target
ELSE true
END
-- Filter action
AND CASE
WHEN @action::text != '' THEN action = @action::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower(@username)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @email::text != '' THEN users.email = @email
ELSE true
END
-- Filter by date_from
AND CASE
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
ELSE true
END
-- Filter by date_to
AND CASE
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
ELSE true
END
-- Filter request_id
AND CASE
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
-- Avoid a slow scan on a large table with joins. The caller
-- passes the count cap and we add 1 so the frontend can detect
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
-- -> NULL + 1 = NULL).
-- NOTE: Parameterizing this so that we can easily change from,
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
-- here if disabling the capping on a large table permanently.
-- This way the PG planner can plan parallel execution for
-- potential large wins.
LIMIT NULLIF(@count_cap::int, 0) + 1
) AS limited_count;
SELECT COUNT(*)
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
ELSE true
END
-- Filter organization_id
AND CASE
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN @resource_target::text != '' THEN resource_target = @resource_target
ELSE true
END
-- Filter action
AND CASE
WHEN @action::text != '' THEN action = @action::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower(@username)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @email::text != '' THEN users.email = @email
ELSE true
END
-- Filter by date_from
AND CASE
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
ELSE true
END
-- Filter by date_to
AND CASE
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
ELSE true
END
-- Filter request_id
AND CASE
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
;
-- name: DeleteOldAuditLogConnectionEvents :exec
DELETE FROM audit_logs
+105 -107
View File
@@ -133,113 +133,111 @@ OFFSET
@offset_opt;
-- name: CountConnectionLogs :one
SELECT COUNT(*) AS count FROM (
SELECT 1
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = @organization_id
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN @workspace_owner :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = @workspace_owner_id
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN @workspace_owner_email :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = @workspace_owner_email AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN @type :: text != '' THEN
type = @type :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower(@username) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @user_email :: text != '' THEN
users.email = @user_email
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= @connected_after
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= @connected_before
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = @workspace_id
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = @connection_id
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN @status :: text != '' THEN
((@status = 'ongoing' AND disconnect_time IS NULL) OR
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
-- NOTE: See the CountAuditLogs LIMIT note.
LIMIT NULLIF(@count_cap::int, 0) + 1
) AS limited_count;
SELECT
COUNT(*) AS count
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = @organization_id
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN @workspace_owner :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = @workspace_owner_id
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN @workspace_owner_email :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = @workspace_owner_email AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN @type :: text != '' THEN
type = @type :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower(@username) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @user_email :: text != '' THEN
users.email = @user_email
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= @connected_after
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= @connected_before
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = @workspace_id
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = @connection_id
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN @status :: text != '' THEN
((@status = 'ongoing' AND disconnect_time IS NULL) OR
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
;
-- name: DeleteOldConnectionLogs :execrows
WITH old_logs AS (
+2 -6
View File
@@ -140,9 +140,7 @@ INSERT INTO oauth2_provider_app_codes (
user_id,
resource_uri,
code_challenge,
code_challenge_method,
state_hash,
redirect_uri
code_challenge_method
) VALUES(
$1,
$2,
@@ -153,9 +151,7 @@ INSERT INTO oauth2_provider_app_codes (
$7,
$8,
$9,
$10,
$11,
$12
$10
) RETURNING *;
-- name: DeleteOAuth2ProviderAppCodeByID :exec
-143
View File
@@ -100,146 +100,3 @@ FROM
task_snapshots
WHERE
task_id = $1;
-- name: GetTelemetryTaskEvents :many
-- Returns all data needed to build task lifecycle events for telemetry
-- in a single round-trip. For each task whose workspace is in the
-- given set, fetches:
-- - the latest workspace app binding (task_workspace_apps)
-- - the most recent stop and start builds (workspace_builds)
-- - the last "working" app status (workspace_app_statuses)
-- - the first app status after resume, for active workspaces
--
-- Assumptions:
-- - 1:1 relationship between tasks and workspaces. All builds on the
-- workspace are considered task-related.
-- - Idle duration approximation: If the agent reports "working", does
-- work, then reports "done", we miss that working time.
-- - lws and active_dur join across all historical app IDs for the task,
-- because each resume cycle provisions a new app ID. This ensures
-- pre-pause statuses contribute to idle duration and active duration.
WITH task_app_ids AS (
SELECT task_id, workspace_app_id
FROM task_workspace_apps
),
task_status_timeline AS (
-- All app statuses across every historical app for each task,
-- plus synthetic "boundary" rows at each stop/start build transition.
-- This allows us to correctly take gaps due to pause/resume into account.
SELECT tai.task_id, was.created_at, was.state::text AS state
FROM workspace_app_statuses was
JOIN task_app_ids tai ON tai.workspace_app_id = was.app_id
UNION ALL
SELECT t.id AS task_id, wb.created_at, '_boundary' AS state
FROM tasks t
JOIN workspace_builds wb ON wb.workspace_id = t.workspace_id
WHERE t.deleted_at IS NULL
AND t.workspace_id IS NOT NULL
AND wb.build_number > 1
),
task_event_data AS (
SELECT
t.id AS task_id,
t.workspace_id,
twa.workspace_app_id,
-- Latest stop build.
stop_build.created_at AS stop_build_created_at,
stop_build.reason AS stop_build_reason,
-- Latest start build (task_resume only).
start_build.created_at AS start_build_created_at,
start_build.reason AS start_build_reason,
start_build.build_number AS start_build_number,
-- Last "working" app status (for idle duration).
lws.created_at AS last_working_status_at,
-- First app status after resume (for resume-to-status duration).
-- Only populated for workspaces in an active phase (started more
-- recently than stopped).
fsar.created_at AS first_status_after_resume_at,
-- Cumulative time spent in "working" state.
active_dur.total_working_ms AS active_duration_ms
FROM tasks t
LEFT JOIN LATERAL (
SELECT task_app.workspace_app_id
FROM task_workspace_apps task_app
WHERE task_app.task_id = t.id
ORDER BY task_app.workspace_build_number DESC
LIMIT 1
) twa ON TRUE
LEFT JOIN LATERAL (
SELECT wb.created_at, wb.reason, wb.build_number
FROM workspace_builds wb
WHERE wb.workspace_id = t.workspace_id
AND wb.transition = 'stop'
ORDER BY wb.build_number DESC
LIMIT 1
) stop_build ON TRUE
LEFT JOIN LATERAL (
SELECT wb.created_at, wb.reason, wb.build_number
FROM workspace_builds wb
WHERE wb.workspace_id = t.workspace_id
AND wb.transition = 'start'
ORDER BY wb.build_number DESC
LIMIT 1
) start_build ON TRUE
LEFT JOIN LATERAL (
SELECT tst.created_at
FROM task_status_timeline tst
WHERE tst.task_id = t.id
AND tst.state = 'working'
-- Only consider status before the latest pause so that
-- post-resume statuses don't mask pre-pause idle time.
AND (stop_build.created_at IS NULL
OR tst.created_at <= stop_build.created_at)
ORDER BY tst.created_at DESC
LIMIT 1
) lws ON TRUE
LEFT JOIN LATERAL (
SELECT was.created_at
FROM workspace_app_statuses was
WHERE was.app_id = twa.workspace_app_id
AND was.created_at > start_build.created_at
ORDER BY was.created_at ASC
LIMIT 1
) fsar ON twa.workspace_app_id IS NOT NULL
AND start_build.created_at IS NOT NULL
AND (stop_build.created_at IS NULL
OR start_build.created_at > stop_build.created_at)
-- Active duration: cumulative time spent in "working" state across all
-- historical app IDs for this task. Uses LEAD() to convert point-in-time
-- statuses into intervals, then sums intervals where state='working'. For
-- the last status, falls back to stop_build time (if paused) or @now (if
-- still running).
LEFT JOIN LATERAL (
SELECT COALESCE(
SUM(EXTRACT(EPOCH FROM (interval_end - interval_start)) * 1000)::bigint,
0
)::bigint AS total_working_ms
FROM (
SELECT
tst.created_at AS interval_start,
COALESCE(
LEAD(tst.created_at) OVER (ORDER BY tst.created_at ASC, CASE WHEN tst.state = '_boundary' THEN 1 ELSE 0 END ASC),
CASE WHEN stop_build.created_at IS NOT NULL
AND (start_build.created_at IS NULL
OR stop_build.created_at > start_build.created_at)
THEN stop_build.created_at
ELSE @now::timestamptz
END
) AS interval_end,
tst.state
FROM task_status_timeline tst
WHERE tst.task_id = t.id
) intervals
WHERE intervals.state = 'working'
) active_dur ON TRUE
WHERE t.deleted_at IS NULL
AND t.workspace_id IS NOT NULL
AND EXISTS (
SELECT 1 FROM workspace_builds wb
WHERE wb.workspace_id = t.workspace_id
AND wb.created_at > @created_after
)
)
SELECT * FROM task_event_data
ORDER BY task_id;
@@ -87,4 +87,3 @@ SELECT DISTINCT ON (workspace_id)
FROM workspace_app_statuses
WHERE workspace_id = ANY(@ids :: uuid[])
ORDER BY workspace_id, created_at DESC;
+1 -21
View File
@@ -268,26 +268,6 @@ JOIN workspaces w ON wb.workspace_id = w.id
JOIN templates t ON w.template_id = t.id
JOIN organizations o ON t.organization_id = o.id
JOIN workspace_resources wr ON wr.job_id = wb.job_id
JOIN workspace_agents wa ON wa.resource_id = wr.id AND wa.parent_id IS NULL
JOIN workspace_agents wa ON wa.resource_id = wr.id
WHERE wb.job_id = (SELECT job_id FROM workspace_resources WHERE workspace_resources.id = $1)
GROUP BY wb.created_at, wb.transition, t.name, o.name, w.owner_id;
-- name: GetWorkspaceBuildProvisionerStateByID :one
-- Fetches the provisioner state of a workspace build, joined through to the
-- template so that dbauthz can enforce policy.ActionUpdate on the template.
-- Provisioner state contains sensitive Terraform state and should only be
-- accessible to template administrators.
SELECT
workspace_builds.provisioner_state,
templates.id AS template_id,
templates.organization_id AS template_organization_id,
templates.user_acl,
templates.group_acl
FROM
workspace_builds
INNER JOIN
workspaces ON workspaces.id = workspace_builds.workspace_id
INNER JOIN
templates ON templates.id = workspaces.template_id
WHERE
workspace_builds.id = @workspace_build_id;
-18
View File
@@ -124,24 +124,6 @@ sql:
- column: "tasks_with_status.workspace_app_health"
go_type:
type: "NullWorkspaceAppHealth"
# Workaround for sqlc not interpreting the left join correctly
# in the combined telemetry query.
- column: "task_event_data.start_build_number"
go_type: "database/sql.NullInt32"
- column: "task_event_data.stop_build_created_at"
go_type: "database/sql.NullTime"
- column: "task_event_data.stop_build_reason"
go_type:
type: "NullBuildReason"
- column: "task_event_data.start_build_created_at"
go_type: "database/sql.NullTime"
- column: "task_event_data.start_build_reason"
go_type:
type: "NullBuildReason"
- column: "task_event_data.last_working_status_at"
go_type: "database/sql.NullTime"
- column: "task_event_data.first_status_after_resume_at"
go_type: "database/sql.NullTime"
rename:
group_member: GroupMemberTable
group_members_expanded: GroupMember
+4 -3
View File
@@ -228,11 +228,12 @@ func (p *QueryParamParser) RedirectURL(vals url.Values, base *url.URL, queryPara
})
}
// OAuth 2.1 requires exact redirect URI matching.
if v.String() != base.String() {
// It can be a sub-directory but not a sub-domain, as we have apps on
// sub-domains and that seems too dangerous.
if v.Host != base.Host || !strings.HasPrefix(v.Path, base.Path) {
p.Errors = append(p.Errors, codersdk.ValidationError{
Field: queryParam,
Detail: fmt.Sprintf("Query param %q must exactly match %s", queryParam, base),
Detail: fmt.Sprintf("Query param %q must be a subset of %s", queryParam, base),
})
}
+2 -6
View File
@@ -62,12 +62,8 @@ func CSRF(cookieCfg codersdk.HTTPCookieConfig) func(next http.Handler) http.Hand
mw.ExemptRegexp(regexp.MustCompile("/organizations/[^/]+/provisionerdaemons/*"))
mw.ExemptFunc(func(r *http.Request) bool {
// Enforce CSRF on API routes and the OAuth2 authorize
// endpoint. The authorize endpoint serves a browser consent
// form whose POST must be CSRF-protected to prevent
// cross-site authorization code theft (coder/security#121).
if !strings.HasPrefix(r.URL.Path, "/api") &&
!strings.HasPrefix(r.URL.Path, "/oauth2/authorize") {
// Only enforce CSRF on API routes.
if !strings.HasPrefix(r.URL.Path, "/api") {
return true
}
-20
View File
@@ -51,26 +51,6 @@ func TestCSRFExemptList(t *testing.T) {
URL: "https://coder.com/api/v2/me",
Exempt: false,
},
{
Name: "OAuth2Authorize",
URL: "https://coder.com/oauth2/authorize",
Exempt: false,
},
{
Name: "OAuth2AuthorizeQuery",
URL: "https://coder.com/oauth2/authorize?client_id=test",
Exempt: false,
},
{
Name: "OAuth2Tokens",
URL: "https://coder.com/oauth2/tokens",
Exempt: true,
},
{
Name: "OAuth2Register",
URL: "https://coder.com/oauth2/register",
Exempt: true,
},
}
mw := httpmw.CSRF(codersdk.HTTPCookieConfig{})
+3 -11
View File
@@ -348,12 +348,8 @@ func reapJob(ctx context.Context, log slog.Logger, db database.Store, pub pubsub
// Only copy the provisioner state if there's no state in
// the current build.
currentStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
if err != nil {
return xerrors.Errorf("get workspace build provisioner state: %w", err)
}
if len(currentStateRow.ProvisionerState) == 0 {
// Get the previous build's state if it exists.
if len(build.ProvisionerState) == 0 {
// Get the previous build if it exists.
prevBuild, err := db.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
WorkspaceID: build.WorkspaceID,
BuildNumber: build.BuildNumber - 1,
@@ -362,14 +358,10 @@ func reapJob(ctx context.Context, log slog.Logger, db database.Store, pub pubsub
return xerrors.Errorf("get previous workspace build: %w", err)
}
if err == nil {
prevStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, prevBuild.ID)
if err != nil {
return xerrors.Errorf("get previous workspace build provisioner state: %w", err)
}
err = db.UpdateWorkspaceBuildProvisionerStateByID(ctx, database.UpdateWorkspaceBuildProvisionerStateByIDParams{
ID: build.ID,
UpdatedAt: dbtime.Now(),
ProvisionerState: prevStateRow.ProvisionerState,
ProvisionerState: prevBuild.ProvisionerState,
})
if err != nil {
return xerrors.Errorf("update workspace build by id: %w", err)
+22 -29
View File
@@ -126,9 +126,9 @@ func TestDetectorHungWorkspaceBuild(t *testing.T) {
previousBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: org.ID,
OwnerID: user.ID,
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
ProvisionerState(expectedWorkspaceBuildState).
Succeeded(dbfake.WithJobCompletedAt(twentyMinAgo)).
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
ProvisionerState: expectedWorkspaceBuildState,
}).Succeeded(dbfake.WithJobCompletedAt(twentyMinAgo)).
Do()
// Current build (hung - running job with UpdatedAt > 5 min ago).
@@ -163,9 +163,7 @@ func TestDetectorHungWorkspaceBuild(t *testing.T) {
// Check that the provisioner state was copied.
build, err := db.GetWorkspaceBuildByID(ctx, currentBuild.Build.ID)
require.NoError(t, err)
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
require.NoError(t, err)
require.Equal(t, expectedWorkspaceBuildState, provisionerStateRow.ProvisionerState)
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
detector.Close()
detector.Wait()
@@ -196,9 +194,9 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
previousBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: org.ID,
OwnerID: user.ID,
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
ProvisionerState([]byte(`{"dean":"NOT cool","colin":"also NOT cool"}`)).
Succeeded(dbfake.WithJobCompletedAt(twentyMinAgo)).
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
ProvisionerState: []byte(`{"dean":"NOT cool","colin":"also NOT cool"}`),
}).Succeeded(dbfake.WithJobCompletedAt(twentyMinAgo)).
Do()
// Current build (hung - running job with UpdatedAt > 5 min ago).
@@ -206,8 +204,9 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
currentBuild := dbfake.WorkspaceBuild(t, db, previousBuild.Workspace).
Pubsub(pubsub).
Seed(database.WorkspaceBuild{
BuildNumber: 2,
}).ProvisionerState(expectedWorkspaceBuildState).
BuildNumber: 2,
ProvisionerState: expectedWorkspaceBuildState,
}).
Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
Do()
@@ -236,9 +235,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
// Check that the provisioner state was NOT copied.
build, err := db.GetWorkspaceBuildByID(ctx, currentBuild.Build.ID)
require.NoError(t, err)
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
require.NoError(t, err)
require.Equal(t, expectedWorkspaceBuildState, provisionerStateRow.ProvisionerState)
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
detector.Close()
detector.Wait()
@@ -269,9 +266,9 @@ func TestDetectorHungWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testing.T
currentBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: org.ID,
OwnerID: user.ID,
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
ProvisionerState(expectedWorkspaceBuildState).
Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
ProvisionerState: expectedWorkspaceBuildState,
}).Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
Do()
t.Log("current job ID: ", currentBuild.Build.JobID)
@@ -298,9 +295,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testing.T
// Check that the provisioner state was NOT updated.
build, err := db.GetWorkspaceBuildByID(ctx, currentBuild.Build.ID)
require.NoError(t, err)
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
require.NoError(t, err)
require.Equal(t, expectedWorkspaceBuildState, provisionerStateRow.ProvisionerState)
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
detector.Close()
detector.Wait()
@@ -330,9 +325,9 @@ func TestDetectorPendingWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testin
currentBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: org.ID,
OwnerID: user.ID,
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
ProvisionerState(expectedWorkspaceBuildState).
Pending(dbfake.WithJobCreatedAt(thirtyFiveMinAgo), dbfake.WithJobUpdatedAt(thirtyFiveMinAgo)).
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
ProvisionerState: expectedWorkspaceBuildState,
}).Pending(dbfake.WithJobCreatedAt(thirtyFiveMinAgo), dbfake.WithJobUpdatedAt(thirtyFiveMinAgo)).
Do()
t.Log("current job ID: ", currentBuild.Build.JobID)
@@ -361,9 +356,7 @@ func TestDetectorPendingWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testin
// Check that the provisioner state was NOT updated.
build, err := db.GetWorkspaceBuildByID(ctx, currentBuild.Build.ID)
require.NoError(t, err)
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
require.NoError(t, err)
require.Equal(t, expectedWorkspaceBuildState, provisionerStateRow.ProvisionerState)
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
detector.Close()
detector.Wait()
@@ -405,9 +398,9 @@ func TestDetectorWorkspaceBuildForDormantWorkspace(t *testing.T) {
Time: now.Add(-time.Hour),
Valid: true,
},
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
ProvisionerState(expectedWorkspaceBuildState).
Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
ProvisionerState: expectedWorkspaceBuildState,
}).Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
Do()
t.Log("current job ID: ", currentBuild.Build.JobID)
+22 -40
View File
@@ -2,8 +2,6 @@ package mcp_test
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
@@ -12,7 +10,6 @@ import (
"strings"
"testing"
"github.com/google/uuid"
mcpclient "github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
@@ -27,15 +24,6 @@ import (
"github.com/coder/coder/v2/testutil"
)
// mcpGeneratePKCE creates a PKCE verifier and S256 challenge for MCP
// e2e tests.
func mcpGeneratePKCE() (verifier, challenge string) {
verifier = uuid.NewString() + uuid.NewString()
h := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(h[:])
return verifier, challenge
}
func TestMCPHTTP_E2E_ClientIntegration(t *testing.T) {
t.Parallel()
@@ -565,32 +553,31 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
// In a real flow, this would be done through the browser consent page
// For testing, we'll create the code directly using the internal API
// First, we need to authorize the app (simulating user consent).
staticVerifier, staticChallenge := mcpGeneratePKCE()
authURL := fmt.Sprintf("%s/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&state=test_state&code_challenge=%s&code_challenge_method=S256",
api.AccessURL.String(), app.ID, "http://localhost:3000/callback", staticChallenge)
// First, we need to authorize the app (simulating user consent)
authURL := fmt.Sprintf("%s/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&state=test_state",
api.AccessURL.String(), app.ID, "http://localhost:3000/callback")
// Create an HTTP client that follows redirects but captures the final redirect.
// Create an HTTP client that follows redirects but captures the final redirect
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse // Stop following redirects
},
}
// Make the authorization request (this would normally be done in a browser).
// Make the authorization request (this would normally be done in a browser)
req, err := http.NewRequestWithContext(ctx, "GET", authURL, nil)
require.NoError(t, err)
// Use RFC 6750 Bearer token for authentication.
// Use RFC 6750 Bearer token for authentication
req.Header.Set("Authorization", "Bearer "+coderClient.SessionToken())
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
// The response should be a redirect to the consent page or directly to callback.
// For testing purposes, let's simulate the POST consent approval.
// The response should be a redirect to the consent page or directly to callback
// For testing purposes, let's simulate the POST consent approval
if resp.StatusCode == http.StatusOK {
// This means we got the consent page, now we need to POST consent.
// This means we got the consent page, now we need to POST consent
consentReq, err := http.NewRequestWithContext(ctx, "POST", authURL, nil)
require.NoError(t, err)
consentReq.Header.Set("Authorization", "Bearer "+coderClient.SessionToken())
@@ -601,7 +588,7 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
defer resp.Body.Close()
}
// Extract authorization code from redirect URL.
// Extract authorization code from redirect URL
require.True(t, resp.StatusCode >= 300 && resp.StatusCode < 400, "Expected redirect response")
location := resp.Header.Get("Location")
require.NotEmpty(t, location, "Expected Location header in redirect")
@@ -613,14 +600,13 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
t.Logf("Successfully obtained authorization code: %s", authCode[:10]+"...")
// Step 2: Exchange authorization code for access token and refresh token.
// Step 2: Exchange authorization code for access token and refresh token
tokenRequestBody := url.Values{
"grant_type": {"authorization_code"},
"client_id": {app.ID.String()},
"client_secret": {secret.ClientSecretFull},
"code": {authCode},
"redirect_uri": {"http://localhost:3000/callback"},
"code_verifier": {staticVerifier},
}
tokenReq, err := http.NewRequestWithContext(ctx, "POST", api.AccessURL.String()+"/oauth2/tokens",
@@ -882,44 +868,41 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
t.Logf("Successfully registered dynamic client: %s", clientID)
// Step 3: Perform OAuth2 authorization code flow with dynamically registered client.
dynamicVerifier, dynamicChallenge := mcpGeneratePKCE()
authURL := fmt.Sprintf("%s/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&state=dynamic_state&code_challenge=%s&code_challenge_method=S256",
api.AccessURL.String(), clientID, "http://localhost:3000/callback", dynamicChallenge)
// Step 3: Perform OAuth2 authorization code flow with dynamically registered client
authURL := fmt.Sprintf("%s/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&state=dynamic_state",
api.AccessURL.String(), clientID, "http://localhost:3000/callback")
// Create an HTTP client that captures redirects.
// Create an HTTP client that captures redirects
authClient := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse // Stop following redirects
},
}
// Make the authorization request with authentication.
// Make the authorization request with authentication
authReq, err := http.NewRequestWithContext(ctx, "GET", authURL, nil)
require.NoError(t, err)
authReq.Header.Set("Cookie", fmt.Sprintf("coder_session_token=%s", coderClient.SessionToken()))
authReq.Header.Set("Authorization", "Bearer "+coderClient.SessionToken())
authResp, err := authClient.Do(authReq)
require.NoError(t, err)
defer authResp.Body.Close()
// Handle the response - check for error first.
// Handle the response - check for error first
if authResp.StatusCode == http.StatusBadRequest {
// Read error response for debugging.
// Read error response for debugging
bodyBytes, err := io.ReadAll(authResp.Body)
require.NoError(t, err)
t.Logf("OAuth2 authorization error: %s", string(bodyBytes))
t.FailNow()
}
// Handle consent flow if needed.
// Handle consent flow if needed
if authResp.StatusCode == http.StatusOK {
// This means we got the consent page, now we need to POST consent.
// This means we got the consent page, now we need to POST consent
consentReq, err := http.NewRequestWithContext(ctx, "POST", authURL, nil)
require.NoError(t, err)
consentReq.Header.Set("Cookie", fmt.Sprintf("coder_session_token=%s", coderClient.SessionToken()))
consentReq.Header.Set("Authorization", "Bearer "+coderClient.SessionToken())
consentReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
authResp, err = authClient.Do(consentReq)
@@ -927,7 +910,7 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
defer authResp.Body.Close()
}
// Extract authorization code from redirect.
// Extract authorization code from redirect
require.True(t, authResp.StatusCode >= 300 && authResp.StatusCode < 400,
"Expected redirect response, got %d", authResp.StatusCode)
location := authResp.Header.Get("Location")
@@ -940,14 +923,13 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
t.Logf("Successfully obtained authorization code: %s", authCode[:10]+"...")
// Step 4: Exchange authorization code for access token.
// Step 4: Exchange authorization code for access token
tokenRequestBody := url.Values{
"grant_type": {"authorization_code"},
"client_id": {clientID},
"client_secret": {clientSecret},
"code": {authCode},
"redirect_uri": {"http://localhost:3000/callback"},
"code_verifier": {dynamicVerifier},
}
tokenReq, err := http.NewRequestWithContext(ctx, "POST", api.AccessURL.String()+"/oauth2/tokens",
-9
View File
@@ -287,18 +287,9 @@ func (api *API) paginatedMembers(rw http.ResponseWriter, r *http.Request) {
memberRows = append(memberRows, row)
}
if len(paginatedMemberRows) == 0 {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.PaginatedMembersResponse{
Members: []codersdk.OrganizationMemberWithUserData{},
Count: 0,
})
return
}
members, err := convertOrganizationMembersWithUserData(ctx, api.Database, memberRows)
if err != nil {
httpapi.InternalServerError(rw, err)
return
}
resp := codersdk.PaginatedMembersResponse{
+41 -92
View File
@@ -2,8 +2,6 @@ package coderd_test
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
@@ -291,6 +289,7 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
authError: "Invalid query params:",
},
{
// TODO: This is valid for now, but should it be?
name: "DifferentProtocol",
app: apps.Default,
preAuth: func(valid *oauth2.Config) {
@@ -298,7 +297,6 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
newURL.Scheme = "https"
valid.RedirectURL = newURL.String()
},
authError: "Invalid query params:",
},
{
name: "NestedPath",
@@ -308,7 +306,6 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
newURL.Path = path.Join(newURL.Path, "nested")
valid.RedirectURL = newURL.String()
},
authError: "Invalid query params:",
},
{
// Some oauth implementations allow this, but our users can host
@@ -484,12 +481,11 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
}
var code string
var verifier string
if test.defaultCode != nil {
code = *test.defaultCode
} else {
var err error
code, verifier, err = authorizationFlow(ctx, userClient, valid)
code, err = authorizationFlow(ctx, userClient, valid)
if test.authError != "" {
require.Error(t, err)
require.ErrorContains(t, err, test.authError)
@@ -504,12 +500,8 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
test.preToken(valid)
}
// Do the actual exchange. Include PKCE code_verifier when
// we obtained a code through the authorization flow.
exchangeOpts := append([]oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", verifier),
}, test.exchangeMutate...)
token, err := valid.Exchange(ctx, code, exchangeOpts...)
// Do the actual exchange.
token, err := valid.Exchange(ctx, code, test.exchangeMutate...)
if test.tokenError != "" {
require.Error(t, err)
require.ErrorContains(t, err, test.tokenError)
@@ -691,11 +683,10 @@ func TestOAuth2ProviderTokenRefresh(t *testing.T) {
}
type exchangeSetup struct {
cfg *oauth2.Config
app codersdk.OAuth2ProviderApp
secret codersdk.OAuth2ProviderAppSecretFull
code string
verifier string
cfg *oauth2.Config
app codersdk.OAuth2ProviderApp
secret codersdk.OAuth2ProviderAppSecretFull
code string
}
func TestOAuth2ProviderRevoke(t *testing.T) {
@@ -739,13 +730,11 @@ func TestOAuth2ProviderRevoke(t *testing.T) {
name: "OverrideCodeAndToken",
fn: func(ctx context.Context, client *codersdk.Client, s exchangeSetup) {
// Generating a new code should wipe out the old code.
code, verifier, err := authorizationFlow(ctx, client, s.cfg)
code, err := authorizationFlow(ctx, client, s.cfg)
require.NoError(t, err)
// Generating a new token should wipe out the old token.
_, err = s.cfg.Exchange(ctx, code,
oauth2.SetAuthURLParam("code_verifier", verifier),
)
_, err = s.cfg.Exchange(ctx, code)
require.NoError(t, err)
},
replacesToken: true,
@@ -781,15 +770,14 @@ func TestOAuth2ProviderRevoke(t *testing.T) {
}
// Go through the auth flow to get a code.
code, verifier, err := authorizationFlow(ctx, testClient, cfg)
code, err := authorizationFlow(ctx, testClient, cfg)
require.NoError(t, err)
return exchangeSetup{
cfg: cfg,
app: app,
secret: secret,
code: code,
verifier: verifier,
cfg: cfg,
app: app,
secret: secret,
code: code,
}
}
@@ -806,16 +794,12 @@ func TestOAuth2ProviderRevoke(t *testing.T) {
test.fn(ctx, testClient, testEntities)
// Exchange should fail because the code should be gone.
_, err := testEntities.cfg.Exchange(ctx, testEntities.code,
oauth2.SetAuthURLParam("code_verifier", testEntities.verifier),
)
_, err := testEntities.cfg.Exchange(ctx, testEntities.code)
require.Error(t, err)
// Try again, this time letting the exchange complete first.
testEntities = setup(ctx, testClient, test.name+"-2")
token, err := testEntities.cfg.Exchange(ctx, testEntities.code,
oauth2.SetAuthURLParam("code_verifier", testEntities.verifier),
)
token, err := testEntities.cfg.Exchange(ctx, testEntities.code)
require.NoError(t, err)
// Validate the returned access token and that the app is listed.
@@ -888,38 +872,25 @@ func generateApps(ctx context.Context, t *testing.T, client *codersdk.Client, su
}
}
// generatePKCE creates a PKCE verifier and S256 challenge for testing.
func generatePKCE() (verifier, challenge string) {
verifier = uuid.NewString() + uuid.NewString()
h := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(h[:])
return verifier, challenge
}
func authorizationFlow(ctx context.Context, client *codersdk.Client, cfg *oauth2.Config) (code, codeVerifier string, err error) {
func authorizationFlow(ctx context.Context, client *codersdk.Client, cfg *oauth2.Config) (string, error) {
state := uuid.NewString()
codeVerifier, challenge := generatePKCE()
authURL := cfg.AuthCodeURL(state,
oauth2.SetAuthURLParam("code_challenge", challenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
)
authURL := cfg.AuthCodeURL(state)
// Make a POST request to simulate clicking "Allow" on the authorization page.
// This bypasses the HTML consent page and directly processes the authorization.
code, err = oidctest.OAuth2GetCode(
// Make a POST request to simulate clicking "Allow" on the authorization page
// This bypasses the HTML consent page and directly processes the authorization
return oidctest.OAuth2GetCode(
authURL,
func(req *http.Request) (*http.Response, error) {
// Change to POST to simulate the form submission.
// Change to POST to simulate the form submission
req.Method = http.MethodPost
// Prevent automatic redirect following.
// Prevent automatic redirect following
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
return client.Request(ctx, req.Method, req.URL.String(), nil)
},
)
return code, codeVerifier, err
}
func must[T any](value T, err error) T {
@@ -1026,15 +997,11 @@ func TestOAuth2ProviderResourceIndicators(t *testing.T) {
Scopes: []string{},
}
// Step 1: Authorization with resource parameter and PKCE.
// Step 1: Authorization with resource parameter
state := uuid.NewString()
verifier, challenge := generatePKCE()
authURL := cfg.AuthCodeURL(state,
oauth2.SetAuthURLParam("code_challenge", challenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
)
authURL := cfg.AuthCodeURL(state)
if test.authResource != "" {
// Add resource parameter to auth URL.
// Add resource parameter to auth URL
parsedURL, err := url.Parse(authURL)
require.NoError(t, err)
query := parsedURL.Query()
@@ -1063,7 +1030,7 @@ func TestOAuth2ProviderResourceIndicators(t *testing.T) {
// Step 2: Token exchange with resource parameter
// Use custom token exchange since golang.org/x/oauth2 doesn't support resource parameter in token requests
token, err := customTokenExchange(ctx, ownerClient.URL.String(), apps.Default.ID.String(), secret.ClientSecretFull, code, apps.Default.CallbackURL, test.tokenResource, verifier)
token, err := customTokenExchange(ctx, ownerClient.URL.String(), apps.Default.ID.String(), secret.ClientSecretFull, code, apps.Default.CallbackURL, test.tokenResource)
if test.expectTokenError {
require.Error(t, err)
require.Contains(t, err.Error(), "invalid_target")
@@ -1160,13 +1127,9 @@ func TestOAuth2ProviderCrossResourceAudienceValidation(t *testing.T) {
Scopes: []string{},
}
// Authorization with resource parameter for server1 and PKCE.
// Authorization with resource parameter for server1
state := uuid.NewString()
verifier, challenge := generatePKCE()
authURL := cfg.AuthCodeURL(state,
oauth2.SetAuthURLParam("code_challenge", challenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
)
authURL := cfg.AuthCodeURL(state)
parsedURL, err := url.Parse(authURL)
require.NoError(t, err)
query := parsedURL.Query()
@@ -1186,11 +1149,8 @@ func TestOAuth2ProviderCrossResourceAudienceValidation(t *testing.T) {
)
require.NoError(t, err)
// Exchange code for token with resource parameter and PKCE verifier.
token, err := cfg.Exchange(ctx, code,
oauth2.SetAuthURLParam("resource", resource1),
oauth2.SetAuthURLParam("code_verifier", verifier),
)
// Exchange code for token with resource parameter
token, err := cfg.Exchange(ctx, code, oauth2.SetAuthURLParam("resource", resource1))
require.NoError(t, err)
require.NotEmpty(t, token.AccessToken)
@@ -1266,11 +1226,9 @@ func TestOAuth2RefreshExpiryOutlivesAccess(t *testing.T) {
}
// Authorization and token exchange
code, verifier, err := authorizationFlow(ctx, ownerClient, cfg)
code, err := authorizationFlow(ctx, ownerClient, cfg)
require.NoError(t, err)
tok, err := cfg.Exchange(ctx, code,
oauth2.SetAuthURLParam("code_verifier", verifier),
)
tok, err := cfg.Exchange(ctx, code)
require.NoError(t, err)
require.NotEmpty(t, tok.AccessToken)
require.NotEmpty(t, tok.RefreshToken)
@@ -1295,7 +1253,7 @@ func TestOAuth2RefreshExpiryOutlivesAccess(t *testing.T) {
// customTokenExchange performs a custom OAuth2 token exchange with support for resource parameter
// This is needed because golang.org/x/oauth2 doesn't support custom parameters in token requests
func customTokenExchange(ctx context.Context, baseURL, clientID, clientSecret, code, redirectURI, resource, codeVerifier string) (*oauth2.Token, error) {
func customTokenExchange(ctx context.Context, baseURL, clientID, clientSecret, code, redirectURI, resource string) (*oauth2.Token, error) {
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
@@ -1305,9 +1263,6 @@ func customTokenExchange(ctx context.Context, baseURL, clientID, clientSecret, c
if resource != "" {
data.Set("resource", resource)
}
if codeVerifier != "" {
data.Set("code_verifier", codeVerifier)
}
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/oauth2/tokens", strings.NewReader(data.Encode()))
if err != nil {
@@ -1682,21 +1637,17 @@ func TestOAuth2CoderClient(t *testing.T) {
// Make a new user
client, user := coderdtest.CreateAnotherUser(t, owner, first.OrganizationID)
// Do an OAuth2 token exchange and get a new client with an oauth token.
// Do an OAuth2 token exchange and get a new client with an oauth token
state := uuid.NewString()
verifier, challenge := generatePKCE()
// Get an OAuth2 code for a token exchange.
// Get an OAuth2 code for a token exchange
code, err := oidctest.OAuth2GetCode(
cfg.AuthCodeURL(state,
oauth2.SetAuthURLParam("code_challenge", challenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
),
cfg.AuthCodeURL(state),
func(req *http.Request) (*http.Response, error) {
// Change to POST to simulate the form submission.
// Change to POST to simulate the form submission
req.Method = http.MethodPost
// Prevent automatic redirect following.
// Prevent automatic redirect following
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
@@ -1705,9 +1656,7 @@ func TestOAuth2CoderClient(t *testing.T) {
)
require.NoError(t, err)
token, err := cfg.Exchange(ctx, code,
oauth2.SetAuthURLParam("code_verifier", verifier),
)
token, err := cfg.Exchange(ctx, code)
require.NoError(t, err)
// Use the oauth client's authentication
+10 -64
View File
@@ -1,9 +1,7 @@
package oauth2provider
import (
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"net/http"
"net/url"
@@ -11,7 +9,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/justinas/nosurf"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
@@ -25,7 +22,6 @@ import (
type authorizeParams struct {
clientID string
redirectURL *url.URL
redirectURIProvided bool
responseType codersdk.OAuth2ProviderResponseType
scope []string
state string
@@ -38,13 +34,11 @@ func extractAuthorizeParams(r *http.Request, callbackURL *url.URL) (authorizePar
p := httpapi.NewQueryParamParser()
vals := r.URL.Query()
// response_type and client_id are always required.
p.RequiredNotEmpty("response_type", "client_id")
params := authorizeParams{
clientID: p.String(vals, "", "client_id"),
redirectURL: p.RedirectURL(vals, callbackURL, "redirect_uri"),
redirectURIProvided: vals.Get("redirect_uri") != "",
responseType: httpapi.ParseCustom(p, vals, "", "response_type", httpapi.ParseEnum[codersdk.OAuth2ProviderResponseType]),
scope: strings.Fields(strings.TrimSpace(p.String(vals, "", "scope"))),
state: p.String(vals, "", "state"),
@@ -52,15 +46,6 @@ func extractAuthorizeParams(r *http.Request, callbackURL *url.URL) (authorizePar
codeChallenge: p.String(vals, "", "code_challenge"),
codeChallengeMethod: p.String(vals, "", "code_challenge_method"),
}
// PKCE is required for authorization code flow requests.
if params.responseType == codersdk.OAuth2ProviderResponseTypeCode && params.codeChallenge == "" {
p.Errors = append(p.Errors, codersdk.ValidationError{
Field: "code_challenge",
Detail: `Query param "code_challenge" is required and cannot be empty`,
})
}
// Validate resource indicator syntax (RFC 8707): must be absolute URI without fragment
if err := validateResourceParameter(params.resource); err != nil {
p.Errors = append(p.Errors, codersdk.ValidationError{
@@ -127,22 +112,6 @@ func ShowAuthorizePage(accessURL *url.URL) http.HandlerFunc {
return
}
if params.responseType != codersdk.OAuth2ProviderResponseTypeCode {
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
Status: http.StatusBadRequest,
HideStatus: false,
Title: "Unsupported Response Type",
Description: "Only response_type=code is supported.",
Actions: []site.Action{
{
URL: accessURL.String(),
Text: "Back to site",
},
},
})
return
}
cancel := params.redirectURL
cancelQuery := params.redirectURL.Query()
cancelQuery.Add("error", "access_denied")
@@ -153,7 +122,6 @@ func ShowAuthorizePage(accessURL *url.URL) http.HandlerFunc {
AppName: app.Name,
CancelURI: cancel.String(),
RedirectURI: r.URL.String(),
CSRFToken: nosurf.Token(r),
Username: ua.FriendlyName,
})
}
@@ -179,23 +147,16 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
return
}
// OAuth 2.1 removes the implicit grant. Only
// authorization code flow is supported.
if params.responseType != codersdk.OAuth2ProviderResponseTypeCode {
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest,
codersdk.OAuth2ErrorCodeUnsupportedResponseType,
"Only response_type=code is supported")
return
}
// code_challenge is required (enforced by RequiredNotEmpty above),
// but default the method to S256 if omitted.
if params.codeChallengeMethod == "" {
params.codeChallengeMethod = string(codersdk.OAuth2PKCECodeChallengeMethodS256)
}
if err := codersdk.ValidatePKCECodeChallengeMethod(params.codeChallengeMethod); err != nil {
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, err.Error())
return
// Validate PKCE for public clients (MCP requirement)
if params.codeChallenge != "" {
// If code_challenge is provided but method is not, default to S256
if params.codeChallengeMethod == "" {
params.codeChallengeMethod = string(codersdk.OAuth2PKCECodeChallengeMethodS256)
}
if err := codersdk.ValidatePKCECodeChallengeMethod(params.codeChallengeMethod); err != nil {
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, err.Error())
return
}
}
// TODO: Ignoring scope for now, but should look into implementing.
@@ -233,8 +194,6 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
ResourceUri: sql.NullString{String: params.resource, Valid: params.resource != ""},
CodeChallenge: sql.NullString{String: params.codeChallenge, Valid: params.codeChallenge != ""},
CodeChallengeMethod: sql.NullString{String: params.codeChallengeMethod, Valid: params.codeChallengeMethod != ""},
StateHash: hashOAuth2State(params.state),
RedirectUri: sql.NullString{String: params.redirectURL.String(), Valid: params.redirectURIProvided},
})
if err != nil {
return xerrors.Errorf("insert oauth2 authorization code: %w", err)
@@ -259,16 +218,3 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
http.Redirect(rw, r, params.redirectURL.String(), http.StatusFound)
}
}
// hashOAuth2State returns a SHA-256 hash of the OAuth2 state parameter. If
// the state is empty, it returns a null string.
func hashOAuth2State(state string) sql.NullString {
if state == "" {
return sql.NullString{}
}
hash := sha256.Sum256([]byte(state))
return sql.NullString{
String: hex.EncodeToString(hash[:]),
Valid: true,
}
}
@@ -1,53 +0,0 @@
//nolint:testpackage // Internal test for unexported hashOAuth2State helper.
package oauth2provider
import (
"crypto/sha256"
"encoding/hex"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHashOAuth2State(t *testing.T) {
t.Parallel()
t.Run("EmptyState", func(t *testing.T) {
t.Parallel()
result := hashOAuth2State("")
assert.False(t, result.Valid, "empty state should return invalid NullString")
assert.Empty(t, result.String, "empty state should return empty string")
})
t.Run("NonEmptyState", func(t *testing.T) {
t.Parallel()
state := "test-state-value"
result := hashOAuth2State(state)
require.True(t, result.Valid, "non-empty state should return valid NullString")
// Verify it's a proper SHA-256 hash.
expected := sha256.Sum256([]byte(state))
assert.Equal(t, hex.EncodeToString(expected[:]), result.String,
"state hash should be SHA-256 hex digest")
})
t.Run("DifferentStatesProduceDifferentHashes", func(t *testing.T) {
t.Parallel()
hash1 := hashOAuth2State("state-a")
hash2 := hashOAuth2State("state-b")
require.True(t, hash1.Valid)
require.True(t, hash2.Valid)
assert.NotEqual(t, hash1.String, hash2.String,
"different states should produce different hashes")
})
t.Run("SameStateProducesSameHash", func(t *testing.T) {
t.Parallel()
hash1 := hashOAuth2State("deterministic")
hash2 := hashOAuth2State("deterministic")
require.True(t, hash1.Valid)
assert.Equal(t, hash1.String, hash2.String,
"same state should produce identical hash")
})
}
-32
View File
@@ -1,32 +0,0 @@
package oauth2provider_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/site"
)
func TestOAuthConsentFormIncludesCSRFToken(t *testing.T) {
t.Parallel()
const csrfFieldValue = "csrf-field-value"
req := httptest.NewRequest(http.MethodGet, "https://coder.com/oauth2/authorize", nil)
rec := httptest.NewRecorder()
site.RenderOAuthAllowPage(rec, req, site.RenderOAuthAllowData{
AppName: "Test OAuth App",
CancelURI: "https://coder.com/cancel",
RedirectURI: "https://coder.com/oauth2/authorize?client_id=test",
CSRFToken: csrfFieldValue,
Username: "test-user",
})
require.Equal(t, http.StatusOK, rec.Result().StatusCode)
assert.Contains(t, rec.Body.String(), `name="csrf_token"`)
assert.Contains(t, rec.Body.String(), `value="`+csrfFieldValue+`"`)
}
@@ -158,9 +158,7 @@ func TestOAuth2InvalidPKCE(t *testing.T) {
)
}
// TestOAuth2WithoutPKCEIsRejected verifies that authorization requests without
// a code_challenge are rejected now that PKCE is mandatory.
func TestOAuth2WithoutPKCEIsRejected(t *testing.T) {
func TestOAuth2WithoutPKCE(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
@@ -168,15 +166,15 @@ func TestOAuth2WithoutPKCEIsRejected(t *testing.T) {
})
_ = coderdtest.CreateFirstUser(t, client)
// Create OAuth2 app.
app, _ := oauth2providertest.CreateTestOAuth2App(t, client)
// Create OAuth2 app
app, clientSecret := oauth2providertest.CreateTestOAuth2App(t, client)
t.Cleanup(func() {
oauth2providertest.CleanupOAuth2App(t, client, app.ID)
})
state := oauth2providertest.GenerateState(t)
// Authorization without code_challenge should be rejected.
// Perform authorization without PKCE
authParams := oauth2providertest.AuthorizeParams{
ClientID: app.ID.String(),
ResponseType: "code",
@@ -184,9 +182,21 @@ func TestOAuth2WithoutPKCEIsRejected(t *testing.T) {
State: state,
}
oauth2providertest.AuthorizeOAuth2AppExpectingError(
t, client, client.URL.String(), authParams, http.StatusBadRequest,
)
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
require.NotEmpty(t, code, "should receive authorization code")
// Exchange code for token without PKCE
tokenParams := oauth2providertest.TokenExchangeParams{
GrantType: "authorization_code",
Code: code,
ClientID: app.ID.String(),
ClientSecret: clientSecret,
RedirectURI: oauth2providertest.TestRedirectURI,
}
token := oauth2providertest.ExchangeCodeForToken(t, client.URL.String(), tokenParams)
require.NotEmpty(t, token.AccessToken, "should receive access token")
require.NotEmpty(t, token.RefreshToken, "should receive refresh token")
}
func TestOAuth2TokenExchangeClientSecretBasic(t *testing.T) {
@@ -202,16 +212,13 @@ func TestOAuth2TokenExchangeClientSecretBasic(t *testing.T) {
oauth2providertest.CleanupOAuth2App(t, client, app.ID)
})
codeVerifier, codeChallenge := oauth2providertest.GeneratePKCE(t)
state := oauth2providertest.GenerateState(t)
authParams := oauth2providertest.AuthorizeParams{
ClientID: app.ID.String(),
ResponseType: "code",
RedirectURI: oauth2providertest.TestRedirectURI,
State: state,
CodeChallenge: codeChallenge,
CodeChallengeMethod: "S256",
ClientID: app.ID.String(),
ResponseType: "code",
RedirectURI: oauth2providertest.TestRedirectURI,
State: state,
}
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
@@ -222,7 +229,6 @@ func TestOAuth2TokenExchangeClientSecretBasic(t *testing.T) {
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", oauth2providertest.TestRedirectURI)
data.Set("code_verifier", codeVerifier)
req, err := http.NewRequestWithContext(ctx, "POST", client.URL.String()+"/oauth2/tokens", strings.NewReader(data.Encode()))
require.NoError(t, err, "failed to create token request")
@@ -259,16 +265,13 @@ func TestOAuth2TokenExchangeClientSecretBasicInvalidSecret(t *testing.T) {
oauth2providertest.CleanupOAuth2App(t, client, app.ID)
})
codeVerifier, codeChallenge := oauth2providertest.GeneratePKCE(t)
state := oauth2providertest.GenerateState(t)
authParams := oauth2providertest.AuthorizeParams{
ClientID: app.ID.String(),
ResponseType: "code",
RedirectURI: oauth2providertest.TestRedirectURI,
State: state,
CodeChallenge: codeChallenge,
CodeChallengeMethod: "S256",
ClientID: app.ID.String(),
ResponseType: "code",
RedirectURI: oauth2providertest.TestRedirectURI,
State: state,
}
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
@@ -279,7 +282,6 @@ func TestOAuth2TokenExchangeClientSecretBasicInvalidSecret(t *testing.T) {
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", oauth2providertest.TestRedirectURI)
data.Set("code_verifier", codeVerifier)
wrongSecret := clientSecret + "x"
@@ -347,30 +349,26 @@ func TestOAuth2ResourceParameter(t *testing.T) {
})
state := oauth2providertest.GenerateState(t)
codeVerifier, codeChallenge := oauth2providertest.GeneratePKCE(t)
// Perform authorization with resource parameter.
// Perform authorization with resource parameter
authParams := oauth2providertest.AuthorizeParams{
ClientID: app.ID.String(),
ResponseType: "code",
RedirectURI: oauth2providertest.TestRedirectURI,
State: state,
CodeChallenge: codeChallenge,
CodeChallengeMethod: "S256",
Resource: oauth2providertest.TestResourceURI,
ClientID: app.ID.String(),
ResponseType: "code",
RedirectURI: oauth2providertest.TestRedirectURI,
State: state,
Resource: oauth2providertest.TestResourceURI,
}
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
require.NotEmpty(t, code, "should receive authorization code")
// Exchange code for token with resource parameter.
// Exchange code for token with resource parameter
tokenParams := oauth2providertest.TokenExchangeParams{
GrantType: "authorization_code",
Code: code,
ClientID: app.ID.String(),
ClientSecret: clientSecret,
RedirectURI: oauth2providertest.TestRedirectURI,
CodeVerifier: codeVerifier,
Resource: oauth2providertest.TestResourceURI,
}
@@ -394,16 +392,13 @@ func TestOAuth2TokenRefresh(t *testing.T) {
})
state := oauth2providertest.GenerateState(t)
codeVerifier, codeChallenge := oauth2providertest.GeneratePKCE(t)
// Get initial token.
// Get initial token
authParams := oauth2providertest.AuthorizeParams{
ClientID: app.ID.String(),
ResponseType: "code",
RedirectURI: oauth2providertest.TestRedirectURI,
State: state,
CodeChallenge: codeChallenge,
CodeChallengeMethod: "S256",
ClientID: app.ID.String(),
ResponseType: "code",
RedirectURI: oauth2providertest.TestRedirectURI,
State: state,
}
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
@@ -414,7 +409,6 @@ func TestOAuth2TokenRefresh(t *testing.T) {
ClientID: app.ID.String(),
ClientSecret: clientSecret,
RedirectURI: oauth2providertest.TestRedirectURI,
CodeVerifier: codeVerifier,
}
initialToken := oauth2providertest.ExchangeCodeForToken(t, client.URL.String(), tokenParams)
+7 -20
View File
@@ -254,27 +254,14 @@ func authorizationCodeGrant(ctx context.Context, db database.Store, app database
return codersdk.OAuth2TokenResponse{}, errBadCode
}
// Verify redirect_uri matches the one used during authorization
// (RFC 6749 §4.1.3).
if dbCode.RedirectUri.Valid && dbCode.RedirectUri.String != "" {
if req.RedirectURI != dbCode.RedirectUri.String {
return codersdk.OAuth2TokenResponse{}, errBadCode
// Verify PKCE challenge if present
if dbCode.CodeChallenge.Valid && dbCode.CodeChallenge.String != "" {
if req.CodeVerifier == "" {
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
}
if !VerifyPKCE(dbCode.CodeChallenge.String, req.CodeVerifier) {
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
}
}
// PKCE is mandatory for all authorization code flows
// (OAuth 2.1). Verify the code verifier against the stored
// challenge.
if req.CodeVerifier == "" {
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
}
if !dbCode.CodeChallenge.Valid || dbCode.CodeChallenge.String == "" {
// Code was issued without a challenge — should not happen
// with authorize endpoint enforcement, but defend in depth.
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
}
if !VerifyPKCE(dbCode.CodeChallenge.String, req.CodeVerifier) {
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
}
// Verify resource parameter consistency (RFC 8707)
@@ -318,7 +318,6 @@ func TestExtractAuthorizeParams_Scopes(t *testing.T) {
query.Set("response_type", "code")
query.Set("client_id", "test-client")
query.Set("redirect_uri", "http://localhost:3000/callback")
query.Set("code_challenge", "test-challenge")
if tc.scopeParam != "" {
query.Set("scope", tc.scopeParam)
}
@@ -342,34 +341,6 @@ func TestExtractAuthorizeParams_Scopes(t *testing.T) {
}
}
// TestExtractAuthorizeParams_TokenResponseTypeDoesNotRequirePKCE ensures
// response_type=token is parsed without requiring PKCE fields so callers can
// return unsupported_response_type instead of invalid_request.
func TestExtractAuthorizeParams_TokenResponseTypeDoesNotRequirePKCE(t *testing.T) {
t.Parallel()
callbackURL, err := url.Parse("http://localhost:3000/callback")
require.NoError(t, err)
query := url.Values{}
query.Set("response_type", string(codersdk.OAuth2ProviderResponseTypeToken))
query.Set("client_id", "test-client")
query.Set("redirect_uri", "http://localhost:3000/callback")
reqURL, err := url.Parse("http://localhost:8080/oauth2/authorize?" + query.Encode())
require.NoError(t, err)
req := &http.Request{
Method: http.MethodGet,
URL: reqURL,
}
params, validationErrs, err := extractAuthorizeParams(req, callbackURL)
require.NoError(t, err)
require.Empty(t, validationErrs)
require.Equal(t, codersdk.OAuth2ProviderResponseTypeToken, params.responseType)
}
// TestRefreshTokenGrant_Scopes tests that scopes can be requested during refresh
func TestRefreshTokenGrant_Scopes(t *testing.T) {
t.Parallel()
@@ -19,9 +19,9 @@ import (
)
var (
templatesActiveUsersDesc = prometheus.NewDesc("coderd_insights_templates_active_users", "The number of active users of the template.", []string{"template_name", "organization_name"}, nil)
applicationsUsageSecondsDesc = prometheus.NewDesc("coderd_insights_applications_usage_seconds", "The application usage per template.", []string{"template_name", "application_name", "slug", "organization_name"}, nil)
parametersDesc = prometheus.NewDesc("coderd_insights_parameters", "The parameter usage per template.", []string{"template_name", "parameter_name", "parameter_type", "parameter_value", "organization_name"}, nil)
templatesActiveUsersDesc = prometheus.NewDesc("coderd_insights_templates_active_users", "The number of active users of the template.", []string{"template_name"}, nil)
applicationsUsageSecondsDesc = prometheus.NewDesc("coderd_insights_applications_usage_seconds", "The application usage per template.", []string{"template_name", "application_name", "slug"}, nil)
parametersDesc = prometheus.NewDesc("coderd_insights_parameters", "The parameter usage per template.", []string{"template_name", "parameter_name", "parameter_type", "parameter_value"}, nil)
)
type MetricsCollector struct {
@@ -38,8 +38,7 @@ type insightsData struct {
apps []database.GetTemplateAppInsightsByTemplateRow
params []parameterRow
templateNames map[uuid.UUID]string
organizationNames map[uuid.UUID]string // template ID → org name
templateNames map[uuid.UUID]string
}
type parameterRow struct {
@@ -138,7 +137,6 @@ func (mc *MetricsCollector) Run(ctx context.Context) (func(), error) {
templateIDs := uniqueTemplateIDs(templateInsights, appInsights, paramInsights)
templateNames := make(map[uuid.UUID]string, len(templateIDs))
organizationNames := make(map[uuid.UUID]string, len(templateIDs))
if len(templateIDs) > 0 {
templates, err := mc.database.GetTemplatesWithFilter(ctx, database.GetTemplatesWithFilterParams{
IDs: templateIDs,
@@ -148,31 +146,6 @@ func (mc *MetricsCollector) Run(ctx context.Context) (func(), error) {
return
}
templateNames = onlyTemplateNames(templates)
// Build org name lookup so that metrics can
// distinguish templates with the same name across
// different organizations.
orgIDs := make([]uuid.UUID, 0, len(templates))
for _, t := range templates {
orgIDs = append(orgIDs, t.OrganizationID)
}
orgIDs = slice.Unique(orgIDs)
orgs, err := mc.database.GetOrganizations(ctx, database.GetOrganizationsParams{
IDs: orgIDs,
})
if err != nil {
mc.logger.Error(ctx, "unable to fetch organizations from database", slog.Error(err))
return
}
orgNameByID := make(map[uuid.UUID]string, len(orgs))
for _, o := range orgs {
orgNameByID[o.ID] = o.Name
}
organizationNames = make(map[uuid.UUID]string, len(templates))
for _, t := range templates {
organizationNames[t.ID] = orgNameByID[t.OrganizationID]
}
}
// Refresh the collector state
@@ -181,8 +154,7 @@ func (mc *MetricsCollector) Run(ctx context.Context) (func(), error) {
apps: appInsights,
params: paramInsights,
templateNames: templateNames,
organizationNames: organizationNames,
templateNames: templateNames,
})
}
@@ -222,46 +194,44 @@ func (mc *MetricsCollector) Collect(metricsCh chan<- prometheus.Metric) {
// Custom apps
for _, appRow := range data.apps {
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue, float64(appRow.UsageSeconds), data.templateNames[appRow.TemplateID],
appRow.DisplayName, appRow.SlugOrPort, data.organizationNames[appRow.TemplateID])
appRow.DisplayName, appRow.SlugOrPort)
}
// Built-in apps
for _, templateRow := range data.templates {
orgName := data.organizationNames[templateRow.TemplateID]
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue,
float64(templateRow.UsageVscodeSeconds),
data.templateNames[templateRow.TemplateID],
codersdk.TemplateBuiltinAppDisplayNameVSCode,
"", orgName)
"")
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue,
float64(templateRow.UsageJetbrainsSeconds),
data.templateNames[templateRow.TemplateID],
codersdk.TemplateBuiltinAppDisplayNameJetBrains,
"", orgName)
"")
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue,
float64(templateRow.UsageReconnectingPtySeconds),
data.templateNames[templateRow.TemplateID],
codersdk.TemplateBuiltinAppDisplayNameWebTerminal,
"", orgName)
"")
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue,
float64(templateRow.UsageSshSeconds),
data.templateNames[templateRow.TemplateID],
codersdk.TemplateBuiltinAppDisplayNameSSH,
"", orgName)
"")
}
// Templates
for _, templateRow := range data.templates {
metricsCh <- prometheus.MustNewConstMetric(templatesActiveUsersDesc, prometheus.GaugeValue, float64(templateRow.ActiveUsers), data.templateNames[templateRow.TemplateID], data.organizationNames[templateRow.TemplateID])
metricsCh <- prometheus.MustNewConstMetric(templatesActiveUsersDesc, prometheus.GaugeValue, float64(templateRow.ActiveUsers), data.templateNames[templateRow.TemplateID])
}
// Parameters
for _, parameterRow := range data.params {
metricsCh <- prometheus.MustNewConstMetric(parametersDesc, prometheus.GaugeValue, float64(parameterRow.count), data.templateNames[parameterRow.templateID], parameterRow.name, parameterRow.aType, parameterRow.value, data.organizationNames[parameterRow.templateID])
metricsCh <- prometheus.MustNewConstMetric(parametersDesc, prometheus.GaugeValue, float64(parameterRow.count), data.templateNames[parameterRow.templateID], parameterRow.name, parameterRow.aType, parameterRow.value)
}
}
@@ -1,13 +1,13 @@
{
"coderd_insights_applications_usage_seconds[application_name=JetBrains,organization_name=coder,slug=,template_name=golden-template]": 60,
"coderd_insights_applications_usage_seconds[application_name=Visual Studio Code,organization_name=coder,slug=,template_name=golden-template]": 60,
"coderd_insights_applications_usage_seconds[application_name=Web Terminal,organization_name=coder,slug=,template_name=golden-template]": 0,
"coderd_insights_applications_usage_seconds[application_name=SSH,organization_name=coder,slug=,template_name=golden-template]": 60,
"coderd_insights_applications_usage_seconds[application_name=Golden Slug,organization_name=coder,slug=golden-slug,template_name=golden-template]": 180,
"coderd_insights_parameters[organization_name=coder,parameter_name=first_parameter,parameter_type=string,parameter_value=Foobar,template_name=golden-template]": 1,
"coderd_insights_parameters[organization_name=coder,parameter_name=first_parameter,parameter_type=string,parameter_value=Baz,template_name=golden-template]": 1,
"coderd_insights_parameters[organization_name=coder,parameter_name=second_parameter,parameter_type=bool,parameter_value=true,template_name=golden-template]": 2,
"coderd_insights_parameters[organization_name=coder,parameter_name=third_parameter,parameter_type=number,parameter_value=789,template_name=golden-template]": 1,
"coderd_insights_parameters[organization_name=coder,parameter_name=third_parameter,parameter_type=number,parameter_value=999,template_name=golden-template]": 1,
"coderd_insights_templates_active_users[organization_name=coder,template_name=golden-template]": 1
"coderd_insights_applications_usage_seconds[application_name=JetBrains,slug=,template_name=golden-template]": 60,
"coderd_insights_applications_usage_seconds[application_name=Visual Studio Code,slug=,template_name=golden-template]": 60,
"coderd_insights_applications_usage_seconds[application_name=Web Terminal,slug=,template_name=golden-template]": 0,
"coderd_insights_applications_usage_seconds[application_name=SSH,slug=,template_name=golden-template]": 60,
"coderd_insights_applications_usage_seconds[application_name=Golden Slug,slug=golden-slug,template_name=golden-template]": 180,
"coderd_insights_parameters[parameter_name=first_parameter,parameter_type=string,parameter_value=Foobar,template_name=golden-template]": 1,
"coderd_insights_parameters[parameter_name=first_parameter,parameter_type=string,parameter_value=Baz,template_name=golden-template]": 1,
"coderd_insights_parameters[parameter_name=second_parameter,parameter_type=bool,parameter_value=true,template_name=golden-template]": 2,
"coderd_insights_parameters[parameter_name=third_parameter,parameter_type=number,parameter_value=789,template_name=golden-template]": 1,
"coderd_insights_parameters[parameter_name=third_parameter,parameter_type=number,parameter_value=999,template_name=golden-template]": 1,
"coderd_insights_templates_active_users[template_name=golden-template]": 1
}
@@ -564,7 +564,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
// The check `s.OIDCConfig != nil` is not as strict, since it can be an interface
// pointing to a typed nil.
if !reflect.ValueOf(s.OIDCConfig).IsNil() {
workspaceOwnerOIDCAccessToken, err = ObtainOIDCAccessToken(ctx, s.Logger, s.Database, s.OIDCConfig, owner.ID)
workspaceOwnerOIDCAccessToken, err = obtainOIDCAccessToken(ctx, s.Database, s.OIDCConfig, owner.ID)
if err != nil {
return nil, failJob(fmt.Sprintf("obtain OIDC access token: %s", err))
}
@@ -725,16 +725,11 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
}
}
provisionerStateRow, err := s.Database.GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuild.ID)
if err != nil {
return nil, failJob(fmt.Sprintf("get workspace build provisioner state: %s", err))
}
protoJob.Type = &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
WorkspaceBuildId: workspaceBuild.ID.String(),
WorkspaceName: workspace.Name,
State: provisionerStateRow.ProvisionerState,
State: workspaceBuild.ProvisionerState,
RichParameterValues: convertRichParameterValues(workspaceBuildParameters),
PreviousParameterValues: convertRichParameterValues(lastWorkspaceBuildParameters),
VariableValues: asVariableValues(templateVariables),
@@ -845,11 +840,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
// Record the time the job spent waiting in the queue.
if s.metrics != nil && job.StartedAt.Valid && job.Provisioner.Valid() {
// These timestamps lose their monotonic clock component after a Postgres
// round-trip, so the subtraction is based purely on wall-clock time. Floor at
// 1ms as a defensive measure against clock adjustments producing a negative
// delta while acknowledging there's a non-zero queue time.
queueWaitSeconds := max(job.StartedAt.Time.Sub(job.CreatedAt).Seconds(), 0.001)
queueWaitSeconds := job.StartedAt.Time.Sub(job.CreatedAt).Seconds()
s.metrics.ObserveJobQueueWait(string(job.Provisioner), string(job.Type), jobTransition, jobBuildReason, queueWaitSeconds)
}
@@ -3075,37 +3066,9 @@ func deleteSessionTokenForUserAndWorkspace(ctx context.Context, db database.Stor
return nil
}
func shouldRefreshOIDCToken(link database.UserLink) (bool, time.Time) {
if link.OAuthRefreshToken == "" {
// We cannot refresh even if we wanted to
return false, link.OAuthExpiry
}
if link.OAuthExpiry.IsZero() {
// 0 expire means the token never expires, so we shouldn't refresh
return false, link.OAuthExpiry
}
// This handles an edge case where the token is about to expire. A workspace
// build takes a non-trivial amount of time. If the token is to expire during the
// build, then the build risks failure. To mitigate this, refresh the token
// prematurely.
//
// If an OIDC provider issues short-lived tokens less than our defined period,
// the token will always be refreshed on every workspace build.
//
// By setting the expiration backwards, we are effectively shortening the
// time a token can be alive for by 10 minutes.
// Note: This is how it is done in the oauth2 package's own token refreshing logic.
expiresAt := link.OAuthExpiry.Add(-time.Minute * 10)
// Return if the token is assumed to be expired.
return expiresAt.Before(dbtime.Now()), expiresAt
}
// ObtainOIDCAccessToken returns a valid OpenID Connect access token
// obtainOIDCAccessToken returns a valid OpenID Connect access token
// for the user if it's able to obtain one, otherwise it returns an empty string.
func ObtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.Store, oidcConfig promoauth.OAuth2Config, userID uuid.UUID) (string, error) {
func obtainOIDCAccessToken(ctx context.Context, db database.Store, oidcConfig promoauth.OAuth2Config, userID uuid.UUID) (string, error) {
link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
UserID: userID,
LoginType: database.LoginTypeOIDC,
@@ -3117,13 +3080,11 @@ func ObtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.
return "", xerrors.Errorf("get owner oidc link: %w", err)
}
if shouldRefresh, expiresAt := shouldRefreshOIDCToken(link); shouldRefresh {
if link.OAuthExpiry.Before(dbtime.Now()) && !link.OAuthExpiry.IsZero() && link.OAuthRefreshToken != "" {
token, err := oidcConfig.TokenSource(ctx, &oauth2.Token{
AccessToken: link.OAuthAccessToken,
RefreshToken: link.OAuthRefreshToken,
// Use the expiresAt returned by shouldRefreshOIDCToken.
// It will force a refresh with an expired time.
Expiry: expiresAt,
Expiry: link.OAuthExpiry,
}).Token()
if err != nil {
// If OIDC fails to refresh, we return an empty string and don't fail.
@@ -3148,7 +3109,6 @@ func ObtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.
if err != nil {
return "", xerrors.Errorf("update user link: %w", err)
}
logger.Info(ctx, "refreshed expired OIDC token for user during workspace build", slog.F("user_id", userID))
}
return link.OAuthAccessToken, nil
@@ -16,109 +16,13 @@ import (
"github.com/coder/coder/v2/testutil"
)
func TestShouldRefreshOIDCToken(t *testing.T) {
t.Parallel()
now := dbtime.Now()
testCases := []struct {
name string
link database.UserLink
want bool
}{
{
name: "NoRefreshToken",
link: database.UserLink{OAuthExpiry: now.Add(-time.Hour)},
want: false,
},
{
name: "ZeroExpiry",
link: database.UserLink{OAuthRefreshToken: "refresh"},
want: false,
},
{
name: "LongExpired",
link: database.UserLink{
OAuthRefreshToken: "refresh",
OAuthExpiry: now.Add(-1 * time.Hour),
},
want: true,
},
{
// Edge being "+/- 10 minutes"
name: "EdgeExpired",
link: database.UserLink{
OAuthRefreshToken: "refresh",
OAuthExpiry: now.Add(-1 * time.Minute * 10),
},
want: true,
},
{
name: "Expired",
link: database.UserLink{
OAuthRefreshToken: "refresh",
OAuthExpiry: now.Add(-1 * time.Minute),
},
want: true,
},
{
name: "SoonToBeExpired",
link: database.UserLink{
OAuthRefreshToken: "refresh",
OAuthExpiry: now.Add(5 * time.Minute),
},
want: true,
},
{
name: "SoonToBeExpiredEdge",
link: database.UserLink{
OAuthRefreshToken: "refresh",
OAuthExpiry: now.Add(9 * time.Minute),
},
want: true,
},
{
name: "AfterEdge",
link: database.UserLink{
OAuthRefreshToken: "refresh",
OAuthExpiry: now.Add(11 * time.Minute),
},
want: false,
},
{
name: "NotExpired",
link: database.UserLink{
OAuthRefreshToken: "refresh",
OAuthExpiry: now.Add(time.Hour),
},
want: false,
},
{
name: "NotEvenCloseExpired",
link: database.UserLink{
OAuthRefreshToken: "refresh",
OAuthExpiry: now.Add(time.Hour * 24),
},
want: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
shouldRefresh, _ := shouldRefreshOIDCToken(tc.link)
require.Equal(t, tc.want, shouldRefresh)
})
}
}
func TestObtainOIDCAccessToken(t *testing.T) {
t.Parallel()
ctx := context.Background()
t.Run("NoToken", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
_, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, nil, uuid.Nil)
_, err := obtainOIDCAccessToken(ctx, db, nil, uuid.Nil)
require.NoError(t, err)
})
t.Run("InvalidConfig", func(t *testing.T) {
@@ -131,7 +35,7 @@ func TestObtainOIDCAccessToken(t *testing.T) {
LoginType: database.LoginTypeOIDC,
OAuthExpiry: dbtime.Now().Add(-time.Hour),
})
_, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, &oauth2.Config{}, user.ID)
_, err := obtainOIDCAccessToken(ctx, db, &oauth2.Config{}, user.ID)
require.NoError(t, err)
})
t.Run("MissingLink", func(t *testing.T) {
@@ -140,7 +44,7 @@ func TestObtainOIDCAccessToken(t *testing.T) {
user := dbgen.User(t, db, database.User{
LoginType: database.LoginTypeOIDC,
})
tok, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, &oauth2.Config{}, user.ID)
tok, err := obtainOIDCAccessToken(ctx, db, &oauth2.Config{}, user.ID)
require.Empty(t, tok)
require.NoError(t, err)
})
@@ -153,7 +57,7 @@ func TestObtainOIDCAccessToken(t *testing.T) {
LoginType: database.LoginTypeOIDC,
OAuthExpiry: dbtime.Now().Add(-time.Hour),
})
_, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, &testutil.OAuth2Config{
_, err := obtainOIDCAccessToken(ctx, db, &testutil.OAuth2Config{
Token: &oauth2.Token{
AccessToken: "token",
},
@@ -15,7 +15,6 @@ import (
"testing"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
@@ -31,7 +30,6 @@ import (
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
@@ -60,175 +58,6 @@ import (
"github.com/coder/serpent"
)
// TestTokenIsRefreshedEarly creates a fake OIDC IDP that sets expiration times
// of the token to values that are "near expiration". Expiration being 10minutes
// earlier than it needs to be. The `ObtainOIDCAccessToken` should refresh these
// tokens early.
func TestTokenIsRefreshedEarly(t *testing.T) {
t.Parallel()
t.Run("WithCoderd", func(t *testing.T) {
t.Parallel()
tokenRefreshCount := 0
fake := oidctest.NewFakeIDP(t,
oidctest.WithServing(),
oidctest.WithDefaultExpire(time.Minute*8),
oidctest.WithRefresh(func(email string) error {
tokenRefreshCount++
return nil
}),
)
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
})
db, ps := dbtestutil.NewDB(t)
owner := coderdtest.New(t, &coderdtest.Options{
OIDCConfig: cfg,
IncludeProvisionerDaemon: true,
Database: db,
Pubsub: ps,
})
first := coderdtest.CreateFirstUser(t, owner)
version := coderdtest.CreateTemplateVersion(t, owner, first.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJobCompleted(t, owner, version.ID)
template := coderdtest.CreateTemplate(t, owner, first.OrganizationID, version.ID)
// Setup an OIDC user.
client, _ := fake.Login(t, owner, jwt.MapClaims{
"email": "user@unauthorized.com",
"email_verified": true,
"sub": uuid.NewString(),
})
// Creating a workspace should refresh the oidc early.
tokenRefreshCount = 0
wrk := coderdtest.CreateWorkspace(t, client, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, wrk.LatestBuild.ID)
require.Equal(t, 1, tokenRefreshCount)
})
}
//nolint:tparallel,paralleltest // Sub tests need to run sequentially.
func TestTokenIsRefreshedEarlyWithoutCoderd(t *testing.T) {
t.Parallel()
tokenRefreshCount := 0
fake := oidctest.NewFakeIDP(t,
oidctest.WithServing(),
oidctest.WithDefaultExpire(time.Minute*8),
oidctest.WithRefresh(func(email string) error {
tokenRefreshCount++
return nil
}),
)
cfg := fake.OIDCConfig(t, nil)
// Fetch a valid token from the fake OIDC provider
token, err := fake.GenerateAuthenticatedToken(jwt.MapClaims{
"email": "user@unauthorized.com",
"email_verified": true,
"sub": uuid.NewString(),
})
require.NoError(t, err)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID,
LoginType: database.LoginTypeOIDC,
LinkedID: "foo",
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
// The oauth expiry does not really matter, since each test will manually control
// this value.
OAuthExpiry: dbtime.Now().Add(time.Hour),
})
setLinkExpiration := func(t *testing.T, exp time.Time) database.UserLink {
ctx := testutil.Context(t, testutil.WaitShort)
links, err := db.GetUserLinksByUserID(ctx, user.ID)
require.NoError(t, err)
require.Len(t, links, 1)
link := links[0]
newLink, err := db.UpdateUserLink(ctx, database.UpdateUserLinkParams{
OAuthAccessToken: link.OAuthAccessToken,
OAuthAccessTokenKeyID: link.OAuthAccessTokenKeyID,
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: link.OAuthRefreshTokenKeyID,
OAuthExpiry: exp,
Claims: link.Claims,
UserID: link.UserID,
LoginType: link.LoginType,
})
require.NoError(t, err)
return newLink
}
for _, c := range []struct {
name string
// expires is a function to return a more up to date "now".
// Because the oauth library is calling `time.Now()`, we cannot use
// mocked clocks.
expires func() time.Time
refreshExpected bool
}{
{
name: "ZeroExpiry",
expires: func() time.Time { return time.Time{} },
refreshExpected: false,
},
{
name: "LongExpired",
expires: func() time.Time { return dbtime.Now().Add(-time.Hour) },
refreshExpected: true,
},
{
name: "EdgeExpired",
expires: func() time.Time { return dbtime.Now().Add(-time.Minute * 10) },
refreshExpected: true,
},
{
name: "RecentExpired",
expires: func() time.Time { return dbtime.Now().Add(-time.Second * -1) },
refreshExpected: true,
},
{
name: "Future",
expires: func() time.Time { return dbtime.Now().Add(time.Hour) },
refreshExpected: false,
},
{
name: "FutureWithinRefreshWindow",
expires: func() time.Time { return dbtime.Now().Add(time.Minute * 8) },
refreshExpected: true,
},
} {
t.Run(c.name, func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort)
oldLink := setLinkExpiration(t, c.expires())
tokenRefreshCount = 0
_, err := provisionerdserver.ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, cfg, user.ID)
require.NoError(t, err)
links, err := db.GetUserLinksByUserID(ctx, user.ID)
require.NoError(t, err)
require.Len(t, links, 1)
newLink := links[0]
if c.refreshExpected {
require.Equal(t, 1, tokenRefreshCount)
require.NotEqual(t, oldLink.OAuthAccessToken, newLink.OAuthAccessToken)
require.NotEqual(t, oldLink.OAuthRefreshToken, newLink.OAuthRefreshToken)
} else {
require.Equal(t, 0, tokenRefreshCount)
require.Equal(t, oldLink.OAuthAccessToken, newLink.OAuthAccessToken)
require.Equal(t, oldLink.OAuthRefreshToken, newLink.OAuthRefreshToken)
}
})
}
}
func testTemplateScheduleStore() *atomic.Pointer[schedule.TemplateScheduleStore] {
poitr := &atomic.Pointer[schedule.TemplateScheduleStore]{}
store := schedule.NewAGPLTemplateScheduleStore()
@@ -1492,9 +1321,7 @@ func TestFailJob(t *testing.T) {
<-publishedLogs
build, err := db.GetWorkspaceBuildByID(ctx, buildID)
require.NoError(t, err)
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
require.NoError(t, err)
require.Equal(t, "some state", string(provisionerStateRow.ProvisionerState))
require.Equal(t, "some state", string(build.ProvisionerState))
require.Len(t, auditor.AuditLogs(), 1)
// Assert that the workspace_id field get populated
-1
View File
@@ -81,7 +81,6 @@ const (
SubjectAibridged SubjectType = "aibridged"
SubjectTypeDBPurge SubjectType = "dbpurge"
SubjectTypeBoundaryUsageTracker SubjectType = "boundary_usage_tracker"
SubjectTypeWorkspaceBuilder SubjectType = "workspace_builder"
)
const (
-34
View File
@@ -282,40 +282,6 @@ neq(input.object.owner, "");
p("'10d03e62-7703-4df5-a358-4f76577d4e2f' = id :: text") + " AND " + p("id :: text != ''") + " AND " + p("'' = ''"),
),
},
{
Name: "AuditLogUUID",
Queries: []string{
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
`input.object.org_owner != ""`,
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708", "05f58202-4bfc-43ce-9ba4-5ff6e0174a71"}`,
`"read" in input.object.acl_group_list[input.object.org_owner]`,
},
ExpectedSQL: p(
p("audit_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("audit_logs.organization_id IS NOT NULL") + " OR " +
p("audit_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("audit_logs.organization_id = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'::uuid,'8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
"(false)"),
VariableConverter: regosql.AuditLogConverter(),
},
{
Name: "ConnectionLogUUID",
Queries: []string{
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
`input.object.org_owner != ""`,
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708"}`,
`"read" in input.object.acl_group_list[input.object.org_owner]`,
},
ExpectedSQL: p(
p("connection_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("connection_logs.organization_id IS NOT NULL") + " OR " +
p("connection_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("connection_logs.organization_id = ANY(ARRAY ['8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
"(false)"),
VariableConverter: regosql.ConnectionLogConverter(),
},
}
for _, tc := range testCases {
+2 -2
View File
@@ -53,7 +53,7 @@ func WorkspaceConverter() *sqltypes.VariableConverter {
func AuditLogConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
sqltypes.UUIDVarMatcher("audit_logs.organization_id", []string{"input", "object", "org_owner"}),
sqltypes.StringVarMatcher("COALESCE(audit_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
// Audit logs have no user owner, only owner by an organization.
sqltypes.AlwaysFalse(userOwnerMatcher()),
)
@@ -67,7 +67,7 @@ func AuditLogConverter() *sqltypes.VariableConverter {
func ConnectionLogConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
sqltypes.UUIDVarMatcher("connection_logs.organization_id", []string{"input", "object", "org_owner"}),
sqltypes.StringVarMatcher("COALESCE(connection_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
// Connection logs have no user owner, only owner by an organization.
sqltypes.AlwaysFalse(userOwnerMatcher()),
)
-114
View File
@@ -1,114 +0,0 @@
package sqltypes
import (
"fmt"
"strings"
"github.com/open-policy-agent/opa/ast"
"golang.org/x/xerrors"
)
var (
_ VariableMatcher = astUUIDVar{}
_ Node = astUUIDVar{}
_ SupportsEquality = astUUIDVar{}
)
// astUUIDVar is a variable that represents a UUID column. Unlike
// astStringVar it emits native UUID comparisons (column = 'val'::uuid)
// instead of text-based ones (COALESCE(column::text, ”) = 'val').
// This allows PostgreSQL to use indexes on UUID columns.
type astUUIDVar struct {
Source RegoSource
FieldPath []string
ColumnString string
}
func UUIDVarMatcher(sqlColumn string, regoPath []string) VariableMatcher {
return astUUIDVar{FieldPath: regoPath, ColumnString: sqlColumn}
}
func (astUUIDVar) UseAs() Node { return astUUIDVar{} }
func (u astUUIDVar) ConvertVariable(rego ast.Ref) (Node, bool) {
left, err := RegoVarPath(u.FieldPath, rego)
if err == nil && len(left) == 0 {
return astUUIDVar{
Source: RegoSource(rego.String()),
FieldPath: u.FieldPath,
ColumnString: u.ColumnString,
}, true
}
return nil, false
}
func (u astUUIDVar) SQLString(_ *SQLGenerator) string {
return u.ColumnString
}
// EqualsSQLString handles equality comparisons for UUID columns.
// Rego always produces string literals, so we accept AstString and
// cast the literal to ::uuid in the output SQL. This lets PG use
// native UUID indexes instead of falling back to text comparisons.
// nolint:revive
func (u astUUIDVar) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
switch other.UseAs().(type) {
case AstString:
// The other side is a rego string literal like
// "8c0b9bdc-a013-4b14-a49b-5747bc335708". Emit a comparison
// that casts the literal to uuid so PG can use indexes:
// column = 'val'::uuid
// instead of the text-based:
// 'val' = COALESCE(column::text, '')
s, ok := other.(AstString)
if !ok {
return "", xerrors.Errorf("expected AstString, got %T", other)
}
if s.Value == "" {
// Empty string in rego means "no value". Compare the
// column against NULL since UUID columns represent
// absent values as NULL, not empty strings.
op := "IS NULL"
if not {
op = "IS NOT NULL"
}
return fmt.Sprintf("%s %s", u.ColumnString, op), nil
}
return fmt.Sprintf("%s %s '%s'::uuid",
u.ColumnString, equalsOp(not), s.Value), nil
case astUUIDVar:
return basicSQLEquality(cfg, not, u, other), nil
default:
return "", xerrors.Errorf("unsupported equality: %T %s %T",
u, equalsOp(not), other)
}
}
// ContainedInSQL implements SupportsContainedIn so that a UUID column
// can appear in membership checks like `col = ANY(ARRAY[...])`. The
// array elements are rego strings, so we cast each to ::uuid.
func (u astUUIDVar) ContainedInSQL(_ *SQLGenerator, haystack Node) (string, error) {
arr, ok := haystack.(ASTArray)
if !ok {
return "", xerrors.Errorf("unsupported containedIn: %T in %T", u, haystack)
}
if len(arr.Value) == 0 {
return "false", nil
}
// Build ARRAY['uuid1'::uuid, 'uuid2'::uuid, ...]
values := make([]string, 0, len(arr.Value))
for _, v := range arr.Value {
s, ok := v.(AstString)
if !ok {
return "", xerrors.Errorf("expected AstString array element, got %T", v)
}
values = append(values, fmt.Sprintf("'%s'::uuid", s.Value))
}
return fmt.Sprintf("%s = ANY(ARRAY [%s])",
u.ColumnString,
strings.Join(values, ",")), nil
}
+18 -38
View File
@@ -244,7 +244,6 @@ func SystemRoleName(name string) bool {
type RoleOptions struct {
NoOwnerWorkspaceExec bool
NoWorkspaceSharing bool
}
// ReservedRoleName exists because the database should only allow unique role
@@ -268,23 +267,12 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
opts = &RoleOptions{}
}
denyPermissions := []Permission{}
if opts.NoWorkspaceSharing {
denyPermissions = append(denyPermissions, Permission{
Negate: true,
ResourceType: ResourceWorkspace.Type,
Action: policy.ActionShare,
})
}
ownerWorkspaceActions := ResourceWorkspace.AvailableActions()
if opts.NoOwnerWorkspaceExec {
// Remove ssh and application connect from the owner role. This
// prevents owners from have exec access to all workspaces.
ownerWorkspaceActions = slice.Omit(
ownerWorkspaceActions,
policy.ActionApplicationConnect, policy.ActionSSH,
)
ownerWorkspaceActions = slice.Omit(ownerWorkspaceActions,
policy.ActionApplicationConnect, policy.ActionSSH)
}
// Static roles that never change should be allocated in a closure.
@@ -307,8 +295,7 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
// Explicitly setting PrebuiltWorkspace permissions for clarity.
// Note: even without PrebuiltWorkspace permissions, access is still granted via Workspace permissions.
ResourcePrebuiltWorkspace.Type: {policy.ActionUpdate, policy.ActionDelete},
})...,
),
})...),
User: []Permission{},
ByOrgID: map[string]OrgPermissions{},
}.withCachedRegoValue()
@@ -316,17 +303,13 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
memberRole := Role{
Identifier: RoleMember(),
DisplayName: "Member",
Site: append(
Permissions(map[string][]policy.Action{
ResourceAssignRole.Type: {policy.ActionRead},
// All users can see OAuth2 provider applications.
ResourceOauth2App.Type: {policy.ActionRead},
ResourceWorkspaceProxy.Type: {policy.ActionRead},
}),
denyPermissions...,
),
User: append(
allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUser, ResourceOrganizationMember, ResourceOrganizationMember, ResourceBoundaryUsage),
Site: Permissions(map[string][]policy.Action{
ResourceAssignRole.Type: {policy.ActionRead},
// All users can see OAuth2 provider applications.
ResourceOauth2App.Type: {policy.ActionRead},
ResourceWorkspaceProxy.Type: {policy.ActionRead},
}),
User: append(allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUser, ResourceOrganizationMember, ResourceOrganizationMember, ResourceBoundaryUsage),
Permissions(map[string][]policy.Action{
// Users cannot do create/update/delete on themselves, but they
// can read their own details.
@@ -450,17 +433,14 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
ByOrgID: map[string]OrgPermissions{
// Org admins should not have workspace exec perms.
organizationID.String(): {
Org: append(
allPermsExcept(ResourceWorkspace, ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceAssignRole, ResourceUserSecret, ResourceBoundaryUsage),
Permissions(map[string][]policy.Action{
ResourceWorkspace.Type: slice.Omit(ResourceWorkspace.AvailableActions(), policy.ActionApplicationConnect, policy.ActionSSH),
ResourceWorkspaceDormant.Type: {policy.ActionRead, policy.ActionDelete, policy.ActionCreate, policy.ActionUpdate, policy.ActionWorkspaceStop, policy.ActionCreateAgent, policy.ActionDeleteAgent, policy.ActionUpdateAgent},
// PrebuiltWorkspaces are a subset of Workspaces.
// Explicitly setting PrebuiltWorkspace permissions for clarity.
// Note: even without PrebuiltWorkspace permissions, access is still granted via Workspace permissions.
ResourcePrebuiltWorkspace.Type: {policy.ActionUpdate, policy.ActionDelete},
})...,
),
Org: append(allPermsExcept(ResourceWorkspace, ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceAssignRole, ResourceUserSecret, ResourceBoundaryUsage), Permissions(map[string][]policy.Action{
ResourceWorkspaceDormant.Type: {policy.ActionRead, policy.ActionDelete, policy.ActionCreate, policy.ActionUpdate, policy.ActionWorkspaceStop, policy.ActionCreateAgent, policy.ActionDeleteAgent, policy.ActionUpdateAgent},
ResourceWorkspace.Type: slice.Omit(ResourceWorkspace.AvailableActions(), policy.ActionApplicationConnect, policy.ActionSSH),
// PrebuiltWorkspaces are a subset of Workspaces.
// Explicitly setting PrebuiltWorkspace permissions for clarity.
// Note: even without PrebuiltWorkspace permissions, access is still granted via Workspace permissions.
ResourcePrebuiltWorkspace.Type: {policy.ActionUpdate, policy.ActionDelete},
})...),
Member: []Permission{},
},
},

Some files were not shown because too many files have changed in this diff Show More