Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9764926f92 | |||
| 10d4e42fc1 | |||
| 217ddf46c4 | |||
| 0d3d493eae | |||
| 89b060e245 | |||
| 820d53b66a | |||
| f550028052 | |||
| e6873c8d61 | |||
| 8c0bfcb570 | |||
| c322b92ab0 | |||
| 216a5ac562 | |||
| 86447126d5 | |||
| 55c5b707fb | |||
| 4616c82f3c | |||
| 9ca30e28d6 | |||
| 34c1370090 | |||
| 851c4f907c | |||
| e3dfe45f35 |
@@ -4,7 +4,7 @@ description: |
|
||||
inputs:
|
||||
version:
|
||||
description: "The Go version to use."
|
||||
default: "1.24.10"
|
||||
default: "1.24.6"
|
||||
use-preinstalled-go:
|
||||
description: "Whether to use preinstalled Go."
|
||||
default: "false"
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
app = "sao-paulo-coder"
|
||||
primary_region = "gru"
|
||||
|
||||
[experimental]
|
||||
entrypoint = ["/bin/sh", "-c", "CODER_DERP_SERVER_RELAY_URL=\"http://[${FLY_PRIVATE_IP}]:3000\" /opt/coder wsproxy server"]
|
||||
auto_rollback = true
|
||||
|
||||
[build]
|
||||
image = "ghcr.io/coder/coder-preview:main"
|
||||
|
||||
[env]
|
||||
CODER_ACCESS_URL = "https://sao-paulo.fly.dev.coder.com"
|
||||
CODER_HTTP_ADDRESS = "0.0.0.0:3000"
|
||||
CODER_PRIMARY_ACCESS_URL = "https://dev.coder.com"
|
||||
CODER_WILDCARD_ACCESS_URL = "*--apps.sao-paulo.fly.dev.coder.com"
|
||||
CODER_VERBOSE = "true"
|
||||
|
||||
[http_service]
|
||||
internal_port = 3000
|
||||
force_https = true
|
||||
auto_stop_machines = true
|
||||
auto_start_machines = true
|
||||
min_machines_running = 0
|
||||
|
||||
# Ref: https://fly.io/docs/reference/configuration/#http_service-concurrency
|
||||
[http_service.concurrency]
|
||||
type = "requests"
|
||||
soft_limit = 50
|
||||
hard_limit = 100
|
||||
|
||||
[[vm]]
|
||||
cpu_kind = "shared"
|
||||
cpus = 2
|
||||
memory_mb = 512
|
||||
@@ -376,6 +376,13 @@ jobs:
|
||||
id: go-paths
|
||||
uses: ./.github/actions/setup-go-paths
|
||||
|
||||
- name: Download Go Build Cache
|
||||
id: download-go-build-cache
|
||||
uses: ./.github/actions/test-cache/download
|
||||
with:
|
||||
key-prefix: test-go-build-${{ runner.os }}-${{ runner.arch }}
|
||||
cache-path: ${{ steps.go-paths.outputs.cached-dirs }}
|
||||
|
||||
- name: Setup Go
|
||||
uses: ./.github/actions/setup-go
|
||||
with:
|
||||
@@ -383,7 +390,8 @@ jobs:
|
||||
# download the toolchain configured in go.mod, so we don't
|
||||
# need to reinstall it. It's faster on Windows runners.
|
||||
use-preinstalled-go: ${{ runner.os == 'Windows' }}
|
||||
use-cache: true
|
||||
# Cache is already downloaded above
|
||||
use-cache: false
|
||||
|
||||
- name: Setup Terraform
|
||||
uses: ./.github/actions/setup-tf
|
||||
@@ -492,11 +500,17 @@ jobs:
|
||||
make test
|
||||
|
||||
- name: Upload failed test db dumps
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: failed-test-db-dump-${{matrix.os}}
|
||||
path: "**/*.test.sql"
|
||||
|
||||
- name: Upload Go Build Cache
|
||||
uses: ./.github/actions/test-cache/upload
|
||||
with:
|
||||
cache-key: ${{ steps.download-go-build-cache.outputs.cache-key }}
|
||||
cache-path: ${{ steps.go-paths.outputs.cached-dirs }}
|
||||
|
||||
- name: Upload Test Cache
|
||||
uses: ./.github/actions/test-cache/upload
|
||||
with:
|
||||
@@ -748,7 +762,7 @@ jobs:
|
||||
|
||||
- name: Upload Playwright Failed Tests
|
||||
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: failed-test-videos${{ matrix.variant.premium && '-premium' || '' }}
|
||||
path: ./site/test-results/**/*.webm
|
||||
@@ -756,7 +770,7 @@ jobs:
|
||||
|
||||
- name: Upload pprof dumps
|
||||
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: debug-pprof-dumps${{ matrix.variant.premium && '-premium' || '' }}
|
||||
path: ./site/test-results/**/debug-pprof-*.txt
|
||||
@@ -792,7 +806,7 @@ jobs:
|
||||
# the check to pass. This is desired in PRs, but not in mainline.
|
||||
- name: Publish to Chromatic (non-mainline)
|
||||
if: github.ref != 'refs/heads/main' && github.repository_owner == 'coder'
|
||||
uses: chromaui/action@bc2d84ad2b60813a67d995c5582d696104a19383 # v13.3.2
|
||||
uses: chromaui/action@4ffe736a2a8262ea28067ff05a13b635ba31ec05 # v13.3.0
|
||||
env:
|
||||
NODE_OPTIONS: "--max_old_space_size=4096"
|
||||
STORYBOOK: true
|
||||
@@ -824,7 +838,7 @@ jobs:
|
||||
# infinitely "in progress" in mainline unless we re-review each build.
|
||||
- name: Publish to Chromatic (mainline)
|
||||
if: github.ref == 'refs/heads/main' && github.repository_owner == 'coder'
|
||||
uses: chromaui/action@bc2d84ad2b60813a67d995c5582d696104a19383 # v13.3.2
|
||||
uses: chromaui/action@4ffe736a2a8262ea28067ff05a13b635ba31ec05 # v13.3.0
|
||||
env:
|
||||
NODE_OPTIONS: "--max_old_space_size=4096"
|
||||
STORYBOOK: true
|
||||
@@ -1022,7 +1036,7 @@ jobs:
|
||||
|
||||
- name: Upload build artifacts
|
||||
if: ${{ github.repository_owner == 'coder' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: dylibs
|
||||
path: |
|
||||
@@ -1187,7 +1201,7 @@ jobs:
|
||||
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
|
||||
|
||||
- name: Download dylibs
|
||||
uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0
|
||||
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
|
||||
with:
|
||||
name: dylibs
|
||||
path: ./build
|
||||
@@ -1454,7 +1468,7 @@ jobs:
|
||||
|
||||
- name: Upload build artifacts
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: coder
|
||||
path: |
|
||||
|
||||
@@ -163,10 +163,12 @@ jobs:
|
||||
run: |
|
||||
flyctl deploy --image "$IMAGE" --app paris-coder --config ./.github/fly-wsproxies/paris-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_PARIS" --yes
|
||||
flyctl deploy --image "$IMAGE" --app sydney-coder --config ./.github/fly-wsproxies/sydney-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_SYDNEY" --yes
|
||||
flyctl deploy --image "$IMAGE" --app sao-paulo-coder --config ./.github/fly-wsproxies/sao-paulo-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_SAO_PAULO" --yes
|
||||
flyctl deploy --image "$IMAGE" --app jnb-coder --config ./.github/fly-wsproxies/jnb-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_JNB" --yes
|
||||
env:
|
||||
FLY_API_TOKEN: ${{ secrets.FLY_API_TOKEN }}
|
||||
IMAGE: ${{ inputs.image }}
|
||||
TOKEN_PARIS: ${{ secrets.FLY_PARIS_CODER_PROXY_SESSION_TOKEN }}
|
||||
TOKEN_SYDNEY: ${{ secrets.FLY_SYDNEY_CODER_PROXY_SESSION_TOKEN }}
|
||||
TOKEN_SAO_PAULO: ${{ secrets.FLY_SAO_PAULO_CODER_PROXY_SESSION_TOKEN }}
|
||||
TOKEN_JNB: ${{ secrets.FLY_JNB_CODER_PROXY_SESSION_TOKEN }}
|
||||
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
- name: Setup Node
|
||||
uses: ./.github/actions/setup-node
|
||||
|
||||
- uses: tj-actions/changed-files@dbf178ceecb9304128c8e0648591d71208c6e2c9 # v45.0.7
|
||||
- uses: tj-actions/changed-files@d03a93c0dbfac6d6dd6a0d8a5e7daff992b07449 # v45.0.7
|
||||
id: changed-files
|
||||
with:
|
||||
files: |
|
||||
|
||||
@@ -36,11 +36,11 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Nix
|
||||
uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34
|
||||
uses: nixbuild/nix-quick-install-action@1f095fee853b33114486cfdeae62fa099cda35a9 # v33
|
||||
with:
|
||||
# Pinning to 2.28 here, as Nix gets a "error: [json.exception.type_error.302] type must be array, but is string"
|
||||
# on version 2.29 and above.
|
||||
nix_version: "2.28.5"
|
||||
nix_version: "2.28.4"
|
||||
|
||||
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
|
||||
with:
|
||||
|
||||
@@ -131,7 +131,7 @@ jobs:
|
||||
AC_CERTIFICATE_PASSWORD_FILE: /tmp/apple_cert_password.txt
|
||||
|
||||
- name: Upload build artifacts
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: dylibs
|
||||
path: |
|
||||
@@ -327,7 +327,7 @@ jobs:
|
||||
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
|
||||
|
||||
- name: Download dylibs
|
||||
uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0
|
||||
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
|
||||
with:
|
||||
name: dylibs
|
||||
path: ./build
|
||||
@@ -761,7 +761,7 @@ jobs:
|
||||
|
||||
- name: Upload artifacts to actions (if dry-run)
|
||||
if: ${{ inputs.dry_run }}
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: release-artifacts
|
||||
path: |
|
||||
@@ -777,7 +777,7 @@ jobs:
|
||||
|
||||
- name: Upload latest sbom artifact to actions (if dry-run)
|
||||
if: inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true'
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: latest-sbom-artifact
|
||||
path: ./coder_latest_sbom.spdx.json
|
||||
|
||||
@@ -39,7 +39,7 @@ jobs:
|
||||
|
||||
# Upload the results as artifacts.
|
||||
- name: "Upload artifact"
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: SARIF file
|
||||
path: results.sarif
|
||||
@@ -47,6 +47,6 @@ jobs:
|
||||
|
||||
# Upload the results to GitHub's code scanning dashboard.
|
||||
- name: "Upload to code-scanning"
|
||||
uses: github/codeql-action/upload-sarif@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
|
||||
uses: github/codeql-action/upload-sarif@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-go
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
|
||||
uses: github/codeql-action/init@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
|
||||
with:
|
||||
languages: go, javascript
|
||||
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
rm Makefile
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
|
||||
uses: github/codeql-action/analyze@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
|
||||
|
||||
- name: Send Slack notification on failure
|
||||
if: ${{ failure() }}
|
||||
@@ -154,13 +154,13 @@ jobs:
|
||||
severity: "CRITICAL,HIGH"
|
||||
|
||||
- name: Upload Trivy scan results to GitHub Security tab
|
||||
uses: github/codeql-action/upload-sarif@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
|
||||
uses: github/codeql-action/upload-sarif@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
|
||||
with:
|
||||
sarif_file: trivy-results.sarif
|
||||
category: "Trivy"
|
||||
|
||||
- name: Upload Trivy scan results as an artifact
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: trivy
|
||||
path: trivy-results.sarif
|
||||
|
||||
@@ -125,7 +125,7 @@ jobs:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Delete PR Cleanup workflow runs
|
||||
uses: Mattraks/delete-workflow-runs@ab482449ba468316e9a8801e092d0405715c5e6d # v2.1.0
|
||||
uses: Mattraks/delete-workflow-runs@39f0bbed25d76b34de5594dceab824811479e5de # v2.0.6
|
||||
with:
|
||||
token: ${{ github.token }}
|
||||
repository: ${{ github.repository }}
|
||||
@@ -134,7 +134,7 @@ jobs:
|
||||
delete_workflow_pattern: pr-cleanup.yaml
|
||||
|
||||
- name: Delete PR Deploy workflow skipped runs
|
||||
uses: Mattraks/delete-workflow-runs@ab482449ba468316e9a8801e092d0405715c5e6d # v2.1.0
|
||||
uses: Mattraks/delete-workflow-runs@39f0bbed25d76b34de5594dceab824811479e5de # v2.0.6
|
||||
with:
|
||||
token: ${{ github.token }}
|
||||
repository: ${{ github.repository }}
|
||||
|
||||
@@ -89,5 +89,3 @@ result
|
||||
__debug_bin*
|
||||
|
||||
**/.claude/settings.local.json
|
||||
|
||||
/.env
|
||||
|
||||
+12
@@ -18,6 +18,18 @@ coderd/rbac/ @Emyrk
|
||||
scripts/apitypings/ @Emyrk
|
||||
scripts/gensite/ @aslilac
|
||||
|
||||
site/ @aslilac @Parkreiner
|
||||
site/src/hooks/ @Parkreiner
|
||||
# These rules intentionally do not specify any owners. More specific rules
|
||||
# override less specific rules, so these files are "ignored" by the site/ rule.
|
||||
site/e2e/google/protobuf/timestampGenerated.ts
|
||||
site/e2e/provisionerGenerated.ts
|
||||
site/src/api/countriesGenerated.ts
|
||||
site/src/api/rbacresourcesGenerated.ts
|
||||
site/src/api/typesGenerated.ts
|
||||
site/src/testHelpers/entities.ts
|
||||
site/CLAUDE.md
|
||||
|
||||
# The blood and guts of the autostop algorithm, which is quite complex and
|
||||
# requires elite ball knowledge of most of the scheduling code to make changes
|
||||
# without inadvertently affecting other parts of the codebase.
|
||||
|
||||
@@ -636,16 +636,17 @@ TAILNETTEST_MOCKS := \
|
||||
tailnet/tailnettest/subscriptionmock.go
|
||||
|
||||
AIBRIDGED_MOCKS := \
|
||||
enterprise/aibridged/aibridgedmock/clientmock.go \
|
||||
enterprise/aibridged/aibridgedmock/poolmock.go
|
||||
enterprise/x/aibridged/aibridgedmock/clientmock.go \
|
||||
enterprise/x/aibridged/aibridgedmock/poolmock.go
|
||||
|
||||
GEN_FILES := \
|
||||
tailnet/proto/tailnet.pb.go \
|
||||
agent/proto/agent.pb.go \
|
||||
agent/agentsocket/proto/agentsocket.pb.go \
|
||||
provisionersdk/proto/provisioner.pb.go \
|
||||
provisionerd/proto/provisionerd.pb.go \
|
||||
vpn/vpn.pb.go \
|
||||
enterprise/aibridged/proto/aibridged.pb.go \
|
||||
enterprise/x/aibridged/proto/aibridged.pb.go \
|
||||
$(DB_GEN_FILES) \
|
||||
$(SITE_GEN_FILES) \
|
||||
coderd/rbac/object_gen.go \
|
||||
@@ -697,7 +698,7 @@ gen/mark-fresh:
|
||||
provisionersdk/proto/provisioner.pb.go \
|
||||
provisionerd/proto/provisionerd.pb.go \
|
||||
vpn/vpn.pb.go \
|
||||
enterprise/aibridged/proto/aibridged.pb.go \
|
||||
enterprise/x/aibridged/proto/aibridged.pb.go \
|
||||
coderd/database/dump.sql \
|
||||
$(DB_GEN_FILES) \
|
||||
site/src/api/typesGenerated.ts \
|
||||
@@ -768,8 +769,8 @@ codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agen
|
||||
go generate ./codersdk/workspacesdk/agentconnmock/
|
||||
touch "$@"
|
||||
|
||||
$(AIBRIDGED_MOCKS): enterprise/aibridged/client.go enterprise/aibridged/pool.go
|
||||
go generate ./enterprise/aibridged/aibridgedmock/
|
||||
$(AIBRIDGED_MOCKS): enterprise/x/aibridged/client.go enterprise/x/aibridged/pool.go
|
||||
go generate ./enterprise/x/aibridged/aibridgedmock/
|
||||
touch "$@"
|
||||
|
||||
agent/agentcontainers/dcspec/dcspec_gen.go: \
|
||||
@@ -800,6 +801,14 @@ agent/proto/agent.pb.go: agent/proto/agent.proto
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./agent/proto/agent.proto
|
||||
|
||||
agent/agentsocket/proto/agentsocket.pb.go: agent/agentsocket/proto/agentsocket.proto
|
||||
protoc \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./agent/agentsocket/proto/agentsocket.proto
|
||||
|
||||
provisionersdk/proto/provisioner.pb.go: provisionersdk/proto/provisioner.proto
|
||||
protoc \
|
||||
--go_out=. \
|
||||
@@ -822,13 +831,13 @@ vpn/vpn.pb.go: vpn/vpn.proto
|
||||
--go_opt=paths=source_relative \
|
||||
./vpn/vpn.proto
|
||||
|
||||
enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged.proto
|
||||
enterprise/x/aibridged/proto/aibridged.pb.go: enterprise/x/aibridged/proto/aibridged.proto
|
||||
protoc \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./enterprise/aibridged/proto/aibridged.proto
|
||||
./enterprise/x/aibridged/proto/aibridged.proto
|
||||
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
|
||||
# -C sets the directory for the go run command
|
||||
@@ -1182,8 +1191,3 @@ endif
|
||||
|
||||
dogfood/coder/nix.hash: flake.nix flake.lock
|
||||
sha256sum flake.nix flake.lock >./dogfood/coder/nix.hash
|
||||
|
||||
# Count the number of test databases created per test package.
|
||||
count-test-databases:
|
||||
PGPASSWORD=postgres psql -h localhost -U postgres -d coder_testing -P pager=off -c 'SELECT test_package, count(*) as count from test_databases GROUP BY test_package ORDER BY count DESC'
|
||||
.PHONY: count-test-databases
|
||||
|
||||
+40
-1
@@ -40,6 +40,7 @@ import (
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentscripts"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/agent/proto/resourcesmonitor"
|
||||
@@ -91,6 +92,7 @@ type Options struct {
|
||||
Devcontainers bool
|
||||
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
|
||||
Clock quartz.Clock
|
||||
SocketPath string // Path for the agent socket server
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
@@ -190,6 +192,7 @@ func New(options Options) Agent {
|
||||
|
||||
devcontainers: options.Devcontainers,
|
||||
containerAPIOptions: options.DevcontainerAPIOptions,
|
||||
socketPath: options.SocketPath,
|
||||
}
|
||||
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
|
||||
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
|
||||
@@ -271,6 +274,9 @@ type agent struct {
|
||||
devcontainers bool
|
||||
containerAPIOptions []agentcontainers.Option
|
||||
containerAPI *agentcontainers.API
|
||||
|
||||
socketPath string
|
||||
socketServer *agentsocket.Server
|
||||
}
|
||||
|
||||
func (a *agent) TailnetConn() *tailnet.Conn {
|
||||
@@ -350,9 +356,35 @@ func (a *agent) init() {
|
||||
s.ExperimentalContainers = a.devcontainers
|
||||
},
|
||||
)
|
||||
|
||||
a.initSocketServer()
|
||||
|
||||
go a.runLoop()
|
||||
}
|
||||
|
||||
// initSocketServer initializes server that allows direct communication with a workspace agent using IPC.
|
||||
func (a *agent) initSocketServer() {
|
||||
if a.socketPath == "" {
|
||||
a.logger.Info(a.hardCtx, "socket server disabled (no path configured)")
|
||||
return
|
||||
}
|
||||
|
||||
server, err := agentsocket.NewServer(a.socketPath, a.logger.Named("socket"))
|
||||
if err != nil {
|
||||
a.logger.Warn(a.hardCtx, "failed to create socket server", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
a.logger.Warn(a.hardCtx, "failed to start socket server", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
a.socketServer = server
|
||||
a.logger.Debug(a.hardCtx, "socket server started", slog.F("path", a.socketPath))
|
||||
}
|
||||
|
||||
// runLoop attempts to start the agent in a retry loop.
|
||||
// Coder may be offline temporarily, a connection issue
|
||||
// may be happening, but regardless after the intermittent
|
||||
@@ -1087,7 +1119,7 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context,
|
||||
if err != nil {
|
||||
return xerrors.Errorf("fetch metadata: %w", err)
|
||||
}
|
||||
a.logger.Info(ctx, "fetched manifest")
|
||||
a.logger.Info(ctx, "fetched manifest", slog.F("manifest", mp))
|
||||
manifest, err := agentsdk.ManifestFromProto(mp)
|
||||
if err != nil {
|
||||
a.logger.Critical(ctx, "failed to convert manifest", slog.F("manifest", mp), slog.Error(err))
|
||||
@@ -1920,6 +1952,13 @@ func (a *agent) Close() error {
|
||||
lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownError
|
||||
}
|
||||
}
|
||||
|
||||
if a.socketServer != nil {
|
||||
if err := a.socketServer.Stop(); err != nil {
|
||||
a.logger.Error(a.hardCtx, "socket server close", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
a.setLifecycle(lifecycleState)
|
||||
|
||||
err = a.scriptRunner.Close()
|
||||
|
||||
@@ -682,6 +682,8 @@ func (api *API) updaterLoop() {
|
||||
} else {
|
||||
prevErr = nil
|
||||
}
|
||||
default:
|
||||
api.logger.Debug(api.ctx, "updater loop ticker skipped, update in progress")
|
||||
}
|
||||
|
||||
return nil // Always nil to keep the ticker going.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,88 @@
|
||||
syntax = "proto3";
|
||||
option go_package = "github.com/coder/coder/v2/agent/agentsocket/proto";
|
||||
|
||||
package coder.agentsocket.v1;
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
|
||||
message PingRequest {}
|
||||
|
||||
message PingResponse {
|
||||
string message = 1;
|
||||
google.protobuf.Timestamp timestamp = 2;
|
||||
}
|
||||
|
||||
message SyncStartRequest {
|
||||
string unit = 1;
|
||||
}
|
||||
|
||||
message SyncStartResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message SyncWantRequest {
|
||||
string unit = 1;
|
||||
string depends_on = 2;
|
||||
}
|
||||
|
||||
message SyncWantResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message SyncCompleteRequest {
|
||||
string unit = 1;
|
||||
}
|
||||
|
||||
message SyncCompleteResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message SyncReadyRequest {
|
||||
string unit = 1;
|
||||
}
|
||||
|
||||
message SyncReadyResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message SyncStatusRequest {
|
||||
string unit = 1;
|
||||
bool recursive = 2;
|
||||
}
|
||||
|
||||
message DependencyInfo {
|
||||
string depends_on = 1;
|
||||
string required_status = 2;
|
||||
string current_status = 3;
|
||||
bool is_satisfied = 4;
|
||||
}
|
||||
|
||||
message SyncStatusResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
string unit = 3;
|
||||
string status = 4;
|
||||
bool is_ready = 5;
|
||||
repeated DependencyInfo dependencies = 6;
|
||||
string dot = 7;
|
||||
}
|
||||
|
||||
// AgentSocket provides direct access to the agent over local IPC.
|
||||
service AgentSocket {
|
||||
// Ping the agent to check if it is alive.
|
||||
rpc Ping(PingRequest) returns (PingResponse);
|
||||
// Report the start of a unit.
|
||||
rpc SyncStart(SyncStartRequest) returns (SyncStartResponse);
|
||||
// Declare a dependency between units.
|
||||
rpc SyncWant(SyncWantRequest) returns (SyncWantResponse);
|
||||
// Report the completion of a unit.
|
||||
rpc SyncComplete(SyncCompleteRequest) returns (SyncCompleteResponse);
|
||||
// Request whether a unit is ready to be started. That is, all dependencies are satisfied.
|
||||
rpc SyncReady(SyncReadyRequest) returns (SyncReadyResponse);
|
||||
// Get the status of a unit and list its dependencies.
|
||||
rpc SyncStatus(SyncStatusRequest) returns (SyncStatusResponse);
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
|
||||
// protoc-gen-go-drpc version: v0.0.34
|
||||
// source: agent/agentsocket/proto/agentsocket.proto
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
context "context"
|
||||
errors "errors"
|
||||
protojson "google.golang.org/protobuf/encoding/protojson"
|
||||
proto "google.golang.org/protobuf/proto"
|
||||
drpc "storj.io/drpc"
|
||||
drpcerr "storj.io/drpc/drpcerr"
|
||||
)
|
||||
|
||||
type drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto struct{}
|
||||
|
||||
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) Marshal(msg drpc.Message) ([]byte, error) {
|
||||
return proto.Marshal(msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) MarshalAppend(buf []byte, msg drpc.Message) ([]byte, error) {
|
||||
return proto.MarshalOptions{}.MarshalAppend(buf, msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) Unmarshal(buf []byte, msg drpc.Message) error {
|
||||
return proto.Unmarshal(buf, msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) JSONMarshal(msg drpc.Message) ([]byte, error) {
|
||||
return protojson.Marshal(msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) JSONUnmarshal(buf []byte, msg drpc.Message) error {
|
||||
return protojson.Unmarshal(buf, msg.(proto.Message))
|
||||
}
|
||||
|
||||
type DRPCAgentSocketClient interface {
|
||||
DRPCConn() drpc.Conn
|
||||
|
||||
Ping(ctx context.Context, in *PingRequest) (*PingResponse, error)
|
||||
SyncStart(ctx context.Context, in *SyncStartRequest) (*SyncStartResponse, error)
|
||||
SyncWant(ctx context.Context, in *SyncWantRequest) (*SyncWantResponse, error)
|
||||
SyncComplete(ctx context.Context, in *SyncCompleteRequest) (*SyncCompleteResponse, error)
|
||||
SyncReady(ctx context.Context, in *SyncReadyRequest) (*SyncReadyResponse, error)
|
||||
SyncStatus(ctx context.Context, in *SyncStatusRequest) (*SyncStatusResponse, error)
|
||||
}
|
||||
|
||||
type drpcAgentSocketClient struct {
|
||||
cc drpc.Conn
|
||||
}
|
||||
|
||||
func NewDRPCAgentSocketClient(cc drpc.Conn) DRPCAgentSocketClient {
|
||||
return &drpcAgentSocketClient{cc}
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) DRPCConn() drpc.Conn { return c.cc }
|
||||
|
||||
func (c *drpcAgentSocketClient) Ping(ctx context.Context, in *PingRequest) (*PingResponse, error) {
|
||||
out := new(PingResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/Ping", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) SyncStart(ctx context.Context, in *SyncStartRequest) (*SyncStartResponse, error) {
|
||||
out := new(SyncStartResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncStart", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) SyncWant(ctx context.Context, in *SyncWantRequest) (*SyncWantResponse, error) {
|
||||
out := new(SyncWantResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncWant", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) SyncComplete(ctx context.Context, in *SyncCompleteRequest) (*SyncCompleteResponse, error) {
|
||||
out := new(SyncCompleteResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncComplete", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) SyncReady(ctx context.Context, in *SyncReadyRequest) (*SyncReadyResponse, error) {
|
||||
out := new(SyncReadyResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncReady", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) SyncStatus(ctx context.Context, in *SyncStatusRequest) (*SyncStatusResponse, error) {
|
||||
out := new(SyncStatusResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncStatus", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
type DRPCAgentSocketServer interface {
|
||||
Ping(context.Context, *PingRequest) (*PingResponse, error)
|
||||
SyncStart(context.Context, *SyncStartRequest) (*SyncStartResponse, error)
|
||||
SyncWant(context.Context, *SyncWantRequest) (*SyncWantResponse, error)
|
||||
SyncComplete(context.Context, *SyncCompleteRequest) (*SyncCompleteResponse, error)
|
||||
SyncReady(context.Context, *SyncReadyRequest) (*SyncReadyResponse, error)
|
||||
SyncStatus(context.Context, *SyncStatusRequest) (*SyncStatusResponse, error)
|
||||
}
|
||||
|
||||
type DRPCAgentSocketUnimplementedServer struct{}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) Ping(context.Context, *PingRequest) (*PingResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) SyncStart(context.Context, *SyncStartRequest) (*SyncStartResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) SyncWant(context.Context, *SyncWantRequest) (*SyncWantResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) SyncComplete(context.Context, *SyncCompleteRequest) (*SyncCompleteResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) SyncReady(context.Context, *SyncReadyRequest) (*SyncReadyResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) SyncStatus(context.Context, *SyncStatusRequest) (*SyncStatusResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
type DRPCAgentSocketDescription struct{}
|
||||
|
||||
func (DRPCAgentSocketDescription) NumMethods() int { return 6 }
|
||||
|
||||
func (DRPCAgentSocketDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
|
||||
switch n {
|
||||
case 0:
|
||||
return "/coder.agentsocket.v1.AgentSocket/Ping", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
Ping(
|
||||
ctx,
|
||||
in1.(*PingRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.Ping, true
|
||||
case 1:
|
||||
return "/coder.agentsocket.v1.AgentSocket/SyncStart", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
SyncStart(
|
||||
ctx,
|
||||
in1.(*SyncStartRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.SyncStart, true
|
||||
case 2:
|
||||
return "/coder.agentsocket.v1.AgentSocket/SyncWant", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
SyncWant(
|
||||
ctx,
|
||||
in1.(*SyncWantRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.SyncWant, true
|
||||
case 3:
|
||||
return "/coder.agentsocket.v1.AgentSocket/SyncComplete", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
SyncComplete(
|
||||
ctx,
|
||||
in1.(*SyncCompleteRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.SyncComplete, true
|
||||
case 4:
|
||||
return "/coder.agentsocket.v1.AgentSocket/SyncReady", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
SyncReady(
|
||||
ctx,
|
||||
in1.(*SyncReadyRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.SyncReady, true
|
||||
case 5:
|
||||
return "/coder.agentsocket.v1.AgentSocket/SyncStatus", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
SyncStatus(
|
||||
ctx,
|
||||
in1.(*SyncStatusRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.SyncStatus, true
|
||||
default:
|
||||
return "", nil, nil, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func DRPCRegisterAgentSocket(mux drpc.Mux, impl DRPCAgentSocketServer) error {
|
||||
return mux.Register(impl, DRPCAgentSocketDescription{})
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_PingStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*PingResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_PingStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_PingStream) SendAndClose(m *PingResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_SyncStartStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*SyncStartResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_SyncStartStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_SyncStartStream) SendAndClose(m *SyncStartResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_SyncWantStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*SyncWantResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_SyncWantStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_SyncWantStream) SendAndClose(m *SyncWantResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_SyncCompleteStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*SyncCompleteResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_SyncCompleteStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_SyncCompleteStream) SendAndClose(m *SyncCompleteResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_SyncReadyStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*SyncReadyResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_SyncReadyStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_SyncReadyStream) SendAndClose(m *SyncReadyResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_SyncStatusStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*SyncStatusResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_SyncStatusStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_SyncStatusStream) SendAndClose(m *SyncStatusResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package proto
|
||||
|
||||
import "github.com/coder/coder/v2/apiversion"
|
||||
|
||||
// Version history:
|
||||
//
|
||||
// API v1.0:
|
||||
// - Initial release
|
||||
// - Ping
|
||||
// - Sync operations: SyncStart, SyncWant, SyncComplete, SyncWait, SyncStatus
|
||||
|
||||
const (
|
||||
CurrentMajor = 1
|
||||
CurrentMinor = 0
|
||||
)
|
||||
|
||||
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor)
|
||||
@@ -0,0 +1,185 @@
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
"storj.io/drpc/drpcmux"
|
||||
"storj.io/drpc/drpcserver"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/agentsocket/proto"
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
)
|
||||
|
||||
// Server provides access to the DRPCAgentSocketService via a Unix domain socket.
|
||||
// Do not invoke Server{} directly. Use NewServer() instead.
|
||||
type Server struct {
|
||||
logger slog.Logger
|
||||
path string
|
||||
listener net.Listener
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
drpcServer *drpcserver.Server
|
||||
service *DRPCAgentSocketService
|
||||
}
|
||||
|
||||
func NewServer(path string, logger slog.Logger) (*Server, error) {
|
||||
logger = logger.Named("agentsocket")
|
||||
server := &Server{
|
||||
logger: logger,
|
||||
path: path,
|
||||
service: &DRPCAgentSocketService{
|
||||
logger: logger,
|
||||
unitManager: unit.NewManager[string, string](),
|
||||
},
|
||||
}
|
||||
|
||||
mux := drpcmux.New()
|
||||
err := proto.DRPCRegisterAgentSocket(mux, server.service)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to register drpc service: %w", err)
|
||||
}
|
||||
|
||||
server.drpcServer = drpcserver.NewWithOptions(mux, drpcserver.Options{
|
||||
Manager: drpcsdk.DefaultDRPCOptions(nil),
|
||||
Log: func(err error) {
|
||||
if errors.Is(err, context.Canceled) ||
|
||||
errors.Is(err, context.DeadlineExceeded) {
|
||||
return
|
||||
}
|
||||
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
|
||||
},
|
||||
})
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
var ErrServerAlreadyStarted = xerrors.New("server already started")
|
||||
|
||||
func (s *Server) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listener != nil {
|
||||
return ErrServerAlreadyStarted
|
||||
}
|
||||
|
||||
// This context is canceled by s.Stop() when the server is stopped.
|
||||
// canceling it will close all connections.
|
||||
s.ctx, s.cancel = context.WithCancel(context.Background())
|
||||
|
||||
if s.path == "" {
|
||||
var err error
|
||||
s.path, err = getDefaultSocketPath()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get default socket path: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
listener, err := createSocket(s.path)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create socket: %w", err)
|
||||
}
|
||||
|
||||
s.listener = listener
|
||||
|
||||
s.logger.Info(s.ctx, "agent socket server started", slog.F("path", s.path))
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.acceptConnections()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listener == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.logger.Info(s.ctx, "stopping agent socket server")
|
||||
|
||||
s.cancel()
|
||||
|
||||
if err := s.listener.Close(); err != nil {
|
||||
s.logger.Warn(s.ctx, "error closing socket listener", slog.Error(err))
|
||||
}
|
||||
|
||||
// Wait for all connections to finish
|
||||
s.wg.Wait()
|
||||
|
||||
if err := cleanupSocket(s.path); err != nil {
|
||||
s.logger.Warn(s.ctx, "error cleaning up socket file", slog.Error(err))
|
||||
}
|
||||
|
||||
s.listener = nil
|
||||
s.logger.Info(s.ctx, "agent socket server stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) acceptConnections() {
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
s.logger.Warn(s.ctx, "error accepting connection", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.handleConnection(conn)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||
s.logger.Warn(s.ctx, "failed to set connection deadline", slog.Error(err))
|
||||
}
|
||||
|
||||
s.logger.Debug(s.ctx, "new connection accepted", slog.F("remote_addr", conn.RemoteAddr()))
|
||||
|
||||
config := yamux.DefaultConfig()
|
||||
config.Logger = nil
|
||||
session, err := yamux.Server(conn, config)
|
||||
if err != nil {
|
||||
s.logger.Warn(s.ctx, "failed to create yamux session", slog.Error(err))
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
err = s.drpcServer.Serve(s.ctx, session)
|
||||
if err != nil {
|
||||
s.logger.Debug(s.ctx, "drpc server finished", slog.Error(err))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package agentsocket_test
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
)
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("StartStop", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
logger := slog.Make().Leveled(slog.LevelDebug)
|
||||
server, err := agentsocket.NewServer(socketPath, logger)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, server.Start())
|
||||
require.NoError(t, server.Stop())
|
||||
})
|
||||
|
||||
t.Run("AlreadyStarted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
logger := slog.Make().Leveled(slog.LevelDebug)
|
||||
server, err := agentsocket.NewServer(socketPath, logger)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, server.Start())
|
||||
require.ErrorIs(t, server.Start(), agentsocket.ErrServerAlreadyStarted)
|
||||
})
|
||||
|
||||
t.Run("AutoSocketPath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
logger := slog.Make().Leveled(slog.LevelDebug)
|
||||
server, err := agentsocket.NewServer(socketPath, logger)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, server.Start())
|
||||
require.NoError(t, server.Stop())
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,262 @@
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/agentsocket/proto"
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
)
|
||||
|
||||
var _ proto.DRPCAgentSocketServer = (*DRPCAgentSocketService)(nil)
|
||||
|
||||
type DRPCAgentSocketService struct {
|
||||
mu sync.RWMutex
|
||||
unitManager *unit.Manager[string, string]
|
||||
logger slog.Logger
|
||||
}
|
||||
|
||||
func (*DRPCAgentSocketService) Ping(_ context.Context, _ *proto.PingRequest) (*proto.PingResponse, error) {
|
||||
return &proto.PingResponse{
|
||||
Message: "pong",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketService) SyncStart(_ context.Context, req *proto.SyncStartRequest) (*proto.SyncStartResponse, error) {
|
||||
if s.unitManager == nil {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "dependency tracker not available",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if req.Unit == "" {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "Unit name is required",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := s.unitManager.Register(req.Unit); err != nil {
|
||||
// If already registered, that's okay - we can still update status
|
||||
if !errors.Is(err, unit.ErrConsumerAlreadyRegistered) {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "Failed to register unit: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
isReady, err := s.unitManager.IsReady(req.Unit)
|
||||
if err != nil {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "Failed to check readiness: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
if !isReady {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "Unit is not ready",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := s.unitManager.UpdateStatus(req.Unit, unit.StatusStarted); err != nil {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "Failed to update status: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &proto.SyncStartResponse{
|
||||
Success: true,
|
||||
Message: "Unit " + req.Unit + " started successfully",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketService) SyncWant(_ context.Context, req *proto.SyncWantRequest) (*proto.SyncWantResponse, error) {
|
||||
if s.unitManager == nil {
|
||||
return &proto.SyncWantResponse{
|
||||
Success: false,
|
||||
Message: "unit manager not available",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if req.Unit == "" || req.DependsOn == "" {
|
||||
return &proto.SyncWantResponse{
|
||||
Success: false,
|
||||
Message: "unit and depends_on are required",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := s.unitManager.Register(req.Unit); err != nil {
|
||||
if !errors.Is(err, unit.ErrConsumerAlreadyRegistered) {
|
||||
return &proto.SyncWantResponse{
|
||||
Success: false,
|
||||
Message: "failed to register unit: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.unitManager.Register(req.DependsOn); err != nil {
|
||||
if !errors.Is(err, unit.ErrConsumerAlreadyRegistered) {
|
||||
return &proto.SyncWantResponse{
|
||||
Success: false,
|
||||
Message: "failed to register dependency unit: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.unitManager.AddDependency(req.Unit, req.DependsOn, unit.StatusComplete); err != nil {
|
||||
return &proto.SyncWantResponse{
|
||||
Success: false,
|
||||
Message: "failed to add dependency: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &proto.SyncWantResponse{
|
||||
Success: true,
|
||||
Message: "Unit " + req.Unit + " now depends on " + req.DependsOn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketService) SyncComplete(_ context.Context, req *proto.SyncCompleteRequest) (*proto.SyncCompleteResponse, error) {
|
||||
if s.unitManager == nil {
|
||||
return &proto.SyncCompleteResponse{
|
||||
Success: false,
|
||||
Message: "unit manager not available",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if req.Unit == "" {
|
||||
return &proto.SyncCompleteResponse{
|
||||
Success: false,
|
||||
Message: "unit name is required",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := s.unitManager.UpdateStatus(req.Unit, unit.StatusComplete); err != nil {
|
||||
return &proto.SyncCompleteResponse{
|
||||
Success: false,
|
||||
Message: "failed to update status: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &proto.SyncCompleteResponse{
|
||||
Success: true,
|
||||
Message: "unit " + req.Unit + " completed successfully",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketService) SyncReady(_ context.Context, req *proto.SyncReadyRequest) (*proto.SyncReadyResponse, error) {
|
||||
if s.unitManager == nil {
|
||||
return &proto.SyncReadyResponse{
|
||||
Success: false,
|
||||
Message: "unit manager not available",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if req.Unit == "" {
|
||||
return &proto.SyncReadyResponse{
|
||||
Success: false,
|
||||
Message: "unit name is required",
|
||||
}, nil
|
||||
}
|
||||
|
||||
isReady, err := s.unitManager.IsReady(req.Unit)
|
||||
if err != nil {
|
||||
return &proto.SyncReadyResponse{
|
||||
Success: false,
|
||||
Message: "failed to check readiness: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if !isReady {
|
||||
return &proto.SyncReadyResponse{
|
||||
Success: false,
|
||||
Message: unit.ErrDependenciesNotSatisfied.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &proto.SyncReadyResponse{
|
||||
Success: true,
|
||||
Message: "unit " + req.Unit + " dependencies are satisfied",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketService) SyncStatus(_ context.Context, req *proto.SyncStatusRequest) (*proto.SyncStatusResponse, error) {
|
||||
if s.unitManager == nil {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "unit manager not available",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if req.Unit == "" {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "unit name is required",
|
||||
}, nil
|
||||
}
|
||||
|
||||
status, err := s.unitManager.GetStatus(req.Unit)
|
||||
if err != nil {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "failed to get unit status: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
isReady, err := s.unitManager.IsReady(req.Unit)
|
||||
if err != nil {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "failed to check readiness: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
dependencies, err := s.unitManager.GetAllDependencies(req.Unit)
|
||||
if err != nil {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "failed to get dependencies: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
var depInfos []*proto.DependencyInfo
|
||||
for _, dep := range dependencies {
|
||||
depInfos = append(depInfos, &proto.DependencyInfo{
|
||||
DependsOn: dep.DependsOn,
|
||||
RequiredStatus: dep.RequiredStatus,
|
||||
CurrentStatus: dep.CurrentStatus,
|
||||
IsSatisfied: dep.IsSatisfied,
|
||||
})
|
||||
}
|
||||
|
||||
var dotStr string
|
||||
if req.Recursive {
|
||||
dotStr, err = s.unitManager.ExportDOT("dependency_graph")
|
||||
if err != nil {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "failed to export DOT: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: true,
|
||||
Message: "unit status retrieved successfully",
|
||||
Unit: req.Unit,
|
||||
Status: status,
|
||||
IsReady: isReady,
|
||||
Dependencies: depInfos,
|
||||
Dot: dotStr,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
package agentsocket_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Ping", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
response, err := client.Ping(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "pong", response.Message)
|
||||
})
|
||||
|
||||
t.Run("SyncStart", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NewUnit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
})
|
||||
|
||||
t.Run("UnitAlreadyStarted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
// First Start
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
|
||||
// Second Start
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.ErrorContains(t, err, unit.ErrSameStatusAlreadySet.Error())
|
||||
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
})
|
||||
|
||||
t.Run("UnitAlreadyCompleted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
// First start
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
|
||||
// Complete the unit
|
||||
err = client.SyncComplete(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "completed", status.Status)
|
||||
|
||||
// Second start
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
})
|
||||
|
||||
t.Run("UnitNotReady", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
client.SyncWant(context.Background(), "test-unit", "dependency-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.ErrorContains(t, err, "Unit is not ready")
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", status.Status)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("SyncWant", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NewUnits", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
// If units are not registered, they are registered automatically
|
||||
err = client.SyncWant(context.Background(), "test-unit", "dependency-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn)
|
||||
require.Equal(t, "completed", status.Dependencies[0].RequiredStatus)
|
||||
})
|
||||
|
||||
t.Run("DependencyAlreadyRegistered", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
// Start the dependency unit
|
||||
err = client.SyncStart(context.Background(), "dependency-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "dependency-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
|
||||
// Add the dependency after the dependency unit has already started
|
||||
err = client.SyncWant(context.Background(), "test-unit", "dependency-unit")
|
||||
|
||||
// Dependencies can be added even if the dependency unit has already started
|
||||
require.NoError(t, err)
|
||||
|
||||
// The dependency is now reflected in the test unit's status
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn)
|
||||
require.Equal(t, "completed", status.Dependencies[0].RequiredStatus)
|
||||
})
|
||||
|
||||
t.Run("DependencyAddedAfterDependentStarted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
// Start the dependent unit
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
|
||||
// Add the dependency after the dependency unit has already started
|
||||
err = client.SyncWant(context.Background(), "test-unit", "dependency-unit")
|
||||
|
||||
// Dependencies can be added even if the dependent unit has already started.
|
||||
// The dependency applies the next time a unit is started. The current status is not updated.
|
||||
// This is to allow flexible dependency management. It does mean that users of this API should
|
||||
// take care to add dependencies before they start their dependent units.
|
||||
require.NoError(t, err)
|
||||
|
||||
// The dependency is now reflected in the test unit's status
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn)
|
||||
require.Equal(t, "completed", status.Dependencies[0].RequiredStatus)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
//go:build !windows
|
||||
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// createSocket creates a Unix domain socket listener
|
||||
func createSocket(path string) (net.Listener, error) {
|
||||
if !isSocketAvailable(path) {
|
||||
return nil, xerrors.Errorf("socket path %s is not available", path)
|
||||
}
|
||||
|
||||
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||
return nil, xerrors.Errorf("remove existing socket: %w", err)
|
||||
}
|
||||
|
||||
// Create parent directory if it doesn't exist
|
||||
parentDir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(parentDir, 0o700); err != nil {
|
||||
return nil, xerrors.Errorf("create socket directory: %w", err)
|
||||
}
|
||||
|
||||
listener, err := net.Listen("unix", path)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("listen on unix socket: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(path, 0o600); err != nil {
|
||||
_ = listener.Close()
|
||||
return nil, xerrors.Errorf("set socket permissions: %w", err)
|
||||
}
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// getDefaultSocketPath returns the default socket path for Unix-like systems
|
||||
func getDefaultSocketPath() (string, error) {
|
||||
// Try XDG_RUNTIME_DIR first
|
||||
if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" {
|
||||
return filepath.Join(runtimeDir, "coder-agent.sock"), nil
|
||||
}
|
||||
|
||||
// Fall back to /tmp with user-specific path
|
||||
uid := os.Getuid()
|
||||
return filepath.Join("/tmp", fmt.Sprintf("coder-agent-%d.sock", uid)), nil
|
||||
}
|
||||
|
||||
// CleanupSocket removes the socket file
|
||||
func cleanupSocket(path string) error {
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
// isSocketAvailable checks if a socket path is available for use
|
||||
func isSocketAvailable(path string) bool {
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Try to connect to see if it's actually listening
|
||||
conn, err := net.Dial("unix", path)
|
||||
if err != nil {
|
||||
// If we can't connect, the socket is not in use
|
||||
// Socket is available for use
|
||||
return true
|
||||
}
|
||||
_ = conn.Close()
|
||||
// Socket is in use
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
//go:build windows
|
||||
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
// createSocket creates a Unix domain socket listener on Windows
|
||||
// Falls back to named pipe if Unix sockets are not supported
|
||||
func CreateSocket(path string) (net.Listener, error) {
|
||||
// Try Unix domain socket first (Windows 10 build 17063+)
|
||||
listener, err := net.Listen("unix", path)
|
||||
if err == nil {
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// Fall back to named pipe
|
||||
pipePath := `\\.\pipe\coder-agent`
|
||||
listener, err = net.Listen("tcp", pipePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// getDefaultSocketPath returns the default socket path for Windows
|
||||
func GetDefaultSocketPath() (string, error) {
|
||||
// Try to use a temporary directory
|
||||
tempDir := os.TempDir()
|
||||
if tempDir == "" {
|
||||
tempDir = "C:\\temp"
|
||||
}
|
||||
|
||||
// Create a user-specific subdirectory
|
||||
uid := os.Getuid()
|
||||
userDir := filepath.Join(tempDir, "coder-agent", strconv.Itoa(uid))
|
||||
|
||||
if err := os.MkdirAll(userDir, 0o700); err != nil {
|
||||
return "", fmt.Errorf("create user directory: %w", err)
|
||||
}
|
||||
|
||||
return filepath.Join(userDir, "agent.sock"), nil
|
||||
}
|
||||
|
||||
// cleanupSocket removes the socket file
|
||||
func CleanupSocket(path string) error {
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
// isSocketAvailable checks if a socket path is available for use
|
||||
func IsSocketAvailable(path string, logger slog.Logger) bool {
|
||||
logger.Debug(context.Background(), "Checking socket availability on Windows", slog.F("path", path))
|
||||
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
logger.Debug(context.Background(), "Socket file does not exist, path is available", slog.F("path", path))
|
||||
return true
|
||||
}
|
||||
logger.Debug(context.Background(), "Socket file exists, checking if it's listening", slog.F("path", path))
|
||||
|
||||
// Try to connect to see if it's actually listening
|
||||
conn, err := net.Dial("unix", path)
|
||||
if err != nil {
|
||||
// If we can't connect, the socket is not in use
|
||||
logger.Debug(context.Background(), "Cannot connect to socket, path is available", slog.F("path", path), slog.Error(err))
|
||||
return true
|
||||
}
|
||||
_ = conn.Close()
|
||||
logger.Debug(context.Background(), "Socket is listening, path is not available", slog.F("path", path))
|
||||
return false
|
||||
}
|
||||
|
||||
// getSocketInfo returns information about the socket file
|
||||
func GetSocketInfo(path string) (*SocketInfo, error) {
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// On Windows, we'll use a simplified approach for now
|
||||
// In a real implementation, you'd get the security descriptor
|
||||
return &SocketInfo{
|
||||
Path: path,
|
||||
UID: 0, // Simplified for now
|
||||
GID: 0, // Simplified for now
|
||||
Mode: stat.Mode(),
|
||||
ModTime: stat.ModTime(),
|
||||
Owner: "unknown",
|
||||
Group: "unknown",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SocketInfo contains information about a socket file
|
||||
type SocketInfo struct {
|
||||
Path string
|
||||
UID int
|
||||
GID int
|
||||
Mode os.FileMode
|
||||
ModTime time.Time
|
||||
Owner string // Windows SID string
|
||||
Group string // Windows SID string
|
||||
}
|
||||
@@ -0,0 +1,307 @@
|
||||
package unit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// ErrConsumerNotFound is returned when a consumer ID is not registered.
|
||||
var ErrConsumerNotFound = xerrors.New("consumer not found")
|
||||
|
||||
// ErrConsumerAlreadyRegistered is returned when a consumer ID is already registered.
|
||||
var ErrConsumerAlreadyRegistered = xerrors.New("consumer already registered")
|
||||
|
||||
// ErrCannotUpdateOtherConsumer is returned when attempting to update another consumer's status.
|
||||
var ErrCannotUpdateOtherConsumer = xerrors.New("cannot update other consumer's status")
|
||||
|
||||
// ErrDependenciesNotSatisfied is returned when a consumer's dependencies are not satisfied.
|
||||
var ErrDependenciesNotSatisfied = xerrors.New("unit dependencies not satisfied")
|
||||
|
||||
// ErrSameStatusAlreadySet is returned when attempting to set the same status as the current status.
|
||||
var ErrSameStatusAlreadySet = xerrors.New("same status already set")
|
||||
|
||||
// Status constants for dependency tracking
|
||||
const (
|
||||
StatusStarted = "started"
|
||||
StatusComplete = "completed"
|
||||
)
|
||||
|
||||
// dependencyVertex represents a vertex in the dependency graph that is associated with a consumer.
|
||||
type dependencyVertex[ConsumerID comparable] struct {
|
||||
ID ConsumerID
|
||||
}
|
||||
|
||||
// Dependency represents a dependency relationship between consumers.
|
||||
type Dependency[StatusType, ConsumerID comparable] struct {
|
||||
Consumer ConsumerID
|
||||
DependsOn ConsumerID
|
||||
RequiredStatus StatusType
|
||||
CurrentStatus StatusType
|
||||
IsSatisfied bool
|
||||
}
|
||||
|
||||
// Manager provides reactive dependency tracking over a Graph.
|
||||
// It manages consumer registration, dependency relationships, and status updates
|
||||
// with automatic recalculation of readiness when dependencies are satisfied.
|
||||
type Manager[StatusType, ConsumerID comparable] struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// The underlying graph that stores dependency relationships
|
||||
graph *Graph[StatusType, *dependencyVertex[ConsumerID]]
|
||||
|
||||
// Track current status of each consumer
|
||||
consumerStatus map[ConsumerID]StatusType
|
||||
|
||||
// Track readiness state (cached to avoid repeated graph traversal)
|
||||
consumerReadiness map[ConsumerID]bool
|
||||
|
||||
// Track which consumers are registered
|
||||
registeredConsumers map[ConsumerID]bool
|
||||
|
||||
// Store vertex instances for each consumer to ensure consistent references
|
||||
consumerVertices map[ConsumerID]*dependencyVertex[ConsumerID]
|
||||
}
|
||||
|
||||
// NewManager creates a new Manager instance.
|
||||
func NewManager[StatusType, ConsumerID comparable]() *Manager[StatusType, ConsumerID] {
|
||||
return &Manager[StatusType, ConsumerID]{
|
||||
graph: &Graph[StatusType, *dependencyVertex[ConsumerID]]{},
|
||||
consumerStatus: make(map[ConsumerID]StatusType),
|
||||
consumerReadiness: make(map[ConsumerID]bool),
|
||||
registeredConsumers: make(map[ConsumerID]bool),
|
||||
consumerVertices: make(map[ConsumerID]*dependencyVertex[ConsumerID]),
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers a new consumer as a vertex in the dependency graph.
|
||||
func (dt *Manager[StatusType, ConsumerID]) Register(id ConsumerID) error {
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
if dt.registeredConsumers[id] {
|
||||
return ErrConsumerAlreadyRegistered
|
||||
}
|
||||
|
||||
// Create and store the vertex for this consumer
|
||||
vertex := &dependencyVertex[ConsumerID]{ID: id}
|
||||
dt.consumerVertices[id] = vertex
|
||||
dt.registeredConsumers[id] = true
|
||||
dt.consumerReadiness[id] = true // New consumers start as ready (no dependencies)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddDependency adds a dependency relationship between consumers.
|
||||
// The consumer depends on the dependsOn consumer reaching the requiredStatus.
|
||||
func (dt *Manager[StatusType, ConsumerID]) AddDependency(consumer ConsumerID, dependsOn ConsumerID, requiredStatus StatusType) error {
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return xerrors.Errorf("consumer %v is not registered", consumer)
|
||||
}
|
||||
if !dt.registeredConsumers[dependsOn] {
|
||||
return xerrors.Errorf("consumer %v is not registered", dependsOn)
|
||||
}
|
||||
|
||||
// Get the stored vertices for both consumers
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
dependsOnVertex := dt.consumerVertices[dependsOn]
|
||||
|
||||
// Add the dependency edge to the graph
|
||||
// The edge goes from consumer to dependsOn, representing the dependency
|
||||
err := dt.graph.AddEdge(consumerVertex, dependsOnVertex, requiredStatus)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to add dependency: %w", err)
|
||||
}
|
||||
|
||||
// Recalculate readiness for the consumer since it now has a dependency
|
||||
dt.recalculateReadinessUnsafe(consumer)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates a consumer's status and recalculates readiness for affected dependents.
|
||||
func (dt *Manager[StatusType, ConsumerID]) UpdateStatus(consumer ConsumerID, newStatus StatusType) error {
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return ErrConsumerNotFound
|
||||
}
|
||||
|
||||
// Update the consumer's status
|
||||
if dt.consumerStatus[consumer] == newStatus {
|
||||
return ErrSameStatusAlreadySet
|
||||
}
|
||||
dt.consumerStatus[consumer] = newStatus
|
||||
|
||||
// Get all consumers that depend on this one (reverse adjacent vertices)
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
dependentEdges := dt.graph.GetReverseAdjacentVertices(consumerVertex)
|
||||
|
||||
// Recalculate readiness for all dependents
|
||||
for _, edge := range dependentEdges {
|
||||
dt.recalculateReadinessUnsafe(edge.From.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsReady checks if all dependencies for a consumer are satisfied.
|
||||
func (dt *Manager[StatusType, ConsumerID]) IsReady(consumer ConsumerID) (bool, error) {
|
||||
dt.mu.RLock()
|
||||
defer dt.mu.RUnlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return false, ErrConsumerNotFound
|
||||
}
|
||||
|
||||
return dt.consumerReadiness[consumer], nil
|
||||
}
|
||||
|
||||
// GetUnmetDependencies returns a list of unsatisfied dependencies for a consumer.
|
||||
func (dt *Manager[StatusType, ConsumerID]) GetUnmetDependencies(consumer ConsumerID) ([]Dependency[StatusType, ConsumerID], error) {
|
||||
dt.mu.RLock()
|
||||
defer dt.mu.RUnlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return nil, ErrConsumerNotFound
|
||||
}
|
||||
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
|
||||
|
||||
var unmetDependencies []Dependency[StatusType, ConsumerID]
|
||||
|
||||
for _, edge := range forwardEdges {
|
||||
dependsOnConsumer := edge.To.ID
|
||||
requiredStatus := edge.Edge
|
||||
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
|
||||
if !exists {
|
||||
// If the dependency consumer has no status, it's not satisfied
|
||||
var zeroStatus StatusType
|
||||
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
|
||||
Consumer: consumer,
|
||||
DependsOn: dependsOnConsumer,
|
||||
RequiredStatus: requiredStatus,
|
||||
CurrentStatus: zeroStatus, // Zero value
|
||||
IsSatisfied: false,
|
||||
})
|
||||
} else {
|
||||
isSatisfied := currentStatus == requiredStatus
|
||||
if !isSatisfied {
|
||||
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
|
||||
Consumer: consumer,
|
||||
DependsOn: dependsOnConsumer,
|
||||
RequiredStatus: requiredStatus,
|
||||
CurrentStatus: currentStatus,
|
||||
IsSatisfied: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return unmetDependencies, nil
|
||||
}
|
||||
|
||||
// recalculateReadinessUnsafe recalculates the readiness state for a consumer.
|
||||
// This method assumes the caller holds the write lock.
|
||||
func (dt *Manager[StatusType, ConsumerID]) recalculateReadinessUnsafe(consumer ConsumerID) {
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
|
||||
|
||||
// If there are no dependencies, the consumer is ready
|
||||
if len(forwardEdges) == 0 {
|
||||
dt.consumerReadiness[consumer] = true
|
||||
return
|
||||
}
|
||||
|
||||
// Check if all dependencies are satisfied
|
||||
allSatisfied := true
|
||||
for _, edge := range forwardEdges {
|
||||
dependsOnConsumer := edge.To.ID
|
||||
requiredStatus := edge.Edge
|
||||
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
|
||||
if !exists || currentStatus != requiredStatus {
|
||||
allSatisfied = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
dt.consumerReadiness[consumer] = allSatisfied
|
||||
}
|
||||
|
||||
// GetGraph returns the underlying graph for visualization and debugging.
|
||||
// This should be used carefully as it exposes the internal graph structure.
|
||||
func (dt *Manager[StatusType, ConsumerID]) GetGraph() *Graph[StatusType, *dependencyVertex[ConsumerID]] {
|
||||
return dt.graph
|
||||
}
|
||||
|
||||
// GetStatus returns the current status of a consumer.
|
||||
func (dt *Manager[StatusType, ConsumerID]) GetStatus(consumer ConsumerID) (StatusType, error) {
|
||||
dt.mu.RLock()
|
||||
defer dt.mu.RUnlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
var zeroStatus StatusType
|
||||
return zeroStatus, ErrConsumerNotFound
|
||||
}
|
||||
|
||||
status, exists := dt.consumerStatus[consumer]
|
||||
if !exists {
|
||||
var zeroStatus StatusType
|
||||
return zeroStatus, nil
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
// GetAllDependencies returns all dependencies for a consumer, both satisfied and unsatisfied.
|
||||
func (dt *Manager[StatusType, ConsumerID]) GetAllDependencies(consumer ConsumerID) ([]Dependency[StatusType, ConsumerID], error) {
|
||||
dt.mu.RLock()
|
||||
defer dt.mu.RUnlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return nil, ErrConsumerNotFound
|
||||
}
|
||||
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
|
||||
|
||||
var allDependencies []Dependency[StatusType, ConsumerID]
|
||||
|
||||
for _, edge := range forwardEdges {
|
||||
dependsOnConsumer := edge.To.ID
|
||||
requiredStatus := edge.Edge
|
||||
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
|
||||
if !exists {
|
||||
// If the dependency consumer has no status, it's not satisfied
|
||||
var zeroStatus StatusType
|
||||
allDependencies = append(allDependencies, Dependency[StatusType, ConsumerID]{
|
||||
Consumer: consumer,
|
||||
DependsOn: dependsOnConsumer,
|
||||
RequiredStatus: requiredStatus,
|
||||
CurrentStatus: zeroStatus, // Zero value
|
||||
IsSatisfied: false,
|
||||
})
|
||||
} else {
|
||||
isSatisfied := currentStatus == requiredStatus
|
||||
allDependencies = append(allDependencies, Dependency[StatusType, ConsumerID]{
|
||||
Consumer: consumer,
|
||||
DependsOn: dependsOnConsumer,
|
||||
RequiredStatus: requiredStatus,
|
||||
CurrentStatus: currentStatus,
|
||||
IsSatisfied: isSatisfied,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return allDependencies, nil
|
||||
}
|
||||
|
||||
// ExportDOT exports the dependency graph to DOT format for visualization.
|
||||
func (dt *Manager[StatusType, ConsumerID]) ExportDOT(name string) (string, error) {
|
||||
return dt.graph.ToDOT(name)
|
||||
}
|
||||
@@ -0,0 +1,691 @@
|
||||
package unit_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
)
|
||||
|
||||
type testStatus string
|
||||
|
||||
const (
|
||||
statusStarted testStatus = "started"
|
||||
statusRunning testStatus = "running"
|
||||
statusCompleted testStatus = "completed"
|
||||
)
|
||||
|
||||
type testConsumerID string
|
||||
|
||||
const (
|
||||
consumerA testConsumerID = "serviceA"
|
||||
consumerB testConsumerID = "serviceB"
|
||||
consumerC testConsumerID = "serviceC"
|
||||
consumerD testConsumerID = "serviceD"
|
||||
)
|
||||
|
||||
func TestDependencyTracker_Register(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
t.Run("RegisterNewConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Consumer should be ready initially (no dependencies)
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("RegisterDuplicateConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tracker.Register(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already registered")
|
||||
})
|
||||
|
||||
t.Run("RegisterMultipleConsumers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// All should be ready initially
|
||||
for _, consumer := range consumers {
|
||||
ready, err := tracker.IsReady(consumer)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_AddDependency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("AddDependencyBetweenRegisteredConsumers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A should no longer be ready (depends on B)
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// B should still be ready (no dependencies)
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("AddDependencyWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to add dependency to unregistered consumer
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not registered")
|
||||
})
|
||||
|
||||
t.Run("AddDependencyFromUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to add dependency from unregistered consumer
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not registered")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_UpdateStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("UpdateStatusTriggersReadinessRecalculation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially A is not ready
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "running" - A should become ready
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
err := tracker.UpdateStatus(consumerA, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
})
|
||||
|
||||
t.Run("LinearChainDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create chain: A depends on B being "started", B depends on C being "completed"
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially only C is ready
|
||||
ready, err := tracker.IsReady(consumerC)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "completed" - B should become ready
|
||||
err = tracker.UpdateStatus(consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "started" - A should become ready
|
||||
err = tracker.UpdateStatus(consumerB, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_GetUnmetDependencies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("GetUnmetDependenciesForConsumerWithNoDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, unmet)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesForConsumerWithUnsatisfiedDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, unmet, 1)
|
||||
|
||||
assert.Equal(t, consumerA, unmet[0].Consumer)
|
||||
assert.Equal(t, consumerB, unmet[0].DependsOn)
|
||||
assert.Equal(t, statusRunning, unmet[0].RequiredStatus)
|
||||
assert.False(t, unmet[0].IsSatisfied)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesForConsumerWithSatisfiedDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update B to "running"
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, unmet)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesForUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
assert.Nil(t, unmet)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_ConcurrentOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ConcurrentStatusUpdates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create dependencies: A depends on B, B depends on C, C depends on D
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const numGoroutines = 10
|
||||
|
||||
// Launch goroutines that update statuses
|
||||
errors := make([]error, numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Update D to completed (should make C ready)
|
||||
err := tracker.UpdateStatus(consumerD, statusCompleted)
|
||||
if err != nil {
|
||||
errors[goroutineID] = err
|
||||
return
|
||||
}
|
||||
|
||||
// Update C to started (should make B ready)
|
||||
err = tracker.UpdateStatus(consumerC, statusStarted)
|
||||
if err != nil {
|
||||
errors[goroutineID] = err
|
||||
return
|
||||
}
|
||||
|
||||
// Update B to running (should make A ready)
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
if err != nil {
|
||||
errors[goroutineID] = err
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check for any errors in goroutines
|
||||
for i, err := range errors {
|
||||
require.NoError(t, err, "goroutine %d had error", i)
|
||||
}
|
||||
|
||||
// All consumers should be ready after the updates
|
||||
for _, consumer := range consumers {
|
||||
ready, err := tracker.IsReady(consumer)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConcurrentReadinessChecks", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const numGoroutines = 20
|
||||
|
||||
// Launch goroutines that check readiness
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Check readiness multiple times
|
||||
for j := 0; j < 10; j++ {
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
// Initially should be false, then true after B is updated
|
||||
_ = ready
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
// B should always be ready (no dependencies)
|
||||
assert.True(t, ready)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Update B to "running" in the middle of readiness checks
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_MultipleDependencies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ConsumerWithMultipleDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// A depends on B being "running" AND C being "started"
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A should not be ready (depends on both B and C)
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "running" - A should still not be ready (needs C too)
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "started" - A should now be ready
|
||||
err = tracker.UpdateStatus(consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("ComplexDependencyChain", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create complex dependency graph:
|
||||
// A depends on B being "running" AND C being "started"
|
||||
// B depends on D being "completed"
|
||||
// C depends on D being "completed"
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially only D is ready
|
||||
ready, err := tracker.IsReady(consumerD)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerC)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update D to "completed" - B and C should become ready
|
||||
err = tracker.UpdateStatus(consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerC)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "running" - A should still not be ready (needs C)
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "started" - A should now be ready
|
||||
err = tracker.UpdateStatus(consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("DifferentStatusTypes", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerC)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running" AND C being "completed"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update B to "running" but not C - A should not be ready
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "completed" - A should now be ready
|
||||
err = tracker.UpdateStatus(consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_ErrorCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
err := tracker.UpdateStatus(consumerA, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
})
|
||||
|
||||
t.Run("IsReadyWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
assert.False(t, ready)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
assert.Nil(t, unmet)
|
||||
})
|
||||
|
||||
t.Run("AddDependencyWithUnregisteredConsumers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Try to add dependency with unregistered consumers
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not registered")
|
||||
})
|
||||
|
||||
t.Run("CyclicDependencyDetection", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to make B depend on A (creates cycle)
|
||||
err = tracker.AddDependency(consumerB, consumerA, statusStarted)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "would create a cycle")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_ToDOT(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ExportSimpleGraph", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add dependency
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
dot, err := tracker.ExportDOT("test")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, dot)
|
||||
assert.Contains(t, dot, "digraph")
|
||||
})
|
||||
|
||||
t.Run("ExportComplexGraph", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create complex dependency graph
|
||||
// A depends on B and C, B depends on D, C depends on D
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
dot, err := tracker.ExportDOT("complex")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, dot)
|
||||
assert.Contains(t, dot, "digraph")
|
||||
})
|
||||
}
|
||||
@@ -56,6 +56,7 @@ func workspaceAgent() *serpent.Command {
|
||||
devcontainers bool
|
||||
devcontainerProjectDiscovery bool
|
||||
devcontainerDiscoveryAutostart bool
|
||||
socketPath string
|
||||
)
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
@@ -297,6 +298,7 @@ func workspaceAgent() *serpent.Command {
|
||||
agentcontainers.WithProjectDiscovery(devcontainerProjectDiscovery),
|
||||
agentcontainers.WithDiscoveryAutostart(devcontainerDiscoveryAutostart),
|
||||
},
|
||||
SocketPath: socketPath,
|
||||
})
|
||||
|
||||
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
|
||||
@@ -449,6 +451,12 @@ func workspaceAgent() *serpent.Command {
|
||||
Description: "Allow the agent to autostart devcontainer projects it discovers based on their configuration.",
|
||||
Value: serpent.BoolOf(&devcontainerDiscoveryAutostart),
|
||||
},
|
||||
{
|
||||
Flag: "socket-path",
|
||||
Env: "CODER_AGENT_SOCKET_PATH",
|
||||
Description: "Specify the path for the agent socket.",
|
||||
Value: serpent.StringOf(&socketPath),
|
||||
},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/pflag"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
var (
|
||||
_ pflag.SliceValue = &AllowListFlag{}
|
||||
_ pflag.Value = &AllowListFlag{}
|
||||
)
|
||||
|
||||
// AllowListFlag implements pflag.SliceValue for codersdk.APIAllowListTarget entries.
|
||||
type AllowListFlag []codersdk.APIAllowListTarget
|
||||
|
||||
func AllowListFlagOf(al *[]codersdk.APIAllowListTarget) *AllowListFlag {
|
||||
return (*AllowListFlag)(al)
|
||||
}
|
||||
|
||||
func (a AllowListFlag) String() string {
|
||||
return strings.Join(a.GetSlice(), ",")
|
||||
}
|
||||
|
||||
func (a AllowListFlag) Value() []codersdk.APIAllowListTarget {
|
||||
return []codersdk.APIAllowListTarget(a)
|
||||
}
|
||||
|
||||
func (AllowListFlag) Type() string { return "allow-list" }
|
||||
|
||||
func (a *AllowListFlag) Set(set string) error {
|
||||
values, err := csv.NewReader(strings.NewReader(set)).Read()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse allow list entries as csv: %w", err)
|
||||
}
|
||||
for _, v := range values {
|
||||
if err := a.Append(v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AllowListFlag) Append(value string) error {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return xerrors.New("allow list entry cannot be empty")
|
||||
}
|
||||
var target codersdk.APIAllowListTarget
|
||||
if err := target.UnmarshalText([]byte(value)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*a = append(*a, target)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AllowListFlag) Replace(items []string) error {
|
||||
*a = []codersdk.APIAllowListTarget{}
|
||||
for _, item := range items {
|
||||
if err := a.Append(item); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AllowListFlag) GetSlice() []string {
|
||||
out := make([]string, len(*a))
|
||||
for i, entry := range *a {
|
||||
out[i] = entry.String()
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -144,6 +144,7 @@ func (r *RootCmd) AGPLExperimental() []*serpent.Command {
|
||||
r.mcpCommand(),
|
||||
r.promptExample(),
|
||||
r.rptyCommand(),
|
||||
r.syncCommand(),
|
||||
r.tasksCommand(),
|
||||
r.boundary(),
|
||||
}
|
||||
|
||||
-47
@@ -109,51 +109,6 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
}
|
||||
},
|
||||
),
|
||||
CompletionHandler: func(inv *serpent.Invocation) []string {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
res, err := client.Workspaces(inv.Context(), codersdk.WorkspaceFilter{
|
||||
Owner: codersdk.Me,
|
||||
})
|
||||
if err != nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var completions []string
|
||||
var wg sync.WaitGroup
|
||||
for _, ws := range res.Workspaces {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
resources, err := client.TemplateVersionResources(inv.Context(), ws.LatestBuild.TemplateVersionID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var agents []codersdk.WorkspaceAgent
|
||||
for _, resource := range resources {
|
||||
agents = append(agents, resource.Agents...)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(agents) == 1 {
|
||||
completions = append(completions, ws.Name)
|
||||
} else {
|
||||
for _, agent := range agents {
|
||||
completions = append(completions, fmt.Sprintf("%s.%s", ws.Name, agent.Name))
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
slices.Sort(completions)
|
||||
return completions
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) (retErr error) {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
@@ -951,8 +906,6 @@ func GetWorkspaceAndAgent(ctx context.Context, inv *serpent.Invocation, client *
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("start workspace with active template version: %w", err)
|
||||
}
|
||||
_, _ = fmt.Fprintln(inv.Stdout, "Unable to start the workspace with template version from last build. Your workspace has been updated to the current active template version.")
|
||||
default:
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("start workspace with current template version: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("start workspace with current template version: %w", err)
|
||||
|
||||
@@ -2447,99 +2447,3 @@ func tempDirUnixSocket(t *testing.T) string {
|
||||
|
||||
return t.TempDir()
|
||||
}
|
||||
|
||||
func TestSSH_Completion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("SingleAgent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, workspace, agentToken := setupWorkspaceForAgent(t)
|
||||
_ = agenttest.New(t, client.URL, agentToken)
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
|
||||
var stdout bytes.Buffer
|
||||
inv, root := clitest.New(t, "ssh", "")
|
||||
inv.Stdout = &stdout
|
||||
inv.Environ.Set("COMPLETION_MODE", "1")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// For single-agent workspaces, the only completion should be the
|
||||
// bare workspace name.
|
||||
output := stdout.String()
|
||||
t.Logf("Completion output: %q", output)
|
||||
require.Contains(t, output, workspace.Name)
|
||||
})
|
||||
|
||||
t.Run("MultiAgent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, store := coderdtest.NewWithDatabase(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
userClient, user := coderdtest.CreateAnotherUserMutators(t, client, first.OrganizationID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
|
||||
r.Username = "multiuser"
|
||||
})
|
||||
|
||||
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
|
||||
Name: "multiworkspace",
|
||||
OrganizationID: first.OrganizationID,
|
||||
OwnerID: user.ID,
|
||||
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
|
||||
return []*proto.Agent{
|
||||
{
|
||||
Name: "agent1",
|
||||
Auth: &proto.Agent_Token{},
|
||||
},
|
||||
{
|
||||
Name: "agent2",
|
||||
Auth: &proto.Agent_Token{},
|
||||
},
|
||||
}
|
||||
}).Do()
|
||||
|
||||
var stdout bytes.Buffer
|
||||
inv, root := clitest.New(t, "ssh", "")
|
||||
inv.Stdout = &stdout
|
||||
inv.Environ.Set("COMPLETION_MODE", "1")
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// For multi-agent workspaces, completions should include the
|
||||
// workspace.agent format but NOT the bare workspace name.
|
||||
output := stdout.String()
|
||||
t.Logf("Completion output: %q", output)
|
||||
lines := strings.Split(strings.TrimSpace(output), "\n")
|
||||
require.NotContains(t, lines, r.Workspace.Name)
|
||||
require.Contains(t, output, r.Workspace.Name+".agent1")
|
||||
require.Contains(t, output, r.Workspace.Name+".agent2")
|
||||
})
|
||||
|
||||
t.Run("NetworkError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var stdout bytes.Buffer
|
||||
inv, _ := clitest.New(t, "ssh", "")
|
||||
inv.Stdout = &stdout
|
||||
inv.Environ.Set("COMPLETION_MODE", "1")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
output := stdout.String()
|
||||
require.Empty(t, output)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -87,7 +87,6 @@ func buildNumberOption(n *int64) serpent.Option {
|
||||
|
||||
func (r *RootCmd) statePush() *serpent.Command {
|
||||
var buildNumber int64
|
||||
var noBuild bool
|
||||
cmd := &serpent.Command{
|
||||
Use: "push <workspace> <file>",
|
||||
Short: "Push a Terraform state file to a workspace.",
|
||||
@@ -127,16 +126,6 @@ func (r *RootCmd) statePush() *serpent.Command {
|
||||
return err
|
||||
}
|
||||
|
||||
if noBuild {
|
||||
// Update state directly without triggering a build.
|
||||
err = client.UpdateWorkspaceBuildState(inv.Context(), build.ID, state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, _ = fmt.Fprintln(inv.Stdout, "State updated successfully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
build, err = client.CreateWorkspaceBuild(inv.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{
|
||||
TemplateVersionID: build.TemplateVersionID,
|
||||
Transition: build.Transition,
|
||||
@@ -150,12 +139,6 @@ func (r *RootCmd) statePush() *serpent.Command {
|
||||
}
|
||||
cmd.Options = serpent.OptionSet{
|
||||
buildNumberOption(&buildNumber),
|
||||
{
|
||||
Flag: "no-build",
|
||||
FlagShorthand: "n",
|
||||
Description: "Update the state without triggering a workspace build. Useful for state-only migrations.",
|
||||
Value: serpent.BoolOf(&noBuild),
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package cli_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -11,7 +10,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -160,49 +158,4 @@ func TestStatePush(t *testing.T) {
|
||||
err := inv.Run()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("NoBuild", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, store := coderdtest.NewWithDatabase(t, nil)
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
templateAdmin, taUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID, rbac.RoleTemplateAdmin())
|
||||
initialState := []byte("initial state")
|
||||
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taUser.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{ProvisionerState: initialState}).
|
||||
Do()
|
||||
wantState := []byte("updated state")
|
||||
stateFile, err := os.CreateTemp(t.TempDir(), "")
|
||||
require.NoError(t, err)
|
||||
_, err = stateFile.Write(wantState)
|
||||
require.NoError(t, err)
|
||||
err = stateFile.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
inv, root := clitest.New(t, "state", "push", "--no-build", r.Workspace.Name, stateFile.Name())
|
||||
clitest.SetupConfig(t, templateAdmin, root)
|
||||
var stdout bytes.Buffer
|
||||
inv.Stdout = &stdout
|
||||
err = inv.Run()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, stdout.String(), "State updated successfully")
|
||||
|
||||
// Verify the state was updated by pulling it.
|
||||
inv, root = clitest.New(t, "state", "pull", r.Workspace.Name)
|
||||
var gotState bytes.Buffer
|
||||
inv.Stdout = &gotState
|
||||
clitest.SetupConfig(t, templateAdmin, root)
|
||||
err = inv.Run()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wantState, bytes.TrimSpace(gotState.Bytes()))
|
||||
|
||||
// Verify no new build was created.
|
||||
builds, err := store.GetWorkspaceBuildsByWorkspaceID(dbauthz.AsSystemRestricted(context.Background()), database.GetWorkspaceBuildsByWorkspaceIDParams{
|
||||
WorkspaceID: r.Workspace.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, builds, 1, "expected only the initial build, no new build should be created")
|
||||
})
|
||||
}
|
||||
|
||||
+25
@@ -0,0 +1,25 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncCommand() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "sync",
|
||||
Short: "Synchronize with the local agent socket",
|
||||
Long: "Commands for interacting with the local Coder agent via socket communication.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.syncPing(),
|
||||
r.syncStart(),
|
||||
r.syncWant(),
|
||||
r.syncComplete(),
|
||||
r.syncWait(),
|
||||
r.syncStatus(),
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncComplete() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "complete <unit>",
|
||||
Short: "Mark a unit as complete in the dependency graph",
|
||||
Long: "Set a unit's status to complete in the dependency graph.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if len(i.Args) != 1 {
|
||||
return xerrors.New("exactly one unit name is required")
|
||||
}
|
||||
unit := i.Args[0]
|
||||
|
||||
// Show initial message
|
||||
fmt.Printf("Completing unit '%s'...\n", unit)
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Complete the unit
|
||||
if err := client.SyncComplete(ctx, unit); err != nil {
|
||||
return xerrors.Errorf("complete unit failed: %w", err)
|
||||
}
|
||||
|
||||
// Display success message
|
||||
fmt.Printf("Unit '%s' completed successfully\n", unit)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncPing() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "ping",
|
||||
Short: "Ping the local agent socket",
|
||||
Long: "Test connectivity to the local Coder agent via socket communication.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
// Show initial message
|
||||
fmt.Println("Pinging agent socket...")
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Measure round-trip time
|
||||
start := time.Now()
|
||||
resp, err := client.Ping(ctx)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
return xerrors.Errorf("ping failed: %w", err)
|
||||
}
|
||||
|
||||
// Display results
|
||||
fmt.Printf("Response: %s\n", resp.Message)
|
||||
fmt.Printf("Timestamp: %s\n", resp.Timestamp.Format(time.RFC3339))
|
||||
fmt.Printf("Round-trip time: %s\n", duration.Round(time.Microsecond))
|
||||
fmt.Println("Status: healthy")
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
const (
|
||||
// SyncPollInterval is the interval between dependency checks for sync start
|
||||
SyncPollInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncStart() *serpent.Command {
|
||||
var timeout time.Duration
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "start <unit>",
|
||||
Short: "Start a unit in the dependency graph",
|
||||
Long: "Register a unit in the dependency graph and set its status to started. Waits for all dependencies to be satisfied before marking as started.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if len(i.Args) != 1 {
|
||||
return xerrors.New("exactly one unit name is required")
|
||||
}
|
||||
unitName := i.Args[0]
|
||||
|
||||
// Set up context with timeout if specified
|
||||
if timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Show initial message
|
||||
fmt.Printf("Starting unit '%s'...\n", unitName)
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Check if dependencies are satisfied first
|
||||
err = client.SyncReady(ctx, unitName)
|
||||
if err != nil {
|
||||
// Check if it's a "not ready" error (expected if dependencies exist)
|
||||
if xerrors.Is(err, unit.ErrDependenciesNotSatisfied) {
|
||||
// Dependencies exist but aren't satisfied, start polling
|
||||
fmt.Printf("Waiting for dependencies of unit '%s' to be satisfied...\n", unitName)
|
||||
|
||||
// Poll until dependencies are satisfied
|
||||
ticker := time.NewTicker(SyncPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
pollLoop:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return xerrors.Errorf("timeout waiting for dependencies of unit '%s'", unitName)
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
// Check if dependencies are satisfied
|
||||
err := client.SyncReady(ctx, unitName)
|
||||
if err == nil {
|
||||
// Dependencies are satisfied
|
||||
fmt.Printf("Dependencies satisfied, marking unit '%s' as started\n", unitName)
|
||||
break pollLoop
|
||||
}
|
||||
|
||||
// Check if it's still a "not ready" error (expected while waiting)
|
||||
if xerrors.Is(err, unit.ErrDependenciesNotSatisfied) {
|
||||
// Still waiting, continue polling
|
||||
continue
|
||||
}
|
||||
|
||||
// Some other error occurred
|
||||
return xerrors.Errorf("error checking dependencies: %w", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Some other error occurred
|
||||
return xerrors.Errorf("error checking dependencies: %w", err)
|
||||
}
|
||||
} else {
|
||||
// No dependencies or already satisfied
|
||||
fmt.Printf("Dependencies satisfied, marking unit '%s' as started\n", unitName)
|
||||
}
|
||||
|
||||
// Start the unit
|
||||
if err := client.SyncStart(ctx, unitName); err != nil {
|
||||
return xerrors.Errorf("start unit failed: %w", err)
|
||||
}
|
||||
|
||||
// Display success message
|
||||
fmt.Printf("Unit '%s' started successfully\n", unitName)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = append(cmd.Options, serpent.Option{
|
||||
Flag: "timeout",
|
||||
Description: "Maximum time to wait for dependencies (e.g., 30s, 5m). No timeout by default.",
|
||||
Value: serpent.DurationOf(&timeout),
|
||||
})
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
type outputFormat string
|
||||
|
||||
const (
|
||||
outputFormatHuman outputFormat = "human"
|
||||
outputFormatJSON outputFormat = "json"
|
||||
outputFormatDOT outputFormat = "dot"
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncStatus() *serpent.Command {
|
||||
var (
|
||||
output string
|
||||
recursive bool
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "status <unit>",
|
||||
Short: "Show the status of a unit and its dependencies",
|
||||
Long: "Display the current status of a unit and information about its dependencies. Supports multiple output formats.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if len(i.Args) != 1 {
|
||||
return xerrors.New("exactly one unit name is required")
|
||||
}
|
||||
unit := i.Args[0]
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Get status information
|
||||
statusResp, err := client.SyncStatus(ctx, unit, recursive)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get status failed: %w", err)
|
||||
}
|
||||
|
||||
// Output based on format
|
||||
switch outputFormat(output) {
|
||||
case outputFormatJSON:
|
||||
return outputJSON(statusResp)
|
||||
case outputFormatDOT:
|
||||
return outputDOT(statusResp)
|
||||
default: // outputFormatHuman
|
||||
return outputHuman(statusResp)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = append(cmd.Options,
|
||||
serpent.Option{
|
||||
Flag: "output",
|
||||
FlagShorthand: "o",
|
||||
Description: "Output format: human, json, or dot.",
|
||||
Value: serpent.EnumOf(&output, "human", "json", "dot"),
|
||||
},
|
||||
serpent.Option{
|
||||
Flag: "recursive",
|
||||
FlagShorthand: "r",
|
||||
Description: "Show transitive dependencies and include DOT graph.",
|
||||
Value: serpent.BoolOf(&recursive),
|
||||
},
|
||||
)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func outputJSON(statusResp *agentsdk.SyncStatusResponse) error {
|
||||
encoder := json.NewEncoder(os.Stdout)
|
||||
encoder.SetIndent("", " ")
|
||||
return encoder.Encode(statusResp)
|
||||
}
|
||||
|
||||
func outputDOT(statusResp *agentsdk.SyncStatusResponse) error {
|
||||
if statusResp.DOT == "" {
|
||||
return xerrors.New("DOT output requires --recursive flag")
|
||||
}
|
||||
fmt.Println(statusResp.DOT)
|
||||
return nil
|
||||
}
|
||||
|
||||
func outputHuman(statusResp *agentsdk.SyncStatusResponse) error {
|
||||
// Unit status
|
||||
fmt.Printf("Unit: %s\n", statusResp.Unit)
|
||||
fmt.Printf("Status: %s\n", statusResp.Status)
|
||||
fmt.Printf("Ready: %t\n", statusResp.IsReady)
|
||||
fmt.Println()
|
||||
|
||||
// Dependencies
|
||||
if len(statusResp.Dependencies) == 0 {
|
||||
fmt.Println("No dependencies")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Println("Dependencies:")
|
||||
fmt.Println(strings.Repeat("-", 80))
|
||||
fmt.Printf("%-20s %-15s %-15s %-10s\n", "Depends On", "Required", "Current", "Satisfied")
|
||||
fmt.Println(strings.Repeat("-", 80))
|
||||
|
||||
for _, dep := range statusResp.Dependencies {
|
||||
satisfied := "✓"
|
||||
if !dep.IsSatisfied {
|
||||
satisfied = "✗"
|
||||
}
|
||||
fmt.Printf("%-20s %-15s %-15s %-10s\n",
|
||||
dep.DependsOn,
|
||||
dep.RequiredStatus,
|
||||
dep.CurrentStatus,
|
||||
satisfied,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,359 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
)
|
||||
|
||||
// mockAgentSocketServer simulates the agent socket server for testing
|
||||
type mockAgentSocketServer struct {
|
||||
listener net.Listener
|
||||
handlers map[string]func(string) (string, error)
|
||||
}
|
||||
|
||||
func newMockAgentSocketServer(t *testing.T, socketPath string) *mockAgentSocketServer {
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := &mockAgentSocketServer{
|
||||
listener: listener,
|
||||
handlers: make(map[string]func(string) (string, error)),
|
||||
}
|
||||
|
||||
// Set up default handlers
|
||||
server.handlers["sync.wait"] = func(unitName string) (string, error) {
|
||||
// Always return dependencies not satisfied to trigger polling
|
||||
return "", unit.ErrDependenciesNotSatisfied
|
||||
}
|
||||
|
||||
server.handlers["sync.start"] = func(unitName string) (string, error) {
|
||||
return "Unit " + unitName + " started successfully", nil
|
||||
}
|
||||
|
||||
go server.serve(t)
|
||||
return server
|
||||
}
|
||||
|
||||
func (s *mockAgentSocketServer) serve(t *testing.T) {
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
t.Logf("Accept error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
go s.handleConnection(t, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *mockAgentSocketServer) handleConnection(t *testing.T, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
// Simple JSON-RPC-like protocol simulation
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
t.Logf("Read error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
request := string(buf[:n])
|
||||
|
||||
// Parse method from request (simplified)
|
||||
var method string
|
||||
if strings.Contains(request, "sync.wait") {
|
||||
method = "sync.wait"
|
||||
} else if strings.Contains(request, "sync.start") {
|
||||
method = "sync.start"
|
||||
}
|
||||
|
||||
handler, exists := s.handlers[method]
|
||||
if !exists {
|
||||
response := `{"error": {"code": -32601, "message": "Method not found"}}`
|
||||
_, _ = conn.Write([]byte(response))
|
||||
return
|
||||
}
|
||||
|
||||
// Extract unit name from request (simplified)
|
||||
unitName := "test-unit"
|
||||
if strings.Contains(request, "test-unit") {
|
||||
unitName = "test-unit"
|
||||
}
|
||||
|
||||
message, err := handler(unitName)
|
||||
if err != nil {
|
||||
response := fmt.Sprintf(`{"error": {"code": -32603, "message": %q}}`, err.Error())
|
||||
_, _ = conn.Write([]byte(response))
|
||||
return
|
||||
}
|
||||
|
||||
response := fmt.Sprintf(`{"result": {"success": true, "message": %q}}`, message)
|
||||
_, _ = conn.Write([]byte(response))
|
||||
}
|
||||
|
||||
func (s *mockAgentSocketServer) setHandler(method string, handler func(string) (string, error)) {
|
||||
s.handlers[method] = handler
|
||||
}
|
||||
|
||||
func (s *mockAgentSocketServer) close() {
|
||||
_ = s.listener.Close()
|
||||
}
|
||||
|
||||
func TestSyncStartTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Test with a short timeout
|
||||
inv, _ := clitest.New(t, "exp", "sync", "start", "test-unit", "--timeout", "100ms")
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should timeout after approximately 100ms
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
|
||||
|
||||
// Should timeout within a reasonable range (100ms + some buffer for test execution)
|
||||
assert.True(t, duration >= 100*time.Millisecond, "Duration should be at least 100ms, got %v", duration)
|
||||
assert.True(t, duration < 2*time.Second, "Duration should be less than 2s, got %v", duration)
|
||||
}
|
||||
|
||||
func TestSyncWaitTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Test with a short timeout
|
||||
inv, _ := clitest.New(t, "exp", "sync", "wait", "test-unit", "--timeout", "100ms")
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should timeout after approximately 100ms
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
|
||||
|
||||
// Should timeout within a reasonable range (100ms + some buffer for test execution)
|
||||
assert.True(t, duration >= 100*time.Millisecond, "Duration should be at least 100ms, got %v", duration)
|
||||
assert.True(t, duration < 2*time.Second, "Duration should be less than 2s, got %v", duration)
|
||||
}
|
||||
|
||||
func TestSyncStartNoTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Set up handler that will eventually succeed
|
||||
callCount := 0
|
||||
server.setHandler("sync.wait", func(unitName string) (string, error) {
|
||||
callCount++
|
||||
if callCount >= 3 {
|
||||
// After 3 calls, dependencies are satisfied
|
||||
return "Dependencies satisfied", nil
|
||||
}
|
||||
return "", unit.ErrDependenciesNotSatisfied
|
||||
})
|
||||
|
||||
// Test without timeout - should eventually succeed
|
||||
inv, _ := clitest.New(t, "exp", "sync", "start", "test-unit")
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should succeed after a few polling cycles
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Should take at least 2 seconds (2 polling cycles at 1s interval)
|
||||
assert.True(t, duration >= 2*time.Second, "Duration should be at least 2s, got %v", duration)
|
||||
assert.True(t, callCount >= 3, "Should have made at least 3 calls, got %d", callCount)
|
||||
}
|
||||
|
||||
func TestSyncWaitNoTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Set up handler that will eventually succeed
|
||||
callCount := 0
|
||||
server.setHandler("sync.wait", func(unitName string) (string, error) {
|
||||
callCount++
|
||||
if callCount >= 3 {
|
||||
// After 3 calls, dependencies are satisfied
|
||||
return "Dependencies satisfied", nil
|
||||
}
|
||||
return "", unit.ErrDependenciesNotSatisfied
|
||||
})
|
||||
|
||||
// Test without timeout - should eventually succeed
|
||||
inv, _ := clitest.New(t, "exp", "sync", "wait", "test-unit")
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should succeed after a few polling cycles
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Should take at least 2 seconds (2 polling cycles at 1s interval)
|
||||
assert.True(t, duration >= 2*time.Second, "Duration should be at least 2s, got %v", duration)
|
||||
assert.True(t, callCount >= 3, "Should have made at least 3 calls, got %d", callCount)
|
||||
}
|
||||
|
||||
func TestSyncStartTimeoutWithDifferentValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
timeout string
|
||||
expected time.Duration
|
||||
}{
|
||||
{"50ms", "50ms", 50 * time.Millisecond},
|
||||
{"200ms", "200ms", 200 * time.Millisecond},
|
||||
{"500ms", "500ms", 500 * time.Millisecond},
|
||||
{"1s", "1s", 1 * time.Second},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Test with specified timeout
|
||||
inv, _ := clitest.New(t, "exp", "sync", "start", "test-unit", "--timeout", tc.timeout)
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should timeout after approximately the specified duration
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
|
||||
|
||||
// Should timeout within a reasonable range
|
||||
assert.True(t, duration >= tc.expected, "Duration should be at least %v, got %v", tc.expected, duration)
|
||||
assert.True(t, duration < tc.expected+2*time.Second, "Duration should be less than %v, got %v", tc.expected+2*time.Second, duration)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncWaitTimeoutWithDifferentValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
timeout string
|
||||
expected time.Duration
|
||||
}{
|
||||
{"50ms", "50ms", 50 * time.Millisecond},
|
||||
{"200ms", "200ms", 200 * time.Millisecond},
|
||||
{"500ms", "500ms", 500 * time.Millisecond},
|
||||
{"1s", "1s", 1 * time.Second},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Test with specified timeout
|
||||
inv, _ := clitest.New(t, "exp", "sync", "wait", "test-unit", "--timeout", tc.timeout)
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should timeout after approximately the specified duration
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
|
||||
|
||||
// Should timeout within a reasonable range
|
||||
assert.True(t, duration >= tc.expected, "Duration should be at least %v, got %v", tc.expected, duration)
|
||||
assert.True(t, duration < tc.expected+2*time.Second, "Duration should be less than %v, got %v", tc.expected+2*time.Second, duration)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
const (
|
||||
// PollInterval is the interval between dependency checks
|
||||
PollInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncWait() *serpent.Command {
|
||||
var timeout time.Duration
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "wait <unit>",
|
||||
Short: "Wait for a unit's dependencies to be satisfied",
|
||||
Long: "Poll until all dependencies for a unit are met. Exits when dependencies are satisfied or timeout is reached.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if len(i.Args) != 1 {
|
||||
return xerrors.New("exactly one unit name is required")
|
||||
}
|
||||
unitName := i.Args[0]
|
||||
|
||||
// Set up context with timeout if specified
|
||||
if timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Show initial message
|
||||
fmt.Printf("Waiting for dependencies of unit '%s' to be satisfied...\n", unitName)
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Poll until dependencies are satisfied
|
||||
ticker := time.NewTicker(PollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return xerrors.Errorf("timeout waiting for dependencies of unit '%s'", unitName)
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
// Check if dependencies are satisfied
|
||||
err := client.SyncReady(ctx, unitName)
|
||||
if err == nil {
|
||||
// Dependencies are satisfied
|
||||
fmt.Printf("Dependencies for unit '%s' are now satisfied\n", unitName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if it's a "not ready" error (expected while waiting)
|
||||
if xerrors.Is(err, unit.ErrDependenciesNotSatisfied) {
|
||||
// Still waiting, continue polling
|
||||
continue
|
||||
}
|
||||
|
||||
// Some other error occurred
|
||||
return xerrors.Errorf("error checking dependencies: %w", err)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = append(cmd.Options, serpent.Option{
|
||||
Flag: "timeout",
|
||||
Description: "Maximum time to wait for dependencies (e.g., 30s, 5m). No timeout by default.",
|
||||
Value: serpent.DurationOf(&timeout),
|
||||
})
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncWant() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "want <unit> <depends-on>",
|
||||
Short: "Declare a dependency between units",
|
||||
Long: "Declare that a unit depends on another unit reaching complete status.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if len(i.Args) != 2 {
|
||||
return xerrors.New("exactly two arguments are required: unit and depends-on")
|
||||
}
|
||||
unit := i.Args[0]
|
||||
dependsOn := i.Args[1]
|
||||
|
||||
// Show initial message
|
||||
fmt.Printf("Declaring dependency: '%s' depends on '%s'...\n", unit, dependsOn)
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Declare the dependency
|
||||
if err := client.SyncWant(ctx, unit, dependsOn); err != nil {
|
||||
return xerrors.Errorf("declare dependency failed: %w", err)
|
||||
}
|
||||
|
||||
// Display success message
|
||||
fmt.Printf("Dependency declared: '%s' now depends on '%s'\n", unit, dependsOn)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSyncWant(t *testing.T) {
|
||||
}
|
||||
+3
@@ -67,6 +67,9 @@ OPTIONS:
|
||||
--script-data-dir string, $CODER_AGENT_SCRIPT_DATA_DIR (default: /tmp)
|
||||
Specify the location for storing script data.
|
||||
|
||||
--socket-path string, $CODER_AGENT_SOCKET_PATH
|
||||
Specify the path for the agent socket.
|
||||
|
||||
--ssh-max-timeout duration, $CODER_AGENT_SSH_MAX_TIMEOUT (default: 72h)
|
||||
Specify the max timeout for a SSH connection, it is advisable to set
|
||||
it to a minimum of 60s, but no more than 72h.
|
||||
|
||||
+1
-2
@@ -90,7 +90,6 @@
|
||||
"allow_renames": false,
|
||||
"favorite": false,
|
||||
"next_start_at": "====[timestamp]=====",
|
||||
"is_prebuild": false,
|
||||
"task_id": null
|
||||
"is_prebuild": false
|
||||
}
|
||||
]
|
||||
|
||||
-35
@@ -80,41 +80,6 @@ OPTIONS:
|
||||
Periodically check for new releases of Coder and inform the owner. The
|
||||
check is performed once per day.
|
||||
|
||||
AIBRIDGE OPTIONS:
|
||||
--aibridge-anthropic-base-url string, $CODER_AIBRIDGE_ANTHROPIC_BASE_URL (default: https://api.anthropic.com/)
|
||||
The base URL of the Anthropic API.
|
||||
|
||||
--aibridge-anthropic-key string, $CODER_AIBRIDGE_ANTHROPIC_KEY
|
||||
The key to authenticate against the Anthropic API.
|
||||
|
||||
--aibridge-bedrock-access-key string, $CODER_AIBRIDGE_BEDROCK_ACCESS_KEY
|
||||
The access key to authenticate against the AWS Bedrock API.
|
||||
|
||||
--aibridge-bedrock-access-key-secret string, $CODER_AIBRIDGE_BEDROCK_ACCESS_KEY_SECRET
|
||||
The access key secret to use with the access key to authenticate
|
||||
against the AWS Bedrock API.
|
||||
|
||||
--aibridge-bedrock-model string, $CODER_AIBRIDGE_BEDROCK_MODEL (default: global.anthropic.claude-sonnet-4-5-20250929-v1:0)
|
||||
The model to use when making requests to the AWS Bedrock API.
|
||||
|
||||
--aibridge-bedrock-region string, $CODER_AIBRIDGE_BEDROCK_REGION
|
||||
The AWS Bedrock API region.
|
||||
|
||||
--aibridge-bedrock-small-fastmodel string, $CODER_AIBRIDGE_BEDROCK_SMALL_FAST_MODEL (default: global.anthropic.claude-haiku-4-5-20251001-v1:0)
|
||||
The small fast model to use when making requests to the AWS Bedrock
|
||||
API. Claude Code uses Haiku-class models to perform background tasks.
|
||||
See
|
||||
https://docs.claude.com/en/docs/claude-code/settings#environment-variables.
|
||||
|
||||
--aibridge-enabled bool, $CODER_AIBRIDGE_ENABLED (default: false)
|
||||
Whether to start an in-memory aibridged instance.
|
||||
|
||||
--aibridge-openai-base-url string, $CODER_AIBRIDGE_OPENAI_BASE_URL (default: https://api.openai.com/v1/)
|
||||
The base URL of the OpenAI API.
|
||||
|
||||
--aibridge-openai-key string, $CODER_AIBRIDGE_OPENAI_KEY
|
||||
The key to authenticate against the OpenAI API.
|
||||
|
||||
CLIENT OPTIONS:
|
||||
These options change the behavior of how clients interact with the Coder.
|
||||
Clients include the Coder CLI, Coder Desktop, IDE extensions, and the web UI.
|
||||
|
||||
-4
@@ -9,9 +9,5 @@ OPTIONS:
|
||||
-b, --build int
|
||||
Specify a workspace build to target by name. Defaults to latest.
|
||||
|
||||
-n, --no-build bool
|
||||
Update the state without triggering a workspace build. Useful for
|
||||
state-only migrations.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
|
||||
-5
@@ -16,10 +16,6 @@ USAGE:
|
||||
|
||||
$ coder tokens ls
|
||||
|
||||
- Create a scoped token:
|
||||
|
||||
$ coder tokens create --scope workspace:read --allow workspace:<uuid>
|
||||
|
||||
- Remove a token by ID:
|
||||
|
||||
$ coder tokens rm WuoWs4ZsMX
|
||||
@@ -28,7 +24,6 @@ SUBCOMMANDS:
|
||||
create Create a token
|
||||
list List tokens
|
||||
remove Delete a token
|
||||
view Display detailed information about a token
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
|
||||
+1
-9
@@ -6,20 +6,12 @@ USAGE:
|
||||
Create a token
|
||||
|
||||
OPTIONS:
|
||||
--allow allow-list
|
||||
Repeatable allow-list entry (<type>:<uuid>, e.g. workspace:1234-...).
|
||||
|
||||
--lifetime string, $CODER_TOKEN_LIFETIME
|
||||
Duration for the token lifetime. Supports standard Go duration units
|
||||
(ns, us, ms, s, m, h) plus d (days) and y (years). Examples: 8h, 30d,
|
||||
1y, 1d12h30m.
|
||||
Specify a duration for the lifetime of the token.
|
||||
|
||||
-n, --name string, $CODER_TOKEN_NAME
|
||||
Specify a human-readable name.
|
||||
|
||||
--scope string-array
|
||||
Repeatable scope to attach to the token (e.g. workspace:read).
|
||||
|
||||
-u, --user string, $CODER_TOKEN_USER
|
||||
Specify the user to create the token for (Only works if logged in user
|
||||
is admin).
|
||||
|
||||
+1
-1
@@ -12,7 +12,7 @@ OPTIONS:
|
||||
Specifies whether all users' tokens will be listed or not (must have
|
||||
Owner role to see all tokens).
|
||||
|
||||
-c, --column [id|name|scopes|allow list|last used|expires at|created at|owner] (default: id,name,scopes,allow list,last used,expires at,created at)
|
||||
-c, --column [id|name|last used|expires at|created at|owner] (default: id,name,last used,expires at,created at)
|
||||
Columns to display in table output.
|
||||
|
||||
-o, --output table|json (default: table)
|
||||
|
||||
-16
@@ -1,16 +0,0 @@
|
||||
coder v0.0.0-devel
|
||||
|
||||
USAGE:
|
||||
coder tokens view [flags] <name|id>
|
||||
|
||||
Display detailed information about a token
|
||||
|
||||
OPTIONS:
|
||||
-c, --column [id|name|scopes|allow list|last used|expires at|created at|owner] (default: id,name,scopes,allow list,last used,expires at,created at,owner)
|
||||
Columns to display in table output.
|
||||
|
||||
-o, --output table|json (default: table)
|
||||
Output format.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
+4
-21
@@ -714,7 +714,8 @@ workspace_prebuilds:
|
||||
# (default: 3, type: int)
|
||||
failure_hard_limit: 3
|
||||
aibridge:
|
||||
# Whether to start an in-memory aibridged instance.
|
||||
# Whether to start an in-memory aibridged instance ("aibridge" experiment must be
|
||||
# enabled, too).
|
||||
# (default: false, type: bool)
|
||||
enabled: false
|
||||
# The base URL of the OpenAI API.
|
||||
@@ -725,25 +726,7 @@ aibridge:
|
||||
openai_key: ""
|
||||
# The base URL of the Anthropic API.
|
||||
# (default: https://api.anthropic.com/, type: string)
|
||||
anthropic_base_url: https://api.anthropic.com/
|
||||
base_url: https://api.anthropic.com/
|
||||
# The key to authenticate against the Anthropic API.
|
||||
# (default: <unset>, type: string)
|
||||
anthropic_key: ""
|
||||
# The AWS Bedrock API region.
|
||||
# (default: <unset>, type: string)
|
||||
bedrock_region: ""
|
||||
# The access key to authenticate against the AWS Bedrock API.
|
||||
# (default: <unset>, type: string)
|
||||
bedrock_access_key: ""
|
||||
# The access key secret to use with the access key to authenticate against the AWS
|
||||
# Bedrock API.
|
||||
# (default: <unset>, type: string)
|
||||
bedrock_access_key_secret: ""
|
||||
# The model to use when making requests to the AWS Bedrock API.
|
||||
# (default: global.anthropic.claude-sonnet-4-5-20250929-v1:0, type: string)
|
||||
bedrock_model: global.anthropic.claude-sonnet-4-5-20250929-v1:0
|
||||
# The small fast model to use when making requests to the AWS Bedrock API. Claude
|
||||
# Code uses Haiku-class models to perform background tasks. See
|
||||
# https://docs.claude.com/en/docs/claude-code/settings#environment-variables.
|
||||
# (default: global.anthropic.claude-haiku-4-5-20251001-v1:0, type: string)
|
||||
bedrock_small_fast_model: global.anthropic.claude-haiku-4-5-20251001-v1:0
|
||||
key: ""
|
||||
|
||||
+6
-104
@@ -4,14 +4,12 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
@@ -29,10 +27,6 @@ func (r *RootCmd) tokens() *serpent.Command {
|
||||
Description: "List your tokens",
|
||||
Command: "coder tokens ls",
|
||||
},
|
||||
Example{
|
||||
Description: "Create a scoped token",
|
||||
Command: "coder tokens create --scope workspace:read --allow workspace:<uuid>",
|
||||
},
|
||||
Example{
|
||||
Description: "Remove a token by ID",
|
||||
Command: "coder tokens rm WuoWs4ZsMX",
|
||||
@@ -45,7 +39,6 @@ func (r *RootCmd) tokens() *serpent.Command {
|
||||
Children: []*serpent.Command{
|
||||
r.createToken(),
|
||||
r.listTokens(),
|
||||
r.viewToken(),
|
||||
r.removeToken(),
|
||||
},
|
||||
}
|
||||
@@ -57,8 +50,6 @@ func (r *RootCmd) createToken() *serpent.Command {
|
||||
tokenLifetime string
|
||||
name string
|
||||
user string
|
||||
scopes []string
|
||||
allowList []codersdk.APIAllowListTarget
|
||||
)
|
||||
cmd := &serpent.Command{
|
||||
Use: "create",
|
||||
@@ -97,18 +88,10 @@ func (r *RootCmd) createToken() *serpent.Command {
|
||||
}
|
||||
}
|
||||
|
||||
req := codersdk.CreateTokenRequest{
|
||||
res, err := client.CreateToken(inv.Context(), userID, codersdk.CreateTokenRequest{
|
||||
Lifetime: parsedLifetime,
|
||||
TokenName: name,
|
||||
}
|
||||
if len(req.Scopes) == 0 {
|
||||
req.Scopes = slice.StringEnums[codersdk.APIKeyScope](scopes)
|
||||
}
|
||||
if len(allowList) > 0 {
|
||||
req.AllowList = append([]codersdk.APIAllowListTarget(nil), allowList...)
|
||||
}
|
||||
|
||||
res, err := client.CreateToken(inv.Context(), userID, req)
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create tokens: %w", err)
|
||||
}
|
||||
@@ -123,7 +106,7 @@ func (r *RootCmd) createToken() *serpent.Command {
|
||||
{
|
||||
Flag: "lifetime",
|
||||
Env: "CODER_TOKEN_LIFETIME",
|
||||
Description: "Duration for the token lifetime. Supports standard Go duration units (ns, us, ms, s, m, h) plus d (days) and y (years). Examples: 8h, 30d, 1y, 1d12h30m.",
|
||||
Description: "Specify a duration for the lifetime of the token.",
|
||||
Value: serpent.StringOf(&tokenLifetime),
|
||||
},
|
||||
{
|
||||
@@ -140,16 +123,6 @@ func (r *RootCmd) createToken() *serpent.Command {
|
||||
Description: "Specify the user to create the token for (Only works if logged in user is admin).",
|
||||
Value: serpent.StringOf(&user),
|
||||
},
|
||||
{
|
||||
Flag: "scope",
|
||||
Description: "Repeatable scope to attach to the token (e.g. workspace:read).",
|
||||
Value: serpent.StringArrayOf(&scopes),
|
||||
},
|
||||
{
|
||||
Flag: "allow",
|
||||
Description: "Repeatable allow-list entry (<type>:<uuid>, e.g. workspace:1234-...).",
|
||||
Value: AllowListFlagOf(&allowList),
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
@@ -163,8 +136,6 @@ type tokenListRow struct {
|
||||
// For table format:
|
||||
ID string `json:"-" table:"id,default_sort"`
|
||||
TokenName string `json:"token_name" table:"name"`
|
||||
Scopes string `json:"-" table:"scopes"`
|
||||
Allow string `json:"-" table:"allow list"`
|
||||
LastUsed time.Time `json:"-" table:"last used"`
|
||||
ExpiresAt time.Time `json:"-" table:"expires at"`
|
||||
CreatedAt time.Time `json:"-" table:"created at"`
|
||||
@@ -172,47 +143,20 @@ type tokenListRow struct {
|
||||
}
|
||||
|
||||
func tokenListRowFromToken(token codersdk.APIKeyWithOwner) tokenListRow {
|
||||
return tokenListRowFromKey(token.APIKey, token.Username)
|
||||
}
|
||||
|
||||
func tokenListRowFromKey(token codersdk.APIKey, owner string) tokenListRow {
|
||||
return tokenListRow{
|
||||
APIKey: token,
|
||||
APIKey: token.APIKey,
|
||||
ID: token.ID,
|
||||
TokenName: token.TokenName,
|
||||
Scopes: joinScopes(token.Scopes),
|
||||
Allow: joinAllowList(token.AllowList),
|
||||
LastUsed: token.LastUsed,
|
||||
ExpiresAt: token.ExpiresAt,
|
||||
CreatedAt: token.CreatedAt,
|
||||
Owner: owner,
|
||||
Owner: token.Username,
|
||||
}
|
||||
}
|
||||
|
||||
func joinScopes(scopes []codersdk.APIKeyScope) string {
|
||||
if len(scopes) == 0 {
|
||||
return ""
|
||||
}
|
||||
vals := slice.ToStrings(scopes)
|
||||
sort.Strings(vals)
|
||||
return strings.Join(vals, ", ")
|
||||
}
|
||||
|
||||
func joinAllowList(entries []codersdk.APIAllowListTarget) string {
|
||||
if len(entries) == 0 {
|
||||
return ""
|
||||
}
|
||||
vals := make([]string, len(entries))
|
||||
for i, entry := range entries {
|
||||
vals[i] = entry.String()
|
||||
}
|
||||
sort.Strings(vals)
|
||||
return strings.Join(vals, ", ")
|
||||
}
|
||||
|
||||
func (r *RootCmd) listTokens() *serpent.Command {
|
||||
// we only display the 'owner' column if the --all argument is passed in
|
||||
defaultCols := []string{"id", "name", "scopes", "allow list", "last used", "expires at", "created at"}
|
||||
defaultCols := []string{"id", "name", "last used", "expires at", "created at"}
|
||||
if slices.Contains(os.Args, "-a") || slices.Contains(os.Args, "--all") {
|
||||
defaultCols = append(defaultCols, "owner")
|
||||
}
|
||||
@@ -282,48 +226,6 @@ func (r *RootCmd) listTokens() *serpent.Command {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (r *RootCmd) viewToken() *serpent.Command {
|
||||
formatter := cliui.NewOutputFormatter(
|
||||
cliui.TableFormat([]tokenListRow{}, []string{"id", "name", "scopes", "allow list", "last used", "expires at", "created at", "owner"}),
|
||||
cliui.JSONFormat(),
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "view <name|id>",
|
||||
Short: "Display detailed information about a token",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(1),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tokenName := inv.Args[0]
|
||||
token, err := client.APIKeyByName(inv.Context(), codersdk.Me, tokenName)
|
||||
if err != nil {
|
||||
maybeID := strings.Split(tokenName, "-")[0]
|
||||
token, err = client.APIKeyByID(inv.Context(), codersdk.Me, maybeID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("fetch api key by name or id: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
row := tokenListRowFromKey(*token, "")
|
||||
out, err := formatter.Format(inv.Context(), []tokenListRow{row})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintln(inv.Stdout, out)
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
formatter.AttachOptions(&cmd.Options)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (r *RootCmd) removeToken() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "remove <name|id|token>",
|
||||
|
||||
+3
-56
@@ -4,13 +4,10 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -49,18 +46,6 @@ func TestTokens(t *testing.T) {
|
||||
require.NotEmpty(t, res)
|
||||
id := res[:10]
|
||||
|
||||
allowWorkspaceID := uuid.New()
|
||||
allowSpec := fmt.Sprintf("workspace:%s", allowWorkspaceID.String())
|
||||
inv, root = clitest.New(t, "tokens", "create", "--name", "scoped-token", "--scope", string(codersdk.APIKeyScopeWorkspaceRead), "--allow", allowSpec)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
res = buf.String()
|
||||
require.NotEmpty(t, res)
|
||||
scopedTokenID := res[:10]
|
||||
|
||||
// Test creating a token for second user from first user's (admin) session
|
||||
inv, root = clitest.New(t, "tokens", "create", "--name", "token-two", "--user", secondUser.ID.String())
|
||||
clitest.SetupConfig(t, client, root)
|
||||
@@ -82,7 +67,7 @@ func TestTokens(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
res = buf.String()
|
||||
require.NotEmpty(t, res)
|
||||
// Result should only contain the tokens created for the admin user
|
||||
// Result should only contain the token created for the admin user
|
||||
require.Contains(t, res, "ID")
|
||||
require.Contains(t, res, "EXPIRES AT")
|
||||
require.Contains(t, res, "CREATED AT")
|
||||
@@ -91,16 +76,6 @@ func TestTokens(t *testing.T) {
|
||||
// Result should not contain the token created for the second user
|
||||
require.NotContains(t, res, secondTokenID)
|
||||
|
||||
inv, root = clitest.New(t, "tokens", "view", "scoped-token")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
res = buf.String()
|
||||
require.Contains(t, res, string(codersdk.APIKeyScopeWorkspaceRead))
|
||||
require.Contains(t, res, allowSpec)
|
||||
|
||||
// Test listing tokens from the second user's session
|
||||
inv, root = clitest.New(t, "tokens", "ls")
|
||||
clitest.SetupConfig(t, secondUserClient, root)
|
||||
@@ -126,14 +101,6 @@ func TestTokens(t *testing.T) {
|
||||
// User (non-admin) should not be able to create a token for another user
|
||||
require.Error(t, err)
|
||||
|
||||
inv, root = clitest.New(t, "tokens", "create", "--name", "invalid-allow", "--allow", "badvalue")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid allow_list entry")
|
||||
|
||||
inv, root = clitest.New(t, "tokens", "ls", "--output=json")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
@@ -143,17 +110,8 @@ func TestTokens(t *testing.T) {
|
||||
|
||||
var tokens []codersdk.APIKey
|
||||
require.NoError(t, json.Unmarshal(buf.Bytes(), &tokens))
|
||||
require.Len(t, tokens, 2)
|
||||
tokenByName := make(map[string]codersdk.APIKey, len(tokens))
|
||||
for _, tk := range tokens {
|
||||
tokenByName[tk.TokenName] = tk
|
||||
}
|
||||
require.Contains(t, tokenByName, "token-one")
|
||||
require.Contains(t, tokenByName, "scoped-token")
|
||||
scopedToken := tokenByName["scoped-token"]
|
||||
require.Contains(t, scopedToken.Scopes, codersdk.APIKeyScopeWorkspaceRead)
|
||||
require.Len(t, scopedToken.AllowList, 1)
|
||||
require.Equal(t, allowSpec, scopedToken.AllowList[0].String())
|
||||
require.Len(t, tokens, 1)
|
||||
require.Equal(t, id, tokens[0].ID)
|
||||
|
||||
// Delete by name
|
||||
inv, root = clitest.New(t, "tokens", "rm", "token-one")
|
||||
@@ -177,17 +135,6 @@ func TestTokens(t *testing.T) {
|
||||
require.NotEmpty(t, res)
|
||||
require.Contains(t, res, "deleted")
|
||||
|
||||
// Delete scoped token by ID
|
||||
inv, root = clitest.New(t, "tokens", "rm", scopedTokenID)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
res = buf.String()
|
||||
require.NotEmpty(t, res)
|
||||
require.Contains(t, res, "deleted")
|
||||
|
||||
// Create third token
|
||||
inv, root = clitest.New(t, "tokens", "create", "--name", "token-three")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
@@ -239,10 +239,6 @@ func (a *API) Serve(ctx context.Context, l net.Listener) error {
|
||||
return xerrors.Errorf("create agent API server: %w", err)
|
||||
}
|
||||
|
||||
if err := a.ResourcesMonitoringAPI.InitMonitors(ctx); err != nil {
|
||||
return xerrors.Errorf("initialize resource monitoring: %w", err)
|
||||
}
|
||||
|
||||
return server.Serve(ctx, l)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
@@ -34,60 +33,42 @@ type ResourcesMonitoringAPI struct {
|
||||
|
||||
Debounce time.Duration
|
||||
Config resourcesmonitor.Config
|
||||
|
||||
// Cache resource monitors on first call to avoid millions of DB queries per day.
|
||||
memoryMonitor database.WorkspaceAgentMemoryResourceMonitor
|
||||
volumeMonitors []database.WorkspaceAgentVolumeResourceMonitor
|
||||
monitorsLock sync.RWMutex
|
||||
}
|
||||
|
||||
// InitMonitors fetches resource monitors from the database and caches them.
|
||||
// This must be called once after creating a ResourcesMonitoringAPI, the context should be
|
||||
// the agent per-RPC connection context. If fetching fails with a real error (not sql.ErrNoRows), the
|
||||
// connection should be torn down.
|
||||
func (a *ResourcesMonitoringAPI) InitMonitors(ctx context.Context) error {
|
||||
memMon, err := a.Database.FetchMemoryResourceMonitorsByAgentID(ctx, a.AgentID)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return xerrors.Errorf("fetch memory resource monitor: %w", err)
|
||||
}
|
||||
// If sql.ErrNoRows, memoryMonitor stays as zero value (CreatedAt.IsZero() = true).
|
||||
// Otherwise, store the fetched monitor.
|
||||
if err == nil {
|
||||
a.memoryMonitor = memMon
|
||||
func (a *ResourcesMonitoringAPI) GetResourcesMonitoringConfiguration(ctx context.Context, _ *proto.GetResourcesMonitoringConfigurationRequest) (*proto.GetResourcesMonitoringConfigurationResponse, error) {
|
||||
memoryMonitor, memoryErr := a.Database.FetchMemoryResourceMonitorsByAgentID(ctx, a.AgentID)
|
||||
if memoryErr != nil && !errors.Is(memoryErr, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("failed to fetch memory resource monitor: %w", memoryErr)
|
||||
}
|
||||
|
||||
volMons, err := a.Database.FetchVolumesResourceMonitorsByAgentID(ctx, a.AgentID)
|
||||
volumeMonitors, err := a.Database.FetchVolumesResourceMonitorsByAgentID(ctx, a.AgentID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("fetch volume resource monitors: %w", err)
|
||||
return nil, xerrors.Errorf("failed to fetch volume resource monitors: %w", err)
|
||||
}
|
||||
// 0 length is valid, indicating none configured, since the volume monitors in the DB can be many.
|
||||
a.volumeMonitors = volMons
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *ResourcesMonitoringAPI) GetResourcesMonitoringConfiguration(_ context.Context, _ *proto.GetResourcesMonitoringConfigurationRequest) (*proto.GetResourcesMonitoringConfigurationResponse, error) {
|
||||
return &proto.GetResourcesMonitoringConfigurationResponse{
|
||||
Config: &proto.GetResourcesMonitoringConfigurationResponse_Config{
|
||||
CollectionIntervalSeconds: int32(a.Config.CollectionInterval.Seconds()),
|
||||
NumDatapoints: a.Config.NumDatapoints,
|
||||
},
|
||||
Memory: func() *proto.GetResourcesMonitoringConfigurationResponse_Memory {
|
||||
if a.memoryMonitor.CreatedAt.IsZero() {
|
||||
if memoryErr != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &proto.GetResourcesMonitoringConfigurationResponse_Memory{
|
||||
Enabled: a.memoryMonitor.Enabled,
|
||||
Enabled: memoryMonitor.Enabled,
|
||||
}
|
||||
}(),
|
||||
Volumes: func() []*proto.GetResourcesMonitoringConfigurationResponse_Volume {
|
||||
volumes := make([]*proto.GetResourcesMonitoringConfigurationResponse_Volume, 0, len(a.volumeMonitors))
|
||||
for _, monitor := range a.volumeMonitors {
|
||||
volumes := make([]*proto.GetResourcesMonitoringConfigurationResponse_Volume, 0, len(volumeMonitors))
|
||||
for _, monitor := range volumeMonitors {
|
||||
volumes = append(volumes, &proto.GetResourcesMonitoringConfigurationResponse_Volume{
|
||||
Enabled: monitor.Enabled,
|
||||
Path: monitor.Path,
|
||||
})
|
||||
}
|
||||
|
||||
return volumes
|
||||
}(),
|
||||
}, nil
|
||||
@@ -96,10 +77,6 @@ func (a *ResourcesMonitoringAPI) GetResourcesMonitoringConfiguration(_ context.C
|
||||
func (a *ResourcesMonitoringAPI) PushResourcesMonitoringUsage(ctx context.Context, req *proto.PushResourcesMonitoringUsageRequest) (*proto.PushResourcesMonitoringUsageResponse, error) {
|
||||
var err error
|
||||
|
||||
// Lock for the entire push operation since calls are sequential from the agent
|
||||
a.monitorsLock.Lock()
|
||||
defer a.monitorsLock.Unlock()
|
||||
|
||||
if memoryErr := a.monitorMemory(ctx, req.Datapoints); memoryErr != nil {
|
||||
err = errors.Join(err, xerrors.Errorf("monitor memory: %w", memoryErr))
|
||||
}
|
||||
@@ -112,7 +89,18 @@ func (a *ResourcesMonitoringAPI) PushResourcesMonitoringUsage(ctx context.Contex
|
||||
}
|
||||
|
||||
func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints []*proto.PushResourcesMonitoringUsageRequest_Datapoint) error {
|
||||
if !a.memoryMonitor.Enabled {
|
||||
monitor, err := a.Database.FetchMemoryResourceMonitorsByAgentID(ctx, a.AgentID)
|
||||
if err != nil {
|
||||
// It is valid for an agent to not have a memory monitor, so we
|
||||
// do not want to treat it as an error.
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return xerrors.Errorf("fetch memory resource monitor: %w", err)
|
||||
}
|
||||
|
||||
if !monitor.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -121,15 +109,15 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
|
||||
usageDatapoints = append(usageDatapoints, datapoint.Memory)
|
||||
}
|
||||
|
||||
usageStates := resourcesmonitor.CalculateMemoryUsageStates(a.memoryMonitor, usageDatapoints)
|
||||
usageStates := resourcesmonitor.CalculateMemoryUsageStates(monitor, usageDatapoints)
|
||||
|
||||
oldState := a.memoryMonitor.State
|
||||
oldState := monitor.State
|
||||
newState := resourcesmonitor.NextState(a.Config, oldState, usageStates)
|
||||
|
||||
debouncedUntil, shouldNotify := a.memoryMonitor.Debounce(a.Debounce, a.Clock.Now(), oldState, newState)
|
||||
debouncedUntil, shouldNotify := monitor.Debounce(a.Debounce, a.Clock.Now(), oldState, newState)
|
||||
|
||||
//nolint:gocritic // We need to be able to update the resource monitor here.
|
||||
err := a.Database.UpdateMemoryResourceMonitor(dbauthz.AsResourceMonitor(ctx), database.UpdateMemoryResourceMonitorParams{
|
||||
err = a.Database.UpdateMemoryResourceMonitor(dbauthz.AsResourceMonitor(ctx), database.UpdateMemoryResourceMonitorParams{
|
||||
AgentID: a.AgentID,
|
||||
State: newState,
|
||||
UpdatedAt: dbtime.Time(a.Clock.Now()),
|
||||
@@ -139,11 +127,6 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
|
||||
return xerrors.Errorf("update workspace monitor: %w", err)
|
||||
}
|
||||
|
||||
// Update cached state
|
||||
a.memoryMonitor.State = newState
|
||||
a.memoryMonitor.DebouncedUntil = dbtime.Time(debouncedUntil)
|
||||
a.memoryMonitor.UpdatedAt = dbtime.Time(a.Clock.Now())
|
||||
|
||||
if !shouldNotify {
|
||||
return nil
|
||||
}
|
||||
@@ -160,7 +143,7 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
|
||||
notifications.TemplateWorkspaceOutOfMemory,
|
||||
map[string]string{
|
||||
"workspace": workspace.Name,
|
||||
"threshold": fmt.Sprintf("%d%%", a.memoryMonitor.Threshold),
|
||||
"threshold": fmt.Sprintf("%d%%", monitor.Threshold),
|
||||
},
|
||||
map[string]any{
|
||||
// NOTE(DanielleMaywood):
|
||||
@@ -186,9 +169,14 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
|
||||
}
|
||||
|
||||
func (a *ResourcesMonitoringAPI) monitorVolumes(ctx context.Context, datapoints []*proto.PushResourcesMonitoringUsageRequest_Datapoint) error {
|
||||
volumeMonitors, err := a.Database.FetchVolumesResourceMonitorsByAgentID(ctx, a.AgentID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get or insert volume monitor: %w", err)
|
||||
}
|
||||
|
||||
outOfDiskVolumes := make([]map[string]any, 0)
|
||||
|
||||
for i, monitor := range a.volumeMonitors {
|
||||
for _, monitor := range volumeMonitors {
|
||||
if !monitor.Enabled {
|
||||
continue
|
||||
}
|
||||
@@ -231,11 +219,6 @@ func (a *ResourcesMonitoringAPI) monitorVolumes(ctx context.Context, datapoints
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update workspace monitor: %w", err)
|
||||
}
|
||||
|
||||
// Update cached state
|
||||
a.volumeMonitors[i].State = newState
|
||||
a.volumeMonitors[i].DebouncedUntil = dbtime.Time(debouncedUntil)
|
||||
a.volumeMonitors[i].UpdatedAt = dbtime.Time(a.Clock.Now())
|
||||
}
|
||||
|
||||
if len(outOfDiskVolumes) == 0 {
|
||||
|
||||
@@ -101,9 +101,6 @@ func TestMemoryResourceMonitorDebounce(t *testing.T) {
|
||||
Threshold: 80,
|
||||
})
|
||||
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
// When: The monitor is given a state that will trigger NOK
|
||||
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
|
||||
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
|
||||
@@ -307,9 +304,6 @@ func TestMemoryResourceMonitor(t *testing.T) {
|
||||
Threshold: 80,
|
||||
})
|
||||
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
clock.Set(collectedAt)
|
||||
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
|
||||
Datapoints: datapoints,
|
||||
@@ -343,8 +337,6 @@ func TestMemoryResourceMonitorMissingData(t *testing.T) {
|
||||
State: database.WorkspaceAgentMonitorStateOK,
|
||||
Threshold: 80,
|
||||
})
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
// When: A datapoint is missing, surrounded by two NOK datapoints.
|
||||
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
|
||||
@@ -395,9 +387,6 @@ func TestMemoryResourceMonitorMissingData(t *testing.T) {
|
||||
Threshold: 80,
|
||||
})
|
||||
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
// When: A datapoint is missing, surrounded by two OK datapoints.
|
||||
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
|
||||
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
|
||||
@@ -477,9 +466,6 @@ func TestVolumeResourceMonitorDebounce(t *testing.T) {
|
||||
Threshold: 80,
|
||||
})
|
||||
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
// When:
|
||||
// - First monitor is in a NOK state
|
||||
// - Second monitor is in an OK state
|
||||
@@ -756,9 +742,6 @@ func TestVolumeResourceMonitor(t *testing.T) {
|
||||
Threshold: tt.thresholdPercent,
|
||||
})
|
||||
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
clock.Set(collectedAt)
|
||||
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
|
||||
Datapoints: datapoints,
|
||||
@@ -797,9 +780,6 @@ func TestVolumeResourceMonitorMultiple(t *testing.T) {
|
||||
Threshold: 80,
|
||||
})
|
||||
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
// When: both of them move to a NOK state
|
||||
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
|
||||
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
|
||||
@@ -852,9 +832,6 @@ func TestVolumeResourceMonitorMissingData(t *testing.T) {
|
||||
Threshold: 80,
|
||||
})
|
||||
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
// When: A datapoint is missing, surrounded by two NOK datapoints.
|
||||
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
|
||||
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
|
||||
@@ -914,9 +891,6 @@ func TestVolumeResourceMonitorMissingData(t *testing.T) {
|
||||
Threshold: 80,
|
||||
})
|
||||
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
// When: A datapoint is missing, surrounded by two OK datapoints.
|
||||
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
|
||||
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
|
||||
|
||||
+3
-4
@@ -143,7 +143,7 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if !templateVersion.HasAITask.Valid || !templateVersion.HasAITask.Bool {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: `Template does not have a valid "coder_ai_task" resource.`,
|
||||
Message: fmt.Sprintf(`Template does not have required parameter %q`, codersdk.AITaskPromptParameterName),
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -241,7 +241,6 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
|
||||
// Create task record in the database before creating the workspace so that
|
||||
// we can request that the workspace be linked to it after creation.
|
||||
dbTaskTable, err = tx.InsertTask(ctx, database.InsertTaskParams{
|
||||
ID: uuid.New(),
|
||||
OrganizationID: templateVersion.OrganizationID,
|
||||
OwnerID: owner.ID,
|
||||
Name: taskName,
|
||||
@@ -339,8 +338,8 @@ func taskFromDBTaskAndWorkspace(dbTask database.Task, ws codersdk.Workspace) cod
|
||||
ID: dbTask.ID,
|
||||
OrganizationID: dbTask.OrganizationID,
|
||||
OwnerID: dbTask.OwnerID,
|
||||
OwnerName: dbTask.OwnerUsername,
|
||||
OwnerAvatarURL: dbTask.OwnerAvatarUrl,
|
||||
OwnerName: ws.OwnerName,
|
||||
OwnerAvatarURL: ws.OwnerAvatarURL,
|
||||
Name: dbTask.Name,
|
||||
TemplateID: ws.TemplateID,
|
||||
TemplateVersionID: dbTask.TemplateVersionID,
|
||||
|
||||
+10
-85
@@ -1,7 +1,6 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"io"
|
||||
@@ -259,9 +258,6 @@ func TestTasks(t *testing.T) {
|
||||
// Wait for the workspace to be built.
|
||||
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
if assert.True(t, workspace.TaskID.Valid, "task id should be set on workspace") {
|
||||
assert.Equal(t, task.ID, workspace.TaskID.UUID, "workspace task id should match")
|
||||
}
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
// List tasks via experimental API and verify the prompt and status mapping.
|
||||
@@ -300,9 +296,6 @@ func TestTasks(t *testing.T) {
|
||||
// Get the workspace and wait for it to be ready.
|
||||
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
if assert.True(t, ws.TaskID.Valid, "task id should be set on workspace") {
|
||||
assert.Equal(t, task.ID, ws.TaskID.UUID, "workspace task id should match")
|
||||
}
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
ws = coderdtest.MustWorkspace(t, client, task.WorkspaceID.UUID)
|
||||
// Assert invariant: the workspace has exactly one resource with one agent with one app.
|
||||
@@ -377,23 +370,24 @@ func TestTasks(t *testing.T) {
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
if assert.True(t, ws.TaskID.Valid, "task id should be set on workspace") {
|
||||
assert.Equal(t, task.ID, ws.TaskID.UUID, "workspace task id should match")
|
||||
}
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
|
||||
err = exp.DeleteTask(ctx, "me", task.ID)
|
||||
require.NoError(t, err, "delete task request should be accepted")
|
||||
|
||||
// Poll until the workspace is deleted.
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) (done bool) {
|
||||
for {
|
||||
dws, derr := client.DeletedWorkspace(ctx, task.WorkspaceID.UUID)
|
||||
if !assert.NoError(t, derr, "expected to fetch deleted workspace before deadline") {
|
||||
return false
|
||||
if derr == nil && dws.LatestBuild.Status == codersdk.WorkspaceStatusDeleted {
|
||||
break
|
||||
}
|
||||
t.Logf("workspace latest_build status: %q", dws.LatestBuild.Status)
|
||||
return dws.LatestBuild.Status == codersdk.WorkspaceStatusDeleted
|
||||
}, testutil.IntervalMedium, "workspace should be deleted before deadline")
|
||||
if ctx.Err() != nil {
|
||||
require.NoError(t, derr, "expected to fetch deleted workspace before deadline")
|
||||
require.Equal(t, codersdk.WorkspaceStatusDeleted, dws.LatestBuild.Status, "workspace should be deleted before deadline")
|
||||
break
|
||||
}
|
||||
time.Sleep(testutil.IntervalMedium)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
@@ -426,9 +420,6 @@ func TestTasks(t *testing.T) {
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
ws := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
if assert.False(t, ws.TaskID.Valid, "task id should not be set on non-task workspace") {
|
||||
assert.Zero(t, ws.TaskID, "non-task workspace task id should be empty")
|
||||
}
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
@@ -477,72 +468,6 @@ func TestTasks(t *testing.T) {
|
||||
t.Fatalf("unexpected status code: %d (expected 403 or 404)", authErr.StatusCode())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DeletedWorkspace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
template := createAITemplate(t, client, user)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
task, err := exp.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
|
||||
TemplateVersionID: template.ActiveVersionID,
|
||||
Input: "delete me",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
|
||||
// Mark the workspace as deleted directly in the database, bypassing provisionerd.
|
||||
require.NoError(t, db.UpdateWorkspaceDeletedByID(dbauthz.AsProvisionerd(ctx), database.UpdateWorkspaceDeletedByIDParams{
|
||||
ID: ws.ID,
|
||||
Deleted: true,
|
||||
}))
|
||||
// We should still be able to fetch the task if its workspace was deleted.
|
||||
// Provisionerdserver will attempt delete the related task when deleting a workspace.
|
||||
// This test ensures that we can still handle the case where, for some reason, the
|
||||
// task has not been marked as deleted, but the workspace has.
|
||||
task, err = exp.TaskByID(ctx, task.ID)
|
||||
require.NoError(t, err, "fetching a task should still work if its related workspace is deleted")
|
||||
err = exp.DeleteTask(ctx, task.OwnerID.String(), task.ID)
|
||||
require.NoError(t, err, "should be possible to delete a task with no workspace")
|
||||
})
|
||||
|
||||
t.Run("DeletingTaskWorkspaceDeletesTask", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
template := createAITemplate(t, client, user)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
task, err := exp.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
|
||||
TemplateVersionID: template.ActiveVersionID,
|
||||
Input: "delete me",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
if assert.True(t, ws.TaskID.Valid, "task id should be set on workspace") {
|
||||
assert.Equal(t, task.ID, ws.TaskID.UUID, "workspace task id should match")
|
||||
}
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
|
||||
// When; the task workspace is deleted
|
||||
coderdtest.MustTransitionWorkspace(t, client, ws.ID, codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionDelete)
|
||||
// Then: the task associated with the workspace is also deleted
|
||||
_, err = exp.TaskByID(ctx, task.ID)
|
||||
require.Error(t, err, "expected an error fetching the task")
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr, "expected a codersdk.Error")
|
||||
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Send", func(t *testing.T) {
|
||||
|
||||
Generated
+11
-110
@@ -85,7 +85,7 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/aibridge/interceptions": {
|
||||
"/api/experimental/aibridge/interceptions": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
@@ -10008,45 +10008,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"put": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Builds"
|
||||
],
|
||||
"summary": "Update workspace build state",
|
||||
"operationId": "update-workspace-build-state",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Workspace build ID",
|
||||
"name": "workspacebuild",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Request body",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateWorkspaceBuildStateRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/workspacebuilds/{workspacebuild}/timings": {
|
||||
@@ -11707,35 +11668,12 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeBedrockConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"access_key": {
|
||||
"type": "string"
|
||||
},
|
||||
"access_key_secret": {
|
||||
"type": "string"
|
||||
},
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"region": {
|
||||
"type": "string"
|
||||
},
|
||||
"small_fast_model": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"anthropic": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAnthropicConfig"
|
||||
},
|
||||
"bedrock": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeBedrockConfig"
|
||||
},
|
||||
"enabled": {
|
||||
"type": "boolean"
|
||||
},
|
||||
@@ -11747,10 +11685,6 @@ const docTemplate = `{
|
||||
"codersdk.AIBridgeInterception": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -12562,13 +12496,6 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"organization_member_permissions": {
|
||||
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.Permission"
|
||||
}
|
||||
},
|
||||
"organization_permissions": {
|
||||
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
|
||||
"type": "array",
|
||||
@@ -13792,13 +13719,6 @@ const docTemplate = `{
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"organization_member_permissions": {
|
||||
"description": "OrganizationMemberPermissions are specific to the organization the role belongs to.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.Permission"
|
||||
}
|
||||
},
|
||||
"organization_permissions": {
|
||||
"description": "OrganizationPermissions are specific to the organization the role belongs to.",
|
||||
"type": "array",
|
||||
@@ -14355,9 +14275,11 @@ const docTemplate = `{
|
||||
"web-push",
|
||||
"oauth2",
|
||||
"mcp-server-http",
|
||||
"workspace-sharing"
|
||||
"workspace-sharing",
|
||||
"aibridge"
|
||||
],
|
||||
"x-enum-comments": {
|
||||
"ExperimentAIBridge": "Enables AI Bridge functionality.",
|
||||
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
|
||||
"ExperimentExample": "This isn't used for anything.",
|
||||
"ExperimentMCPServerHTTP": "Enables the MCP HTTP server functionality.",
|
||||
@@ -14375,7 +14297,8 @@ const docTemplate = `{
|
||||
"ExperimentWebPush",
|
||||
"ExperimentOAuth2",
|
||||
"ExperimentMCPServerHTTP",
|
||||
"ExperimentWorkspaceSharing"
|
||||
"ExperimentWorkspaceSharing",
|
||||
"ExperimentAIBridge"
|
||||
]
|
||||
},
|
||||
"codersdk.ExternalAPIKeyScopes": {
|
||||
@@ -17563,13 +17486,6 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"organization_member_permissions": {
|
||||
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.Permission"
|
||||
}
|
||||
},
|
||||
"organization_permissions": {
|
||||
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
|
||||
"type": "array",
|
||||
@@ -19198,17 +19114,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateWorkspaceBuildStateRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"state": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateWorkspaceDormancy": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -19762,14 +19667,6 @@ const docTemplate = `{
|
||||
"description": "OwnerName is the username of the owner of the workspace.",
|
||||
"type": "string"
|
||||
},
|
||||
"task_id": {
|
||||
"description": "TaskID, if set, indicates that the workspace is relevant to the given codersdk.Task.",
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/uuid.NullUUID"
|
||||
}
|
||||
]
|
||||
},
|
||||
"template_active_version_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -20577,7 +20474,7 @@ const docTemplate = `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ai_task_sidebar_app_id": {
|
||||
"description": "Deprecated: This field has been replaced with ` + "`" + `Task.WorkspaceAppID` + "`" + `",
|
||||
"description": "Deprecated: This field has been replaced with ` + "`" + `TaskAppID` + "`" + `",
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
@@ -20659,6 +20556,10 @@ const docTemplate = `{
|
||||
}
|
||||
]
|
||||
},
|
||||
"task_app_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"template_version_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
|
||||
Generated
+11
-106
@@ -65,7 +65,7 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/aibridge/interceptions": {
|
||||
"/api/experimental/aibridge/interceptions": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
@@ -8870,41 +8870,6 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"put": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"consumes": ["application/json"],
|
||||
"tags": ["Builds"],
|
||||
"summary": "Update workspace build state",
|
||||
"operationId": "update-workspace-build-state",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Workspace build ID",
|
||||
"name": "workspacebuild",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Request body",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateWorkspaceBuildStateRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/workspacebuilds/{workspacebuild}/timings": {
|
||||
@@ -10399,35 +10364,12 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeBedrockConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"access_key": {
|
||||
"type": "string"
|
||||
},
|
||||
"access_key_secret": {
|
||||
"type": "string"
|
||||
},
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"region": {
|
||||
"type": "string"
|
||||
},
|
||||
"small_fast_model": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"anthropic": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAnthropicConfig"
|
||||
},
|
||||
"bedrock": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeBedrockConfig"
|
||||
},
|
||||
"enabled": {
|
||||
"type": "boolean"
|
||||
},
|
||||
@@ -10439,10 +10381,6 @@
|
||||
"codersdk.AIBridgeInterception": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -11240,13 +11178,6 @@
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"organization_member_permissions": {
|
||||
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.Permission"
|
||||
}
|
||||
},
|
||||
"organization_permissions": {
|
||||
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
|
||||
"type": "array",
|
||||
@@ -12402,13 +12333,6 @@
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"organization_member_permissions": {
|
||||
"description": "OrganizationMemberPermissions are specific to the organization the role belongs to.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.Permission"
|
||||
}
|
||||
},
|
||||
"organization_permissions": {
|
||||
"description": "OrganizationPermissions are specific to the organization the role belongs to.",
|
||||
"type": "array",
|
||||
@@ -12958,9 +12882,11 @@
|
||||
"web-push",
|
||||
"oauth2",
|
||||
"mcp-server-http",
|
||||
"workspace-sharing"
|
||||
"workspace-sharing",
|
||||
"aibridge"
|
||||
],
|
||||
"x-enum-comments": {
|
||||
"ExperimentAIBridge": "Enables AI Bridge functionality.",
|
||||
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
|
||||
"ExperimentExample": "This isn't used for anything.",
|
||||
"ExperimentMCPServerHTTP": "Enables the MCP HTTP server functionality.",
|
||||
@@ -12978,7 +12904,8 @@
|
||||
"ExperimentWebPush",
|
||||
"ExperimentOAuth2",
|
||||
"ExperimentMCPServerHTTP",
|
||||
"ExperimentWorkspaceSharing"
|
||||
"ExperimentWorkspaceSharing",
|
||||
"ExperimentAIBridge"
|
||||
]
|
||||
},
|
||||
"codersdk.ExternalAPIKeyScopes": {
|
||||
@@ -16051,13 +15978,6 @@
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"organization_member_permissions": {
|
||||
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.Permission"
|
||||
}
|
||||
},
|
||||
"organization_permissions": {
|
||||
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
|
||||
"type": "array",
|
||||
@@ -17616,17 +17536,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateWorkspaceBuildStateRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"state": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateWorkspaceDormancy": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -18144,14 +18053,6 @@
|
||||
"description": "OwnerName is the username of the owner of the workspace.",
|
||||
"type": "string"
|
||||
},
|
||||
"task_id": {
|
||||
"description": "TaskID, if set, indicates that the workspace is relevant to the given codersdk.Task.",
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/uuid.NullUUID"
|
||||
}
|
||||
]
|
||||
},
|
||||
"template_active_version_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -18907,7 +18808,7 @@
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ai_task_sidebar_app_id": {
|
||||
"description": "Deprecated: This field has been replaced with `Task.WorkspaceAppID`",
|
||||
"description": "Deprecated: This field has been replaced with `TaskAppID`",
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
@@ -18985,6 +18886,10 @@
|
||||
}
|
||||
]
|
||||
},
|
||||
"task_app_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"template_version_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
|
||||
+2
-2
@@ -509,11 +509,11 @@ func (api *API) auditLogResourceLink(ctx context.Context, alog database.GetAudit
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
user, err := api.Database.GetUserByID(ctx, task.OwnerID)
|
||||
workspace, err := api.Database.GetWorkspaceByID(ctx, task.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("/tasks/%s/%s", user.Username, task.ID)
|
||||
return fmt.Sprintf("/tasks/%s/%s", workspace.OwnerName, task.Name)
|
||||
|
||||
default:
|
||||
return ""
|
||||
|
||||
+10
-11
@@ -50,13 +50,6 @@ func TestCheckPermissions(t *testing.T) {
|
||||
},
|
||||
Action: "read",
|
||||
},
|
||||
readOrgWorkspaces: {
|
||||
Object: codersdk.AuthorizationObject{
|
||||
ResourceType: codersdk.ResourceWorkspace,
|
||||
OrganizationID: adminUser.OrganizationID.String(),
|
||||
},
|
||||
Action: "read",
|
||||
},
|
||||
readMyself: {
|
||||
Object: codersdk.AuthorizationObject{
|
||||
ResourceType: codersdk.ResourceUser,
|
||||
@@ -65,10 +58,16 @@ func TestCheckPermissions(t *testing.T) {
|
||||
Action: "read",
|
||||
},
|
||||
readOwnWorkspaces: {
|
||||
Object: codersdk.AuthorizationObject{
|
||||
ResourceType: codersdk.ResourceWorkspace,
|
||||
OwnerID: "me",
|
||||
},
|
||||
Action: "read",
|
||||
},
|
||||
readOrgWorkspaces: {
|
||||
Object: codersdk.AuthorizationObject{
|
||||
ResourceType: codersdk.ResourceWorkspace,
|
||||
OrganizationID: adminUser.OrganizationID.String(),
|
||||
OwnerID: "me",
|
||||
},
|
||||
Action: "read",
|
||||
},
|
||||
@@ -93,9 +92,9 @@ func TestCheckPermissions(t *testing.T) {
|
||||
UserID: adminUser.UserID,
|
||||
Check: map[string]bool{
|
||||
readAllUsers: true,
|
||||
readOrgWorkspaces: true,
|
||||
readMyself: true,
|
||||
readOwnWorkspaces: true,
|
||||
readOrgWorkspaces: true,
|
||||
updateSpecificTemplate: true,
|
||||
},
|
||||
},
|
||||
@@ -105,9 +104,9 @@ func TestCheckPermissions(t *testing.T) {
|
||||
UserID: orgAdminUser.ID,
|
||||
Check: map[string]bool{
|
||||
readAllUsers: true,
|
||||
readOrgWorkspaces: true,
|
||||
readMyself: true,
|
||||
readOwnWorkspaces: true,
|
||||
readOrgWorkspaces: true,
|
||||
updateSpecificTemplate: true,
|
||||
},
|
||||
},
|
||||
@@ -117,9 +116,9 @@ func TestCheckPermissions(t *testing.T) {
|
||||
UserID: memberUser.ID,
|
||||
Check: map[string]bool{
|
||||
readAllUsers: false,
|
||||
readOrgWorkspaces: false,
|
||||
readMyself: true,
|
||||
readOwnWorkspaces: true,
|
||||
readOrgWorkspaces: false,
|
||||
updateSpecificTemplate: false,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -1764,175 +1764,3 @@ func TestExecutorAutostartSkipsWhenNoProvisionersAvailable(t *testing.T) {
|
||||
|
||||
assert.Len(t, stats.Transitions, 1, "should create builds when provisioners are available")
|
||||
}
|
||||
|
||||
func TestExecutorTaskWorkspace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
createTaskTemplate := func(t *testing.T, client *codersdk.Client, orgID uuid.UUID, ctx context.Context, defaultTTL time.Duration) codersdk.Template {
|
||||
t.Helper()
|
||||
|
||||
taskAppID := uuid.New()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, orgID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{
|
||||
Type: &proto.Response_Plan{
|
||||
Plan: &proto.PlanComplete{HasAiTasks: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
ProvisionApply: []*proto.Response{
|
||||
{
|
||||
Type: &proto.Response_Apply{
|
||||
Apply: &proto.ApplyComplete{
|
||||
Resources: []*proto.Resource{
|
||||
{
|
||||
Agents: []*proto.Agent{
|
||||
{
|
||||
Id: uuid.NewString(),
|
||||
Name: "dev",
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: uuid.NewString(),
|
||||
},
|
||||
Apps: []*proto.App{
|
||||
{
|
||||
Id: taskAppID.String(),
|
||||
Slug: "task-app",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
AiTasks: []*proto.AITask{
|
||||
{
|
||||
AppId: taskAppID.String(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, orgID, version.ID)
|
||||
|
||||
if defaultTTL > 0 {
|
||||
_, err := client.UpdateTemplateMeta(ctx, template.ID, codersdk.UpdateTemplateMeta{
|
||||
DefaultTTLMillis: defaultTTL.Milliseconds(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
return template
|
||||
}
|
||||
|
||||
createTaskWorkspace := func(t *testing.T, client *codersdk.Client, template codersdk.Template, ctx context.Context, input string) codersdk.Workspace {
|
||||
t.Helper()
|
||||
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
task, err := exp.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
|
||||
TemplateVersionID: template.ActiveVersionID,
|
||||
Input: input,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace")
|
||||
|
||||
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
return workspace
|
||||
}
|
||||
|
||||
t.Run("Autostart", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
sched = mustSchedule(t, "CRON_TZ=UTC 0 * * * *")
|
||||
tickCh = make(chan time.Time)
|
||||
statsCh = make(chan autobuild.Stats)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
AutobuildTicker: tickCh,
|
||||
IncludeProvisionerDaemon: true,
|
||||
AutobuildStats: statsCh,
|
||||
})
|
||||
admin = coderdtest.CreateFirstUser(t, client)
|
||||
)
|
||||
|
||||
// Given: A task workspace
|
||||
template := createTaskTemplate(t, client, admin.OrganizationID, ctx, 0)
|
||||
workspace := createTaskWorkspace(t, client, template, ctx, "test task for autostart")
|
||||
|
||||
// Given: The task workspace has an autostart schedule
|
||||
err := client.UpdateWorkspaceAutostart(ctx, workspace.ID, codersdk.UpdateWorkspaceAutostartRequest{
|
||||
Schedule: ptr.Ref(sched.String()),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Given: That the workspace is in a stopped state.
|
||||
workspace = coderdtest.MustTransitionWorkspace(t, client, workspace.ID, codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionStop)
|
||||
|
||||
p, err := coderdtest.GetProvisionerForTags(db, time.Now(), workspace.OrganizationID, map[string]string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// When: the autobuild executor ticks after the scheduled time
|
||||
go func() {
|
||||
tickTime := sched.Next(workspace.LatestBuild.CreatedAt)
|
||||
coderdtest.UpdateProvisionerLastSeenAt(t, db, p.ID, tickTime)
|
||||
tickCh <- tickTime
|
||||
close(tickCh)
|
||||
}()
|
||||
|
||||
// Then: We expect to see a start transition
|
||||
stats := <-statsCh
|
||||
require.Len(t, stats.Transitions, 1, "lifecycle executor should transition the task workspace")
|
||||
assert.Contains(t, stats.Transitions, workspace.ID, "task workspace should be in transitions")
|
||||
assert.Equal(t, database.WorkspaceTransitionStart, stats.Transitions[workspace.ID], "should autostart the workspace")
|
||||
require.Empty(t, stats.Errors, "should have no errors when managing task workspaces")
|
||||
})
|
||||
|
||||
t.Run("Autostop", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
tickCh = make(chan time.Time)
|
||||
statsCh = make(chan autobuild.Stats)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
AutobuildTicker: tickCh,
|
||||
IncludeProvisionerDaemon: true,
|
||||
AutobuildStats: statsCh,
|
||||
})
|
||||
admin = coderdtest.CreateFirstUser(t, client)
|
||||
)
|
||||
|
||||
// Given: A task workspace with an 8 hour deadline
|
||||
template := createTaskTemplate(t, client, admin.OrganizationID, ctx, 8*time.Hour)
|
||||
workspace := createTaskWorkspace(t, client, template, ctx, "test task for autostop")
|
||||
|
||||
// Given: The workspace is currently running
|
||||
workspace = coderdtest.MustWorkspace(t, client, workspace.ID)
|
||||
require.Equal(t, codersdk.WorkspaceTransitionStart, workspace.LatestBuild.Transition)
|
||||
require.NotZero(t, workspace.LatestBuild.Deadline, "workspace should have a deadline for autostop")
|
||||
|
||||
p, err := coderdtest.GetProvisionerForTags(db, time.Now(), workspace.OrganizationID, map[string]string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// When: the autobuild executor ticks after the deadline
|
||||
go func() {
|
||||
tickTime := workspace.LatestBuild.Deadline.Time.Add(time.Minute)
|
||||
coderdtest.UpdateProvisionerLastSeenAt(t, db, p.ID, tickTime)
|
||||
tickCh <- tickTime
|
||||
close(tickCh)
|
||||
}()
|
||||
|
||||
// Then: We expect to see a stop transition
|
||||
stats := <-statsCh
|
||||
require.Len(t, stats.Transitions, 1, "lifecycle executor should transition the task workspace")
|
||||
assert.Contains(t, stats.Transitions, workspace.ID, "task workspace should be in transitions")
|
||||
assert.Equal(t, database.WorkspaceTransitionStop, stats.Transitions[workspace.ID], "should autostop the workspace")
|
||||
require.Empty(t, stats.Errors, "should have no errors when managing task workspaces")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1496,7 +1496,6 @@ func New(options *Options) *API {
|
||||
r.Get("/parameters", api.workspaceBuildParameters)
|
||||
r.Get("/resources", api.workspaceBuildResourcesDeprecated)
|
||||
r.Get("/state", api.workspaceBuildState)
|
||||
r.Put("/state", api.workspaceBuildUpdateState)
|
||||
r.Get("/timings", api.workspaceBuildTimings)
|
||||
})
|
||||
r.Route("/authcheck", func(r chi.Router) {
|
||||
|
||||
@@ -14,7 +14,6 @@ const (
|
||||
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
|
||||
CheckWorkspaceBuildsAiTaskSidebarAppIDRequired CheckConstraint = "workspace_builds_ai_task_sidebar_app_id_required" // workspace_builds
|
||||
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
)
|
||||
|
||||
@@ -714,13 +714,12 @@ func RBACRole(role rbac.Role) codersdk.Role {
|
||||
|
||||
orgPerms := role.ByOrgID[slim.OrganizationID]
|
||||
return codersdk.Role{
|
||||
Name: slim.Name,
|
||||
OrganizationID: slim.OrganizationID,
|
||||
DisplayName: slim.DisplayName,
|
||||
SitePermissions: List(role.Site, RBACPermission),
|
||||
UserPermissions: List(role.User, RBACPermission),
|
||||
OrganizationPermissions: List(orgPerms.Org, RBACPermission),
|
||||
OrganizationMemberPermissions: List(orgPerms.Member, RBACPermission),
|
||||
Name: slim.Name,
|
||||
OrganizationID: slim.OrganizationID,
|
||||
DisplayName: slim.DisplayName,
|
||||
SitePermissions: List(role.Site, RBACPermission),
|
||||
OrganizationPermissions: List(orgPerms.Org, RBACPermission),
|
||||
UserPermissions: List(role.User, RBACPermission),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -735,8 +734,8 @@ func Role(role database.CustomRole) codersdk.Role {
|
||||
OrganizationID: orgID,
|
||||
DisplayName: role.DisplayName,
|
||||
SitePermissions: List(role.SitePermissions, Permission),
|
||||
UserPermissions: List(role.UserPermissions, Permission),
|
||||
OrganizationPermissions: List(role.OrgPermissions, Permission),
|
||||
UserPermissions: List(role.UserPermissions, Permission),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -963,7 +962,7 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
|
||||
// created_at ASC
|
||||
return sdkToolUsages[i].CreatedAt.Before(sdkToolUsages[j].CreatedAt)
|
||||
})
|
||||
intc := codersdk.AIBridgeInterception{
|
||||
return codersdk.AIBridgeInterception{
|
||||
ID: interception.ID,
|
||||
Initiator: MinimalUserFromVisibleUser(initiator),
|
||||
Provider: interception.Provider,
|
||||
@@ -974,10 +973,6 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
|
||||
UserPrompts: sdkUserPrompts,
|
||||
ToolUsages: sdkToolUsages,
|
||||
}
|
||||
if interception.EndedAt.Valid {
|
||||
intc.EndedAt = &interception.EndedAt.Time
|
||||
}
|
||||
return intc
|
||||
}
|
||||
|
||||
func AIBridgeTokenUsage(usage database.AIBridgeTokenUsage) codersdk.AIBridgeTokenUsage {
|
||||
|
||||
@@ -217,10 +217,10 @@ var (
|
||||
rbac.ResourceTemplate.Type: {policy.ActionRead, policy.ActionUpdate},
|
||||
// Unsure why provisionerd needs update and read personal
|
||||
rbac.ResourceUser.Type: {policy.ActionRead, policy.ActionReadPersonal, policy.ActionUpdatePersonal},
|
||||
rbac.ResourceWorkspaceDormant.Type: {policy.ActionDelete, policy.ActionRead, policy.ActionUpdate, policy.ActionWorkspaceStop, policy.ActionCreateAgent},
|
||||
rbac.ResourceWorkspaceDormant.Type: {policy.ActionDelete, policy.ActionRead, policy.ActionUpdate, policy.ActionWorkspaceStop},
|
||||
rbac.ResourceWorkspace.Type: {policy.ActionDelete, policy.ActionRead, policy.ActionUpdate, policy.ActionWorkspaceStart, policy.ActionWorkspaceStop, policy.ActionCreateAgent},
|
||||
// Provisionerd needs to read, update, and delete tasks associated with workspaces.
|
||||
rbac.ResourceTask.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
// Provisionerd needs to read and update tasks associated with workspaces.
|
||||
rbac.ResourceTask.Type: {policy.ActionRead, policy.ActionUpdate},
|
||||
rbac.ResourceApiKey.Type: {policy.WildcardSymbol},
|
||||
// When org scoped provisioner credentials are implemented,
|
||||
// this can be reduced to read a specific org.
|
||||
@@ -254,7 +254,6 @@ var (
|
||||
rbac.ResourceFile.Type: {policy.ActionRead}, // Required to read terraform files
|
||||
rbac.ResourceNotificationMessage.Type: {policy.ActionCreate, policy.ActionRead},
|
||||
rbac.ResourceSystem.Type: {policy.WildcardSymbol},
|
||||
rbac.ResourceTask.Type: {policy.ActionRead, policy.ActionUpdate},
|
||||
rbac.ResourceTemplate.Type: {policy.ActionRead, policy.ActionUpdate},
|
||||
rbac.ResourceUser.Type: {policy.ActionRead},
|
||||
rbac.ResourceWorkspace.Type: {policy.ActionDelete, policy.ActionRead, policy.ActionUpdate, policy.ActionWorkspaceStart, policy.ActionWorkspaceStop},
|
||||
@@ -396,13 +395,11 @@ var (
|
||||
Identifier: rbac.RoleIdentifier{Name: "subagentapi"},
|
||||
DisplayName: "Sub Agent API",
|
||||
Site: []rbac.Permission{},
|
||||
User: []rbac.Permission{},
|
||||
User: rbac.Permissions(map[string][]policy.Action{
|
||||
rbac.ResourceWorkspace.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreateAgent, policy.ActionDeleteAgent},
|
||||
}),
|
||||
ByOrgID: map[string]rbac.OrgPermissions{
|
||||
orgID.String(): {
|
||||
Member: rbac.Permissions(map[string][]policy.Action{
|
||||
rbac.ResourceWorkspace.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreateAgent, policy.ActionDeleteAgent},
|
||||
}),
|
||||
},
|
||||
orgID.String(): {},
|
||||
},
|
||||
},
|
||||
}),
|
||||
@@ -1293,17 +1290,14 @@ func (q *querier) customRoleCheck(ctx context.Context, role database.CustomRole)
|
||||
return xerrors.Errorf("invalid role: %w", err)
|
||||
}
|
||||
|
||||
if len(rbacRole.ByOrgID) > 0 && (len(rbacRole.Site) > 0 || len(rbacRole.User) > 0) {
|
||||
// This is a choice to keep roles simple. If we allow mixing site and org
|
||||
// scoped perms, then knowing who can do what gets more complicated. Roles
|
||||
// should either be entirely org-scoped or entirely unrelated to
|
||||
// organizations.
|
||||
return xerrors.Errorf("invalid custom role, cannot assign both org-scoped and site/user permissions at the same time")
|
||||
if len(rbacRole.ByOrgID) > 0 && len(rbacRole.Site) > 0 {
|
||||
// This is a choice to keep roles simple. If we allow mixing site and org scoped perms, then knowing who can
|
||||
// do what gets more complicated.
|
||||
return xerrors.Errorf("invalid custom role, cannot assign both org and site permissions at the same time")
|
||||
}
|
||||
|
||||
if len(rbacRole.ByOrgID) > 1 {
|
||||
// Again to avoid more complexity in our roles. Roles are limited to one
|
||||
// organization.
|
||||
// Again to avoid more complexity in our roles
|
||||
return xerrors.Errorf("invalid custom role, cannot assign permissions to more than 1 org at a time")
|
||||
}
|
||||
|
||||
@@ -1319,18 +1313,7 @@ func (q *querier) customRoleCheck(ctx context.Context, role database.CustomRole)
|
||||
for _, orgPerm := range perms.Org {
|
||||
err := q.customRoleEscalationCheck(ctx, act, orgPerm, rbac.Object{OrgID: orgID, Type: orgPerm.ResourceType})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("org=%q: org: %w", orgID, err)
|
||||
}
|
||||
}
|
||||
for _, memberPerm := range perms.Member {
|
||||
// The person giving the permission should still be required to have
|
||||
// the permissions throughout the org in order to give individuals the
|
||||
// same permission among their own resources, since the role can be given
|
||||
// to anyone. The `Owner` is intentionally omitted from the `Object` to
|
||||
// enforce this.
|
||||
err := q.customRoleEscalationCheck(ctx, act, memberPerm, rbac.Object{OrgID: orgID, Type: memberPerm.ResourceType})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("org=%q: member: %w", orgID, err)
|
||||
return xerrors.Errorf("org=%q: %w", orgID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1348,8 +1331,8 @@ func (q *querier) customRoleCheck(ctx context.Context, role database.CustomRole)
|
||||
func (q *querier) authorizeProvisionerJob(ctx context.Context, job database.ProvisionerJob) error {
|
||||
switch job.Type {
|
||||
case database.ProvisionerJobTypeWorkspaceBuild:
|
||||
// Authorized call to get workspace build. If we can read the build, we can
|
||||
// read the job.
|
||||
// Authorized call to get workspace build. If we can read the build, we
|
||||
// can read the job.
|
||||
_, err := q.GetWorkspaceBuildByJobID(ctx, job.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("fetch related workspace build: %w", err)
|
||||
@@ -1392,8 +1375,8 @@ func (q *querier) ActivityBumpWorkspace(ctx context.Context, arg database.Activi
|
||||
}
|
||||
|
||||
func (q *querier) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error) {
|
||||
// Although this technically only reads users, only system-related functions
|
||||
// should be allowed to call this.
|
||||
// Although this technically only reads users, only system-related functions should be
|
||||
// allowed to call this.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1412,8 +1395,8 @@ func (q *querier) ArchiveUnusedTemplateVersions(ctx context.Context, arg databas
|
||||
}
|
||||
|
||||
func (q *querier) BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg database.BatchUpdateWorkspaceLastUsedAtParams) error {
|
||||
// Could be any workspace and checking auth to each workspace is overkill for
|
||||
// the purpose of this function.
|
||||
// Could be any workspace and checking auth to each workspace is overkill for the purpose
|
||||
// of this function.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceWorkspace.All()); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1441,13 +1424,6 @@ func (q *querier) BulkMarkNotificationMessagesSent(ctx context.Context, arg data
|
||||
return q.db.BulkMarkNotificationMessagesSent(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg database.CalculateAIBridgeInterceptionsTelemetrySummaryParams) (database.CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
|
||||
return database.CalculateAIBridgeInterceptionsTelemetrySummaryRow{}, err
|
||||
}
|
||||
return q.db.CalculateAIBridgeInterceptionsTelemetrySummary(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) {
|
||||
empty := database.ClaimPrebuiltWorkspaceRow{}
|
||||
|
||||
@@ -1747,13 +1723,6 @@ func (q *querier) DeleteOldProvisionerDaemons(ctx context.Context) error {
|
||||
return q.db.DeleteOldProvisionerDaemons(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOldTelemetryLocks(ctx context.Context, beforeTime time.Time) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteOldTelemetryLocks(ctx, beforeTime)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOldWorkspaceAgentLogs(ctx context.Context, threshold time.Time) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
@@ -2649,13 +2618,6 @@ func (q *querier) GetOrganizationsByUserID(ctx context.Context, userID database.
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetOrganizationsByUserID)(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg database.GetOrganizationsWithPrebuildStatusParams) ([]database.GetOrganizationsWithPrebuildStatusRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceOrganization.All()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetOrganizationsWithPrebuildStatus(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
|
||||
version, err := q.db.GetTemplateVersionByJobID(ctx, jobID)
|
||||
if err != nil {
|
||||
@@ -4250,13 +4212,6 @@ func (q *querier) InsertTelemetryItemIfNotExists(ctx context.Context, arg databa
|
||||
return q.db.InsertTelemetryItemIfNotExists(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertTelemetryLock(ctx context.Context, arg database.InsertTelemetryLockParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.InsertTelemetryLock(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) error {
|
||||
obj := rbac.ResourceTemplate.InOrg(arg.OrganizationID)
|
||||
if err := q.authorizeContext(ctx, policy.ActionCreate, obj); err != nil {
|
||||
@@ -4568,13 +4523,6 @@ func (q *querier) ListAIBridgeInterceptions(ctx context.Context, arg database.Li
|
||||
return q.db.ListAuthorizedAIBridgeInterceptions(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg database.ListAIBridgeInterceptionsTelemetrySummariesParams) ([]database.ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
// This function is a system function until we implement a join for aibridge interceptions.
|
||||
// Matches the behavior of the workspaces listing endpoint.
|
||||
@@ -4763,13 +4711,6 @@ func (q *querier) UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error {
|
||||
return update(q.log, q.auth, fetch, q.db.UnfavoriteWorkspace)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateAIBridgeInterceptionEnded(ctx context.Context, params database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
|
||||
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, params.ID); err != nil {
|
||||
return database.AIBridgeInterception{}, err
|
||||
}
|
||||
return q.db.UpdateAIBridgeInterceptionEnded(ctx, params)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
|
||||
fetch := func(ctx context.Context, arg database.UpdateAPIKeyByIDParams) (database.APIKey, error) {
|
||||
return q.db.GetAPIKeyByID(ctx, arg.ID)
|
||||
@@ -4941,10 +4882,10 @@ func (q *querier) UpdateOrganizationDeletedByID(ctx context.Context, arg databas
|
||||
return deleteQ(q.log, q.auth, q.db.GetOrganizationByID, deleteF)(ctx, arg.ID)
|
||||
}
|
||||
|
||||
func (q *querier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
|
||||
func (q *querier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
|
||||
// Prebuild operation for canceling pending prebuild jobs from non-active template versions
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourcePrebuiltWorkspace); err != nil {
|
||||
return []database.UpdatePrebuildProvisionerJobWithCancelRow{}, err
|
||||
return []uuid.UUID{}, err
|
||||
}
|
||||
return q.db.UpdatePrebuildProvisionerJobWithCancel(ctx, arg)
|
||||
}
|
||||
|
||||
@@ -646,13 +646,10 @@ func (s *MethodTestSuite) TestProvisionerJob() {
|
||||
PresetID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
Now: dbtime.Now(),
|
||||
}
|
||||
canceledJobs := []database.UpdatePrebuildProvisionerJobWithCancelRow{
|
||||
{ID: uuid.New(), WorkspaceID: uuid.New(), TemplateID: uuid.New(), TemplateVersionPresetID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
|
||||
{ID: uuid.New(), WorkspaceID: uuid.New(), TemplateID: uuid.New(), TemplateVersionPresetID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
|
||||
}
|
||||
jobIDs := []uuid.UUID{uuid.New(), uuid.New()}
|
||||
|
||||
dbm.EXPECT().UpdatePrebuildProvisionerJobWithCancel(gomock.Any(), arg).Return(canceledJobs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourcePrebuiltWorkspace, policy.ActionUpdate).Returns(canceledJobs)
|
||||
dbm.EXPECT().UpdatePrebuildProvisionerJobWithCancel(gomock.Any(), arg).Return(jobIDs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourcePrebuiltWorkspace, policy.ActionUpdate).Returns(jobIDs)
|
||||
}))
|
||||
s.Run("GetProvisionerJobsByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
org := testutil.Fake(s.T(), faker, database.Organization{})
|
||||
@@ -3759,14 +3756,6 @@ func (s *MethodTestSuite) TestPrebuilds() {
|
||||
dbm.EXPECT().GetPrebuildMetrics(gomock.Any()).Return([]database.GetPrebuildMetricsRow{}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetOrganizationsWithPrebuildStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetOrganizationsWithPrebuildStatusParams{
|
||||
UserID: uuid.New(),
|
||||
GroupName: "test",
|
||||
}
|
||||
dbm.EXPECT().GetOrganizationsWithPrebuildStatus(gomock.Any(), arg).Return([]database.GetOrganizationsWithPrebuildStatusRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceOrganization.All(), policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetPrebuildsSettings", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetPrebuildsSettings(gomock.Any()).Return("{}", nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
@@ -4628,35 +4617,4 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
db.EXPECT().ListAIBridgeToolUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeToolUsage{}, nil).AnyTimes()
|
||||
check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.AIBridgeToolUsage{})
|
||||
}))
|
||||
|
||||
s.Run("UpdateAIBridgeInterceptionEnded", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
intcID := uuid.UUID{1}
|
||||
params := database.UpdateAIBridgeInterceptionEndedParams{ID: intcID}
|
||||
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intcID})
|
||||
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intcID).Return(intc, nil).AnyTimes() // Validation.
|
||||
db.EXPECT().UpdateAIBridgeInterceptionEnded(gomock.Any(), params).Return(intc, nil).AnyTimes()
|
||||
check.Args(params).Asserts(intc, policy.ActionUpdate).Returns(intc)
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestTelemetry() {
|
||||
s.Run("InsertTelemetryLock", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
db.EXPECT().InsertTelemetryLock(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
check.Args(database.InsertTelemetryLockParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate)
|
||||
}))
|
||||
|
||||
s.Run("DeleteOldTelemetryLocks", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
db.EXPECT().DeleteOldTelemetryLocks(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
check.Args(time.Time{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeInterceptionsTelemetrySummaries", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
db.EXPECT().ListAIBridgeInterceptionsTelemetrySummaries(gomock.Any(), gomock.Any()).Return([]database.ListAIBridgeInterceptionsTelemetrySummariesRow{}, nil).AnyTimes()
|
||||
check.Args(database.ListAIBridgeInterceptionsTelemetrySummariesParams{}).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead)
|
||||
}))
|
||||
|
||||
s.Run("CalculateAIBridgeInterceptionsTelemetrySummary", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
db.EXPECT().CalculateAIBridgeInterceptionsTelemetrySummary(gomock.Any(), gomock.Any()).Return(database.CalculateAIBridgeInterceptionsTelemetrySummaryRow{}, nil).AnyTimes()
|
||||
check.Args(database.CalculateAIBridgeInterceptionsTelemetrySummaryParams{}).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead)
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -41,7 +41,6 @@ type WorkspaceResponse struct {
|
||||
Build database.WorkspaceBuild
|
||||
AgentToken string
|
||||
TemplateVersionResponse
|
||||
Task database.Task
|
||||
}
|
||||
|
||||
// WorkspaceBuildBuilder generates workspace builds and associated
|
||||
@@ -58,7 +57,6 @@ type WorkspaceBuildBuilder struct {
|
||||
agentToken string
|
||||
jobStatus database.ProvisionerJobStatus
|
||||
taskAppID uuid.UUID
|
||||
taskSeed database.TaskTable
|
||||
}
|
||||
|
||||
// WorkspaceBuild generates a workspace build for the provided workspace.
|
||||
@@ -117,28 +115,25 @@ func (b WorkspaceBuildBuilder) WithAgent(mutations ...func([]*sdkproto.Agent) []
|
||||
return b
|
||||
}
|
||||
|
||||
func (b WorkspaceBuildBuilder) WithTask(taskSeed database.TaskTable, appSeed *sdkproto.App) WorkspaceBuildBuilder {
|
||||
//nolint:revive // returns modified struct
|
||||
b.taskSeed = taskSeed
|
||||
|
||||
if appSeed == nil {
|
||||
appSeed = &sdkproto.App{}
|
||||
func (b WorkspaceBuildBuilder) WithTask(seed *sdkproto.App) WorkspaceBuildBuilder {
|
||||
if seed == nil {
|
||||
seed = &sdkproto.App{}
|
||||
}
|
||||
|
||||
var err error
|
||||
//nolint: revive // returns modified struct
|
||||
b.taskAppID, err = uuid.Parse(takeFirst(appSeed.Id, uuid.NewString()))
|
||||
b.taskAppID, err = uuid.Parse(takeFirst(seed.Id, uuid.NewString()))
|
||||
require.NoError(b.t, err)
|
||||
|
||||
return b.Params(database.WorkspaceBuildParameter{
|
||||
Name: codersdk.AITaskPromptParameterName,
|
||||
Value: b.taskSeed.Prompt,
|
||||
Value: "list me",
|
||||
}).WithAgent(func(a []*sdkproto.Agent) []*sdkproto.Agent {
|
||||
a[0].Apps = []*sdkproto.App{
|
||||
{
|
||||
Id: b.taskAppID.String(),
|
||||
Slug: takeFirst(appSeed.Slug, "task-app"),
|
||||
Url: takeFirst(appSeed.Url, ""),
|
||||
Slug: takeFirst(seed.Slug, "task-app"),
|
||||
Url: takeFirst(seed.Url, ""),
|
||||
},
|
||||
}
|
||||
return a
|
||||
@@ -166,19 +161,6 @@ func (b WorkspaceBuildBuilder) Canceled() WorkspaceBuildBuilder {
|
||||
// Workspace will be optionally populated if no ID is set on the provided
|
||||
// workspace.
|
||||
func (b WorkspaceBuildBuilder) Do() WorkspaceResponse {
|
||||
var resp WorkspaceResponse
|
||||
// Use transaction, like real wsbuilder.
|
||||
err := b.db.InTx(func(tx database.Store) error {
|
||||
//nolint:revive // calls do on modified struct
|
||||
b.db = tx
|
||||
resp = b.doInTX()
|
||||
return nil
|
||||
}, nil)
|
||||
require.NoError(b.t, err)
|
||||
return resp
|
||||
}
|
||||
|
||||
func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
|
||||
b.t.Helper()
|
||||
jobID := uuid.New()
|
||||
b.seed.ID = uuid.New()
|
||||
@@ -230,37 +212,6 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
|
||||
b.seed.WorkspaceID = b.ws.ID
|
||||
b.seed.InitiatorID = takeFirst(b.seed.InitiatorID, b.ws.OwnerID)
|
||||
|
||||
// If a task was requested, ensure it exists and is associated with this
|
||||
// workspace.
|
||||
if b.taskAppID != uuid.Nil {
|
||||
b.logger.Debug(context.Background(), "creating or updating task", "task_id", b.taskSeed.ID)
|
||||
b.taskSeed.OrganizationID = takeFirst(b.taskSeed.OrganizationID, b.ws.OrganizationID)
|
||||
b.taskSeed.OwnerID = takeFirst(b.taskSeed.OwnerID, b.ws.OwnerID)
|
||||
b.taskSeed.Name = takeFirst(b.taskSeed.Name, b.ws.Name)
|
||||
b.taskSeed.WorkspaceID = uuid.NullUUID{UUID: takeFirst(b.taskSeed.WorkspaceID.UUID, b.ws.ID), Valid: true}
|
||||
b.taskSeed.TemplateVersionID = takeFirst(b.taskSeed.TemplateVersionID, b.seed.TemplateVersionID)
|
||||
|
||||
// Try to fetch existing task and update its workspace ID.
|
||||
if task, err := b.db.GetTaskByID(ownerCtx, b.taskSeed.ID); err == nil {
|
||||
if !task.WorkspaceID.Valid {
|
||||
b.logger.Info(context.Background(), "updating task workspace id", "task_id", b.taskSeed.ID, "workspace_id", b.ws.ID)
|
||||
_, err = b.db.UpdateTaskWorkspaceID(ownerCtx, database.UpdateTaskWorkspaceIDParams{
|
||||
ID: b.taskSeed.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: b.ws.ID, Valid: true},
|
||||
})
|
||||
require.NoError(b.t, err, "update task workspace id")
|
||||
} else if task.WorkspaceID.UUID != b.ws.ID {
|
||||
require.Fail(b.t, "task already has a workspace id, mismatch", task.WorkspaceID.UUID, b.ws.ID)
|
||||
}
|
||||
} else if errors.Is(err, sql.ErrNoRows) {
|
||||
task := dbgen.Task(b.t, b.db, b.taskSeed)
|
||||
b.taskSeed.ID = task.ID
|
||||
b.logger.Info(context.Background(), "created new task", "task_id", b.taskSeed.ID)
|
||||
} else {
|
||||
require.NoError(b.t, err, "get task by id")
|
||||
}
|
||||
}
|
||||
|
||||
// Create a provisioner job for the build!
|
||||
payload, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{
|
||||
WorkspaceBuildID: b.seed.ID,
|
||||
@@ -373,11 +324,6 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
|
||||
b.logger.Debug(context.Background(), "linked task to workspace build",
|
||||
slog.F("task_id", task.ID),
|
||||
slog.F("build_number", resp.Build.BuildNumber))
|
||||
|
||||
// Update task after linking.
|
||||
task, err = b.db.GetTaskByID(ownerCtx, task.ID)
|
||||
require.NoError(b.t, err, "get task by id")
|
||||
resp.Task = task
|
||||
}
|
||||
|
||||
for i := range b.params {
|
||||
|
||||
@@ -1495,7 +1495,7 @@ func ClaimPrebuild(
|
||||
return claimedWorkspace
|
||||
}
|
||||
|
||||
func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertAIBridgeInterceptionParams, endedAt *time.Time) database.AIBridgeInterception {
|
||||
func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertAIBridgeInterceptionParams) database.AIBridgeInterception {
|
||||
interception, err := db.InsertAIBridgeInterception(genCtx, database.InsertAIBridgeInterceptionParams{
|
||||
ID: takeFirst(seed.ID, uuid.New()),
|
||||
InitiatorID: takeFirst(seed.InitiatorID, uuid.New()),
|
||||
@@ -1504,13 +1504,6 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
|
||||
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
|
||||
StartedAt: takeFirst(seed.StartedAt, dbtime.Now()),
|
||||
})
|
||||
if endedAt != nil {
|
||||
interception, err = db.UpdateAIBridgeInterceptionEnded(genCtx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: interception.ID,
|
||||
EndedAt: *endedAt,
|
||||
})
|
||||
require.NoError(t, err, "insert aibridge interception")
|
||||
}
|
||||
require.NoError(t, err, "insert aibridge interception")
|
||||
return interception
|
||||
}
|
||||
@@ -1576,7 +1569,6 @@ func Task(t testing.TB, db database.Store, orig database.TaskTable) database.Tas
|
||||
}
|
||||
|
||||
task, err := db.InsertTask(genCtx, database.InsertTaskParams{
|
||||
ID: takeFirst(orig.ID, uuid.New()),
|
||||
OrganizationID: orig.OrganizationID,
|
||||
OwnerID: orig.OwnerID,
|
||||
Name: takeFirst(orig.Name, taskname.GenerateFallback()),
|
||||
|
||||
@@ -158,13 +158,6 @@ func (m queryMetricsStore) BulkMarkNotificationMessagesSent(ctx context.Context,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg database.CalculateAIBridgeInterceptionsTelemetrySummaryParams) (database.CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CalculateAIBridgeInterceptionsTelemetrySummary(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("CalculateAIBridgeInterceptionsTelemetrySummary").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ClaimPrebuiltWorkspace(ctx, arg)
|
||||
@@ -410,13 +403,6 @@ func (m queryMetricsStore) DeleteOldProvisionerDaemons(ctx context.Context) erro
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteOldTelemetryLocks(ctx, periodEndingAtBefore)
|
||||
m.queryLatencies.WithLabelValues("DeleteOldTelemetryLocks").Observe(time.Since(start).Seconds())
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOldWorkspaceAgentLogs(ctx context.Context, arg time.Time) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteOldWorkspaceAgentLogs(ctx, arg)
|
||||
@@ -1243,13 +1229,6 @@ func (m queryMetricsStore) GetOrganizationsByUserID(ctx context.Context, userID
|
||||
return organizations, err
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg database.GetOrganizationsWithPrebuildStatusParams) ([]database.GetOrganizationsWithPrebuildStatusRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetOrganizationsWithPrebuildStatus(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetOrganizationsWithPrebuildStatus").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
|
||||
start := time.Now()
|
||||
schemas, err := m.s.GetParameterSchemasByJobID(ctx, jobID)
|
||||
@@ -2538,13 +2517,6 @@ func (m queryMetricsStore) InsertTelemetryItemIfNotExists(ctx context.Context, a
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertTelemetryLock(ctx context.Context, arg database.InsertTelemetryLockParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.InsertTelemetryLock(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertTelemetryLock").Observe(time.Since(start).Seconds())
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) error {
|
||||
start := time.Now()
|
||||
err := m.s.InsertTemplate(ctx, arg)
|
||||
@@ -2762,13 +2734,6 @@ func (m queryMetricsStore) ListAIBridgeInterceptions(ctx context.Context, arg da
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg database.ListAIBridgeInterceptionsTelemetrySummariesParams) ([]database.ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("ListAIBridgeInterceptionsTelemetrySummaries").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIds)
|
||||
@@ -2923,13 +2888,6 @@ func (m queryMetricsStore) UnfavoriteWorkspace(ctx context.Context, arg uuid.UUI
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateAIBridgeInterceptionEnded(ctx context.Context, id database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateAIBridgeInterceptionEnded(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("UpdateAIBridgeInterceptionEnded").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
|
||||
start := time.Now()
|
||||
err := m.s.UpdateAPIKeyByID(ctx, arg)
|
||||
@@ -3049,7 +3007,7 @@ func (m queryMetricsStore) UpdateOrganizationDeletedByID(ctx context.Context, ar
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
|
||||
func (m queryMetricsStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdatePrebuildProvisionerJobWithCancel(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdatePrebuildProvisionerJobWithCancel").Observe(time.Since(start).Seconds())
|
||||
|
||||
@@ -190,21 +190,6 @@ func (mr *MockStoreMockRecorder) BulkMarkNotificationMessagesSent(ctx, arg any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BulkMarkNotificationMessagesSent", reflect.TypeOf((*MockStore)(nil).BulkMarkNotificationMessagesSent), ctx, arg)
|
||||
}
|
||||
|
||||
// CalculateAIBridgeInterceptionsTelemetrySummary mocks base method.
|
||||
func (m *MockStore) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg database.CalculateAIBridgeInterceptionsTelemetrySummaryParams) (database.CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CalculateAIBridgeInterceptionsTelemetrySummary", ctx, arg)
|
||||
ret0, _ := ret[0].(database.CalculateAIBridgeInterceptionsTelemetrySummaryRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CalculateAIBridgeInterceptionsTelemetrySummary indicates an expected call of CalculateAIBridgeInterceptionsTelemetrySummary.
|
||||
func (mr *MockStoreMockRecorder) CalculateAIBridgeInterceptionsTelemetrySummary(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CalculateAIBridgeInterceptionsTelemetrySummary", reflect.TypeOf((*MockStore)(nil).CalculateAIBridgeInterceptionsTelemetrySummary), ctx, arg)
|
||||
}
|
||||
|
||||
// ClaimPrebuiltWorkspace mocks base method.
|
||||
func (m *MockStore) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -751,20 +736,6 @@ func (mr *MockStoreMockRecorder) DeleteOldProvisionerDaemons(ctx any) *gomock.Ca
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldProvisionerDaemons", reflect.TypeOf((*MockStore)(nil).DeleteOldProvisionerDaemons), ctx)
|
||||
}
|
||||
|
||||
// DeleteOldTelemetryLocks mocks base method.
|
||||
func (m *MockStore) DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteOldTelemetryLocks", ctx, periodEndingAtBefore)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteOldTelemetryLocks indicates an expected call of DeleteOldTelemetryLocks.
|
||||
func (mr *MockStoreMockRecorder) DeleteOldTelemetryLocks(ctx, periodEndingAtBefore any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldTelemetryLocks", reflect.TypeOf((*MockStore)(nil).DeleteOldTelemetryLocks), ctx, periodEndingAtBefore)
|
||||
}
|
||||
|
||||
// DeleteOldWorkspaceAgentLogs mocks base method.
|
||||
func (m *MockStore) DeleteOldWorkspaceAgentLogs(ctx context.Context, threshold time.Time) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2622,21 +2593,6 @@ func (mr *MockStoreMockRecorder) GetOrganizationsByUserID(ctx, arg any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationsByUserID", reflect.TypeOf((*MockStore)(nil).GetOrganizationsByUserID), ctx, arg)
|
||||
}
|
||||
|
||||
// GetOrganizationsWithPrebuildStatus mocks base method.
|
||||
func (m *MockStore) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg database.GetOrganizationsWithPrebuildStatusParams) ([]database.GetOrganizationsWithPrebuildStatusRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetOrganizationsWithPrebuildStatus", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetOrganizationsWithPrebuildStatusRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetOrganizationsWithPrebuildStatus indicates an expected call of GetOrganizationsWithPrebuildStatus.
|
||||
func (mr *MockStoreMockRecorder) GetOrganizationsWithPrebuildStatus(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationsWithPrebuildStatus", reflect.TypeOf((*MockStore)(nil).GetOrganizationsWithPrebuildStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// GetParameterSchemasByJobID mocks base method.
|
||||
func (m *MockStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -5436,20 +5392,6 @@ func (mr *MockStoreMockRecorder) InsertTelemetryItemIfNotExists(ctx, arg any) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertTelemetryItemIfNotExists", reflect.TypeOf((*MockStore)(nil).InsertTelemetryItemIfNotExists), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertTelemetryLock mocks base method.
|
||||
func (m *MockStore) InsertTelemetryLock(ctx context.Context, arg database.InsertTelemetryLockParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertTelemetryLock", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// InsertTelemetryLock indicates an expected call of InsertTelemetryLock.
|
||||
func (mr *MockStoreMockRecorder) InsertTelemetryLock(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertTelemetryLock", reflect.TypeOf((*MockStore)(nil).InsertTelemetryLock), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertTemplate mocks base method.
|
||||
func (m *MockStore) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -5905,21 +5847,6 @@ func (mr *MockStoreMockRecorder) ListAIBridgeInterceptions(ctx, arg any) *gomock
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).ListAIBridgeInterceptions), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeInterceptionsTelemetrySummaries mocks base method.
|
||||
func (m *MockStore) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg database.ListAIBridgeInterceptionsTelemetrySummariesParams) ([]database.ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAIBridgeInterceptionsTelemetrySummaries", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.ListAIBridgeInterceptionsTelemetrySummariesRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAIBridgeInterceptionsTelemetrySummaries indicates an expected call of ListAIBridgeInterceptionsTelemetrySummaries.
|
||||
func (mr *MockStoreMockRecorder) ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeInterceptionsTelemetrySummaries", reflect.TypeOf((*MockStore)(nil).ListAIBridgeInterceptionsTelemetrySummaries), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeTokenUsagesByInterceptionIDs mocks base method.
|
||||
func (m *MockStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6289,21 +6216,6 @@ func (mr *MockStoreMockRecorder) UnfavoriteWorkspace(ctx, id any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnfavoriteWorkspace", reflect.TypeOf((*MockStore)(nil).UnfavoriteWorkspace), ctx, id)
|
||||
}
|
||||
|
||||
// UpdateAIBridgeInterceptionEnded mocks base method.
|
||||
func (m *MockStore) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateAIBridgeInterceptionEnded", ctx, arg)
|
||||
ret0, _ := ret[0].(database.AIBridgeInterception)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateAIBridgeInterceptionEnded indicates an expected call of UpdateAIBridgeInterceptionEnded.
|
||||
func (mr *MockStoreMockRecorder) UpdateAIBridgeInterceptionEnded(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAIBridgeInterceptionEnded", reflect.TypeOf((*MockStore)(nil).UpdateAIBridgeInterceptionEnded), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateAPIKeyByID mocks base method.
|
||||
func (m *MockStore) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6555,10 +6467,10 @@ func (mr *MockStoreMockRecorder) UpdateOrganizationDeletedByID(ctx, arg any) *go
|
||||
}
|
||||
|
||||
// UpdatePrebuildProvisionerJobWithCancel mocks base method.
|
||||
func (m *MockStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
|
||||
func (m *MockStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdatePrebuildProvisionerJobWithCancel", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.UpdatePrebuildProvisionerJobWithCancelRow)
|
||||
ret0, _ := ret[0].([]uuid.UUID)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
@@ -24,12 +24,6 @@ const (
|
||||
// but we won't touch the `connection_logs` table.
|
||||
maxAuditLogConnectionEventAge = 90 * 24 * time.Hour // 90 days
|
||||
auditLogConnectionEventBatchSize = 1000
|
||||
// Telemetry heartbeats are used to deduplicate events across replicas. We
|
||||
// don't need to persist heartbeat rows for longer than 24 hours, as they
|
||||
// are only used for deduplication across replicas. The time needs to be
|
||||
// long enough to cover the maximum interval of a heartbeat event (currently
|
||||
// 1 hour) plus some buffer.
|
||||
maxTelemetryHeartbeatAge = 24 * time.Hour
|
||||
)
|
||||
|
||||
// New creates a new periodically purging database instance.
|
||||
@@ -77,10 +71,6 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, clk quartz.
|
||||
if err := tx.ExpirePrebuildsAPIKeys(ctx, dbtime.Time(start)); err != nil {
|
||||
return xerrors.Errorf("failed to expire prebuilds user api keys: %w", err)
|
||||
}
|
||||
deleteOldTelemetryLocksBefore := start.Add(-maxTelemetryHeartbeatAge)
|
||||
if err := tx.DeleteOldTelemetryLocks(ctx, deleteOldTelemetryLocksBefore); err != nil {
|
||||
return xerrors.Errorf("failed to delete old telemetry locks: %w", err)
|
||||
}
|
||||
|
||||
deleteOldAuditLogConnectionEventsBefore := start.Add(-maxAuditLogConnectionEventAge)
|
||||
if err := tx.DeleteOldAuditLogConnectionEvents(ctx, database.DeleteOldAuditLogConnectionEventsParams{
|
||||
|
||||
@@ -704,56 +704,3 @@ func TestExpireOldAPIKeys(t *testing.T) {
|
||||
// Out of an abundance of caution, we do not expire explicitly named prebuilds API keys.
|
||||
assertKeyActive(namedPrebuildsAPIKey.ID)
|
||||
}
|
||||
|
||||
func TestDeleteOldTelemetryHeartbeats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
clk := quartz.NewMock(t)
|
||||
now := clk.Now().UTC()
|
||||
|
||||
// Insert telemetry heartbeats.
|
||||
err := db.InsertTelemetryLock(ctx, database.InsertTelemetryLockParams{
|
||||
EventType: "aibridge_interceptions_summary",
|
||||
PeriodEndingAt: now.Add(-25 * time.Hour), // should be purged
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = db.InsertTelemetryLock(ctx, database.InsertTelemetryLockParams{
|
||||
EventType: "aibridge_interceptions_summary",
|
||||
PeriodEndingAt: now.Add(-23 * time.Hour), // should be kept
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = db.InsertTelemetryLock(ctx, database.InsertTelemetryLockParams{
|
||||
EventType: "aibridge_interceptions_summary",
|
||||
PeriodEndingAt: now, // should be kept
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, clk)
|
||||
defer closer.Close()
|
||||
<-done // doTick() has now run.
|
||||
|
||||
require.Eventuallyf(t, func() bool {
|
||||
// We use an SQL queries directly here because we don't expose queries
|
||||
// for deleting heartbeats in the application code.
|
||||
var totalCount int
|
||||
err := sqlDB.QueryRowContext(ctx, `
|
||||
SELECT COUNT(*) FROM telemetry_locks;
|
||||
`).Scan(&totalCount)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var oldCount int
|
||||
err = sqlDB.QueryRowContext(ctx, `
|
||||
SELECT COUNT(*) FROM telemetry_locks WHERE period_ending_at < $1;
|
||||
`, now.Add(-24*time.Hour)).Scan(&oldCount)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Expect 2 heartbeats remaining and none older than 24 hours.
|
||||
t.Logf("eventually: total count: %d, old count: %d", totalCount, oldCount)
|
||||
return totalCount == 2 && oldCount == 0
|
||||
}, testutil.WaitShort, testutil.IntervalFast, "it should delete old telemetry heartbeats")
|
||||
}
|
||||
|
||||
@@ -6,8 +6,6 @@ import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -47,8 +45,6 @@ func (b *Broker) Create(t TBSubset, opts ...OpenOption) (ConnectionParams, error
|
||||
host = defaultConnectionParams.Host
|
||||
port = defaultConnectionParams.Port
|
||||
)
|
||||
packageName := getTestPackageName(t)
|
||||
testName := t.Name()
|
||||
|
||||
// Use a time-based prefix to make it easier to find the database
|
||||
// when debugging.
|
||||
@@ -59,9 +55,9 @@ func (b *Broker) Create(t TBSubset, opts ...OpenOption) (ConnectionParams, error
|
||||
}
|
||||
dbName := now + "_" + dbSuffix
|
||||
|
||||
// TODO: add package and test name
|
||||
_, err = b.coderTestingDB.Exec(
|
||||
"INSERT INTO test_databases (name, process_uuid, test_package, test_name) VALUES ($1, $2, $3, $4)",
|
||||
dbName, b.uuid, packageName, testName)
|
||||
"INSERT INTO test_databases (name, process_uuid) VALUES ($1, $2)", dbName, b.uuid)
|
||||
if err != nil {
|
||||
return ConnectionParams{}, xerrors.Errorf("insert test_database row: %w", err)
|
||||
}
|
||||
@@ -108,10 +104,10 @@ func (b *Broker) clean(t TBSubset, dbName string) func() {
|
||||
func (b *Broker) init(t TBSubset) error {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
b.refCount++
|
||||
t.Cleanup(b.decRef)
|
||||
if b.coderTestingDB != nil {
|
||||
// already initialized
|
||||
b.refCount++
|
||||
t.Cleanup(b.decRef)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -128,8 +124,8 @@ func (b *Broker) init(t TBSubset) error {
|
||||
return xerrors.Errorf("open postgres connection: %w", err)
|
||||
}
|
||||
|
||||
// coderTestingSQLInit is idempotent, so we can run it every time.
|
||||
_, err = coderTestingDB.Exec(coderTestingSQLInit)
|
||||
// creating the db can succeed even if the database doesn't exist. Ping it to find out.
|
||||
err = coderTestingDB.Ping()
|
||||
var pqErr *pq.Error
|
||||
if xerrors.As(err, &pqErr) && pqErr.Code == "3D000" {
|
||||
// database does not exist.
|
||||
@@ -149,8 +145,6 @@ func (b *Broker) init(t TBSubset) error {
|
||||
return xerrors.Errorf("ping '%s' database: %w", CoderTestingDBName, err)
|
||||
}
|
||||
b.coderTestingDB = coderTestingDB
|
||||
b.refCount++
|
||||
t.Cleanup(b.decRef)
|
||||
|
||||
if b.uuid == uuid.Nil {
|
||||
b.uuid = uuid.New()
|
||||
@@ -192,42 +186,3 @@ func (b *Broker) decRef() {
|
||||
b.coderTestingDB = nil
|
||||
}
|
||||
}
|
||||
|
||||
// getTestPackageName returns the package name of the test that called it.
|
||||
func getTestPackageName(t TBSubset) string {
|
||||
packageName := "unknown"
|
||||
// Ask runtime.Callers for up to 100 program counters, including runtime.Callers itself.
|
||||
pc := make([]uintptr, 100)
|
||||
n := runtime.Callers(0, pc)
|
||||
if n == 0 {
|
||||
// No PCs available. This can happen if the first argument to
|
||||
// runtime.Callers is large.
|
||||
//
|
||||
// Return now to avoid processing the zero Frame that would
|
||||
// otherwise be returned by frames.Next below.
|
||||
t.Logf("could not determine test package name: no PCs available")
|
||||
return packageName
|
||||
}
|
||||
|
||||
pc = pc[:n] // pass only valid pcs to runtime.CallersFrames
|
||||
frames := runtime.CallersFrames(pc)
|
||||
|
||||
// Loop to get frames.
|
||||
// A fixed number of PCs can expand to an indefinite number of Frames.
|
||||
for {
|
||||
frame, more := frames.Next()
|
||||
|
||||
if strings.HasPrefix(frame.Function, "github.com/coder/coder/v2/") {
|
||||
packageName = strings.SplitN(strings.TrimPrefix(frame.Function, "github.com/coder/coder/v2/"), ".", 2)[0]
|
||||
}
|
||||
if strings.HasPrefix(frame.Function, "testing") {
|
||||
break
|
||||
}
|
||||
|
||||
// Check whether there are more frames to process after this one.
|
||||
if !more {
|
||||
break
|
||||
}
|
||||
}
|
||||
return packageName
|
||||
}
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
package dbtestutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetTestPackageName(t *testing.T) {
|
||||
t.Parallel()
|
||||
packageName := getTestPackageName(t)
|
||||
require.Equal(t, "coderd/database/dbtestutil", packageName)
|
||||
}
|
||||
@@ -1,6 +1,3 @@
|
||||
BEGIN TRANSACTION;
|
||||
SELECT pg_advisory_xact_lock(7283699);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS test_databases (
|
||||
name text PRIMARY KEY,
|
||||
created_at timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
@@ -9,10 +6,3 @@ CREATE TABLE IF NOT EXISTS test_databases (
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS test_databases_process_uuid ON test_databases (process_uuid, dropped_at);
|
||||
|
||||
ALTER TABLE test_databases ADD COLUMN IF NOT EXISTS test_name text;
|
||||
COMMENT ON COLUMN test_databases.test_name IS 'Name of the test that created the database';
|
||||
ALTER TABLE test_databases ADD COLUMN IF NOT EXISTS test_package text;
|
||||
COMMENT ON COLUMN test_databases.test_package IS 'Package of the test that created the database';
|
||||
|
||||
COMMIT;
|
||||
|
||||
Generated
+14
-41
@@ -1828,15 +1828,6 @@ CREATE TABLE tasks (
|
||||
deleted_at timestamp with time zone
|
||||
);
|
||||
|
||||
CREATE VIEW visible_users AS
|
||||
SELECT users.id,
|
||||
users.username,
|
||||
users.name,
|
||||
users.avatar_url
|
||||
FROM users;
|
||||
|
||||
COMMENT ON VIEW visible_users IS 'Visible fields of users are allowed to be joined with other tables for including context of other resources.';
|
||||
|
||||
CREATE TABLE workspace_agents (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
@@ -1987,16 +1978,8 @@ CREATE VIEW tasks_with_status AS
|
||||
END AS status,
|
||||
task_app.workspace_build_number,
|
||||
task_app.workspace_agent_id,
|
||||
task_app.workspace_app_id,
|
||||
task_owner.owner_username,
|
||||
task_owner.owner_name,
|
||||
task_owner.owner_avatar_url
|
||||
FROM (((((tasks
|
||||
CROSS JOIN LATERAL ( SELECT vu.username AS owner_username,
|
||||
vu.name AS owner_name,
|
||||
vu.avatar_url AS owner_avatar_url
|
||||
FROM visible_users vu
|
||||
WHERE (vu.id = tasks.owner_id)) task_owner)
|
||||
task_app.workspace_app_id
|
||||
FROM ((((tasks
|
||||
LEFT JOIN LATERAL ( SELECT task_app_1.workspace_build_number,
|
||||
task_app_1.workspace_agent_id,
|
||||
task_app_1.workspace_app_id
|
||||
@@ -2029,18 +2012,6 @@ CREATE TABLE telemetry_items (
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE telemetry_locks (
|
||||
event_type text NOT NULL,
|
||||
period_ending_at timestamp with time zone NOT NULL,
|
||||
CONSTRAINT telemetry_lock_event_type_constraint CHECK ((event_type = 'aibridge_interceptions_summary'::text))
|
||||
);
|
||||
|
||||
COMMENT ON TABLE telemetry_locks IS 'Telemetry lock tracking table for deduplication of heartbeat events across replicas.';
|
||||
|
||||
COMMENT ON COLUMN telemetry_locks.event_type IS 'The type of event that was sent.';
|
||||
|
||||
COMMENT ON COLUMN telemetry_locks.period_ending_at IS 'The heartbeat period end timestamp.';
|
||||
|
||||
CREATE TABLE template_usage_stats (
|
||||
start_time timestamp with time zone NOT NULL,
|
||||
end_time timestamp with time zone NOT NULL,
|
||||
@@ -2227,6 +2198,15 @@ COMMENT ON COLUMN template_versions.external_auth_providers IS 'IDs of External
|
||||
|
||||
COMMENT ON COLUMN template_versions.message IS 'Message describing the changes in this version of the template, similar to a Git commit message. Like a commit message, this should be a short, high-level description of the changes in this version of the template. This message is immutable and should not be updated after the fact.';
|
||||
|
||||
CREATE VIEW visible_users AS
|
||||
SELECT users.id,
|
||||
users.username,
|
||||
users.name,
|
||||
users.avatar_url
|
||||
FROM users;
|
||||
|
||||
COMMENT ON VIEW visible_users IS 'Visible fields of users are allowed to be joined with other tables for including context of other resources.';
|
||||
|
||||
CREATE VIEW template_version_with_user AS
|
||||
SELECT template_versions.id,
|
||||
template_versions.template_id,
|
||||
@@ -2922,13 +2902,11 @@ CREATE VIEW workspaces_expanded AS
|
||||
templates.name AS template_name,
|
||||
templates.display_name AS template_display_name,
|
||||
templates.icon AS template_icon,
|
||||
templates.description AS template_description,
|
||||
tasks.id AS task_id
|
||||
FROM ((((workspaces
|
||||
templates.description AS template_description
|
||||
FROM (((workspaces
|
||||
JOIN visible_users ON ((workspaces.owner_id = visible_users.id)))
|
||||
JOIN organizations ON ((workspaces.organization_id = organizations.id)))
|
||||
JOIN templates ON ((workspaces.template_id = templates.id)))
|
||||
LEFT JOIN tasks ON ((workspaces.id = tasks.workspace_id)));
|
||||
JOIN templates ON ((workspaces.template_id = templates.id)));
|
||||
|
||||
COMMENT ON VIEW workspaces_expanded IS 'Joins in the display name information such as username, avatar, and organization name.';
|
||||
|
||||
@@ -3112,9 +3090,6 @@ ALTER TABLE ONLY tasks
|
||||
ALTER TABLE ONLY telemetry_items
|
||||
ADD CONSTRAINT telemetry_items_pkey PRIMARY KEY (key);
|
||||
|
||||
ALTER TABLE ONLY telemetry_locks
|
||||
ADD CONSTRAINT telemetry_locks_pkey PRIMARY KEY (event_type, period_ending_at);
|
||||
|
||||
ALTER TABLE ONLY template_usage_stats
|
||||
ADD CONSTRAINT template_usage_stats_pkey PRIMARY KEY (start_time, template_id, user_id);
|
||||
|
||||
@@ -3340,8 +3315,6 @@ CREATE INDEX idx_tailnet_tunnels_dst_id ON tailnet_tunnels USING hash (dst_id);
|
||||
|
||||
CREATE INDEX idx_tailnet_tunnels_src_id ON tailnet_tunnels USING hash (src_id);
|
||||
|
||||
CREATE INDEX idx_telemetry_locks_period_ending_at ON telemetry_locks USING btree (period_ending_at);
|
||||
|
||||
CREATE UNIQUE INDEX idx_template_version_presets_default ON template_version_presets USING btree (template_version_id) WHERE (is_default = true);
|
||||
|
||||
CREATE INDEX idx_template_versions_has_ai_task ON template_versions USING btree (has_ai_task);
|
||||
|
||||
@@ -141,19 +141,13 @@ ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'workspace_proxy:read';
|
||||
ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'workspace_proxy:update';
|
||||
-- End enum extensions
|
||||
|
||||
-- Purge old API keys to speed up the migration for large deployments.
|
||||
-- Note: that problem should be solved in coderd once PR 20863 is released:
|
||||
-- https://github.com/coder/coder/blob/main/coderd/database/dbpurge/dbpurge.go#L85
|
||||
DELETE FROM api_keys WHERE expires_at < NOW() - INTERVAL '7 days';
|
||||
|
||||
-- Add new columns without defaults; backfill; then enforce NOT NULL
|
||||
ALTER TABLE api_keys ADD COLUMN scopes api_key_scope[];
|
||||
ALTER TABLE api_keys ADD COLUMN allow_list text[];
|
||||
|
||||
-- Backfill existing rows for compatibility
|
||||
UPDATE api_keys SET
|
||||
scopes = ARRAY[scope::api_key_scope],
|
||||
allow_list = ARRAY['*:*'];
|
||||
UPDATE api_keys SET scopes = ARRAY[scope::api_key_scope];
|
||||
UPDATE api_keys SET allow_list = ARRAY['*:*'];
|
||||
|
||||
-- Enforce NOT NULL
|
||||
ALTER TABLE api_keys ALTER COLUMN scopes SET NOT NULL;
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
DROP TABLE telemetry_locks;
|
||||
@@ -1,12 +0,0 @@
|
||||
CREATE TABLE telemetry_locks (
|
||||
event_type TEXT NOT NULL CONSTRAINT telemetry_lock_event_type_constraint CHECK (event_type IN ('aibridge_interceptions_summary')),
|
||||
period_ending_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
|
||||
PRIMARY KEY (event_type, period_ending_at)
|
||||
);
|
||||
|
||||
COMMENT ON TABLE telemetry_locks IS 'Telemetry lock tracking table for deduplication of heartbeat events across replicas.';
|
||||
COMMENT ON COLUMN telemetry_locks.event_type IS 'The type of event that was sent.';
|
||||
COMMENT ON COLUMN telemetry_locks.period_ending_at IS 'The heartbeat period end timestamp.';
|
||||
|
||||
CREATE INDEX idx_telemetry_locks_period_ending_at ON telemetry_locks (period_ending_at);
|
||||
@@ -1,74 +0,0 @@
|
||||
-- Drop view from 000390_tasks_with_status_user_fields.up.sql.
|
||||
DROP VIEW IF EXISTS tasks_with_status;
|
||||
|
||||
-- Restore from 000382_add_columns_to_tasks_with_status.up.sql.
|
||||
CREATE VIEW
|
||||
tasks_with_status
|
||||
AS
|
||||
SELECT
|
||||
tasks.*,
|
||||
CASE
|
||||
WHEN tasks.workspace_id IS NULL OR latest_build.job_status IS NULL THEN 'pending'::task_status
|
||||
|
||||
WHEN latest_build.job_status = 'failed' THEN 'error'::task_status
|
||||
|
||||
WHEN latest_build.transition IN ('stop', 'delete')
|
||||
AND latest_build.job_status = 'succeeded' THEN 'paused'::task_status
|
||||
|
||||
WHEN latest_build.transition = 'start'
|
||||
AND latest_build.job_status = 'pending' THEN 'initializing'::task_status
|
||||
|
||||
WHEN latest_build.transition = 'start' AND latest_build.job_status IN ('running', 'succeeded') THEN
|
||||
CASE
|
||||
WHEN agent_status.none THEN 'initializing'::task_status
|
||||
WHEN agent_status.connecting THEN 'initializing'::task_status
|
||||
WHEN agent_status.connected THEN
|
||||
CASE
|
||||
WHEN app_status.any_unhealthy THEN 'error'::task_status
|
||||
WHEN app_status.any_initializing THEN 'initializing'::task_status
|
||||
WHEN app_status.all_healthy_or_disabled THEN 'active'::task_status
|
||||
ELSE 'unknown'::task_status
|
||||
END
|
||||
ELSE 'unknown'::task_status
|
||||
END
|
||||
|
||||
ELSE 'unknown'::task_status
|
||||
END AS status,
|
||||
task_app.*
|
||||
FROM
|
||||
tasks
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT workspace_build_number, workspace_agent_id, workspace_app_id
|
||||
FROM task_workspace_apps task_app
|
||||
WHERE task_id = tasks.id
|
||||
ORDER BY workspace_build_number DESC
|
||||
LIMIT 1
|
||||
) task_app ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT
|
||||
workspace_build.transition,
|
||||
provisioner_job.job_status,
|
||||
workspace_build.job_id
|
||||
FROM workspace_builds workspace_build
|
||||
JOIN provisioner_jobs provisioner_job ON provisioner_job.id = workspace_build.job_id
|
||||
WHERE workspace_build.workspace_id = tasks.workspace_id
|
||||
AND workspace_build.build_number = task_app.workspace_build_number
|
||||
) latest_build ON TRUE
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
COUNT(*) = 0 AS none,
|
||||
bool_or(workspace_agent.lifecycle_state IN ('created', 'starting')) AS connecting,
|
||||
bool_and(workspace_agent.lifecycle_state = 'ready') AS connected
|
||||
FROM workspace_agents workspace_agent
|
||||
WHERE workspace_agent.id = task_app.workspace_agent_id
|
||||
) agent_status
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
bool_or(workspace_app.health = 'unhealthy') AS any_unhealthy,
|
||||
bool_or(workspace_app.health = 'initializing') AS any_initializing,
|
||||
bool_and(workspace_app.health IN ('healthy', 'disabled')) AS all_healthy_or_disabled
|
||||
FROM workspace_apps workspace_app
|
||||
WHERE workspace_app.id = task_app.workspace_app_id
|
||||
) app_status
|
||||
WHERE
|
||||
tasks.deleted_at IS NULL;
|
||||
@@ -1,84 +0,0 @@
|
||||
-- Drop view from 00037_add_columns_to_tasks_with_status.up.sql.
|
||||
DROP VIEW IF EXISTS tasks_with_status;
|
||||
|
||||
-- Add owner_name, owner_avatar_url columns.
|
||||
CREATE VIEW
|
||||
tasks_with_status
|
||||
AS
|
||||
SELECT
|
||||
tasks.*,
|
||||
CASE
|
||||
WHEN tasks.workspace_id IS NULL OR latest_build.job_status IS NULL THEN 'pending'::task_status
|
||||
|
||||
WHEN latest_build.job_status = 'failed' THEN 'error'::task_status
|
||||
|
||||
WHEN latest_build.transition IN ('stop', 'delete')
|
||||
AND latest_build.job_status = 'succeeded' THEN 'paused'::task_status
|
||||
|
||||
WHEN latest_build.transition = 'start'
|
||||
AND latest_build.job_status = 'pending' THEN 'initializing'::task_status
|
||||
|
||||
WHEN latest_build.transition = 'start' AND latest_build.job_status IN ('running', 'succeeded') THEN
|
||||
CASE
|
||||
WHEN agent_status.none THEN 'initializing'::task_status
|
||||
WHEN agent_status.connecting THEN 'initializing'::task_status
|
||||
WHEN agent_status.connected THEN
|
||||
CASE
|
||||
WHEN app_status.any_unhealthy THEN 'error'::task_status
|
||||
WHEN app_status.any_initializing THEN 'initializing'::task_status
|
||||
WHEN app_status.all_healthy_or_disabled THEN 'active'::task_status
|
||||
ELSE 'unknown'::task_status
|
||||
END
|
||||
ELSE 'unknown'::task_status
|
||||
END
|
||||
|
||||
ELSE 'unknown'::task_status
|
||||
END AS status,
|
||||
task_app.*,
|
||||
task_owner.*
|
||||
FROM
|
||||
tasks
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
vu.username AS owner_username,
|
||||
vu.name AS owner_name,
|
||||
vu.avatar_url AS owner_avatar_url
|
||||
FROM visible_users vu
|
||||
WHERE vu.id = tasks.owner_id
|
||||
) task_owner
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT workspace_build_number, workspace_agent_id, workspace_app_id
|
||||
FROM task_workspace_apps task_app
|
||||
WHERE task_id = tasks.id
|
||||
ORDER BY workspace_build_number DESC
|
||||
LIMIT 1
|
||||
) task_app ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT
|
||||
workspace_build.transition,
|
||||
provisioner_job.job_status,
|
||||
workspace_build.job_id
|
||||
FROM workspace_builds workspace_build
|
||||
JOIN provisioner_jobs provisioner_job ON provisioner_job.id = workspace_build.job_id
|
||||
WHERE workspace_build.workspace_id = tasks.workspace_id
|
||||
AND workspace_build.build_number = task_app.workspace_build_number
|
||||
) latest_build ON TRUE
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
COUNT(*) = 0 AS none,
|
||||
bool_or(workspace_agent.lifecycle_state IN ('created', 'starting')) AS connecting,
|
||||
bool_and(workspace_agent.lifecycle_state = 'ready') AS connected
|
||||
FROM workspace_agents workspace_agent
|
||||
WHERE workspace_agent.id = task_app.workspace_agent_id
|
||||
) agent_status
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
bool_or(workspace_app.health = 'unhealthy') AS any_unhealthy,
|
||||
bool_or(workspace_app.health = 'initializing') AS any_initializing,
|
||||
bool_and(workspace_app.health IN ('healthy', 'disabled')) AS all_healthy_or_disabled
|
||||
FROM workspace_apps workspace_app
|
||||
WHERE workspace_app.id = task_app.workspace_app_id
|
||||
) app_status
|
||||
WHERE
|
||||
tasks.deleted_at IS NULL;
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
UPDATE notification_templates
|
||||
SET enabled_by_default = true
|
||||
WHERE id IN (
|
||||
'8c5a4d12-9f7e-4b3a-a1c8-6e4f2d9b5a7c',
|
||||
'3b7e8f1a-4c2d-49a6-b5e9-7f3a1c8d6b4e',
|
||||
'bd4b7168-d05e-4e19-ad0f-3593b77aa90f',
|
||||
'd4a6271c-cced-4ed0-84ad-afd02a9c7799'
|
||||
);
|
||||
@@ -1,8 +0,0 @@
|
||||
UPDATE notification_templates
|
||||
SET enabled_by_default = false
|
||||
WHERE id IN (
|
||||
'8c5a4d12-9f7e-4b3a-a1c8-6e4f2d9b5a7c',
|
||||
'3b7e8f1a-4c2d-49a6-b5e9-7f3a1c8d6b4e',
|
||||
'bd4b7168-d05e-4e19-ad0f-3593b77aa90f',
|
||||
'd4a6271c-cced-4ed0-84ad-afd02a9c7799'
|
||||
);
|
||||
@@ -1,39 +0,0 @@
|
||||
DROP VIEW workspaces_expanded;
|
||||
|
||||
-- Recreate the view from 000354_workspace_acl.up.sql
|
||||
CREATE VIEW workspaces_expanded AS
|
||||
SELECT workspaces.id,
|
||||
workspaces.created_at,
|
||||
workspaces.updated_at,
|
||||
workspaces.owner_id,
|
||||
workspaces.organization_id,
|
||||
workspaces.template_id,
|
||||
workspaces.deleted,
|
||||
workspaces.name,
|
||||
workspaces.autostart_schedule,
|
||||
workspaces.ttl,
|
||||
workspaces.last_used_at,
|
||||
workspaces.dormant_at,
|
||||
workspaces.deleting_at,
|
||||
workspaces.automatic_updates,
|
||||
workspaces.favorite,
|
||||
workspaces.next_start_at,
|
||||
workspaces.group_acl,
|
||||
workspaces.user_acl,
|
||||
visible_users.avatar_url AS owner_avatar_url,
|
||||
visible_users.username AS owner_username,
|
||||
visible_users.name AS owner_name,
|
||||
organizations.name AS organization_name,
|
||||
organizations.display_name AS organization_display_name,
|
||||
organizations.icon AS organization_icon,
|
||||
organizations.description AS organization_description,
|
||||
templates.name AS template_name,
|
||||
templates.display_name AS template_display_name,
|
||||
templates.icon AS template_icon,
|
||||
templates.description AS template_description
|
||||
FROM (((workspaces
|
||||
JOIN visible_users ON ((workspaces.owner_id = visible_users.id)))
|
||||
JOIN organizations ON ((workspaces.organization_id = organizations.id)))
|
||||
JOIN templates ON ((workspaces.template_id = templates.id)));
|
||||
|
||||
COMMENT ON VIEW workspaces_expanded IS 'Joins in the display name information such as username, avatar, and organization name.';
|
||||
@@ -1,42 +0,0 @@
|
||||
DROP VIEW workspaces_expanded;
|
||||
|
||||
-- Add nullable task_id to workspaces_expanded view
|
||||
CREATE VIEW workspaces_expanded AS
|
||||
SELECT workspaces.id,
|
||||
workspaces.created_at,
|
||||
workspaces.updated_at,
|
||||
workspaces.owner_id,
|
||||
workspaces.organization_id,
|
||||
workspaces.template_id,
|
||||
workspaces.deleted,
|
||||
workspaces.name,
|
||||
workspaces.autostart_schedule,
|
||||
workspaces.ttl,
|
||||
workspaces.last_used_at,
|
||||
workspaces.dormant_at,
|
||||
workspaces.deleting_at,
|
||||
workspaces.automatic_updates,
|
||||
workspaces.favorite,
|
||||
workspaces.next_start_at,
|
||||
workspaces.group_acl,
|
||||
workspaces.user_acl,
|
||||
visible_users.avatar_url AS owner_avatar_url,
|
||||
visible_users.username AS owner_username,
|
||||
visible_users.name AS owner_name,
|
||||
organizations.name AS organization_name,
|
||||
organizations.display_name AS organization_display_name,
|
||||
organizations.icon AS organization_icon,
|
||||
organizations.description AS organization_description,
|
||||
templates.name AS template_name,
|
||||
templates.display_name AS template_display_name,
|
||||
templates.icon AS template_icon,
|
||||
templates.description AS template_description,
|
||||
tasks.id AS task_id
|
||||
FROM ((((workspaces
|
||||
JOIN visible_users ON ((workspaces.owner_id = visible_users.id)))
|
||||
JOIN organizations ON ((workspaces.organization_id = organizations.id)))
|
||||
JOIN templates ON ((workspaces.template_id = templates.id)))
|
||||
LEFT JOIN tasks ON ((workspaces.id = tasks.workspace_id)));
|
||||
|
||||
COMMENT ON VIEW workspaces_expanded IS 'Joins in the display name information such as username, avatar, and organization name.';
|
||||
|
||||
coderd/database/migrations/testdata/fixtures/000371_add_api_key_and_oauth2_provider_app_token.up.sql
Vendored
-57
@@ -1,57 +0,0 @@
|
||||
-- Ensure api_keys and oauth2_provider_app_tokens have live data after
|
||||
-- migration 000371 deletes expired rows.
|
||||
INSERT INTO api_keys (
|
||||
id,
|
||||
hashed_secret,
|
||||
user_id,
|
||||
last_used,
|
||||
expires_at,
|
||||
created_at,
|
||||
updated_at,
|
||||
login_type,
|
||||
lifetime_seconds,
|
||||
ip_address,
|
||||
token_name,
|
||||
scopes,
|
||||
allow_list
|
||||
)
|
||||
VALUES (
|
||||
'fixture-api-key',
|
||||
'\xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
|
||||
'30095c71-380b-457a-8995-97b8ee6e5307',
|
||||
NOW() - INTERVAL '1 hour',
|
||||
NOW() + INTERVAL '30 days',
|
||||
NOW() - INTERVAL '1 day',
|
||||
NOW() - INTERVAL '1 day',
|
||||
'password',
|
||||
86400,
|
||||
'0.0.0.0',
|
||||
'fixture-api-key',
|
||||
ARRAY['workspace:read']::api_key_scope[],
|
||||
ARRAY['*:*']
|
||||
)
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
|
||||
INSERT INTO oauth2_provider_app_tokens (
|
||||
id,
|
||||
created_at,
|
||||
expires_at,
|
||||
hash_prefix,
|
||||
refresh_hash,
|
||||
app_secret_id,
|
||||
api_key_id,
|
||||
audience,
|
||||
user_id
|
||||
)
|
||||
VALUES (
|
||||
'9f92f3c9-811f-4f6f-9a1c-3f2eed1f9f15',
|
||||
NOW() - INTERVAL '30 minutes',
|
||||
NOW() + INTERVAL '30 days',
|
||||
CAST('fixture-hash-prefix' AS bytea),
|
||||
CAST('fixture-refresh-hash' AS bytea),
|
||||
'b0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
|
||||
'fixture-api-key',
|
||||
'https://coder.example.com',
|
||||
'30095c71-380b-457a-8995-97b8ee6e5307'
|
||||
)
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
@@ -1,8 +0,0 @@
|
||||
INSERT INTO telemetry_locks (
|
||||
event_type,
|
||||
period_ending_at
|
||||
)
|
||||
VALUES (
|
||||
'aibridge_interceptions_summary',
|
||||
'2025-01-01 00:00:00+00'::timestamptz
|
||||
);
|
||||
@@ -208,7 +208,6 @@ func (s APIKeyScopes) expandRBACScope() (rbac.Scope, error) {
|
||||
for orgID, perms := range expanded.ByOrgID {
|
||||
orgPerms := merged.ByOrgID[orgID]
|
||||
orgPerms.Org = append(orgPerms.Org, perms.Org...)
|
||||
orgPerms.Member = append(orgPerms.Member, perms.Member...)
|
||||
merged.ByOrgID[orgID] = orgPerms
|
||||
}
|
||||
merged.User = append(merged.User, expanded.User...)
|
||||
@@ -221,7 +220,6 @@ func (s APIKeyScopes) expandRBACScope() (rbac.Scope, error) {
|
||||
merged.User = rbac.DeduplicatePermissions(merged.User)
|
||||
for orgID, perms := range merged.ByOrgID {
|
||||
perms.Org = rbac.DeduplicatePermissions(perms.Org)
|
||||
perms.Member = rbac.DeduplicatePermissions(perms.Member)
|
||||
merged.ByOrgID[orgID] = perms
|
||||
}
|
||||
|
||||
|
||||
@@ -321,7 +321,6 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
|
||||
&i.TemplateDisplayName,
|
||||
&i.TemplateIcon,
|
||||
&i.TemplateDescription,
|
||||
&i.TaskID,
|
||||
&i.TemplateVersionID,
|
||||
&i.TemplateVersionName,
|
||||
&i.LatestBuildCompletedAt,
|
||||
|
||||
@@ -4221,9 +4221,6 @@ type Task struct {
|
||||
WorkspaceBuildNumber sql.NullInt32 `db:"workspace_build_number" json:"workspace_build_number"`
|
||||
WorkspaceAgentID uuid.NullUUID `db:"workspace_agent_id" json:"workspace_agent_id"`
|
||||
WorkspaceAppID uuid.NullUUID `db:"workspace_app_id" json:"workspace_app_id"`
|
||||
OwnerUsername string `db:"owner_username" json:"owner_username"`
|
||||
OwnerName string `db:"owner_name" json:"owner_name"`
|
||||
OwnerAvatarUrl string `db:"owner_avatar_url" json:"owner_avatar_url"`
|
||||
}
|
||||
|
||||
type TaskTable struct {
|
||||
@@ -4253,14 +4250,6 @@ type TelemetryItem struct {
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
// Telemetry lock tracking table for deduplication of heartbeat events across replicas.
|
||||
type TelemetryLock struct {
|
||||
// The type of event that was sent.
|
||||
EventType string `db:"event_type" json:"event_type"`
|
||||
// The heartbeat period end timestamp.
|
||||
PeriodEndingAt time.Time `db:"period_ending_at" json:"period_ending_at"`
|
||||
}
|
||||
|
||||
// Joins in the display name information such as username, avatar, and organization name.
|
||||
type Template struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
@@ -4663,7 +4652,6 @@ type Workspace struct {
|
||||
TemplateDisplayName string `db:"template_display_name" json:"template_display_name"`
|
||||
TemplateIcon string `db:"template_icon" json:"template_icon"`
|
||||
TemplateDescription string `db:"template_description" json:"template_description"`
|
||||
TaskID uuid.NullUUID `db:"task_id" json:"task_id"`
|
||||
}
|
||||
|
||||
type WorkspaceAgent struct {
|
||||
|
||||
@@ -60,9 +60,6 @@ type sqlcQuerier interface {
|
||||
BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error
|
||||
BulkMarkNotificationMessagesFailed(ctx context.Context, arg BulkMarkNotificationMessagesFailedParams) (int64, error)
|
||||
BulkMarkNotificationMessagesSent(ctx context.Context, arg BulkMarkNotificationMessagesSentParams) (int64, error)
|
||||
// Calculates the telemetry summary for a given provider, model, and client
|
||||
// combination for telemetry reporting.
|
||||
CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg CalculateAIBridgeInterceptionsTelemetrySummaryParams) (CalculateAIBridgeInterceptionsTelemetrySummaryRow, error)
|
||||
ClaimPrebuiltWorkspace(ctx context.Context, arg ClaimPrebuiltWorkspaceParams) (ClaimPrebuiltWorkspaceRow, error)
|
||||
CleanTailnetCoordinators(ctx context.Context) error
|
||||
CleanTailnetLostPeers(ctx context.Context) error
|
||||
@@ -110,8 +107,6 @@ type sqlcQuerier interface {
|
||||
// A provisioner daemon with "zeroed" last_seen_at column indicates possible
|
||||
// connectivity issues (no provisioner daemon activity since registration).
|
||||
DeleteOldProvisionerDaemons(ctx context.Context) error
|
||||
// Deletes old telemetry locks from the telemetry_locks table.
|
||||
DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error
|
||||
// If an agent hasn't connected in the last 7 days, we purge it's logs.
|
||||
// Exception: if the logs are related to the latest build, we keep those around.
|
||||
// Logs can take up a lot of space, so it's important we clean up frequently.
|
||||
@@ -269,9 +264,6 @@ type sqlcQuerier interface {
|
||||
GetOrganizationResourceCountByID(ctx context.Context, organizationID uuid.UUID) (GetOrganizationResourceCountByIDRow, error)
|
||||
GetOrganizations(ctx context.Context, arg GetOrganizationsParams) ([]Organization, error)
|
||||
GetOrganizationsByUserID(ctx context.Context, arg GetOrganizationsByUserIDParams) ([]Organization, error)
|
||||
// GetOrganizationsWithPrebuildStatus returns organizations with prebuilds configured and their
|
||||
// membership status for the prebuilds system user (org membership, group existence, group membership).
|
||||
GetOrganizationsWithPrebuildStatus(ctx context.Context, arg GetOrganizationsWithPrebuildStatusParams) ([]GetOrganizationsWithPrebuildStatusRow, error)
|
||||
GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error)
|
||||
GetPrebuildMetrics(ctx context.Context) ([]GetPrebuildMetricsRow, error)
|
||||
GetPrebuildsSettings(ctx context.Context) (string, error)
|
||||
@@ -567,12 +559,6 @@ type sqlcQuerier interface {
|
||||
InsertReplica(ctx context.Context, arg InsertReplicaParams) (Replica, error)
|
||||
InsertTask(ctx context.Context, arg InsertTaskParams) (TaskTable, error)
|
||||
InsertTelemetryItemIfNotExists(ctx context.Context, arg InsertTelemetryItemIfNotExistsParams) error
|
||||
// Inserts a new lock row into the telemetry_locks table. Replicas should call
|
||||
// this function prior to attempting to generate or publish a heartbeat event to
|
||||
// the telemetry service.
|
||||
// If the query returns a duplicate primary key error, the replica should not
|
||||
// attempt to generate or publish the event to the telemetry service.
|
||||
InsertTelemetryLock(ctx context.Context, arg InsertTelemetryLockParams) error
|
||||
InsertTemplate(ctx context.Context, arg InsertTemplateParams) error
|
||||
InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) error
|
||||
InsertTemplateVersionParameter(ctx context.Context, arg InsertTemplateVersionParameterParams) (TemplateVersionParameter, error)
|
||||
@@ -609,9 +595,6 @@ type sqlcQuerier interface {
|
||||
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
|
||||
InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error)
|
||||
ListAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams) ([]ListAIBridgeInterceptionsRow, error)
|
||||
// Finds all unique AIBridge interception telemetry summaries combinations
|
||||
// (provider, model, client) in the given timeframe for telemetry reporting.
|
||||
ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error)
|
||||
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
|
||||
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
|
||||
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
|
||||
@@ -649,7 +632,6 @@ type sqlcQuerier interface {
|
||||
// This will always work regardless of the current state of the template version.
|
||||
UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error
|
||||
UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error
|
||||
UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error)
|
||||
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
|
||||
UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error)
|
||||
UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error)
|
||||
@@ -670,7 +652,7 @@ type sqlcQuerier interface {
|
||||
// Cancels all pending provisioner jobs for prebuilt workspaces on a specific preset from an
|
||||
// inactive template version.
|
||||
// This is an optimization to clean up stale pending jobs.
|
||||
UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg UpdatePrebuildProvisionerJobWithCancelParams) ([]UpdatePrebuildProvisionerJobWithCancelRow, error)
|
||||
UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error)
|
||||
UpdatePresetPrebuildStatus(ctx context.Context, arg UpdatePresetPrebuildStatusParams) error
|
||||
UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg UpdateProvisionerDaemonLastSeenAtParams) error
|
||||
UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error
|
||||
|
||||
@@ -7248,9 +7248,7 @@ func TestTaskNameUniqueness(t *testing.T) {
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
taskID := uuid.New()
|
||||
task, err := db.InsertTask(ctx, database.InsertTaskParams{
|
||||
ID: taskID,
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: tt.ownerID,
|
||||
Name: tt.taskName,
|
||||
@@ -7265,7 +7263,6 @@ func TestTaskNameUniqueness(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, uuid.Nil, task.ID)
|
||||
require.NotEqual(t, task1.ID, task.ID)
|
||||
require.Equal(t, taskID, task.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -7727,68 +7724,3 @@ func TestUpdateTaskWorkspaceID(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAIBridgeInterceptionEnded(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
t.Run("NonExistingInterception", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
got, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: uuid.New(),
|
||||
EndedAt: time.Now(),
|
||||
})
|
||||
require.ErrorContains(t, err, "no rows in result set")
|
||||
require.EqualValues(t, database.AIBridgeInterception{}, got)
|
||||
})
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
interceptions := []database.AIBridgeInterception{}
|
||||
|
||||
for _, uid := range []uuid.UUID{{1}, {2}, {3}} {
|
||||
insertParams := database.InsertAIBridgeInterceptionParams{
|
||||
ID: uid,
|
||||
InitiatorID: user.ID,
|
||||
Metadata: json.RawMessage("{}"),
|
||||
}
|
||||
|
||||
intc, err := db.InsertAIBridgeInterception(ctx, insertParams)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uid, intc.ID)
|
||||
require.False(t, intc.EndedAt.Valid)
|
||||
interceptions = append(interceptions, intc)
|
||||
}
|
||||
|
||||
intc0 := interceptions[0]
|
||||
endedAt := time.Now()
|
||||
// Mark first interception as done
|
||||
updated, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: intc0.ID,
|
||||
EndedAt: endedAt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, updated.ID, intc0.ID)
|
||||
require.True(t, updated.EndedAt.Valid)
|
||||
require.WithinDuration(t, endedAt, updated.EndedAt.Time, 5*time.Second)
|
||||
|
||||
// Updating first interception again should fail
|
||||
updated, err = db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: intc0.ID,
|
||||
EndedAt: endedAt.Add(time.Hour),
|
||||
})
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
|
||||
// Other interceptions should not have ended_at set
|
||||
for _, intc := range interceptions[1:] {
|
||||
got, err := db.GetAIBridgeInterceptionByID(ctx, intc.ID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, got.EndedAt.Valid)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
+25
-419
@@ -111,164 +111,6 @@ func (q *sqlQuerier) ActivityBumpWorkspace(ctx context.Context, arg ActivityBump
|
||||
return err
|
||||
}
|
||||
|
||||
const calculateAIBridgeInterceptionsTelemetrySummary = `-- name: CalculateAIBridgeInterceptionsTelemetrySummary :one
|
||||
WITH interceptions_in_range AS (
|
||||
-- Get all matching interceptions in the given timeframe.
|
||||
SELECT
|
||||
id,
|
||||
initiator_id,
|
||||
(ended_at - started_at) AS duration
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
provider = $1::text
|
||||
AND model = $2::text
|
||||
-- TODO: use the client value once we have it (see https://github.com/coder/aibridge/issues/31)
|
||||
AND 'unknown' = $3::text
|
||||
AND ended_at IS NOT NULL -- incomplete interceptions are not included in summaries
|
||||
AND ended_at >= $4::timestamptz
|
||||
AND ended_at < $5::timestamptz
|
||||
),
|
||||
interception_counts AS (
|
||||
SELECT
|
||||
COUNT(id) AS interception_count,
|
||||
COUNT(DISTINCT initiator_id) AS unique_initiator_count
|
||||
FROM
|
||||
interceptions_in_range
|
||||
),
|
||||
duration_percentiles AS (
|
||||
SELECT
|
||||
(COALESCE(PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p50_millis,
|
||||
(COALESCE(PERCENTILE_CONT(0.90) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p90_millis,
|
||||
(COALESCE(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p95_millis,
|
||||
(COALESCE(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p99_millis
|
||||
FROM
|
||||
interceptions_in_range
|
||||
),
|
||||
token_aggregates AS (
|
||||
SELECT
|
||||
COALESCE(SUM(tu.input_tokens), 0) AS token_count_input,
|
||||
COALESCE(SUM(tu.output_tokens), 0) AS token_count_output,
|
||||
-- Cached tokens are stored in metadata JSON, extract if available.
|
||||
-- Read tokens may be stored in:
|
||||
-- - cache_read_input (Anthropic)
|
||||
-- - prompt_cached (OpenAI)
|
||||
COALESCE(SUM(
|
||||
COALESCE((tu.metadata->>'cache_read_input')::bigint, 0) +
|
||||
COALESCE((tu.metadata->>'prompt_cached')::bigint, 0)
|
||||
), 0) AS token_count_cached_read,
|
||||
-- Written tokens may be stored in:
|
||||
-- - cache_creation_input (Anthropic)
|
||||
-- Note that cache_ephemeral_5m_input and cache_ephemeral_1h_input on
|
||||
-- Anthropic are included in the cache_creation_input field.
|
||||
COALESCE(SUM(
|
||||
COALESCE((tu.metadata->>'cache_creation_input')::bigint, 0)
|
||||
), 0) AS token_count_cached_written,
|
||||
COUNT(tu.id) AS token_usages_count
|
||||
FROM
|
||||
interceptions_in_range i
|
||||
LEFT JOIN
|
||||
aibridge_token_usages tu ON i.id = tu.interception_id
|
||||
),
|
||||
prompt_aggregates AS (
|
||||
SELECT
|
||||
COUNT(up.id) AS user_prompts_count
|
||||
FROM
|
||||
interceptions_in_range i
|
||||
LEFT JOIN
|
||||
aibridge_user_prompts up ON i.id = up.interception_id
|
||||
),
|
||||
tool_aggregates AS (
|
||||
SELECT
|
||||
COUNT(tu.id) FILTER (WHERE tu.injected = true) AS tool_calls_count_injected,
|
||||
COUNT(tu.id) FILTER (WHERE tu.injected = false) AS tool_calls_count_non_injected,
|
||||
COUNT(tu.id) FILTER (WHERE tu.injected = true AND tu.invocation_error IS NOT NULL) AS injected_tool_call_error_count
|
||||
FROM
|
||||
interceptions_in_range i
|
||||
LEFT JOIN
|
||||
aibridge_tool_usages tu ON i.id = tu.interception_id
|
||||
)
|
||||
SELECT
|
||||
ic.interception_count::bigint AS interception_count,
|
||||
dp.interception_duration_p50_millis::bigint AS interception_duration_p50_millis,
|
||||
dp.interception_duration_p90_millis::bigint AS interception_duration_p90_millis,
|
||||
dp.interception_duration_p95_millis::bigint AS interception_duration_p95_millis,
|
||||
dp.interception_duration_p99_millis::bigint AS interception_duration_p99_millis,
|
||||
ic.unique_initiator_count::bigint AS unique_initiator_count,
|
||||
pa.user_prompts_count::bigint AS user_prompts_count,
|
||||
tok_agg.token_usages_count::bigint AS token_usages_count,
|
||||
tok_agg.token_count_input::bigint AS token_count_input,
|
||||
tok_agg.token_count_output::bigint AS token_count_output,
|
||||
tok_agg.token_count_cached_read::bigint AS token_count_cached_read,
|
||||
tok_agg.token_count_cached_written::bigint AS token_count_cached_written,
|
||||
tool_agg.tool_calls_count_injected::bigint AS tool_calls_count_injected,
|
||||
tool_agg.tool_calls_count_non_injected::bigint AS tool_calls_count_non_injected,
|
||||
tool_agg.injected_tool_call_error_count::bigint AS injected_tool_call_error_count
|
||||
FROM
|
||||
interception_counts ic,
|
||||
duration_percentiles dp,
|
||||
token_aggregates tok_agg,
|
||||
prompt_aggregates pa,
|
||||
tool_aggregates tool_agg
|
||||
`
|
||||
|
||||
type CalculateAIBridgeInterceptionsTelemetrySummaryParams struct {
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
Model string `db:"model" json:"model"`
|
||||
Client string `db:"client" json:"client"`
|
||||
EndedAtAfter time.Time `db:"ended_at_after" json:"ended_at_after"`
|
||||
EndedAtBefore time.Time `db:"ended_at_before" json:"ended_at_before"`
|
||||
}
|
||||
|
||||
type CalculateAIBridgeInterceptionsTelemetrySummaryRow struct {
|
||||
InterceptionCount int64 `db:"interception_count" json:"interception_count"`
|
||||
InterceptionDurationP50Millis int64 `db:"interception_duration_p50_millis" json:"interception_duration_p50_millis"`
|
||||
InterceptionDurationP90Millis int64 `db:"interception_duration_p90_millis" json:"interception_duration_p90_millis"`
|
||||
InterceptionDurationP95Millis int64 `db:"interception_duration_p95_millis" json:"interception_duration_p95_millis"`
|
||||
InterceptionDurationP99Millis int64 `db:"interception_duration_p99_millis" json:"interception_duration_p99_millis"`
|
||||
UniqueInitiatorCount int64 `db:"unique_initiator_count" json:"unique_initiator_count"`
|
||||
UserPromptsCount int64 `db:"user_prompts_count" json:"user_prompts_count"`
|
||||
TokenUsagesCount int64 `db:"token_usages_count" json:"token_usages_count"`
|
||||
TokenCountInput int64 `db:"token_count_input" json:"token_count_input"`
|
||||
TokenCountOutput int64 `db:"token_count_output" json:"token_count_output"`
|
||||
TokenCountCachedRead int64 `db:"token_count_cached_read" json:"token_count_cached_read"`
|
||||
TokenCountCachedWritten int64 `db:"token_count_cached_written" json:"token_count_cached_written"`
|
||||
ToolCallsCountInjected int64 `db:"tool_calls_count_injected" json:"tool_calls_count_injected"`
|
||||
ToolCallsCountNonInjected int64 `db:"tool_calls_count_non_injected" json:"tool_calls_count_non_injected"`
|
||||
InjectedToolCallErrorCount int64 `db:"injected_tool_call_error_count" json:"injected_tool_call_error_count"`
|
||||
}
|
||||
|
||||
// Calculates the telemetry summary for a given provider, model, and client
|
||||
// combination for telemetry reporting.
|
||||
func (q *sqlQuerier) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg CalculateAIBridgeInterceptionsTelemetrySummaryParams) (CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
|
||||
row := q.db.QueryRowContext(ctx, calculateAIBridgeInterceptionsTelemetrySummary,
|
||||
arg.Provider,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.EndedAtAfter,
|
||||
arg.EndedAtBefore,
|
||||
)
|
||||
var i CalculateAIBridgeInterceptionsTelemetrySummaryRow
|
||||
err := row.Scan(
|
||||
&i.InterceptionCount,
|
||||
&i.InterceptionDurationP50Millis,
|
||||
&i.InterceptionDurationP90Millis,
|
||||
&i.InterceptionDurationP95Millis,
|
||||
&i.InterceptionDurationP99Millis,
|
||||
&i.UniqueInitiatorCount,
|
||||
&i.UserPromptsCount,
|
||||
&i.TokenUsagesCount,
|
||||
&i.TokenCountInput,
|
||||
&i.TokenCountOutput,
|
||||
&i.TokenCountCachedRead,
|
||||
&i.TokenCountCachedWritten,
|
||||
&i.ToolCallsCountInjected,
|
||||
&i.ToolCallsCountNonInjected,
|
||||
&i.InjectedToolCallErrorCount,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const countAIBridgeInterceptions = `-- name: CountAIBridgeInterceptions :one
|
||||
SELECT
|
||||
COUNT(*)
|
||||
@@ -805,57 +647,6 @@ func (q *sqlQuerier) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBr
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listAIBridgeInterceptionsTelemetrySummaries = `-- name: ListAIBridgeInterceptionsTelemetrySummaries :many
|
||||
SELECT
|
||||
DISTINCT ON (provider, model, client)
|
||||
provider,
|
||||
model,
|
||||
-- TODO: use the client value once we have it (see https://github.com/coder/aibridge/issues/31)
|
||||
'unknown' AS client
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
ended_at IS NOT NULL -- incomplete interceptions are not included in summaries
|
||||
AND ended_at >= $1::timestamptz
|
||||
AND ended_at < $2::timestamptz
|
||||
`
|
||||
|
||||
type ListAIBridgeInterceptionsTelemetrySummariesParams struct {
|
||||
EndedAtAfter time.Time `db:"ended_at_after" json:"ended_at_after"`
|
||||
EndedAtBefore time.Time `db:"ended_at_before" json:"ended_at_before"`
|
||||
}
|
||||
|
||||
type ListAIBridgeInterceptionsTelemetrySummariesRow struct {
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
Model string `db:"model" json:"model"`
|
||||
Client string `db:"client" json:"client"`
|
||||
}
|
||||
|
||||
// Finds all unique AIBridge interception telemetry summaries combinations
|
||||
// (provider, model, client) in the given timeframe for telemetry reporting.
|
||||
func (q *sqlQuerier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, listAIBridgeInterceptionsTelemetrySummaries, arg.EndedAtAfter, arg.EndedAtBefore)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ListAIBridgeInterceptionsTelemetrySummariesRow
|
||||
for rows.Next() {
|
||||
var i ListAIBridgeInterceptionsTelemetrySummariesRow
|
||||
if err := rows.Scan(&i.Provider, &i.Model, &i.Client); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listAIBridgeTokenUsagesByInterceptionIDs = `-- name: ListAIBridgeTokenUsagesByInterceptionIDs :many
|
||||
SELECT
|
||||
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
|
||||
@@ -987,35 +778,6 @@ func (q *sqlQuerier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Contex
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateAIBridgeInterceptionEnded = `-- name: UpdateAIBridgeInterceptionEnded :one
|
||||
UPDATE aibridge_interceptions
|
||||
SET ended_at = $1::timestamptz
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
AND ended_at IS NULL
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at
|
||||
`
|
||||
|
||||
type UpdateAIBridgeInterceptionEndedParams struct {
|
||||
EndedAt time.Time `db:"ended_at" json:"ended_at"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateAIBridgeInterceptionEnded, arg.EndedAt, arg.ID)
|
||||
var i AIBridgeInterception
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.InitiatorID,
|
||||
&i.Provider,
|
||||
&i.Model,
|
||||
&i.StartedAt,
|
||||
&i.Metadata,
|
||||
&i.EndedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteAPIKeyByID = `-- name: DeleteAPIKeyByID :exec
|
||||
DELETE FROM
|
||||
api_keys
|
||||
@@ -8285,93 +8047,6 @@ func (q *sqlQuerier) FindMatchingPresetID(ctx context.Context, arg FindMatchingP
|
||||
return template_version_preset_id, err
|
||||
}
|
||||
|
||||
const getOrganizationsWithPrebuildStatus = `-- name: GetOrganizationsWithPrebuildStatus :many
|
||||
WITH orgs_with_prebuilds AS (
|
||||
-- Get unique organizations that have presets with prebuilds configured
|
||||
SELECT DISTINCT o.id, o.name
|
||||
FROM organizations o
|
||||
INNER JOIN templates t ON t.organization_id = o.id
|
||||
INNER JOIN template_versions tv ON tv.template_id = t.id
|
||||
INNER JOIN template_version_presets tvp ON tvp.template_version_id = tv.id
|
||||
WHERE tvp.desired_instances IS NOT NULL
|
||||
),
|
||||
prebuild_user_membership AS (
|
||||
-- Check if the user is a member of the organizations
|
||||
SELECT om.organization_id
|
||||
FROM organization_members om
|
||||
INNER JOIN orgs_with_prebuilds owp ON owp.id = om.organization_id
|
||||
WHERE om.user_id = $1::uuid
|
||||
),
|
||||
prebuild_groups AS (
|
||||
-- Check if the organizations have the prebuilds group
|
||||
SELECT g.organization_id, g.id as group_id
|
||||
FROM groups g
|
||||
INNER JOIN orgs_with_prebuilds owp ON owp.id = g.organization_id
|
||||
WHERE g.name = $2::text
|
||||
),
|
||||
prebuild_group_membership AS (
|
||||
-- Check if the user is in the prebuilds group
|
||||
SELECT pg.organization_id
|
||||
FROM prebuild_groups pg
|
||||
INNER JOIN group_members gm ON gm.group_id = pg.group_id
|
||||
WHERE gm.user_id = $1::uuid
|
||||
)
|
||||
SELECT
|
||||
owp.id AS organization_id,
|
||||
owp.name AS organization_name,
|
||||
(pum.organization_id IS NOT NULL)::boolean AS has_prebuild_user,
|
||||
pg.group_id AS prebuilds_group_id,
|
||||
(pgm.organization_id IS NOT NULL)::boolean AS has_prebuild_user_in_group
|
||||
FROM orgs_with_prebuilds owp
|
||||
LEFT JOIN prebuild_groups pg ON pg.organization_id = owp.id
|
||||
LEFT JOIN prebuild_user_membership pum ON pum.organization_id = owp.id
|
||||
LEFT JOIN prebuild_group_membership pgm ON pgm.organization_id = owp.id
|
||||
`
|
||||
|
||||
type GetOrganizationsWithPrebuildStatusParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
GroupName string `db:"group_name" json:"group_name"`
|
||||
}
|
||||
|
||||
type GetOrganizationsWithPrebuildStatusRow struct {
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
OrganizationName string `db:"organization_name" json:"organization_name"`
|
||||
HasPrebuildUser bool `db:"has_prebuild_user" json:"has_prebuild_user"`
|
||||
PrebuildsGroupID uuid.NullUUID `db:"prebuilds_group_id" json:"prebuilds_group_id"`
|
||||
HasPrebuildUserInGroup bool `db:"has_prebuild_user_in_group" json:"has_prebuild_user_in_group"`
|
||||
}
|
||||
|
||||
// GetOrganizationsWithPrebuildStatus returns organizations with prebuilds configured and their
|
||||
// membership status for the prebuilds system user (org membership, group existence, group membership).
|
||||
func (q *sqlQuerier) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg GetOrganizationsWithPrebuildStatusParams) ([]GetOrganizationsWithPrebuildStatusRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getOrganizationsWithPrebuildStatus, arg.UserID, arg.GroupName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetOrganizationsWithPrebuildStatusRow
|
||||
for rows.Next() {
|
||||
var i GetOrganizationsWithPrebuildStatusRow
|
||||
if err := rows.Scan(
|
||||
&i.OrganizationID,
|
||||
&i.OrganizationName,
|
||||
&i.HasPrebuildUser,
|
||||
&i.PrebuildsGroupID,
|
||||
&i.HasPrebuildUserInGroup,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getPrebuildMetrics = `-- name: GetPrebuildMetrics :many
|
||||
SELECT
|
||||
t.name as template_name,
|
||||
@@ -8774,8 +8449,12 @@ func (q *sqlQuerier) GetTemplatePresetsWithPrebuilds(ctx context.Context, templa
|
||||
}
|
||||
|
||||
const updatePrebuildProvisionerJobWithCancel = `-- name: UpdatePrebuildProvisionerJobWithCancel :many
|
||||
WITH jobs_to_cancel AS (
|
||||
SELECT pj.id, w.id AS workspace_id, w.template_id, wpb.template_version_preset_id
|
||||
UPDATE provisioner_jobs
|
||||
SET
|
||||
canceled_at = $1::timestamptz,
|
||||
completed_at = $1::timestamptz
|
||||
WHERE id IN (
|
||||
SELECT pj.id
|
||||
FROM provisioner_jobs pj
|
||||
INNER JOIN workspace_prebuild_builds wpb ON wpb.job_id = pj.id
|
||||
INNER JOIN workspaces w ON w.id = wpb.workspace_id
|
||||
@@ -8794,13 +8473,7 @@ WITH jobs_to_cancel AS (
|
||||
AND pj.canceled_at IS NULL
|
||||
AND pj.completed_at IS NULL
|
||||
)
|
||||
UPDATE provisioner_jobs
|
||||
SET
|
||||
canceled_at = $1::timestamptz,
|
||||
completed_at = $1::timestamptz
|
||||
FROM jobs_to_cancel
|
||||
WHERE provisioner_jobs.id = jobs_to_cancel.id
|
||||
RETURNING jobs_to_cancel.id, jobs_to_cancel.workspace_id, jobs_to_cancel.template_id, jobs_to_cancel.template_version_preset_id
|
||||
RETURNING id
|
||||
`
|
||||
|
||||
type UpdatePrebuildProvisionerJobWithCancelParams struct {
|
||||
@@ -8808,34 +8481,22 @@ type UpdatePrebuildProvisionerJobWithCancelParams struct {
|
||||
PresetID uuid.NullUUID `db:"preset_id" json:"preset_id"`
|
||||
}
|
||||
|
||||
type UpdatePrebuildProvisionerJobWithCancelRow struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
|
||||
TemplateID uuid.UUID `db:"template_id" json:"template_id"`
|
||||
TemplateVersionPresetID uuid.NullUUID `db:"template_version_preset_id" json:"template_version_preset_id"`
|
||||
}
|
||||
|
||||
// Cancels all pending provisioner jobs for prebuilt workspaces on a specific preset from an
|
||||
// inactive template version.
|
||||
// This is an optimization to clean up stale pending jobs.
|
||||
func (q *sqlQuerier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg UpdatePrebuildProvisionerJobWithCancelParams) ([]UpdatePrebuildProvisionerJobWithCancelRow, error) {
|
||||
func (q *sqlQuerier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
|
||||
rows, err := q.db.QueryContext(ctx, updatePrebuildProvisionerJobWithCancel, arg.Now, arg.PresetID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []UpdatePrebuildProvisionerJobWithCancelRow
|
||||
var items []uuid.UUID
|
||||
for rows.Next() {
|
||||
var i UpdatePrebuildProvisionerJobWithCancelRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.WorkspaceID,
|
||||
&i.TemplateID,
|
||||
&i.TemplateVersionPresetID,
|
||||
); err != nil {
|
||||
var id uuid.UUID
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
items = append(items, id)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
@@ -13065,7 +12726,7 @@ func (q *sqlQuerier) DeleteTask(ctx context.Context, arg DeleteTaskParams) (Task
|
||||
}
|
||||
|
||||
const getTaskByID = `-- name: GetTaskByID :one
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE id = $1::uuid
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id FROM tasks_with_status WHERE id = $1::uuid
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error) {
|
||||
@@ -13086,15 +12747,12 @@ func (q *sqlQuerier) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error
|
||||
&i.WorkspaceBuildNumber,
|
||||
&i.WorkspaceAgentID,
|
||||
&i.WorkspaceAppID,
|
||||
&i.OwnerUsername,
|
||||
&i.OwnerName,
|
||||
&i.OwnerAvatarUrl,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTaskByWorkspaceID = `-- name: GetTaskByWorkspaceID :one
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE workspace_id = $1::uuid
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id FROM tasks_with_status WHERE workspace_id = $1::uuid
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (Task, error) {
|
||||
@@ -13115,9 +12773,6 @@ func (q *sqlQuerier) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.
|
||||
&i.WorkspaceBuildNumber,
|
||||
&i.WorkspaceAgentID,
|
||||
&i.WorkspaceAppID,
|
||||
&i.OwnerUsername,
|
||||
&i.OwnerName,
|
||||
&i.OwnerAvatarUrl,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -13126,12 +12781,11 @@ const insertTask = `-- name: InsertTask :one
|
||||
INSERT INTO tasks
|
||||
(id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
(gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at
|
||||
`
|
||||
|
||||
type InsertTaskParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
@@ -13144,7 +12798,6 @@ type InsertTaskParams struct {
|
||||
|
||||
func (q *sqlQuerier) InsertTask(ctx context.Context, arg InsertTaskParams) (TaskTable, error) {
|
||||
row := q.db.QueryRowContext(ctx, insertTask,
|
||||
arg.ID,
|
||||
arg.OrganizationID,
|
||||
arg.OwnerID,
|
||||
arg.Name,
|
||||
@@ -13171,7 +12824,7 @@ func (q *sqlQuerier) InsertTask(ctx context.Context, arg InsertTaskParams) (Task
|
||||
}
|
||||
|
||||
const listTasks = `-- name: ListTasks :many
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id, owner_username, owner_name, owner_avatar_url FROM tasks_with_status tws
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id FROM tasks_with_status tws
|
||||
WHERE tws.deleted_at IS NULL
|
||||
AND CASE WHEN $1::UUID != '00000000-0000-0000-0000-000000000000' THEN tws.owner_id = $1::UUID ELSE TRUE END
|
||||
AND CASE WHEN $2::UUID != '00000000-0000-0000-0000-000000000000' THEN tws.organization_id = $2::UUID ELSE TRUE END
|
||||
@@ -13209,9 +12862,6 @@ func (q *sqlQuerier) ListTasks(ctx context.Context, arg ListTasksParams) ([]Task
|
||||
&i.WorkspaceBuildNumber,
|
||||
&i.WorkspaceAgentID,
|
||||
&i.WorkspaceAppID,
|
||||
&i.OwnerUsername,
|
||||
&i.OwnerName,
|
||||
&i.OwnerAvatarUrl,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -13385,41 +13035,6 @@ func (q *sqlQuerier) UpsertTelemetryItem(ctx context.Context, arg UpsertTelemetr
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteOldTelemetryLocks = `-- name: DeleteOldTelemetryLocks :exec
|
||||
DELETE FROM
|
||||
telemetry_locks
|
||||
WHERE
|
||||
period_ending_at < $1::timestamptz
|
||||
`
|
||||
|
||||
// Deletes old telemetry locks from the telemetry_locks table.
|
||||
func (q *sqlQuerier) DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteOldTelemetryLocks, periodEndingAtBefore)
|
||||
return err
|
||||
}
|
||||
|
||||
const insertTelemetryLock = `-- name: InsertTelemetryLock :exec
|
||||
INSERT INTO
|
||||
telemetry_locks (event_type, period_ending_at)
|
||||
VALUES
|
||||
($1, $2)
|
||||
`
|
||||
|
||||
type InsertTelemetryLockParams struct {
|
||||
EventType string `db:"event_type" json:"event_type"`
|
||||
PeriodEndingAt time.Time `db:"period_ending_at" json:"period_ending_at"`
|
||||
}
|
||||
|
||||
// Inserts a new lock row into the telemetry_locks table. Replicas should call
|
||||
// this function prior to attempting to generate or publish a heartbeat event to
|
||||
// the telemetry service.
|
||||
// If the query returns a duplicate primary key error, the replica should not
|
||||
// attempt to generate or publish the event to the telemetry service.
|
||||
func (q *sqlQuerier) InsertTelemetryLock(ctx context.Context, arg InsertTelemetryLockParams) error {
|
||||
_, err := q.db.ExecContext(ctx, insertTelemetryLock, arg.EventType, arg.PeriodEndingAt)
|
||||
return err
|
||||
}
|
||||
|
||||
const getTemplateAverageBuildTime = `-- name: GetTemplateAverageBuildTime :one
|
||||
WITH build_times AS (
|
||||
SELECT
|
||||
@@ -21927,7 +21542,7 @@ func (q *sqlQuerier) GetWorkspaceACLByID(ctx context.Context, id uuid.UUID) (Get
|
||||
|
||||
const getWorkspaceByAgentID = `-- name: GetWorkspaceByAgentID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
|
||||
FROM
|
||||
workspaces_expanded as workspaces
|
||||
WHERE
|
||||
@@ -21988,14 +21603,13 @@ func (q *sqlQuerier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUI
|
||||
&i.TemplateDisplayName,
|
||||
&i.TemplateIcon,
|
||||
&i.TemplateDescription,
|
||||
&i.TaskID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getWorkspaceByID = `-- name: GetWorkspaceByID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
|
||||
FROM
|
||||
workspaces_expanded
|
||||
WHERE
|
||||
@@ -22037,14 +21651,13 @@ func (q *sqlQuerier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (Worksp
|
||||
&i.TemplateDisplayName,
|
||||
&i.TemplateIcon,
|
||||
&i.TemplateDescription,
|
||||
&i.TaskID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getWorkspaceByOwnerIDAndName = `-- name: GetWorkspaceByOwnerIDAndName :one
|
||||
SELECT
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
|
||||
FROM
|
||||
workspaces_expanded as workspaces
|
||||
WHERE
|
||||
@@ -22093,14 +21706,13 @@ func (q *sqlQuerier) GetWorkspaceByOwnerIDAndName(ctx context.Context, arg GetWo
|
||||
&i.TemplateDisplayName,
|
||||
&i.TemplateIcon,
|
||||
&i.TemplateDescription,
|
||||
&i.TaskID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getWorkspaceByResourceID = `-- name: GetWorkspaceByResourceID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
|
||||
FROM
|
||||
workspaces_expanded as workspaces
|
||||
WHERE
|
||||
@@ -22156,14 +21768,13 @@ func (q *sqlQuerier) GetWorkspaceByResourceID(ctx context.Context, resourceID uu
|
||||
&i.TemplateDisplayName,
|
||||
&i.TemplateIcon,
|
||||
&i.TemplateDescription,
|
||||
&i.TaskID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getWorkspaceByWorkspaceAppID = `-- name: GetWorkspaceByWorkspaceAppID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
|
||||
FROM
|
||||
workspaces_expanded as workspaces
|
||||
WHERE
|
||||
@@ -22231,7 +21842,6 @@ func (q *sqlQuerier) GetWorkspaceByWorkspaceAppID(ctx context.Context, workspace
|
||||
&i.TemplateDisplayName,
|
||||
&i.TemplateIcon,
|
||||
&i.TemplateDescription,
|
||||
&i.TaskID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -22281,7 +21891,7 @@ SELECT
|
||||
),
|
||||
filtered_workspaces AS (
|
||||
SELECT
|
||||
workspaces.id, workspaces.created_at, workspaces.updated_at, workspaces.owner_id, workspaces.organization_id, workspaces.template_id, workspaces.deleted, workspaces.name, workspaces.autostart_schedule, workspaces.ttl, workspaces.last_used_at, workspaces.dormant_at, workspaces.deleting_at, workspaces.automatic_updates, workspaces.favorite, workspaces.next_start_at, workspaces.group_acl, workspaces.user_acl, workspaces.owner_avatar_url, workspaces.owner_username, workspaces.owner_name, workspaces.organization_name, workspaces.organization_display_name, workspaces.organization_icon, workspaces.organization_description, workspaces.template_name, workspaces.template_display_name, workspaces.template_icon, workspaces.template_description, workspaces.task_id,
|
||||
workspaces.id, workspaces.created_at, workspaces.updated_at, workspaces.owner_id, workspaces.organization_id, workspaces.template_id, workspaces.deleted, workspaces.name, workspaces.autostart_schedule, workspaces.ttl, workspaces.last_used_at, workspaces.dormant_at, workspaces.deleting_at, workspaces.automatic_updates, workspaces.favorite, workspaces.next_start_at, workspaces.group_acl, workspaces.user_acl, workspaces.owner_avatar_url, workspaces.owner_username, workspaces.owner_name, workspaces.organization_name, workspaces.organization_display_name, workspaces.organization_icon, workspaces.organization_description, workspaces.template_name, workspaces.template_display_name, workspaces.template_icon, workspaces.template_description,
|
||||
latest_build.template_version_id,
|
||||
latest_build.template_version_name,
|
||||
latest_build.completed_at as latest_build_completed_at,
|
||||
@@ -22572,7 +22182,7 @@ WHERE
|
||||
-- @authorize_filter
|
||||
), filtered_workspaces_order AS (
|
||||
SELECT
|
||||
fw.id, fw.created_at, fw.updated_at, fw.owner_id, fw.organization_id, fw.template_id, fw.deleted, fw.name, fw.autostart_schedule, fw.ttl, fw.last_used_at, fw.dormant_at, fw.deleting_at, fw.automatic_updates, fw.favorite, fw.next_start_at, fw.group_acl, fw.user_acl, fw.owner_avatar_url, fw.owner_username, fw.owner_name, fw.organization_name, fw.organization_display_name, fw.organization_icon, fw.organization_description, fw.template_name, fw.template_display_name, fw.template_icon, fw.template_description, fw.task_id, fw.template_version_id, fw.template_version_name, fw.latest_build_completed_at, fw.latest_build_canceled_at, fw.latest_build_error, fw.latest_build_transition, fw.latest_build_status, fw.latest_build_has_ai_task, fw.latest_build_has_external_agent
|
||||
fw.id, fw.created_at, fw.updated_at, fw.owner_id, fw.organization_id, fw.template_id, fw.deleted, fw.name, fw.autostart_schedule, fw.ttl, fw.last_used_at, fw.dormant_at, fw.deleting_at, fw.automatic_updates, fw.favorite, fw.next_start_at, fw.group_acl, fw.user_acl, fw.owner_avatar_url, fw.owner_username, fw.owner_name, fw.organization_name, fw.organization_display_name, fw.organization_icon, fw.organization_description, fw.template_name, fw.template_display_name, fw.template_icon, fw.template_description, fw.template_version_id, fw.template_version_name, fw.latest_build_completed_at, fw.latest_build_canceled_at, fw.latest_build_error, fw.latest_build_transition, fw.latest_build_status, fw.latest_build_has_ai_task, fw.latest_build_has_external_agent
|
||||
FROM
|
||||
filtered_workspaces fw
|
||||
ORDER BY
|
||||
@@ -22593,7 +22203,7 @@ WHERE
|
||||
$25
|
||||
), filtered_workspaces_order_with_summary AS (
|
||||
SELECT
|
||||
fwo.id, fwo.created_at, fwo.updated_at, fwo.owner_id, fwo.organization_id, fwo.template_id, fwo.deleted, fwo.name, fwo.autostart_schedule, fwo.ttl, fwo.last_used_at, fwo.dormant_at, fwo.deleting_at, fwo.automatic_updates, fwo.favorite, fwo.next_start_at, fwo.group_acl, fwo.user_acl, fwo.owner_avatar_url, fwo.owner_username, fwo.owner_name, fwo.organization_name, fwo.organization_display_name, fwo.organization_icon, fwo.organization_description, fwo.template_name, fwo.template_display_name, fwo.template_icon, fwo.template_description, fwo.task_id, fwo.template_version_id, fwo.template_version_name, fwo.latest_build_completed_at, fwo.latest_build_canceled_at, fwo.latest_build_error, fwo.latest_build_transition, fwo.latest_build_status, fwo.latest_build_has_ai_task, fwo.latest_build_has_external_agent
|
||||
fwo.id, fwo.created_at, fwo.updated_at, fwo.owner_id, fwo.organization_id, fwo.template_id, fwo.deleted, fwo.name, fwo.autostart_schedule, fwo.ttl, fwo.last_used_at, fwo.dormant_at, fwo.deleting_at, fwo.automatic_updates, fwo.favorite, fwo.next_start_at, fwo.group_acl, fwo.user_acl, fwo.owner_avatar_url, fwo.owner_username, fwo.owner_name, fwo.organization_name, fwo.organization_display_name, fwo.organization_icon, fwo.organization_description, fwo.template_name, fwo.template_display_name, fwo.template_icon, fwo.template_description, fwo.template_version_id, fwo.template_version_name, fwo.latest_build_completed_at, fwo.latest_build_canceled_at, fwo.latest_build_error, fwo.latest_build_transition, fwo.latest_build_status, fwo.latest_build_has_ai_task, fwo.latest_build_has_external_agent
|
||||
FROM
|
||||
filtered_workspaces_order fwo
|
||||
-- Return a technical summary row with total count of workspaces.
|
||||
@@ -22629,7 +22239,6 @@ WHERE
|
||||
'', -- template_display_name
|
||||
'', -- template_icon
|
||||
'', -- template_description
|
||||
'00000000-0000-0000-0000-000000000000'::uuid, -- task_id
|
||||
-- Extra columns added to ` + "`" + `filtered_workspaces` + "`" + `
|
||||
'00000000-0000-0000-0000-000000000000'::uuid, -- template_version_id
|
||||
'', -- template_version_name
|
||||
@@ -22649,7 +22258,7 @@ WHERE
|
||||
filtered_workspaces
|
||||
)
|
||||
SELECT
|
||||
fwos.id, fwos.created_at, fwos.updated_at, fwos.owner_id, fwos.organization_id, fwos.template_id, fwos.deleted, fwos.name, fwos.autostart_schedule, fwos.ttl, fwos.last_used_at, fwos.dormant_at, fwos.deleting_at, fwos.automatic_updates, fwos.favorite, fwos.next_start_at, fwos.group_acl, fwos.user_acl, fwos.owner_avatar_url, fwos.owner_username, fwos.owner_name, fwos.organization_name, fwos.organization_display_name, fwos.organization_icon, fwos.organization_description, fwos.template_name, fwos.template_display_name, fwos.template_icon, fwos.template_description, fwos.task_id, fwos.template_version_id, fwos.template_version_name, fwos.latest_build_completed_at, fwos.latest_build_canceled_at, fwos.latest_build_error, fwos.latest_build_transition, fwos.latest_build_status, fwos.latest_build_has_ai_task, fwos.latest_build_has_external_agent,
|
||||
fwos.id, fwos.created_at, fwos.updated_at, fwos.owner_id, fwos.organization_id, fwos.template_id, fwos.deleted, fwos.name, fwos.autostart_schedule, fwos.ttl, fwos.last_used_at, fwos.dormant_at, fwos.deleting_at, fwos.automatic_updates, fwos.favorite, fwos.next_start_at, fwos.group_acl, fwos.user_acl, fwos.owner_avatar_url, fwos.owner_username, fwos.owner_name, fwos.organization_name, fwos.organization_display_name, fwos.organization_icon, fwos.organization_description, fwos.template_name, fwos.template_display_name, fwos.template_icon, fwos.template_description, fwos.template_version_id, fwos.template_version_name, fwos.latest_build_completed_at, fwos.latest_build_canceled_at, fwos.latest_build_error, fwos.latest_build_transition, fwos.latest_build_status, fwos.latest_build_has_ai_task, fwos.latest_build_has_external_agent,
|
||||
tc.count
|
||||
FROM
|
||||
filtered_workspaces_order_with_summary fwos
|
||||
@@ -22717,7 +22326,6 @@ type GetWorkspacesRow struct {
|
||||
TemplateDisplayName string `db:"template_display_name" json:"template_display_name"`
|
||||
TemplateIcon string `db:"template_icon" json:"template_icon"`
|
||||
TemplateDescription string `db:"template_description" json:"template_description"`
|
||||
TaskID uuid.NullUUID `db:"task_id" json:"task_id"`
|
||||
TemplateVersionID uuid.UUID `db:"template_version_id" json:"template_version_id"`
|
||||
TemplateVersionName sql.NullString `db:"template_version_name" json:"template_version_name"`
|
||||
LatestBuildCompletedAt sql.NullTime `db:"latest_build_completed_at" json:"latest_build_completed_at"`
|
||||
@@ -22800,7 +22408,6 @@ func (q *sqlQuerier) GetWorkspaces(ctx context.Context, arg GetWorkspacesParams)
|
||||
&i.TemplateDisplayName,
|
||||
&i.TemplateIcon,
|
||||
&i.TemplateDescription,
|
||||
&i.TaskID,
|
||||
&i.TemplateVersionID,
|
||||
&i.TemplateVersionName,
|
||||
&i.LatestBuildCompletedAt,
|
||||
@@ -23531,7 +23138,6 @@ SET
|
||||
WHERE
|
||||
template_id = $3
|
||||
AND dormant_at IS NOT NULL
|
||||
AND deleted = false
|
||||
-- Prebuilt workspaces (identified by having the prebuilds system user as owner_id)
|
||||
-- should not have their dormant or deleting at set, as these are handled by the
|
||||
-- prebuilds reconciliation loop.
|
||||
|
||||
@@ -6,14 +6,6 @@ INSERT INTO aibridge_interceptions (
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateAIBridgeInterceptionEnded :one
|
||||
UPDATE aibridge_interceptions
|
||||
SET ended_at = @ended_at::timestamptz
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
AND ended_at IS NULL
|
||||
RETURNING *;
|
||||
|
||||
-- name: InsertAIBridgeTokenUsage :one
|
||||
INSERT INTO aibridge_token_usages (
|
||||
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
|
||||
@@ -207,122 +199,3 @@ WHERE
|
||||
ORDER BY
|
||||
created_at ASC,
|
||||
id ASC;
|
||||
|
||||
-- name: ListAIBridgeInterceptionsTelemetrySummaries :many
|
||||
-- Finds all unique AIBridge interception telemetry summaries combinations
|
||||
-- (provider, model, client) in the given timeframe for telemetry reporting.
|
||||
SELECT
|
||||
DISTINCT ON (provider, model, client)
|
||||
provider,
|
||||
model,
|
||||
-- TODO: use the client value once we have it (see https://github.com/coder/aibridge/issues/31)
|
||||
'unknown' AS client
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
ended_at IS NOT NULL -- incomplete interceptions are not included in summaries
|
||||
AND ended_at >= @ended_at_after::timestamptz
|
||||
AND ended_at < @ended_at_before::timestamptz;
|
||||
|
||||
-- name: CalculateAIBridgeInterceptionsTelemetrySummary :one
|
||||
-- Calculates the telemetry summary for a given provider, model, and client
|
||||
-- combination for telemetry reporting.
|
||||
WITH interceptions_in_range AS (
|
||||
-- Get all matching interceptions in the given timeframe.
|
||||
SELECT
|
||||
id,
|
||||
initiator_id,
|
||||
(ended_at - started_at) AS duration
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
provider = @provider::text
|
||||
AND model = @model::text
|
||||
-- TODO: use the client value once we have it (see https://github.com/coder/aibridge/issues/31)
|
||||
AND 'unknown' = @client::text
|
||||
AND ended_at IS NOT NULL -- incomplete interceptions are not included in summaries
|
||||
AND ended_at >= @ended_at_after::timestamptz
|
||||
AND ended_at < @ended_at_before::timestamptz
|
||||
),
|
||||
interception_counts AS (
|
||||
SELECT
|
||||
COUNT(id) AS interception_count,
|
||||
COUNT(DISTINCT initiator_id) AS unique_initiator_count
|
||||
FROM
|
||||
interceptions_in_range
|
||||
),
|
||||
duration_percentiles AS (
|
||||
SELECT
|
||||
(COALESCE(PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p50_millis,
|
||||
(COALESCE(PERCENTILE_CONT(0.90) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p90_millis,
|
||||
(COALESCE(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p95_millis,
|
||||
(COALESCE(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p99_millis
|
||||
FROM
|
||||
interceptions_in_range
|
||||
),
|
||||
token_aggregates AS (
|
||||
SELECT
|
||||
COALESCE(SUM(tu.input_tokens), 0) AS token_count_input,
|
||||
COALESCE(SUM(tu.output_tokens), 0) AS token_count_output,
|
||||
-- Cached tokens are stored in metadata JSON, extract if available.
|
||||
-- Read tokens may be stored in:
|
||||
-- - cache_read_input (Anthropic)
|
||||
-- - prompt_cached (OpenAI)
|
||||
COALESCE(SUM(
|
||||
COALESCE((tu.metadata->>'cache_read_input')::bigint, 0) +
|
||||
COALESCE((tu.metadata->>'prompt_cached')::bigint, 0)
|
||||
), 0) AS token_count_cached_read,
|
||||
-- Written tokens may be stored in:
|
||||
-- - cache_creation_input (Anthropic)
|
||||
-- Note that cache_ephemeral_5m_input and cache_ephemeral_1h_input on
|
||||
-- Anthropic are included in the cache_creation_input field.
|
||||
COALESCE(SUM(
|
||||
COALESCE((tu.metadata->>'cache_creation_input')::bigint, 0)
|
||||
), 0) AS token_count_cached_written,
|
||||
COUNT(tu.id) AS token_usages_count
|
||||
FROM
|
||||
interceptions_in_range i
|
||||
LEFT JOIN
|
||||
aibridge_token_usages tu ON i.id = tu.interception_id
|
||||
),
|
||||
prompt_aggregates AS (
|
||||
SELECT
|
||||
COUNT(up.id) AS user_prompts_count
|
||||
FROM
|
||||
interceptions_in_range i
|
||||
LEFT JOIN
|
||||
aibridge_user_prompts up ON i.id = up.interception_id
|
||||
),
|
||||
tool_aggregates AS (
|
||||
SELECT
|
||||
COUNT(tu.id) FILTER (WHERE tu.injected = true) AS tool_calls_count_injected,
|
||||
COUNT(tu.id) FILTER (WHERE tu.injected = false) AS tool_calls_count_non_injected,
|
||||
COUNT(tu.id) FILTER (WHERE tu.injected = true AND tu.invocation_error IS NOT NULL) AS injected_tool_call_error_count
|
||||
FROM
|
||||
interceptions_in_range i
|
||||
LEFT JOIN
|
||||
aibridge_tool_usages tu ON i.id = tu.interception_id
|
||||
)
|
||||
SELECT
|
||||
ic.interception_count::bigint AS interception_count,
|
||||
dp.interception_duration_p50_millis::bigint AS interception_duration_p50_millis,
|
||||
dp.interception_duration_p90_millis::bigint AS interception_duration_p90_millis,
|
||||
dp.interception_duration_p95_millis::bigint AS interception_duration_p95_millis,
|
||||
dp.interception_duration_p99_millis::bigint AS interception_duration_p99_millis,
|
||||
ic.unique_initiator_count::bigint AS unique_initiator_count,
|
||||
pa.user_prompts_count::bigint AS user_prompts_count,
|
||||
tok_agg.token_usages_count::bigint AS token_usages_count,
|
||||
tok_agg.token_count_input::bigint AS token_count_input,
|
||||
tok_agg.token_count_output::bigint AS token_count_output,
|
||||
tok_agg.token_count_cached_read::bigint AS token_count_cached_read,
|
||||
tok_agg.token_count_cached_written::bigint AS token_count_cached_written,
|
||||
tool_agg.tool_calls_count_injected::bigint AS tool_calls_count_injected,
|
||||
tool_agg.tool_calls_count_non_injected::bigint AS tool_calls_count_non_injected,
|
||||
tool_agg.injected_tool_call_error_count::bigint AS injected_tool_call_error_count
|
||||
FROM
|
||||
interception_counts ic,
|
||||
duration_percentiles dp,
|
||||
token_aggregates tok_agg,
|
||||
prompt_aggregates pa,
|
||||
tool_aggregates tool_agg
|
||||
;
|
||||
|
||||
@@ -300,8 +300,12 @@ GROUP BY wpb.template_version_preset_id;
|
||||
-- Cancels all pending provisioner jobs for prebuilt workspaces on a specific preset from an
|
||||
-- inactive template version.
|
||||
-- This is an optimization to clean up stale pending jobs.
|
||||
WITH jobs_to_cancel AS (
|
||||
SELECT pj.id, w.id AS workspace_id, w.template_id, wpb.template_version_preset_id
|
||||
UPDATE provisioner_jobs
|
||||
SET
|
||||
canceled_at = @now::timestamptz,
|
||||
completed_at = @now::timestamptz
|
||||
WHERE id IN (
|
||||
SELECT pj.id
|
||||
FROM provisioner_jobs pj
|
||||
INNER JOIN workspace_prebuild_builds wpb ON wpb.job_id = pj.id
|
||||
INNER JOIN workspaces w ON w.id = wpb.workspace_id
|
||||
@@ -320,54 +324,4 @@ WITH jobs_to_cancel AS (
|
||||
AND pj.canceled_at IS NULL
|
||||
AND pj.completed_at IS NULL
|
||||
)
|
||||
UPDATE provisioner_jobs
|
||||
SET
|
||||
canceled_at = @now::timestamptz,
|
||||
completed_at = @now::timestamptz
|
||||
FROM jobs_to_cancel
|
||||
WHERE provisioner_jobs.id = jobs_to_cancel.id
|
||||
RETURNING jobs_to_cancel.id, jobs_to_cancel.workspace_id, jobs_to_cancel.template_id, jobs_to_cancel.template_version_preset_id;
|
||||
|
||||
-- name: GetOrganizationsWithPrebuildStatus :many
|
||||
-- GetOrganizationsWithPrebuildStatus returns organizations with prebuilds configured and their
|
||||
-- membership status for the prebuilds system user (org membership, group existence, group membership).
|
||||
WITH orgs_with_prebuilds AS (
|
||||
-- Get unique organizations that have presets with prebuilds configured
|
||||
SELECT DISTINCT o.id, o.name
|
||||
FROM organizations o
|
||||
INNER JOIN templates t ON t.organization_id = o.id
|
||||
INNER JOIN template_versions tv ON tv.template_id = t.id
|
||||
INNER JOIN template_version_presets tvp ON tvp.template_version_id = tv.id
|
||||
WHERE tvp.desired_instances IS NOT NULL
|
||||
),
|
||||
prebuild_user_membership AS (
|
||||
-- Check if the user is a member of the organizations
|
||||
SELECT om.organization_id
|
||||
FROM organization_members om
|
||||
INNER JOIN orgs_with_prebuilds owp ON owp.id = om.organization_id
|
||||
WHERE om.user_id = @user_id::uuid
|
||||
),
|
||||
prebuild_groups AS (
|
||||
-- Check if the organizations have the prebuilds group
|
||||
SELECT g.organization_id, g.id as group_id
|
||||
FROM groups g
|
||||
INNER JOIN orgs_with_prebuilds owp ON owp.id = g.organization_id
|
||||
WHERE g.name = @group_name::text
|
||||
),
|
||||
prebuild_group_membership AS (
|
||||
-- Check if the user is in the prebuilds group
|
||||
SELECT pg.organization_id
|
||||
FROM prebuild_groups pg
|
||||
INNER JOIN group_members gm ON gm.group_id = pg.group_id
|
||||
WHERE gm.user_id = @user_id::uuid
|
||||
)
|
||||
SELECT
|
||||
owp.id AS organization_id,
|
||||
owp.name AS organization_name,
|
||||
(pum.organization_id IS NOT NULL)::boolean AS has_prebuild_user,
|
||||
pg.group_id AS prebuilds_group_id,
|
||||
(pgm.organization_id IS NOT NULL)::boolean AS has_prebuild_user_in_group
|
||||
FROM orgs_with_prebuilds owp
|
||||
LEFT JOIN prebuild_groups pg ON pg.organization_id = owp.id
|
||||
LEFT JOIN prebuild_user_membership pum ON pum.organization_id = owp.id
|
||||
LEFT JOIN prebuild_group_membership pgm ON pgm.organization_id = owp.id;
|
||||
RETURNING id;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
INSERT INTO tasks
|
||||
(id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
(gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateTaskWorkspaceID :one
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
-- name: InsertTelemetryLock :exec
|
||||
-- Inserts a new lock row into the telemetry_locks table. Replicas should call
|
||||
-- this function prior to attempting to generate or publish a heartbeat event to
|
||||
-- the telemetry service.
|
||||
-- If the query returns a duplicate primary key error, the replica should not
|
||||
-- attempt to generate or publish the event to the telemetry service.
|
||||
INSERT INTO
|
||||
telemetry_locks (event_type, period_ending_at)
|
||||
VALUES
|
||||
($1, $2);
|
||||
|
||||
-- name: DeleteOldTelemetryLocks :exec
|
||||
-- Deletes old telemetry locks from the telemetry_locks table.
|
||||
DELETE FROM
|
||||
telemetry_locks
|
||||
WHERE
|
||||
period_ending_at < @period_ending_at_before::timestamptz;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user