Compare commits

..

18 Commits

Author SHA1 Message Date
Sas Swart 9764926f92 remove defunct test file 2025-10-30 14:26:47 +00:00
Sas Swart 10d4e42fc1 remove defunct files 2025-10-30 13:37:00 +00:00
Sas Swart 217ddf46c4 fix an incomplete refactor 2025-10-30 13:35:49 +00:00
Sas Swart 0d3d493eae fix an incomplete refactor 2025-10-30 13:28:38 +00:00
Sas Swart 89b060e245 hide functions that do not need to be public 2025-10-30 13:19:00 +00:00
Sas Swart 820d53b66a streamline agentsocket server initialization 2025-10-30 12:55:55 +00:00
Sas Swart f550028052 Move unit statuses to the appropriate package 2025-10-30 12:23:54 +00:00
Sas Swart e6873c8d61 rename dependency_tracker.go to manager.go 2025-10-30 12:21:07 +00:00
Sas Swart 8c0bfcb570 Improve agentsocket rpc naming and documentation 2025-10-30 12:17:27 +00:00
Sas Swart c322b92ab0 remove agent socket auth for now 2025-10-30 12:02:48 +00:00
Sas Swart 216a5ac562 document initSocketServer and tweak its log levels 2025-10-30 11:49:54 +00:00
Sas Swart 86447126d5 make the agent socket path configurable 2025-10-30 11:45:12 +00:00
Sas Swart 55c5b707fb Rename unit.DependencyTracker to unit.Manager 2025-10-30 11:33:20 +00:00
Sas Swart 4616c82f3c switch agent socket to drpc. factor components and add tests 2025-10-30 09:01:17 +00:00
Sas Swart 9ca30e28d6 add a prototype cli command that uses the agent socket 2025-10-28 08:27:25 +00:00
Sas Swart 34c1370090 fix agent socket tests 2025-10-28 06:30:29 +00:00
Sas Swart 851c4f907c add a socket to the agent for local IPC 2025-10-28 06:26:49 +00:00
Sas Swart e3dfe45f35 LLM generated implementation of unit status change communication 2025-10-27 11:10:22 +00:00
250 changed files with 8679 additions and 9073 deletions
+1 -1
View File
@@ -4,7 +4,7 @@ description: |
inputs:
version:
description: "The Go version to use."
default: "1.24.10"
default: "1.24.6"
use-preinstalled-go:
description: "Whether to use preinstalled Go."
default: "false"
@@ -0,0 +1,34 @@
app = "sao-paulo-coder"
primary_region = "gru"
[experimental]
entrypoint = ["/bin/sh", "-c", "CODER_DERP_SERVER_RELAY_URL=\"http://[${FLY_PRIVATE_IP}]:3000\" /opt/coder wsproxy server"]
auto_rollback = true
[build]
image = "ghcr.io/coder/coder-preview:main"
[env]
CODER_ACCESS_URL = "https://sao-paulo.fly.dev.coder.com"
CODER_HTTP_ADDRESS = "0.0.0.0:3000"
CODER_PRIMARY_ACCESS_URL = "https://dev.coder.com"
CODER_WILDCARD_ACCESS_URL = "*--apps.sao-paulo.fly.dev.coder.com"
CODER_VERBOSE = "true"
[http_service]
internal_port = 3000
force_https = true
auto_stop_machines = true
auto_start_machines = true
min_machines_running = 0
# Ref: https://fly.io/docs/reference/configuration/#http_service-concurrency
[http_service.concurrency]
type = "requests"
soft_limit = 50
hard_limit = 100
[[vm]]
cpu_kind = "shared"
cpus = 2
memory_mb = 512
+23 -9
View File
@@ -376,6 +376,13 @@ jobs:
id: go-paths
uses: ./.github/actions/setup-go-paths
- name: Download Go Build Cache
id: download-go-build-cache
uses: ./.github/actions/test-cache/download
with:
key-prefix: test-go-build-${{ runner.os }}-${{ runner.arch }}
cache-path: ${{ steps.go-paths.outputs.cached-dirs }}
- name: Setup Go
uses: ./.github/actions/setup-go
with:
@@ -383,7 +390,8 @@ jobs:
# download the toolchain configured in go.mod, so we don't
# need to reinstall it. It's faster on Windows runners.
use-preinstalled-go: ${{ runner.os == 'Windows' }}
use-cache: true
# Cache is already downloaded above
use-cache: false
- name: Setup Terraform
uses: ./.github/actions/setup-tf
@@ -492,11 +500,17 @@ jobs:
make test
- name: Upload failed test db dumps
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: failed-test-db-dump-${{matrix.os}}
path: "**/*.test.sql"
- name: Upload Go Build Cache
uses: ./.github/actions/test-cache/upload
with:
cache-key: ${{ steps.download-go-build-cache.outputs.cache-key }}
cache-path: ${{ steps.go-paths.outputs.cached-dirs }}
- name: Upload Test Cache
uses: ./.github/actions/test-cache/upload
with:
@@ -748,7 +762,7 @@ jobs:
- name: Upload Playwright Failed Tests
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: failed-test-videos${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/test-results/**/*.webm
@@ -756,7 +770,7 @@ jobs:
- name: Upload pprof dumps
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: debug-pprof-dumps${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/test-results/**/debug-pprof-*.txt
@@ -792,7 +806,7 @@ jobs:
# the check to pass. This is desired in PRs, but not in mainline.
- name: Publish to Chromatic (non-mainline)
if: github.ref != 'refs/heads/main' && github.repository_owner == 'coder'
uses: chromaui/action@bc2d84ad2b60813a67d995c5582d696104a19383 # v13.3.2
uses: chromaui/action@4ffe736a2a8262ea28067ff05a13b635ba31ec05 # v13.3.0
env:
NODE_OPTIONS: "--max_old_space_size=4096"
STORYBOOK: true
@@ -824,7 +838,7 @@ jobs:
# infinitely "in progress" in mainline unless we re-review each build.
- name: Publish to Chromatic (mainline)
if: github.ref == 'refs/heads/main' && github.repository_owner == 'coder'
uses: chromaui/action@bc2d84ad2b60813a67d995c5582d696104a19383 # v13.3.2
uses: chromaui/action@4ffe736a2a8262ea28067ff05a13b635ba31ec05 # v13.3.0
env:
NODE_OPTIONS: "--max_old_space_size=4096"
STORYBOOK: true
@@ -1022,7 +1036,7 @@ jobs:
- name: Upload build artifacts
if: ${{ github.repository_owner == 'coder' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: dylibs
path: |
@@ -1187,7 +1201,7 @@ jobs:
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
- name: Download dylibs
uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
with:
name: dylibs
path: ./build
@@ -1454,7 +1468,7 @@ jobs:
- name: Upload build artifacts
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: coder
path: |
+2
View File
@@ -163,10 +163,12 @@ jobs:
run: |
flyctl deploy --image "$IMAGE" --app paris-coder --config ./.github/fly-wsproxies/paris-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_PARIS" --yes
flyctl deploy --image "$IMAGE" --app sydney-coder --config ./.github/fly-wsproxies/sydney-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_SYDNEY" --yes
flyctl deploy --image "$IMAGE" --app sao-paulo-coder --config ./.github/fly-wsproxies/sao-paulo-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_SAO_PAULO" --yes
flyctl deploy --image "$IMAGE" --app jnb-coder --config ./.github/fly-wsproxies/jnb-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_JNB" --yes
env:
FLY_API_TOKEN: ${{ secrets.FLY_API_TOKEN }}
IMAGE: ${{ inputs.image }}
TOKEN_PARIS: ${{ secrets.FLY_PARIS_CODER_PROXY_SESSION_TOKEN }}
TOKEN_SYDNEY: ${{ secrets.FLY_SYDNEY_CODER_PROXY_SESSION_TOKEN }}
TOKEN_SAO_PAULO: ${{ secrets.FLY_SAO_PAULO_CODER_PROXY_SESSION_TOKEN }}
TOKEN_JNB: ${{ secrets.FLY_JNB_CODER_PROXY_SESSION_TOKEN }}
+1 -1
View File
@@ -30,7 +30,7 @@ jobs:
- name: Setup Node
uses: ./.github/actions/setup-node
- uses: tj-actions/changed-files@dbf178ceecb9304128c8e0648591d71208c6e2c9 # v45.0.7
- uses: tj-actions/changed-files@d03a93c0dbfac6d6dd6a0d8a5e7daff992b07449 # v45.0.7
id: changed-files
with:
files: |
+2 -2
View File
@@ -36,11 +36,11 @@ jobs:
persist-credentials: false
- name: Setup Nix
uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34
uses: nixbuild/nix-quick-install-action@1f095fee853b33114486cfdeae62fa099cda35a9 # v33
with:
# Pinning to 2.28 here, as Nix gets a "error: [json.exception.type_error.302] type must be array, but is string"
# on version 2.29 and above.
nix_version: "2.28.5"
nix_version: "2.28.4"
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
with:
+4 -4
View File
@@ -131,7 +131,7 @@ jobs:
AC_CERTIFICATE_PASSWORD_FILE: /tmp/apple_cert_password.txt
- name: Upload build artifacts
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: dylibs
path: |
@@ -327,7 +327,7 @@ jobs:
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
- name: Download dylibs
uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
with:
name: dylibs
path: ./build
@@ -761,7 +761,7 @@ jobs:
- name: Upload artifacts to actions (if dry-run)
if: ${{ inputs.dry_run }}
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: release-artifacts
path: |
@@ -777,7 +777,7 @@ jobs:
- name: Upload latest sbom artifact to actions (if dry-run)
if: inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true'
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: latest-sbom-artifact
path: ./coder_latest_sbom.spdx.json
+2 -2
View File
@@ -39,7 +39,7 @@ jobs:
# Upload the results as artifacts.
- name: "Upload artifact"
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: SARIF file
path: results.sarif
@@ -47,6 +47,6 @@ jobs:
# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
uses: github/codeql-action/upload-sarif@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
with:
sarif_file: results.sarif
+4 -4
View File
@@ -40,7 +40,7 @@ jobs:
uses: ./.github/actions/setup-go
- name: Initialize CodeQL
uses: github/codeql-action/init@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
uses: github/codeql-action/init@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
with:
languages: go, javascript
@@ -50,7 +50,7 @@ jobs:
rm Makefile
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
uses: github/codeql-action/analyze@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
- name: Send Slack notification on failure
if: ${{ failure() }}
@@ -154,13 +154,13 @@ jobs:
severity: "CRITICAL,HIGH"
- name: Upload Trivy scan results to GitHub Security tab
uses: github/codeql-action/upload-sarif@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
uses: github/codeql-action/upload-sarif@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
with:
sarif_file: trivy-results.sarif
category: "Trivy"
- name: Upload Trivy scan results as an artifact
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: trivy
path: trivy-results.sarif
+2 -2
View File
@@ -125,7 +125,7 @@ jobs:
egress-policy: audit
- name: Delete PR Cleanup workflow runs
uses: Mattraks/delete-workflow-runs@ab482449ba468316e9a8801e092d0405715c5e6d # v2.1.0
uses: Mattraks/delete-workflow-runs@39f0bbed25d76b34de5594dceab824811479e5de # v2.0.6
with:
token: ${{ github.token }}
repository: ${{ github.repository }}
@@ -134,7 +134,7 @@ jobs:
delete_workflow_pattern: pr-cleanup.yaml
- name: Delete PR Deploy workflow skipped runs
uses: Mattraks/delete-workflow-runs@ab482449ba468316e9a8801e092d0405715c5e6d # v2.1.0
uses: Mattraks/delete-workflow-runs@39f0bbed25d76b34de5594dceab824811479e5de # v2.0.6
with:
token: ${{ github.token }}
repository: ${{ github.repository }}
-2
View File
@@ -89,5 +89,3 @@ result
__debug_bin*
**/.claude/settings.local.json
/.env
+12
View File
@@ -18,6 +18,18 @@ coderd/rbac/ @Emyrk
scripts/apitypings/ @Emyrk
scripts/gensite/ @aslilac
site/ @aslilac @Parkreiner
site/src/hooks/ @Parkreiner
# These rules intentionally do not specify any owners. More specific rules
# override less specific rules, so these files are "ignored" by the site/ rule.
site/e2e/google/protobuf/timestampGenerated.ts
site/e2e/provisionerGenerated.ts
site/src/api/countriesGenerated.ts
site/src/api/rbacresourcesGenerated.ts
site/src/api/typesGenerated.ts
site/src/testHelpers/entities.ts
site/CLAUDE.md
# The blood and guts of the autostop algorithm, which is quite complex and
# requires elite ball knowledge of most of the scheduling code to make changes
# without inadvertently affecting other parts of the codebase.
+17 -13
View File
@@ -636,16 +636,17 @@ TAILNETTEST_MOCKS := \
tailnet/tailnettest/subscriptionmock.go
AIBRIDGED_MOCKS := \
enterprise/aibridged/aibridgedmock/clientmock.go \
enterprise/aibridged/aibridgedmock/poolmock.go
enterprise/x/aibridged/aibridgedmock/clientmock.go \
enterprise/x/aibridged/aibridgedmock/poolmock.go
GEN_FILES := \
tailnet/proto/tailnet.pb.go \
agent/proto/agent.pb.go \
agent/agentsocket/proto/agentsocket.pb.go \
provisionersdk/proto/provisioner.pb.go \
provisionerd/proto/provisionerd.pb.go \
vpn/vpn.pb.go \
enterprise/aibridged/proto/aibridged.pb.go \
enterprise/x/aibridged/proto/aibridged.pb.go \
$(DB_GEN_FILES) \
$(SITE_GEN_FILES) \
coderd/rbac/object_gen.go \
@@ -697,7 +698,7 @@ gen/mark-fresh:
provisionersdk/proto/provisioner.pb.go \
provisionerd/proto/provisionerd.pb.go \
vpn/vpn.pb.go \
enterprise/aibridged/proto/aibridged.pb.go \
enterprise/x/aibridged/proto/aibridged.pb.go \
coderd/database/dump.sql \
$(DB_GEN_FILES) \
site/src/api/typesGenerated.ts \
@@ -768,8 +769,8 @@ codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agen
go generate ./codersdk/workspacesdk/agentconnmock/
touch "$@"
$(AIBRIDGED_MOCKS): enterprise/aibridged/client.go enterprise/aibridged/pool.go
go generate ./enterprise/aibridged/aibridgedmock/
$(AIBRIDGED_MOCKS): enterprise/x/aibridged/client.go enterprise/x/aibridged/pool.go
go generate ./enterprise/x/aibridged/aibridgedmock/
touch "$@"
agent/agentcontainers/dcspec/dcspec_gen.go: \
@@ -800,6 +801,14 @@ agent/proto/agent.pb.go: agent/proto/agent.proto
--go-drpc_opt=paths=source_relative \
./agent/proto/agent.proto
agent/agentsocket/proto/agentsocket.pb.go: agent/agentsocket/proto/agentsocket.proto
protoc \
--go_out=. \
--go_opt=paths=source_relative \
--go-drpc_out=. \
--go-drpc_opt=paths=source_relative \
./agent/agentsocket/proto/agentsocket.proto
provisionersdk/proto/provisioner.pb.go: provisionersdk/proto/provisioner.proto
protoc \
--go_out=. \
@@ -822,13 +831,13 @@ vpn/vpn.pb.go: vpn/vpn.proto
--go_opt=paths=source_relative \
./vpn/vpn.proto
enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged.proto
enterprise/x/aibridged/proto/aibridged.pb.go: enterprise/x/aibridged/proto/aibridged.proto
protoc \
--go_out=. \
--go_opt=paths=source_relative \
--go-drpc_out=. \
--go-drpc_opt=paths=source_relative \
./enterprise/aibridged/proto/aibridged.proto
./enterprise/x/aibridged/proto/aibridged.proto
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
# -C sets the directory for the go run command
@@ -1182,8 +1191,3 @@ endif
dogfood/coder/nix.hash: flake.nix flake.lock
sha256sum flake.nix flake.lock >./dogfood/coder/nix.hash
# Count the number of test databases created per test package.
count-test-databases:
PGPASSWORD=postgres psql -h localhost -U postgres -d coder_testing -P pager=off -c 'SELECT test_package, count(*) as count from test_databases GROUP BY test_package ORDER BY count DESC'
.PHONY: count-test-databases
+40 -1
View File
@@ -40,6 +40,7 @@ import (
"github.com/coder/coder/v2/agent/agentcontainers"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentscripts"
"github.com/coder/coder/v2/agent/agentsocket"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/agent/proto/resourcesmonitor"
@@ -91,6 +92,7 @@ type Options struct {
Devcontainers bool
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
Clock quartz.Clock
SocketPath string // Path for the agent socket server
}
type Client interface {
@@ -190,6 +192,7 @@ func New(options Options) Agent {
devcontainers: options.Devcontainers,
containerAPIOptions: options.DevcontainerAPIOptions,
socketPath: options.SocketPath,
}
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
@@ -271,6 +274,9 @@ type agent struct {
devcontainers bool
containerAPIOptions []agentcontainers.Option
containerAPI *agentcontainers.API
socketPath string
socketServer *agentsocket.Server
}
func (a *agent) TailnetConn() *tailnet.Conn {
@@ -350,9 +356,35 @@ func (a *agent) init() {
s.ExperimentalContainers = a.devcontainers
},
)
a.initSocketServer()
go a.runLoop()
}
// initSocketServer initializes server that allows direct communication with a workspace agent using IPC.
func (a *agent) initSocketServer() {
if a.socketPath == "" {
a.logger.Info(a.hardCtx, "socket server disabled (no path configured)")
return
}
server, err := agentsocket.NewServer(a.socketPath, a.logger.Named("socket"))
if err != nil {
a.logger.Warn(a.hardCtx, "failed to create socket server", slog.Error(err))
return
}
err = server.Start()
if err != nil {
a.logger.Warn(a.hardCtx, "failed to start socket server", slog.Error(err))
return
}
a.socketServer = server
a.logger.Debug(a.hardCtx, "socket server started", slog.F("path", a.socketPath))
}
// runLoop attempts to start the agent in a retry loop.
// Coder may be offline temporarily, a connection issue
// may be happening, but regardless after the intermittent
@@ -1087,7 +1119,7 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context,
if err != nil {
return xerrors.Errorf("fetch metadata: %w", err)
}
a.logger.Info(ctx, "fetched manifest")
a.logger.Info(ctx, "fetched manifest", slog.F("manifest", mp))
manifest, err := agentsdk.ManifestFromProto(mp)
if err != nil {
a.logger.Critical(ctx, "failed to convert manifest", slog.F("manifest", mp), slog.Error(err))
@@ -1920,6 +1952,13 @@ func (a *agent) Close() error {
lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownError
}
}
if a.socketServer != nil {
if err := a.socketServer.Stop(); err != nil {
a.logger.Error(a.hardCtx, "socket server close", slog.Error(err))
}
}
a.setLifecycle(lifecycleState)
err = a.scriptRunner.Close()
+2
View File
@@ -682,6 +682,8 @@ func (api *API) updaterLoop() {
} else {
prevErr = nil
}
default:
api.logger.Debug(api.ctx, "updater loop ticker skipped, update in progress")
}
return nil // Always nil to keep the ticker going.
File diff suppressed because it is too large Load Diff
+88
View File
@@ -0,0 +1,88 @@
syntax = "proto3";
option go_package = "github.com/coder/coder/v2/agent/agentsocket/proto";
package coder.agentsocket.v1;
import "google/protobuf/timestamp.proto";
message PingRequest {}
message PingResponse {
string message = 1;
google.protobuf.Timestamp timestamp = 2;
}
message SyncStartRequest {
string unit = 1;
}
message SyncStartResponse {
bool success = 1;
string message = 2;
}
message SyncWantRequest {
string unit = 1;
string depends_on = 2;
}
message SyncWantResponse {
bool success = 1;
string message = 2;
}
message SyncCompleteRequest {
string unit = 1;
}
message SyncCompleteResponse {
bool success = 1;
string message = 2;
}
message SyncReadyRequest {
string unit = 1;
}
message SyncReadyResponse {
bool success = 1;
string message = 2;
}
message SyncStatusRequest {
string unit = 1;
bool recursive = 2;
}
message DependencyInfo {
string depends_on = 1;
string required_status = 2;
string current_status = 3;
bool is_satisfied = 4;
}
message SyncStatusResponse {
bool success = 1;
string message = 2;
string unit = 3;
string status = 4;
bool is_ready = 5;
repeated DependencyInfo dependencies = 6;
string dot = 7;
}
// AgentSocket provides direct access to the agent over local IPC.
service AgentSocket {
// Ping the agent to check if it is alive.
rpc Ping(PingRequest) returns (PingResponse);
// Report the start of a unit.
rpc SyncStart(SyncStartRequest) returns (SyncStartResponse);
// Declare a dependency between units.
rpc SyncWant(SyncWantRequest) returns (SyncWantResponse);
// Report the completion of a unit.
rpc SyncComplete(SyncCompleteRequest) returns (SyncCompleteResponse);
// Request whether a unit is ready to be started. That is, all dependencies are satisfied.
rpc SyncReady(SyncReadyRequest) returns (SyncReadyResponse);
// Get the status of a unit and list its dependencies.
rpc SyncStatus(SyncStatusRequest) returns (SyncStatusResponse);
}
@@ -0,0 +1,311 @@
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
// protoc-gen-go-drpc version: v0.0.34
// source: agent/agentsocket/proto/agentsocket.proto
package proto
import (
context "context"
errors "errors"
protojson "google.golang.org/protobuf/encoding/protojson"
proto "google.golang.org/protobuf/proto"
drpc "storj.io/drpc"
drpcerr "storj.io/drpc/drpcerr"
)
type drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto struct{}
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) Marshal(msg drpc.Message) ([]byte, error) {
return proto.Marshal(msg.(proto.Message))
}
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) MarshalAppend(buf []byte, msg drpc.Message) ([]byte, error) {
return proto.MarshalOptions{}.MarshalAppend(buf, msg.(proto.Message))
}
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) Unmarshal(buf []byte, msg drpc.Message) error {
return proto.Unmarshal(buf, msg.(proto.Message))
}
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) JSONMarshal(msg drpc.Message) ([]byte, error) {
return protojson.Marshal(msg.(proto.Message))
}
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) JSONUnmarshal(buf []byte, msg drpc.Message) error {
return protojson.Unmarshal(buf, msg.(proto.Message))
}
type DRPCAgentSocketClient interface {
DRPCConn() drpc.Conn
Ping(ctx context.Context, in *PingRequest) (*PingResponse, error)
SyncStart(ctx context.Context, in *SyncStartRequest) (*SyncStartResponse, error)
SyncWant(ctx context.Context, in *SyncWantRequest) (*SyncWantResponse, error)
SyncComplete(ctx context.Context, in *SyncCompleteRequest) (*SyncCompleteResponse, error)
SyncReady(ctx context.Context, in *SyncReadyRequest) (*SyncReadyResponse, error)
SyncStatus(ctx context.Context, in *SyncStatusRequest) (*SyncStatusResponse, error)
}
type drpcAgentSocketClient struct {
cc drpc.Conn
}
func NewDRPCAgentSocketClient(cc drpc.Conn) DRPCAgentSocketClient {
return &drpcAgentSocketClient{cc}
}
func (c *drpcAgentSocketClient) DRPCConn() drpc.Conn { return c.cc }
func (c *drpcAgentSocketClient) Ping(ctx context.Context, in *PingRequest) (*PingResponse, error) {
out := new(PingResponse)
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/Ping", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
func (c *drpcAgentSocketClient) SyncStart(ctx context.Context, in *SyncStartRequest) (*SyncStartResponse, error) {
out := new(SyncStartResponse)
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncStart", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
func (c *drpcAgentSocketClient) SyncWant(ctx context.Context, in *SyncWantRequest) (*SyncWantResponse, error) {
out := new(SyncWantResponse)
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncWant", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
func (c *drpcAgentSocketClient) SyncComplete(ctx context.Context, in *SyncCompleteRequest) (*SyncCompleteResponse, error) {
out := new(SyncCompleteResponse)
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncComplete", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
func (c *drpcAgentSocketClient) SyncReady(ctx context.Context, in *SyncReadyRequest) (*SyncReadyResponse, error) {
out := new(SyncReadyResponse)
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncReady", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
func (c *drpcAgentSocketClient) SyncStatus(ctx context.Context, in *SyncStatusRequest) (*SyncStatusResponse, error) {
out := new(SyncStatusResponse)
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncStatus", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
type DRPCAgentSocketServer interface {
Ping(context.Context, *PingRequest) (*PingResponse, error)
SyncStart(context.Context, *SyncStartRequest) (*SyncStartResponse, error)
SyncWant(context.Context, *SyncWantRequest) (*SyncWantResponse, error)
SyncComplete(context.Context, *SyncCompleteRequest) (*SyncCompleteResponse, error)
SyncReady(context.Context, *SyncReadyRequest) (*SyncReadyResponse, error)
SyncStatus(context.Context, *SyncStatusRequest) (*SyncStatusResponse, error)
}
type DRPCAgentSocketUnimplementedServer struct{}
func (s *DRPCAgentSocketUnimplementedServer) Ping(context.Context, *PingRequest) (*PingResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentSocketUnimplementedServer) SyncStart(context.Context, *SyncStartRequest) (*SyncStartResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentSocketUnimplementedServer) SyncWant(context.Context, *SyncWantRequest) (*SyncWantResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentSocketUnimplementedServer) SyncComplete(context.Context, *SyncCompleteRequest) (*SyncCompleteResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentSocketUnimplementedServer) SyncReady(context.Context, *SyncReadyRequest) (*SyncReadyResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentSocketUnimplementedServer) SyncStatus(context.Context, *SyncStatusRequest) (*SyncStatusResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
type DRPCAgentSocketDescription struct{}
func (DRPCAgentSocketDescription) NumMethods() int { return 6 }
func (DRPCAgentSocketDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
switch n {
case 0:
return "/coder.agentsocket.v1.AgentSocket/Ping", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentSocketServer).
Ping(
ctx,
in1.(*PingRequest),
)
}, DRPCAgentSocketServer.Ping, true
case 1:
return "/coder.agentsocket.v1.AgentSocket/SyncStart", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentSocketServer).
SyncStart(
ctx,
in1.(*SyncStartRequest),
)
}, DRPCAgentSocketServer.SyncStart, true
case 2:
return "/coder.agentsocket.v1.AgentSocket/SyncWant", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentSocketServer).
SyncWant(
ctx,
in1.(*SyncWantRequest),
)
}, DRPCAgentSocketServer.SyncWant, true
case 3:
return "/coder.agentsocket.v1.AgentSocket/SyncComplete", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentSocketServer).
SyncComplete(
ctx,
in1.(*SyncCompleteRequest),
)
}, DRPCAgentSocketServer.SyncComplete, true
case 4:
return "/coder.agentsocket.v1.AgentSocket/SyncReady", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentSocketServer).
SyncReady(
ctx,
in1.(*SyncReadyRequest),
)
}, DRPCAgentSocketServer.SyncReady, true
case 5:
return "/coder.agentsocket.v1.AgentSocket/SyncStatus", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentSocketServer).
SyncStatus(
ctx,
in1.(*SyncStatusRequest),
)
}, DRPCAgentSocketServer.SyncStatus, true
default:
return "", nil, nil, nil, false
}
}
func DRPCRegisterAgentSocket(mux drpc.Mux, impl DRPCAgentSocketServer) error {
return mux.Register(impl, DRPCAgentSocketDescription{})
}
type DRPCAgentSocket_PingStream interface {
drpc.Stream
SendAndClose(*PingResponse) error
}
type drpcAgentSocket_PingStream struct {
drpc.Stream
}
func (x *drpcAgentSocket_PingStream) SendAndClose(m *PingResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
return err
}
return x.CloseSend()
}
type DRPCAgentSocket_SyncStartStream interface {
drpc.Stream
SendAndClose(*SyncStartResponse) error
}
type drpcAgentSocket_SyncStartStream struct {
drpc.Stream
}
func (x *drpcAgentSocket_SyncStartStream) SendAndClose(m *SyncStartResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
return err
}
return x.CloseSend()
}
type DRPCAgentSocket_SyncWantStream interface {
drpc.Stream
SendAndClose(*SyncWantResponse) error
}
type drpcAgentSocket_SyncWantStream struct {
drpc.Stream
}
func (x *drpcAgentSocket_SyncWantStream) SendAndClose(m *SyncWantResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
return err
}
return x.CloseSend()
}
type DRPCAgentSocket_SyncCompleteStream interface {
drpc.Stream
SendAndClose(*SyncCompleteResponse) error
}
type drpcAgentSocket_SyncCompleteStream struct {
drpc.Stream
}
func (x *drpcAgentSocket_SyncCompleteStream) SendAndClose(m *SyncCompleteResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
return err
}
return x.CloseSend()
}
type DRPCAgentSocket_SyncReadyStream interface {
drpc.Stream
SendAndClose(*SyncReadyResponse) error
}
type drpcAgentSocket_SyncReadyStream struct {
drpc.Stream
}
func (x *drpcAgentSocket_SyncReadyStream) SendAndClose(m *SyncReadyResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
return err
}
return x.CloseSend()
}
type DRPCAgentSocket_SyncStatusStream interface {
drpc.Stream
SendAndClose(*SyncStatusResponse) error
}
type drpcAgentSocket_SyncStatusStream struct {
drpc.Stream
}
func (x *drpcAgentSocket_SyncStatusStream) SendAndClose(m *SyncStatusResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
return err
}
return x.CloseSend()
}
+17
View File
@@ -0,0 +1,17 @@
package proto
import "github.com/coder/coder/v2/apiversion"
// Version history:
//
// API v1.0:
// - Initial release
// - Ping
// - Sync operations: SyncStart, SyncWant, SyncComplete, SyncWait, SyncStatus
const (
CurrentMajor = 1
CurrentMinor = 0
)
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor)
+185
View File
@@ -0,0 +1,185 @@
package agentsocket
import (
"context"
"errors"
"net"
"sync"
"time"
"golang.org/x/xerrors"
"github.com/hashicorp/yamux"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentsocket/proto"
"github.com/coder/coder/v2/agent/unit"
"github.com/coder/coder/v2/codersdk/drpcsdk"
)
// Server provides access to the DRPCAgentSocketService via a Unix domain socket.
// Do not invoke Server{} directly. Use NewServer() instead.
type Server struct {
logger slog.Logger
path string
listener net.Listener
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
drpcServer *drpcserver.Server
service *DRPCAgentSocketService
}
func NewServer(path string, logger slog.Logger) (*Server, error) {
logger = logger.Named("agentsocket")
server := &Server{
logger: logger,
path: path,
service: &DRPCAgentSocketService{
logger: logger,
unitManager: unit.NewManager[string, string](),
},
}
mux := drpcmux.New()
err := proto.DRPCRegisterAgentSocket(mux, server.service)
if err != nil {
return nil, xerrors.Errorf("failed to register drpc service: %w", err)
}
server.drpcServer = drpcserver.NewWithOptions(mux, drpcserver.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
Log: func(err error) {
if errors.Is(err, context.Canceled) ||
errors.Is(err, context.DeadlineExceeded) {
return
}
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
},
})
return server, nil
}
var ErrServerAlreadyStarted = xerrors.New("server already started")
func (s *Server) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.listener != nil {
return ErrServerAlreadyStarted
}
// This context is canceled by s.Stop() when the server is stopped.
// canceling it will close all connections.
s.ctx, s.cancel = context.WithCancel(context.Background())
if s.path == "" {
var err error
s.path, err = getDefaultSocketPath()
if err != nil {
return xerrors.Errorf("get default socket path: %w", err)
}
}
listener, err := createSocket(s.path)
if err != nil {
return xerrors.Errorf("create socket: %w", err)
}
s.listener = listener
s.logger.Info(s.ctx, "agent socket server started", slog.F("path", s.path))
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.acceptConnections()
}()
return nil
}
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.listener == nil {
return nil
}
s.logger.Info(s.ctx, "stopping agent socket server")
s.cancel()
if err := s.listener.Close(); err != nil {
s.logger.Warn(s.ctx, "error closing socket listener", slog.Error(err))
}
// Wait for all connections to finish
s.wg.Wait()
if err := cleanupSocket(s.path); err != nil {
s.logger.Warn(s.ctx, "error cleaning up socket file", slog.Error(err))
}
s.listener = nil
s.logger.Info(s.ctx, "agent socket server stopped")
return nil
}
func (s *Server) acceptConnections() {
for {
select {
case <-s.ctx.Done():
return
default:
}
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.ctx.Done():
return
default:
s.logger.Warn(s.ctx, "error accepting connection", slog.Error(err))
continue
}
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handleConnection(conn)
}()
}
}
func (s *Server) handleConnection(conn net.Conn) {
defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
s.logger.Warn(s.ctx, "failed to set connection deadline", slog.Error(err))
}
s.logger.Debug(s.ctx, "new connection accepted", slog.F("remote_addr", conn.RemoteAddr()))
config := yamux.DefaultConfig()
config.Logger = nil
session, err := yamux.Server(conn, config)
if err != nil {
s.logger.Warn(s.ctx, "failed to create yamux session", slog.Error(err))
return
}
defer session.Close()
err = s.drpcServer.Serve(s.ctx, session)
if err != nil {
s.logger.Debug(s.ctx, "drpc server finished", slog.Error(err))
}
}
+48
View File
@@ -0,0 +1,48 @@
package agentsocket_test
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentsocket"
)
func TestServer(t *testing.T) {
t.Parallel()
t.Run("StartStop", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
server, err := agentsocket.NewServer(socketPath, logger)
require.NoError(t, err)
require.NoError(t, server.Start())
require.NoError(t, server.Stop())
})
t.Run("AlreadyStarted", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
server, err := agentsocket.NewServer(socketPath, logger)
require.NoError(t, err)
require.NoError(t, server.Start())
require.ErrorIs(t, server.Start(), agentsocket.ErrServerAlreadyStarted)
})
t.Run("AutoSocketPath", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
server, err := agentsocket.NewServer(socketPath, logger)
require.NoError(t, err)
require.NoError(t, server.Start())
require.NoError(t, server.Stop())
})
}
+262
View File
@@ -0,0 +1,262 @@
package agentsocket
import (
"context"
"errors"
"sync"
"time"
"google.golang.org/protobuf/types/known/timestamppb"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentsocket/proto"
"github.com/coder/coder/v2/agent/unit"
)
var _ proto.DRPCAgentSocketServer = (*DRPCAgentSocketService)(nil)
type DRPCAgentSocketService struct {
mu sync.RWMutex
unitManager *unit.Manager[string, string]
logger slog.Logger
}
func (*DRPCAgentSocketService) Ping(_ context.Context, _ *proto.PingRequest) (*proto.PingResponse, error) {
return &proto.PingResponse{
Message: "pong",
Timestamp: timestamppb.New(time.Now()),
}, nil
}
func (s *DRPCAgentSocketService) SyncStart(_ context.Context, req *proto.SyncStartRequest) (*proto.SyncStartResponse, error) {
if s.unitManager == nil {
return &proto.SyncStartResponse{
Success: false,
Message: "dependency tracker not available",
}, nil
}
if req.Unit == "" {
return &proto.SyncStartResponse{
Success: false,
Message: "Unit name is required",
}, nil
}
if err := s.unitManager.Register(req.Unit); err != nil {
// If already registered, that's okay - we can still update status
if !errors.Is(err, unit.ErrConsumerAlreadyRegistered) {
return &proto.SyncStartResponse{
Success: false,
Message: "Failed to register unit: " + err.Error(),
}, nil
}
}
isReady, err := s.unitManager.IsReady(req.Unit)
if err != nil {
return &proto.SyncStartResponse{
Success: false,
Message: "Failed to check readiness: " + err.Error(),
}, nil
}
if !isReady {
return &proto.SyncStartResponse{
Success: false,
Message: "Unit is not ready",
}, nil
}
if err := s.unitManager.UpdateStatus(req.Unit, unit.StatusStarted); err != nil {
return &proto.SyncStartResponse{
Success: false,
Message: "Failed to update status: " + err.Error(),
}, nil
}
return &proto.SyncStartResponse{
Success: true,
Message: "Unit " + req.Unit + " started successfully",
}, nil
}
func (s *DRPCAgentSocketService) SyncWant(_ context.Context, req *proto.SyncWantRequest) (*proto.SyncWantResponse, error) {
if s.unitManager == nil {
return &proto.SyncWantResponse{
Success: false,
Message: "unit manager not available",
}, nil
}
if req.Unit == "" || req.DependsOn == "" {
return &proto.SyncWantResponse{
Success: false,
Message: "unit and depends_on are required",
}, nil
}
if err := s.unitManager.Register(req.Unit); err != nil {
if !errors.Is(err, unit.ErrConsumerAlreadyRegistered) {
return &proto.SyncWantResponse{
Success: false,
Message: "failed to register unit: " + err.Error(),
}, nil
}
}
if err := s.unitManager.Register(req.DependsOn); err != nil {
if !errors.Is(err, unit.ErrConsumerAlreadyRegistered) {
return &proto.SyncWantResponse{
Success: false,
Message: "failed to register dependency unit: " + err.Error(),
}, nil
}
}
if err := s.unitManager.AddDependency(req.Unit, req.DependsOn, unit.StatusComplete); err != nil {
return &proto.SyncWantResponse{
Success: false,
Message: "failed to add dependency: " + err.Error(),
}, nil
}
return &proto.SyncWantResponse{
Success: true,
Message: "Unit " + req.Unit + " now depends on " + req.DependsOn,
}, nil
}
func (s *DRPCAgentSocketService) SyncComplete(_ context.Context, req *proto.SyncCompleteRequest) (*proto.SyncCompleteResponse, error) {
if s.unitManager == nil {
return &proto.SyncCompleteResponse{
Success: false,
Message: "unit manager not available",
}, nil
}
if req.Unit == "" {
return &proto.SyncCompleteResponse{
Success: false,
Message: "unit name is required",
}, nil
}
if err := s.unitManager.UpdateStatus(req.Unit, unit.StatusComplete); err != nil {
return &proto.SyncCompleteResponse{
Success: false,
Message: "failed to update status: " + err.Error(),
}, nil
}
return &proto.SyncCompleteResponse{
Success: true,
Message: "unit " + req.Unit + " completed successfully",
}, nil
}
func (s *DRPCAgentSocketService) SyncReady(_ context.Context, req *proto.SyncReadyRequest) (*proto.SyncReadyResponse, error) {
if s.unitManager == nil {
return &proto.SyncReadyResponse{
Success: false,
Message: "unit manager not available",
}, nil
}
if req.Unit == "" {
return &proto.SyncReadyResponse{
Success: false,
Message: "unit name is required",
}, nil
}
isReady, err := s.unitManager.IsReady(req.Unit)
if err != nil {
return &proto.SyncReadyResponse{
Success: false,
Message: "failed to check readiness: " + err.Error(),
}, nil
}
if !isReady {
return &proto.SyncReadyResponse{
Success: false,
Message: unit.ErrDependenciesNotSatisfied.Error(),
}, nil
}
return &proto.SyncReadyResponse{
Success: true,
Message: "unit " + req.Unit + " dependencies are satisfied",
}, nil
}
func (s *DRPCAgentSocketService) SyncStatus(_ context.Context, req *proto.SyncStatusRequest) (*proto.SyncStatusResponse, error) {
if s.unitManager == nil {
return &proto.SyncStatusResponse{
Success: false,
Message: "unit manager not available",
}, nil
}
if req.Unit == "" {
return &proto.SyncStatusResponse{
Success: false,
Message: "unit name is required",
}, nil
}
status, err := s.unitManager.GetStatus(req.Unit)
if err != nil {
return &proto.SyncStatusResponse{
Success: false,
Message: "failed to get unit status: " + err.Error(),
}, nil
}
isReady, err := s.unitManager.IsReady(req.Unit)
if err != nil {
return &proto.SyncStatusResponse{
Success: false,
Message: "failed to check readiness: " + err.Error(),
}, nil
}
dependencies, err := s.unitManager.GetAllDependencies(req.Unit)
if err != nil {
return &proto.SyncStatusResponse{
Success: false,
Message: "failed to get dependencies: " + err.Error(),
}, nil
}
var depInfos []*proto.DependencyInfo
for _, dep := range dependencies {
depInfos = append(depInfos, &proto.DependencyInfo{
DependsOn: dep.DependsOn,
RequiredStatus: dep.RequiredStatus,
CurrentStatus: dep.CurrentStatus,
IsSatisfied: dep.IsSatisfied,
})
}
var dotStr string
if req.Recursive {
dotStr, err = s.unitManager.ExportDOT("dependency_graph")
if err != nil {
return &proto.SyncStatusResponse{
Success: false,
Message: "failed to export DOT: " + err.Error(),
}, nil
}
}
return &proto.SyncStatusResponse{
Success: true,
Message: "unit status retrieved successfully",
Unit: req.Unit,
Status: status,
IsReady: isReady,
Dependencies: depInfos,
Dot: dotStr,
}, nil
}
+311
View File
@@ -0,0 +1,311 @@
package agentsocket_test
import (
"context"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentsocket"
"github.com/coder/coder/v2/agent/unit"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
func TestDRPCAgentSocketService(t *testing.T) {
t.Parallel()
t.Run("Ping", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
server, err := agentsocket.NewServer(
socketPath,
slog.Make().Leveled(slog.LevelDebug),
)
require.NoError(t, err)
err = server.Start()
require.NoError(t, err)
defer server.Stop()
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: socketPath,
})
require.NoError(t, err)
defer client.Close()
response, err := client.Ping(context.Background())
require.NoError(t, err)
require.Equal(t, "pong", response.Message)
})
t.Run("SyncStart", func(t *testing.T) {
t.Parallel()
t.Run("NewUnit", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
server, err := agentsocket.NewServer(
socketPath,
slog.Make().Leveled(slog.LevelDebug),
)
require.NoError(t, err)
err = server.Start()
require.NoError(t, err)
defer server.Stop()
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: socketPath,
})
require.NoError(t, err)
defer client.Close()
err = client.SyncStart(context.Background(), "test-unit")
require.NoError(t, err)
status, err := client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "started", status.Status)
})
t.Run("UnitAlreadyStarted", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
server, err := agentsocket.NewServer(
socketPath,
slog.Make().Leveled(slog.LevelDebug),
)
require.NoError(t, err)
err = server.Start()
require.NoError(t, err)
defer server.Stop()
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: socketPath,
})
require.NoError(t, err)
defer client.Close()
err = client.SyncStart(context.Background(), "test-unit")
require.NoError(t, err)
// First Start
status, err := client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "started", status.Status)
status, err = client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "started", status.Status)
// Second Start
err = client.SyncStart(context.Background(), "test-unit")
require.ErrorContains(t, err, unit.ErrSameStatusAlreadySet.Error())
status, err = client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "started", status.Status)
})
t.Run("UnitAlreadyCompleted", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
server, err := agentsocket.NewServer(
socketPath,
slog.Make().Leveled(slog.LevelDebug),
)
require.NoError(t, err)
err = server.Start()
require.NoError(t, err)
defer server.Stop()
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: socketPath,
})
require.NoError(t, err)
defer client.Close()
// First start
err = client.SyncStart(context.Background(), "test-unit")
require.NoError(t, err)
status, err := client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "started", status.Status)
// Complete the unit
err = client.SyncComplete(context.Background(), "test-unit")
require.NoError(t, err)
status, err = client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "completed", status.Status)
// Second start
err = client.SyncStart(context.Background(), "test-unit")
require.NoError(t, err)
status, err = client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "started", status.Status)
})
t.Run("UnitNotReady", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
server, err := agentsocket.NewServer(
socketPath,
slog.Make().Leveled(slog.LevelDebug),
)
require.NoError(t, err)
err = server.Start()
require.NoError(t, err)
defer server.Stop()
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: socketPath,
})
require.NoError(t, err)
defer client.Close()
client.SyncWant(context.Background(), "test-unit", "dependency-unit")
require.NoError(t, err)
err = client.SyncStart(context.Background(), "test-unit")
require.ErrorContains(t, err, "Unit is not ready")
status, err := client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "", status.Status)
})
})
t.Run("SyncWant", func(t *testing.T) {
t.Parallel()
t.Run("NewUnits", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
server, err := agentsocket.NewServer(
socketPath,
slog.Make().Leveled(slog.LevelDebug),
)
require.NoError(t, err)
err = server.Start()
require.NoError(t, err)
defer server.Stop()
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: socketPath,
})
require.NoError(t, err)
defer client.Close()
// If units are not registered, they are registered automatically
err = client.SyncWant(context.Background(), "test-unit", "dependency-unit")
require.NoError(t, err)
status, err := client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn)
require.Equal(t, "completed", status.Dependencies[0].RequiredStatus)
})
t.Run("DependencyAlreadyRegistered", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
server, err := agentsocket.NewServer(
socketPath,
slog.Make().Leveled(slog.LevelDebug),
)
require.NoError(t, err)
err = server.Start()
require.NoError(t, err)
defer server.Stop()
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: socketPath,
})
require.NoError(t, err)
defer client.Close()
// Start the dependency unit
err = client.SyncStart(context.Background(), "dependency-unit")
require.NoError(t, err)
status, err := client.SyncStatus(context.Background(), "dependency-unit", false)
require.NoError(t, err)
require.Equal(t, "started", status.Status)
// Add the dependency after the dependency unit has already started
err = client.SyncWant(context.Background(), "test-unit", "dependency-unit")
// Dependencies can be added even if the dependency unit has already started
require.NoError(t, err)
// The dependency is now reflected in the test unit's status
status, err = client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn)
require.Equal(t, "completed", status.Dependencies[0].RequiredStatus)
})
t.Run("DependencyAddedAfterDependentStarted", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
server, err := agentsocket.NewServer(
socketPath,
slog.Make().Leveled(slog.LevelDebug),
)
require.NoError(t, err)
err = server.Start()
require.NoError(t, err)
defer server.Stop()
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: socketPath,
})
require.NoError(t, err)
defer client.Close()
// Start the dependent unit
err = client.SyncStart(context.Background(), "test-unit")
require.NoError(t, err)
status, err := client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "started", status.Status)
// Add the dependency after the dependency unit has already started
err = client.SyncWant(context.Background(), "test-unit", "dependency-unit")
// Dependencies can be added even if the dependent unit has already started.
// The dependency applies the next time a unit is started. The current status is not updated.
// This is to allow flexible dependency management. It does mean that users of this API should
// take care to add dependencies before they start their dependent units.
require.NoError(t, err)
// The dependency is now reflected in the test unit's status
status, err = client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn)
require.Equal(t, "completed", status.Dependencies[0].RequiredStatus)
})
})
}
+76
View File
@@ -0,0 +1,76 @@
//go:build !windows
package agentsocket
import (
"fmt"
"net"
"os"
"path/filepath"
"golang.org/x/xerrors"
)
// createSocket creates a Unix domain socket listener
func createSocket(path string) (net.Listener, error) {
if !isSocketAvailable(path) {
return nil, xerrors.Errorf("socket path %s is not available", path)
}
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return nil, xerrors.Errorf("remove existing socket: %w", err)
}
// Create parent directory if it doesn't exist
parentDir := filepath.Dir(path)
if err := os.MkdirAll(parentDir, 0o700); err != nil {
return nil, xerrors.Errorf("create socket directory: %w", err)
}
listener, err := net.Listen("unix", path)
if err != nil {
return nil, xerrors.Errorf("listen on unix socket: %w", err)
}
if err := os.Chmod(path, 0o600); err != nil {
_ = listener.Close()
return nil, xerrors.Errorf("set socket permissions: %w", err)
}
return listener, nil
}
// getDefaultSocketPath returns the default socket path for Unix-like systems
func getDefaultSocketPath() (string, error) {
// Try XDG_RUNTIME_DIR first
if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" {
return filepath.Join(runtimeDir, "coder-agent.sock"), nil
}
// Fall back to /tmp with user-specific path
uid := os.Getuid()
return filepath.Join("/tmp", fmt.Sprintf("coder-agent-%d.sock", uid)), nil
}
// CleanupSocket removes the socket file
func cleanupSocket(path string) error {
return os.Remove(path)
}
// isSocketAvailable checks if a socket path is available for use
func isSocketAvailable(path string) bool {
// Check if file exists
if _, err := os.Stat(path); os.IsNotExist(err) {
return true
}
// Try to connect to see if it's actually listening
conn, err := net.Dial("unix", path)
if err != nil {
// If we can't connect, the socket is not in use
// Socket is available for use
return true
}
_ = conn.Close()
// Socket is in use
return false
}
+111
View File
@@ -0,0 +1,111 @@
//go:build windows
package agentsocket
import (
"context"
"fmt"
"net"
"os"
"path/filepath"
"strconv"
"time"
"cdr.dev/slog"
)
// createSocket creates a Unix domain socket listener on Windows
// Falls back to named pipe if Unix sockets are not supported
func CreateSocket(path string) (net.Listener, error) {
// Try Unix domain socket first (Windows 10 build 17063+)
listener, err := net.Listen("unix", path)
if err == nil {
return listener, nil
}
// Fall back to named pipe
pipePath := `\\.\pipe\coder-agent`
listener, err = net.Listen("tcp", pipePath)
if err != nil {
return nil, err
}
return listener, nil
}
// getDefaultSocketPath returns the default socket path for Windows
func GetDefaultSocketPath() (string, error) {
// Try to use a temporary directory
tempDir := os.TempDir()
if tempDir == "" {
tempDir = "C:\\temp"
}
// Create a user-specific subdirectory
uid := os.Getuid()
userDir := filepath.Join(tempDir, "coder-agent", strconv.Itoa(uid))
if err := os.MkdirAll(userDir, 0o700); err != nil {
return "", fmt.Errorf("create user directory: %w", err)
}
return filepath.Join(userDir, "agent.sock"), nil
}
// cleanupSocket removes the socket file
func CleanupSocket(path string) error {
return os.Remove(path)
}
// isSocketAvailable checks if a socket path is available for use
func IsSocketAvailable(path string, logger slog.Logger) bool {
logger.Debug(context.Background(), "Checking socket availability on Windows", slog.F("path", path))
// Check if file exists
if _, err := os.Stat(path); os.IsNotExist(err) {
logger.Debug(context.Background(), "Socket file does not exist, path is available", slog.F("path", path))
return true
}
logger.Debug(context.Background(), "Socket file exists, checking if it's listening", slog.F("path", path))
// Try to connect to see if it's actually listening
conn, err := net.Dial("unix", path)
if err != nil {
// If we can't connect, the socket is not in use
logger.Debug(context.Background(), "Cannot connect to socket, path is available", slog.F("path", path), slog.Error(err))
return true
}
_ = conn.Close()
logger.Debug(context.Background(), "Socket is listening, path is not available", slog.F("path", path))
return false
}
// getSocketInfo returns information about the socket file
func GetSocketInfo(path string) (*SocketInfo, error) {
stat, err := os.Stat(path)
if err != nil {
return nil, err
}
// On Windows, we'll use a simplified approach for now
// In a real implementation, you'd get the security descriptor
return &SocketInfo{
Path: path,
UID: 0, // Simplified for now
GID: 0, // Simplified for now
Mode: stat.Mode(),
ModTime: stat.ModTime(),
Owner: "unknown",
Group: "unknown",
}, nil
}
// SocketInfo contains information about a socket file
type SocketInfo struct {
Path string
UID int
GID int
Mode os.FileMode
ModTime time.Time
Owner string // Windows SID string
Group string // Windows SID string
}
+307
View File
@@ -0,0 +1,307 @@
package unit
import (
"sync"
"golang.org/x/xerrors"
)
// ErrConsumerNotFound is returned when a consumer ID is not registered.
var ErrConsumerNotFound = xerrors.New("consumer not found")
// ErrConsumerAlreadyRegistered is returned when a consumer ID is already registered.
var ErrConsumerAlreadyRegistered = xerrors.New("consumer already registered")
// ErrCannotUpdateOtherConsumer is returned when attempting to update another consumer's status.
var ErrCannotUpdateOtherConsumer = xerrors.New("cannot update other consumer's status")
// ErrDependenciesNotSatisfied is returned when a consumer's dependencies are not satisfied.
var ErrDependenciesNotSatisfied = xerrors.New("unit dependencies not satisfied")
// ErrSameStatusAlreadySet is returned when attempting to set the same status as the current status.
var ErrSameStatusAlreadySet = xerrors.New("same status already set")
// Status constants for dependency tracking
const (
StatusStarted = "started"
StatusComplete = "completed"
)
// dependencyVertex represents a vertex in the dependency graph that is associated with a consumer.
type dependencyVertex[ConsumerID comparable] struct {
ID ConsumerID
}
// Dependency represents a dependency relationship between consumers.
type Dependency[StatusType, ConsumerID comparable] struct {
Consumer ConsumerID
DependsOn ConsumerID
RequiredStatus StatusType
CurrentStatus StatusType
IsSatisfied bool
}
// Manager provides reactive dependency tracking over a Graph.
// It manages consumer registration, dependency relationships, and status updates
// with automatic recalculation of readiness when dependencies are satisfied.
type Manager[StatusType, ConsumerID comparable] struct {
mu sync.RWMutex
// The underlying graph that stores dependency relationships
graph *Graph[StatusType, *dependencyVertex[ConsumerID]]
// Track current status of each consumer
consumerStatus map[ConsumerID]StatusType
// Track readiness state (cached to avoid repeated graph traversal)
consumerReadiness map[ConsumerID]bool
// Track which consumers are registered
registeredConsumers map[ConsumerID]bool
// Store vertex instances for each consumer to ensure consistent references
consumerVertices map[ConsumerID]*dependencyVertex[ConsumerID]
}
// NewManager creates a new Manager instance.
func NewManager[StatusType, ConsumerID comparable]() *Manager[StatusType, ConsumerID] {
return &Manager[StatusType, ConsumerID]{
graph: &Graph[StatusType, *dependencyVertex[ConsumerID]]{},
consumerStatus: make(map[ConsumerID]StatusType),
consumerReadiness: make(map[ConsumerID]bool),
registeredConsumers: make(map[ConsumerID]bool),
consumerVertices: make(map[ConsumerID]*dependencyVertex[ConsumerID]),
}
}
// Register registers a new consumer as a vertex in the dependency graph.
func (dt *Manager[StatusType, ConsumerID]) Register(id ConsumerID) error {
dt.mu.Lock()
defer dt.mu.Unlock()
if dt.registeredConsumers[id] {
return ErrConsumerAlreadyRegistered
}
// Create and store the vertex for this consumer
vertex := &dependencyVertex[ConsumerID]{ID: id}
dt.consumerVertices[id] = vertex
dt.registeredConsumers[id] = true
dt.consumerReadiness[id] = true // New consumers start as ready (no dependencies)
return nil
}
// AddDependency adds a dependency relationship between consumers.
// The consumer depends on the dependsOn consumer reaching the requiredStatus.
func (dt *Manager[StatusType, ConsumerID]) AddDependency(consumer ConsumerID, dependsOn ConsumerID, requiredStatus StatusType) error {
dt.mu.Lock()
defer dt.mu.Unlock()
if !dt.registeredConsumers[consumer] {
return xerrors.Errorf("consumer %v is not registered", consumer)
}
if !dt.registeredConsumers[dependsOn] {
return xerrors.Errorf("consumer %v is not registered", dependsOn)
}
// Get the stored vertices for both consumers
consumerVertex := dt.consumerVertices[consumer]
dependsOnVertex := dt.consumerVertices[dependsOn]
// Add the dependency edge to the graph
// The edge goes from consumer to dependsOn, representing the dependency
err := dt.graph.AddEdge(consumerVertex, dependsOnVertex, requiredStatus)
if err != nil {
return xerrors.Errorf("failed to add dependency: %w", err)
}
// Recalculate readiness for the consumer since it now has a dependency
dt.recalculateReadinessUnsafe(consumer)
return nil
}
// UpdateStatus updates a consumer's status and recalculates readiness for affected dependents.
func (dt *Manager[StatusType, ConsumerID]) UpdateStatus(consumer ConsumerID, newStatus StatusType) error {
dt.mu.Lock()
defer dt.mu.Unlock()
if !dt.registeredConsumers[consumer] {
return ErrConsumerNotFound
}
// Update the consumer's status
if dt.consumerStatus[consumer] == newStatus {
return ErrSameStatusAlreadySet
}
dt.consumerStatus[consumer] = newStatus
// Get all consumers that depend on this one (reverse adjacent vertices)
consumerVertex := dt.consumerVertices[consumer]
dependentEdges := dt.graph.GetReverseAdjacentVertices(consumerVertex)
// Recalculate readiness for all dependents
for _, edge := range dependentEdges {
dt.recalculateReadinessUnsafe(edge.From.ID)
}
return nil
}
// IsReady checks if all dependencies for a consumer are satisfied.
func (dt *Manager[StatusType, ConsumerID]) IsReady(consumer ConsumerID) (bool, error) {
dt.mu.RLock()
defer dt.mu.RUnlock()
if !dt.registeredConsumers[consumer] {
return false, ErrConsumerNotFound
}
return dt.consumerReadiness[consumer], nil
}
// GetUnmetDependencies returns a list of unsatisfied dependencies for a consumer.
func (dt *Manager[StatusType, ConsumerID]) GetUnmetDependencies(consumer ConsumerID) ([]Dependency[StatusType, ConsumerID], error) {
dt.mu.RLock()
defer dt.mu.RUnlock()
if !dt.registeredConsumers[consumer] {
return nil, ErrConsumerNotFound
}
consumerVertex := dt.consumerVertices[consumer]
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
var unmetDependencies []Dependency[StatusType, ConsumerID]
for _, edge := range forwardEdges {
dependsOnConsumer := edge.To.ID
requiredStatus := edge.Edge
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
if !exists {
// If the dependency consumer has no status, it's not satisfied
var zeroStatus StatusType
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
Consumer: consumer,
DependsOn: dependsOnConsumer,
RequiredStatus: requiredStatus,
CurrentStatus: zeroStatus, // Zero value
IsSatisfied: false,
})
} else {
isSatisfied := currentStatus == requiredStatus
if !isSatisfied {
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
Consumer: consumer,
DependsOn: dependsOnConsumer,
RequiredStatus: requiredStatus,
CurrentStatus: currentStatus,
IsSatisfied: false,
})
}
}
}
return unmetDependencies, nil
}
// recalculateReadinessUnsafe recalculates the readiness state for a consumer.
// This method assumes the caller holds the write lock.
func (dt *Manager[StatusType, ConsumerID]) recalculateReadinessUnsafe(consumer ConsumerID) {
consumerVertex := dt.consumerVertices[consumer]
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
// If there are no dependencies, the consumer is ready
if len(forwardEdges) == 0 {
dt.consumerReadiness[consumer] = true
return
}
// Check if all dependencies are satisfied
allSatisfied := true
for _, edge := range forwardEdges {
dependsOnConsumer := edge.To.ID
requiredStatus := edge.Edge
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
if !exists || currentStatus != requiredStatus {
allSatisfied = false
break
}
}
dt.consumerReadiness[consumer] = allSatisfied
}
// GetGraph returns the underlying graph for visualization and debugging.
// This should be used carefully as it exposes the internal graph structure.
func (dt *Manager[StatusType, ConsumerID]) GetGraph() *Graph[StatusType, *dependencyVertex[ConsumerID]] {
return dt.graph
}
// GetStatus returns the current status of a consumer.
func (dt *Manager[StatusType, ConsumerID]) GetStatus(consumer ConsumerID) (StatusType, error) {
dt.mu.RLock()
defer dt.mu.RUnlock()
if !dt.registeredConsumers[consumer] {
var zeroStatus StatusType
return zeroStatus, ErrConsumerNotFound
}
status, exists := dt.consumerStatus[consumer]
if !exists {
var zeroStatus StatusType
return zeroStatus, nil
}
return status, nil
}
// GetAllDependencies returns all dependencies for a consumer, both satisfied and unsatisfied.
func (dt *Manager[StatusType, ConsumerID]) GetAllDependencies(consumer ConsumerID) ([]Dependency[StatusType, ConsumerID], error) {
dt.mu.RLock()
defer dt.mu.RUnlock()
if !dt.registeredConsumers[consumer] {
return nil, ErrConsumerNotFound
}
consumerVertex := dt.consumerVertices[consumer]
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
var allDependencies []Dependency[StatusType, ConsumerID]
for _, edge := range forwardEdges {
dependsOnConsumer := edge.To.ID
requiredStatus := edge.Edge
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
if !exists {
// If the dependency consumer has no status, it's not satisfied
var zeroStatus StatusType
allDependencies = append(allDependencies, Dependency[StatusType, ConsumerID]{
Consumer: consumer,
DependsOn: dependsOnConsumer,
RequiredStatus: requiredStatus,
CurrentStatus: zeroStatus, // Zero value
IsSatisfied: false,
})
} else {
isSatisfied := currentStatus == requiredStatus
allDependencies = append(allDependencies, Dependency[StatusType, ConsumerID]{
Consumer: consumer,
DependsOn: dependsOnConsumer,
RequiredStatus: requiredStatus,
CurrentStatus: currentStatus,
IsSatisfied: isSatisfied,
})
}
}
return allDependencies, nil
}
// ExportDOT exports the dependency graph to DOT format for visualization.
func (dt *Manager[StatusType, ConsumerID]) ExportDOT(name string) (string, error) {
return dt.graph.ToDOT(name)
}
+691
View File
@@ -0,0 +1,691 @@
package unit_test
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/unit"
)
type testStatus string
const (
statusStarted testStatus = "started"
statusRunning testStatus = "running"
statusCompleted testStatus = "completed"
)
type testConsumerID string
const (
consumerA testConsumerID = "serviceA"
consumerB testConsumerID = "serviceB"
consumerC testConsumerID = "serviceC"
consumerD testConsumerID = "serviceD"
)
func TestDependencyTracker_Register(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
t.Run("RegisterNewConsumer", func(t *testing.T) {
t.Parallel()
err := tracker.Register(consumerA)
require.NoError(t, err)
// Consumer should be ready initially (no dependencies)
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("RegisterDuplicateConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerA)
require.Error(t, err)
assert.Contains(t, err.Error(), "already registered")
})
t.Run("RegisterMultipleConsumers", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
consumers := []testConsumerID{consumerA, consumerB, consumerC}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// All should be ready initially
for _, consumer := range consumers {
ready, err := tracker.IsReady(consumer)
require.NoError(t, err)
assert.True(t, ready)
}
})
}
func TestDependencyTracker_AddDependency(t *testing.T) {
t.Parallel()
t.Run("AddDependencyBetweenRegisteredConsumers", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
// A should no longer be ready (depends on B)
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// B should still be ready (no dependencies)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("AddDependencyWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
// Try to add dependency to unregistered consumer
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.Error(t, err)
assert.Contains(t, err.Error(), "not registered")
})
t.Run("AddDependencyFromUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.Register(consumerB)
require.NoError(t, err)
// Try to add dependency from unregistered consumer
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.Error(t, err)
assert.Contains(t, err.Error(), "not registered")
})
}
func TestDependencyTracker_UpdateStatus(t *testing.T) {
t.Parallel()
t.Run("UpdateStatusTriggersReadinessRecalculation", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
// Initially A is not ready
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update B to "running" - A should become ready
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.UpdateStatus(consumerA, statusRunning)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
})
t.Run("LinearChainDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register all consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// Create chain: A depends on B being "started", B depends on C being "completed"
err := tracker.AddDependency(consumerA, consumerB, statusStarted)
require.NoError(t, err)
err = tracker.AddDependency(consumerB, consumerC, statusCompleted)
require.NoError(t, err)
// Initially only C is ready
ready, err := tracker.IsReady(consumerC)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.False(t, ready)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update C to "completed" - B should become ready
err = tracker.UpdateStatus(consumerC, statusCompleted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update B to "started" - A should become ready
err = tracker.UpdateStatus(consumerB, statusStarted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
}
func TestDependencyTracker_GetUnmetDependencies(t *testing.T) {
t.Parallel()
t.Run("GetUnmetDependenciesForConsumerWithNoDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.NoError(t, err)
assert.Empty(t, unmet)
})
t.Run("GetUnmetDependenciesForConsumerWithUnsatisfiedDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.NoError(t, err)
require.Len(t, unmet, 1)
assert.Equal(t, consumerA, unmet[0].Consumer)
assert.Equal(t, consumerB, unmet[0].DependsOn)
assert.Equal(t, statusRunning, unmet[0].RequiredStatus)
assert.False(t, unmet[0].IsSatisfied)
})
t.Run("GetUnmetDependenciesForConsumerWithSatisfiedDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
// Update B to "running"
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.NoError(t, err)
assert.Empty(t, unmet)
})
t.Run("GetUnmetDependenciesForUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
assert.Nil(t, unmet)
})
}
func TestDependencyTracker_ConcurrentOperations(t *testing.T) {
t.Parallel()
t.Run("ConcurrentStatusUpdates", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// Create dependencies: A depends on B, B depends on C, C depends on D
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerB, consumerC, statusStarted)
require.NoError(t, err)
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
require.NoError(t, err)
var wg sync.WaitGroup
const numGoroutines = 10
// Launch goroutines that update statuses
errors := make([]error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
// Update D to completed (should make C ready)
err := tracker.UpdateStatus(consumerD, statusCompleted)
if err != nil {
errors[goroutineID] = err
return
}
// Update C to started (should make B ready)
err = tracker.UpdateStatus(consumerC, statusStarted)
if err != nil {
errors[goroutineID] = err
return
}
// Update B to running (should make A ready)
err = tracker.UpdateStatus(consumerB, statusRunning)
if err != nil {
errors[goroutineID] = err
return
}
}(i)
}
wg.Wait()
// Check for any errors in goroutines
for i, err := range errors {
require.NoError(t, err, "goroutine %d had error", i)
}
// All consumers should be ready after the updates
for _, consumer := range consumers {
ready, err := tracker.IsReady(consumer)
require.NoError(t, err)
assert.True(t, ready)
}
})
t.Run("ConcurrentReadinessChecks", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register consumers
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
var wg sync.WaitGroup
const numGoroutines = 20
// Launch goroutines that check readiness
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
// Check readiness multiple times
for j := 0; j < 10; j++ {
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
// Initially should be false, then true after B is updated
_ = ready
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
// B should always be ready (no dependencies)
assert.True(t, ready)
}
}(i)
}
// Update B to "running" in the middle of readiness checks
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
wg.Wait()
})
}
func TestDependencyTracker_MultipleDependencies(t *testing.T) {
t.Parallel()
t.Run("ConsumerWithMultipleDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register all consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// A depends on B being "running" AND C being "started"
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
require.NoError(t, err)
// A should not be ready (depends on both B and C)
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update B to "running" - A should still not be ready (needs C too)
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update C to "started" - A should now be ready
err = tracker.UpdateStatus(consumerC, statusStarted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("ComplexDependencyChain", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register all consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// Create complex dependency graph:
// A depends on B being "running" AND C being "started"
// B depends on D being "completed"
// C depends on D being "completed"
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
require.NoError(t, err)
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
require.NoError(t, err)
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
require.NoError(t, err)
// Initially only D is ready
ready, err := tracker.IsReady(consumerD)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.False(t, ready)
ready, err = tracker.IsReady(consumerC)
require.NoError(t, err)
assert.False(t, ready)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update D to "completed" - B and C should become ready
err = tracker.UpdateStatus(consumerD, statusCompleted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerC)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update B to "running" - A should still not be ready (needs C)
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update C to "started" - A should now be ready
err = tracker.UpdateStatus(consumerC, statusStarted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("DifferentStatusTypes", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register consumers
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
err = tracker.Register(consumerC)
require.NoError(t, err)
// A depends on B being "running" AND C being "completed"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerA, consumerC, statusCompleted)
require.NoError(t, err)
// Update B to "running" but not C - A should not be ready
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update C to "completed" - A should now be ready
err = tracker.UpdateStatus(consumerC, statusCompleted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
}
func TestDependencyTracker_ErrorCases(t *testing.T) {
t.Parallel()
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.UpdateStatus(consumerA, statusRunning)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
})
t.Run("IsReadyWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
ready, err := tracker.IsReady(consumerA)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
assert.False(t, ready)
})
t.Run("GetUnmetDependenciesWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
assert.Nil(t, unmet)
})
t.Run("AddDependencyWithUnregisteredConsumers", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Try to add dependency with unregistered consumers
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.Error(t, err)
assert.Contains(t, err.Error(), "not registered")
})
t.Run("CyclicDependencyDetection", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register consumers
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
// Try to make B depend on A (creates cycle)
err = tracker.AddDependency(consumerB, consumerA, statusStarted)
require.Error(t, err)
assert.Contains(t, err.Error(), "would create a cycle")
})
}
func TestDependencyTracker_ToDOT(t *testing.T) {
t.Parallel()
t.Run("ExportSimpleGraph", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register consumers
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// Add dependency
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
dot, err := tracker.ExportDOT("test")
require.NoError(t, err)
assert.NotEmpty(t, dot)
assert.Contains(t, dot, "digraph")
})
t.Run("ExportComplexGraph", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register all consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// Create complex dependency graph
// A depends on B and C, B depends on D, C depends on D
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
require.NoError(t, err)
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
require.NoError(t, err)
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
require.NoError(t, err)
dot, err := tracker.ExportDOT("complex")
require.NoError(t, err)
assert.NotEmpty(t, dot)
assert.Contains(t, dot, "digraph")
})
}
+8
View File
@@ -56,6 +56,7 @@ func workspaceAgent() *serpent.Command {
devcontainers bool
devcontainerProjectDiscovery bool
devcontainerDiscoveryAutostart bool
socketPath string
)
agentAuth := &AgentAuth{}
cmd := &serpent.Command{
@@ -297,6 +298,7 @@ func workspaceAgent() *serpent.Command {
agentcontainers.WithProjectDiscovery(devcontainerProjectDiscovery),
agentcontainers.WithDiscoveryAutostart(devcontainerDiscoveryAutostart),
},
SocketPath: socketPath,
})
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
@@ -449,6 +451,12 @@ func workspaceAgent() *serpent.Command {
Description: "Allow the agent to autostart devcontainer projects it discovers based on their configuration.",
Value: serpent.BoolOf(&devcontainerDiscoveryAutostart),
},
{
Flag: "socket-path",
Env: "CODER_AGENT_SOCKET_PATH",
Description: "Specify the path for the agent socket.",
Value: serpent.StringOf(&socketPath),
},
}
agentAuth.AttachOptions(cmd, false)
return cmd
-78
View File
@@ -1,78 +0,0 @@
package cli
import (
"encoding/csv"
"strings"
"github.com/spf13/pflag"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
)
var (
_ pflag.SliceValue = &AllowListFlag{}
_ pflag.Value = &AllowListFlag{}
)
// AllowListFlag implements pflag.SliceValue for codersdk.APIAllowListTarget entries.
type AllowListFlag []codersdk.APIAllowListTarget
func AllowListFlagOf(al *[]codersdk.APIAllowListTarget) *AllowListFlag {
return (*AllowListFlag)(al)
}
func (a AllowListFlag) String() string {
return strings.Join(a.GetSlice(), ",")
}
func (a AllowListFlag) Value() []codersdk.APIAllowListTarget {
return []codersdk.APIAllowListTarget(a)
}
func (AllowListFlag) Type() string { return "allow-list" }
func (a *AllowListFlag) Set(set string) error {
values, err := csv.NewReader(strings.NewReader(set)).Read()
if err != nil {
return xerrors.Errorf("parse allow list entries as csv: %w", err)
}
for _, v := range values {
if err := a.Append(v); err != nil {
return err
}
}
return nil
}
func (a *AllowListFlag) Append(value string) error {
value = strings.TrimSpace(value)
if value == "" {
return xerrors.New("allow list entry cannot be empty")
}
var target codersdk.APIAllowListTarget
if err := target.UnmarshalText([]byte(value)); err != nil {
return err
}
*a = append(*a, target)
return nil
}
func (a *AllowListFlag) Replace(items []string) error {
*a = []codersdk.APIAllowListTarget{}
for _, item := range items {
if err := a.Append(item); err != nil {
return err
}
}
return nil
}
func (a *AllowListFlag) GetSlice() []string {
out := make([]string, len(*a))
for i, entry := range *a {
out[i] = entry.String()
}
return out
}
+1
View File
@@ -144,6 +144,7 @@ func (r *RootCmd) AGPLExperimental() []*serpent.Command {
r.mcpCommand(),
r.promptExample(),
r.rptyCommand(),
r.syncCommand(),
r.tasksCommand(),
r.boundary(),
}
-47
View File
@@ -109,51 +109,6 @@ func (r *RootCmd) ssh() *serpent.Command {
}
},
),
CompletionHandler: func(inv *serpent.Invocation) []string {
client, err := r.InitClient(inv)
if err != nil {
return []string{}
}
res, err := client.Workspaces(inv.Context(), codersdk.WorkspaceFilter{
Owner: codersdk.Me,
})
if err != nil {
return []string{}
}
var mu sync.Mutex
var completions []string
var wg sync.WaitGroup
for _, ws := range res.Workspaces {
wg.Add(1)
go func() {
defer wg.Done()
resources, err := client.TemplateVersionResources(inv.Context(), ws.LatestBuild.TemplateVersionID)
if err != nil {
return
}
var agents []codersdk.WorkspaceAgent
for _, resource := range resources {
agents = append(agents, resource.Agents...)
}
mu.Lock()
defer mu.Unlock()
if len(agents) == 1 {
completions = append(completions, ws.Name)
} else {
for _, agent := range agents {
completions = append(completions, fmt.Sprintf("%s.%s", ws.Name, agent.Name))
}
}
}()
}
wg.Wait()
slices.Sort(completions)
return completions
},
Handler: func(inv *serpent.Invocation) (retErr error) {
client, err := r.InitClient(inv)
if err != nil {
@@ -951,8 +906,6 @@ func GetWorkspaceAndAgent(ctx context.Context, inv *serpent.Invocation, client *
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("start workspace with active template version: %w", err)
}
_, _ = fmt.Fprintln(inv.Stdout, "Unable to start the workspace with template version from last build. Your workspace has been updated to the current active template version.")
default:
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("start workspace with current template version: %w", err)
}
} else if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("start workspace with current template version: %w", err)
-96
View File
@@ -2447,99 +2447,3 @@ func tempDirUnixSocket(t *testing.T) string {
return t.TempDir()
}
func TestSSH_Completion(t *testing.T) {
t.Parallel()
t.Run("SingleAgent", func(t *testing.T) {
t.Parallel()
client, workspace, agentToken := setupWorkspaceForAgent(t)
_ = agenttest.New(t, client.URL, agentToken)
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
var stdout bytes.Buffer
inv, root := clitest.New(t, "ssh", "")
inv.Stdout = &stdout
inv.Environ.Set("COMPLETION_MODE", "1")
clitest.SetupConfig(t, client, root)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
// For single-agent workspaces, the only completion should be the
// bare workspace name.
output := stdout.String()
t.Logf("Completion output: %q", output)
require.Contains(t, output, workspace.Name)
})
t.Run("MultiAgent", func(t *testing.T) {
t.Parallel()
client, store := coderdtest.NewWithDatabase(t, nil)
first := coderdtest.CreateFirstUser(t, client)
userClient, user := coderdtest.CreateAnotherUserMutators(t, client, first.OrganizationID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
r.Username = "multiuser"
})
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
Name: "multiworkspace",
OrganizationID: first.OrganizationID,
OwnerID: user.ID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
return []*proto.Agent{
{
Name: "agent1",
Auth: &proto.Agent_Token{},
},
{
Name: "agent2",
Auth: &proto.Agent_Token{},
},
}
}).Do()
var stdout bytes.Buffer
inv, root := clitest.New(t, "ssh", "")
inv.Stdout = &stdout
inv.Environ.Set("COMPLETION_MODE", "1")
clitest.SetupConfig(t, userClient, root)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
// For multi-agent workspaces, completions should include the
// workspace.agent format but NOT the bare workspace name.
output := stdout.String()
t.Logf("Completion output: %q", output)
lines := strings.Split(strings.TrimSpace(output), "\n")
require.NotContains(t, lines, r.Workspace.Name)
require.Contains(t, output, r.Workspace.Name+".agent1")
require.Contains(t, output, r.Workspace.Name+".agent2")
})
t.Run("NetworkError", func(t *testing.T) {
t.Parallel()
var stdout bytes.Buffer
inv, _ := clitest.New(t, "ssh", "")
inv.Stdout = &stdout
inv.Environ.Set("COMPLETION_MODE", "1")
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
output := stdout.String()
require.Empty(t, output)
})
}
-17
View File
@@ -87,7 +87,6 @@ func buildNumberOption(n *int64) serpent.Option {
func (r *RootCmd) statePush() *serpent.Command {
var buildNumber int64
var noBuild bool
cmd := &serpent.Command{
Use: "push <workspace> <file>",
Short: "Push a Terraform state file to a workspace.",
@@ -127,16 +126,6 @@ func (r *RootCmd) statePush() *serpent.Command {
return err
}
if noBuild {
// Update state directly without triggering a build.
err = client.UpdateWorkspaceBuildState(inv.Context(), build.ID, state)
if err != nil {
return err
}
_, _ = fmt.Fprintln(inv.Stdout, "State updated successfully.")
return nil
}
build, err = client.CreateWorkspaceBuild(inv.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{
TemplateVersionID: build.TemplateVersionID,
Transition: build.Transition,
@@ -150,12 +139,6 @@ func (r *RootCmd) statePush() *serpent.Command {
}
cmd.Options = serpent.OptionSet{
buildNumberOption(&buildNumber),
{
Flag: "no-build",
FlagShorthand: "n",
Description: "Update the state without triggering a workspace build. Useful for state-only migrations.",
Value: serpent.BoolOf(&noBuild),
},
}
return cmd
}
-47
View File
@@ -2,7 +2,6 @@ package cli_test
import (
"bytes"
"context"
"fmt"
"os"
"path/filepath"
@@ -11,7 +10,6 @@ import (
"testing"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/stretchr/testify/require"
@@ -160,49 +158,4 @@ func TestStatePush(t *testing.T) {
err := inv.Run()
require.NoError(t, err)
})
t.Run("NoBuild", func(t *testing.T) {
t.Parallel()
client, store := coderdtest.NewWithDatabase(t, nil)
owner := coderdtest.CreateFirstUser(t, client)
templateAdmin, taUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID, rbac.RoleTemplateAdmin())
initialState := []byte("initial state")
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{ProvisionerState: initialState}).
Do()
wantState := []byte("updated state")
stateFile, err := os.CreateTemp(t.TempDir(), "")
require.NoError(t, err)
_, err = stateFile.Write(wantState)
require.NoError(t, err)
err = stateFile.Close()
require.NoError(t, err)
inv, root := clitest.New(t, "state", "push", "--no-build", r.Workspace.Name, stateFile.Name())
clitest.SetupConfig(t, templateAdmin, root)
var stdout bytes.Buffer
inv.Stdout = &stdout
err = inv.Run()
require.NoError(t, err)
require.Contains(t, stdout.String(), "State updated successfully")
// Verify the state was updated by pulling it.
inv, root = clitest.New(t, "state", "pull", r.Workspace.Name)
var gotState bytes.Buffer
inv.Stdout = &gotState
clitest.SetupConfig(t, templateAdmin, root)
err = inv.Run()
require.NoError(t, err)
require.Equal(t, wantState, bytes.TrimSpace(gotState.Bytes()))
// Verify no new build was created.
builds, err := store.GetWorkspaceBuildsByWorkspaceID(dbauthz.AsSystemRestricted(context.Background()), database.GetWorkspaceBuildsByWorkspaceIDParams{
WorkspaceID: r.Workspace.ID,
})
require.NoError(t, err)
require.Len(t, builds, 1, "expected only the initial build, no new build should be created")
})
}
+25
View File
@@ -0,0 +1,25 @@
package cli
import (
"github.com/coder/serpent"
)
func (r *RootCmd) syncCommand() *serpent.Command {
cmd := &serpent.Command{
Use: "sync",
Short: "Synchronize with the local agent socket",
Long: "Commands for interacting with the local Coder agent via socket communication.",
Handler: func(i *serpent.Invocation) error {
return i.Command.HelpHandler(i)
},
Children: []*serpent.Command{
r.syncPing(),
r.syncStart(),
r.syncWant(),
r.syncComplete(),
r.syncWait(),
r.syncStatus(),
},
}
return cmd
}
+50
View File
@@ -0,0 +1,50 @@
package cli
import (
"context"
"fmt"
"golang.org/x/xerrors"
"github.com/coder/serpent"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
func (r *RootCmd) syncComplete() *serpent.Command {
return &serpent.Command{
Use: "complete <unit>",
Short: "Mark a unit as complete in the dependency graph",
Long: "Set a unit's status to complete in the dependency graph.",
Handler: func(i *serpent.Invocation) error {
ctx := context.Background()
if len(i.Args) != 1 {
return xerrors.New("exactly one unit name is required")
}
unit := i.Args[0]
// Show initial message
fmt.Printf("Completing unit '%s'...\n", unit)
// Connect to agent socket
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: "/tmp/coder.sock",
})
if err != nil {
return xerrors.Errorf("connect to agent socket: %w", err)
}
defer client.Close()
// Complete the unit
if err := client.SyncComplete(ctx, unit); err != nil {
return xerrors.Errorf("complete unit failed: %w", err)
}
// Display success message
fmt.Printf("Unit '%s' completed successfully\n", unit)
return nil
},
}
}
+53
View File
@@ -0,0 +1,53 @@
package cli
import (
"context"
"fmt"
"time"
"golang.org/x/xerrors"
"github.com/coder/serpent"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
func (r *RootCmd) syncPing() *serpent.Command {
return &serpent.Command{
Use: "ping",
Short: "Ping the local agent socket",
Long: "Test connectivity to the local Coder agent via socket communication.",
Handler: func(i *serpent.Invocation) error {
ctx := context.Background()
// Show initial message
fmt.Println("Pinging agent socket...")
// Connect to agent socket
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: "/tmp/coder.sock",
})
if err != nil {
return xerrors.Errorf("connect to agent socket: %w", err)
}
defer client.Close()
// Measure round-trip time
start := time.Now()
resp, err := client.Ping(ctx)
duration := time.Since(start)
if err != nil {
return xerrors.Errorf("ping failed: %w", err)
}
// Display results
fmt.Printf("Response: %s\n", resp.Message)
fmt.Printf("Timestamp: %s\n", resp.Timestamp.Format(time.RFC3339))
fmt.Printf("Round-trip time: %s\n", duration.Round(time.Microsecond))
fmt.Println("Status: healthy")
return nil
},
}
}
+122
View File
@@ -0,0 +1,122 @@
package cli
import (
"context"
"fmt"
"time"
"golang.org/x/xerrors"
"github.com/coder/serpent"
"github.com/coder/coder/v2/agent/unit"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
const (
// SyncPollInterval is the interval between dependency checks for sync start
SyncPollInterval = 1 * time.Second
)
func (r *RootCmd) syncStart() *serpent.Command {
var timeout time.Duration
cmd := &serpent.Command{
Use: "start <unit>",
Short: "Start a unit in the dependency graph",
Long: "Register a unit in the dependency graph and set its status to started. Waits for all dependencies to be satisfied before marking as started.",
Handler: func(i *serpent.Invocation) error {
ctx := context.Background()
if len(i.Args) != 1 {
return xerrors.New("exactly one unit name is required")
}
unitName := i.Args[0]
// Set up context with timeout if specified
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
// Show initial message
fmt.Printf("Starting unit '%s'...\n", unitName)
// Connect to agent socket
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: "/tmp/coder.sock",
})
if err != nil {
return xerrors.Errorf("connect to agent socket: %w", err)
}
defer client.Close()
// Check if dependencies are satisfied first
err = client.SyncReady(ctx, unitName)
if err != nil {
// Check if it's a "not ready" error (expected if dependencies exist)
if xerrors.Is(err, unit.ErrDependenciesNotSatisfied) {
// Dependencies exist but aren't satisfied, start polling
fmt.Printf("Waiting for dependencies of unit '%s' to be satisfied...\n", unitName)
// Poll until dependencies are satisfied
ticker := time.NewTicker(SyncPollInterval)
defer ticker.Stop()
pollLoop:
for {
select {
case <-ctx.Done():
if ctx.Err() == context.DeadlineExceeded {
return xerrors.Errorf("timeout waiting for dependencies of unit '%s'", unitName)
}
return ctx.Err()
case <-ticker.C:
// Check if dependencies are satisfied
err := client.SyncReady(ctx, unitName)
if err == nil {
// Dependencies are satisfied
fmt.Printf("Dependencies satisfied, marking unit '%s' as started\n", unitName)
break pollLoop
}
// Check if it's still a "not ready" error (expected while waiting)
if xerrors.Is(err, unit.ErrDependenciesNotSatisfied) {
// Still waiting, continue polling
continue
}
// Some other error occurred
return xerrors.Errorf("error checking dependencies: %w", err)
}
}
} else {
// Some other error occurred
return xerrors.Errorf("error checking dependencies: %w", err)
}
} else {
// No dependencies or already satisfied
fmt.Printf("Dependencies satisfied, marking unit '%s' as started\n", unitName)
}
// Start the unit
if err := client.SyncStart(ctx, unitName); err != nil {
return xerrors.Errorf("start unit failed: %w", err)
}
// Display success message
fmt.Printf("Unit '%s' started successfully\n", unitName)
return nil
},
}
cmd.Options = append(cmd.Options, serpent.Option{
Flag: "timeout",
Description: "Maximum time to wait for dependencies (e.g., 30s, 5m). No timeout by default.",
Value: serpent.DurationOf(&timeout),
})
return cmd
}
+134
View File
@@ -0,0 +1,134 @@
package cli
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
"golang.org/x/xerrors"
"github.com/coder/serpent"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
type outputFormat string
const (
outputFormatHuman outputFormat = "human"
outputFormatJSON outputFormat = "json"
outputFormatDOT outputFormat = "dot"
)
func (r *RootCmd) syncStatus() *serpent.Command {
var (
output string
recursive bool
)
cmd := &serpent.Command{
Use: "status <unit>",
Short: "Show the status of a unit and its dependencies",
Long: "Display the current status of a unit and information about its dependencies. Supports multiple output formats.",
Handler: func(i *serpent.Invocation) error {
ctx := context.Background()
if len(i.Args) != 1 {
return xerrors.New("exactly one unit name is required")
}
unit := i.Args[0]
// Connect to agent socket
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: "/tmp/coder.sock",
})
if err != nil {
return xerrors.Errorf("connect to agent socket: %w", err)
}
defer client.Close()
// Get status information
statusResp, err := client.SyncStatus(ctx, unit, recursive)
if err != nil {
return xerrors.Errorf("get status failed: %w", err)
}
// Output based on format
switch outputFormat(output) {
case outputFormatJSON:
return outputJSON(statusResp)
case outputFormatDOT:
return outputDOT(statusResp)
default: // outputFormatHuman
return outputHuman(statusResp)
}
},
}
cmd.Options = append(cmd.Options,
serpent.Option{
Flag: "output",
FlagShorthand: "o",
Description: "Output format: human, json, or dot.",
Value: serpent.EnumOf(&output, "human", "json", "dot"),
},
serpent.Option{
Flag: "recursive",
FlagShorthand: "r",
Description: "Show transitive dependencies and include DOT graph.",
Value: serpent.BoolOf(&recursive),
},
)
return cmd
}
func outputJSON(statusResp *agentsdk.SyncStatusResponse) error {
encoder := json.NewEncoder(os.Stdout)
encoder.SetIndent("", " ")
return encoder.Encode(statusResp)
}
func outputDOT(statusResp *agentsdk.SyncStatusResponse) error {
if statusResp.DOT == "" {
return xerrors.New("DOT output requires --recursive flag")
}
fmt.Println(statusResp.DOT)
return nil
}
func outputHuman(statusResp *agentsdk.SyncStatusResponse) error {
// Unit status
fmt.Printf("Unit: %s\n", statusResp.Unit)
fmt.Printf("Status: %s\n", statusResp.Status)
fmt.Printf("Ready: %t\n", statusResp.IsReady)
fmt.Println()
// Dependencies
if len(statusResp.Dependencies) == 0 {
fmt.Println("No dependencies")
return nil
}
fmt.Println("Dependencies:")
fmt.Println(strings.Repeat("-", 80))
fmt.Printf("%-20s %-15s %-15s %-10s\n", "Depends On", "Required", "Current", "Satisfied")
fmt.Println(strings.Repeat("-", 80))
for _, dep := range statusResp.Dependencies {
satisfied := "✓"
if !dep.IsSatisfied {
satisfied = "✗"
}
fmt.Printf("%-20s %-15s %-15s %-10s\n",
dep.DependsOn,
dep.RequiredStatus,
dep.CurrentStatus,
satisfied,
)
}
return nil
}
+359
View File
@@ -0,0 +1,359 @@
package cli_test
import (
"errors"
"fmt"
"net"
"os"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/unit"
"github.com/coder/coder/v2/cli/clitest"
)
// mockAgentSocketServer simulates the agent socket server for testing
type mockAgentSocketServer struct {
listener net.Listener
handlers map[string]func(string) (string, error)
}
func newMockAgentSocketServer(t *testing.T, socketPath string) *mockAgentSocketServer {
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
server := &mockAgentSocketServer{
listener: listener,
handlers: make(map[string]func(string) (string, error)),
}
// Set up default handlers
server.handlers["sync.wait"] = func(unitName string) (string, error) {
// Always return dependencies not satisfied to trigger polling
return "", unit.ErrDependenciesNotSatisfied
}
server.handlers["sync.start"] = func(unitName string) (string, error) {
return "Unit " + unitName + " started successfully", nil
}
go server.serve(t)
return server
}
func (s *mockAgentSocketServer) serve(t *testing.T) {
for {
conn, err := s.listener.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
t.Logf("Accept error: %v", err)
}
return
}
go s.handleConnection(t, conn)
}
}
func (s *mockAgentSocketServer) handleConnection(t *testing.T, conn net.Conn) {
defer conn.Close()
// Simple JSON-RPC-like protocol simulation
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil {
t.Logf("Read error: %v", err)
return
}
request := string(buf[:n])
// Parse method from request (simplified)
var method string
if strings.Contains(request, "sync.wait") {
method = "sync.wait"
} else if strings.Contains(request, "sync.start") {
method = "sync.start"
}
handler, exists := s.handlers[method]
if !exists {
response := `{"error": {"code": -32601, "message": "Method not found"}}`
_, _ = conn.Write([]byte(response))
return
}
// Extract unit name from request (simplified)
unitName := "test-unit"
if strings.Contains(request, "test-unit") {
unitName = "test-unit"
}
message, err := handler(unitName)
if err != nil {
response := fmt.Sprintf(`{"error": {"code": -32603, "message": %q}}`, err.Error())
_, _ = conn.Write([]byte(response))
return
}
response := fmt.Sprintf(`{"result": {"success": true, "message": %q}}`, message)
_, _ = conn.Write([]byte(response))
}
func (s *mockAgentSocketServer) setHandler(method string, handler func(string) (string, error)) {
s.handlers[method] = handler
}
func (s *mockAgentSocketServer) close() {
_ = s.listener.Close()
}
func TestSyncStartTimeout(t *testing.T) {
t.Parallel()
// Create a unique temporary socket file
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
// Remove existing socket if it exists
_ = os.Remove(socketPath)
defer func() { _ = os.Remove(socketPath) }()
// Start mock server
server := newMockAgentSocketServer(t, socketPath)
defer server.close()
// Test with a short timeout
inv, _ := clitest.New(t, "exp", "sync", "start", "test-unit", "--timeout", "100ms")
// Override the socket path for this test
inv.Args = append(inv.Args, "--agent-socket", socketPath)
start := time.Now()
err := inv.Run()
duration := time.Since(start)
// Should timeout after approximately 100ms
assert.Error(t, err)
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
// Should timeout within a reasonable range (100ms + some buffer for test execution)
assert.True(t, duration >= 100*time.Millisecond, "Duration should be at least 100ms, got %v", duration)
assert.True(t, duration < 2*time.Second, "Duration should be less than 2s, got %v", duration)
}
func TestSyncWaitTimeout(t *testing.T) {
t.Parallel()
// Create a unique temporary socket file
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
// Remove existing socket if it exists
_ = os.Remove(socketPath)
defer func() { _ = os.Remove(socketPath) }()
// Start mock server
server := newMockAgentSocketServer(t, socketPath)
defer server.close()
// Test with a short timeout
inv, _ := clitest.New(t, "exp", "sync", "wait", "test-unit", "--timeout", "100ms")
// Override the socket path for this test
inv.Args = append(inv.Args, "--agent-socket", socketPath)
start := time.Now()
err := inv.Run()
duration := time.Since(start)
// Should timeout after approximately 100ms
assert.Error(t, err)
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
// Should timeout within a reasonable range (100ms + some buffer for test execution)
assert.True(t, duration >= 100*time.Millisecond, "Duration should be at least 100ms, got %v", duration)
assert.True(t, duration < 2*time.Second, "Duration should be less than 2s, got %v", duration)
}
func TestSyncStartNoTimeout(t *testing.T) {
t.Parallel()
// Create a unique temporary socket file
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
// Remove existing socket if it exists
_ = os.Remove(socketPath)
defer func() { _ = os.Remove(socketPath) }()
// Start mock server
server := newMockAgentSocketServer(t, socketPath)
defer server.close()
// Set up handler that will eventually succeed
callCount := 0
server.setHandler("sync.wait", func(unitName string) (string, error) {
callCount++
if callCount >= 3 {
// After 3 calls, dependencies are satisfied
return "Dependencies satisfied", nil
}
return "", unit.ErrDependenciesNotSatisfied
})
// Test without timeout - should eventually succeed
inv, _ := clitest.New(t, "exp", "sync", "start", "test-unit")
// Override the socket path for this test
inv.Args = append(inv.Args, "--agent-socket", socketPath)
start := time.Now()
err := inv.Run()
duration := time.Since(start)
// Should succeed after a few polling cycles
assert.NoError(t, err)
// Should take at least 2 seconds (2 polling cycles at 1s interval)
assert.True(t, duration >= 2*time.Second, "Duration should be at least 2s, got %v", duration)
assert.True(t, callCount >= 3, "Should have made at least 3 calls, got %d", callCount)
}
func TestSyncWaitNoTimeout(t *testing.T) {
t.Parallel()
// Create a unique temporary socket file
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
// Remove existing socket if it exists
_ = os.Remove(socketPath)
defer func() { _ = os.Remove(socketPath) }()
// Start mock server
server := newMockAgentSocketServer(t, socketPath)
defer server.close()
// Set up handler that will eventually succeed
callCount := 0
server.setHandler("sync.wait", func(unitName string) (string, error) {
callCount++
if callCount >= 3 {
// After 3 calls, dependencies are satisfied
return "Dependencies satisfied", nil
}
return "", unit.ErrDependenciesNotSatisfied
})
// Test without timeout - should eventually succeed
inv, _ := clitest.New(t, "exp", "sync", "wait", "test-unit")
// Override the socket path for this test
inv.Args = append(inv.Args, "--agent-socket", socketPath)
start := time.Now()
err := inv.Run()
duration := time.Since(start)
// Should succeed after a few polling cycles
assert.NoError(t, err)
// Should take at least 2 seconds (2 polling cycles at 1s interval)
assert.True(t, duration >= 2*time.Second, "Duration should be at least 2s, got %v", duration)
assert.True(t, callCount >= 3, "Should have made at least 3 calls, got %d", callCount)
}
func TestSyncStartTimeoutWithDifferentValues(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
timeout string
expected time.Duration
}{
{"50ms", "50ms", 50 * time.Millisecond},
{"200ms", "200ms", 200 * time.Millisecond},
{"500ms", "500ms", 500 * time.Millisecond},
{"1s", "1s", 1 * time.Second},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// Create a unique temporary socket file
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
// Remove existing socket if it exists
_ = os.Remove(socketPath)
defer func() { _ = os.Remove(socketPath) }()
// Start mock server
server := newMockAgentSocketServer(t, socketPath)
defer server.close()
// Test with specified timeout
inv, _ := clitest.New(t, "exp", "sync", "start", "test-unit", "--timeout", tc.timeout)
// Override the socket path for this test
inv.Args = append(inv.Args, "--agent-socket", socketPath)
start := time.Now()
err := inv.Run()
duration := time.Since(start)
// Should timeout after approximately the specified duration
assert.Error(t, err)
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
// Should timeout within a reasonable range
assert.True(t, duration >= tc.expected, "Duration should be at least %v, got %v", tc.expected, duration)
assert.True(t, duration < tc.expected+2*time.Second, "Duration should be less than %v, got %v", tc.expected+2*time.Second, duration)
})
}
}
func TestSyncWaitTimeoutWithDifferentValues(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
timeout string
expected time.Duration
}{
{"50ms", "50ms", 50 * time.Millisecond},
{"200ms", "200ms", 200 * time.Millisecond},
{"500ms", "500ms", 500 * time.Millisecond},
{"1s", "1s", 1 * time.Second},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// Create a unique temporary socket file
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
// Remove existing socket if it exists
_ = os.Remove(socketPath)
defer func() { _ = os.Remove(socketPath) }()
// Start mock server
server := newMockAgentSocketServer(t, socketPath)
defer server.close()
// Test with specified timeout
inv, _ := clitest.New(t, "exp", "sync", "wait", "test-unit", "--timeout", tc.timeout)
// Override the socket path for this test
inv.Args = append(inv.Args, "--agent-socket", socketPath)
start := time.Now()
err := inv.Run()
duration := time.Since(start)
// Should timeout after approximately the specified duration
assert.Error(t, err)
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
// Should timeout within a reasonable range
assert.True(t, duration >= tc.expected, "Duration should be at least %v, got %v", tc.expected, duration)
assert.True(t, duration < tc.expected+2*time.Second, "Duration should be less than %v, got %v", tc.expected+2*time.Second, duration)
})
}
}
+95
View File
@@ -0,0 +1,95 @@
package cli
import (
"context"
"fmt"
"time"
"golang.org/x/xerrors"
"github.com/coder/serpent"
"github.com/coder/coder/v2/agent/unit"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
const (
// PollInterval is the interval between dependency checks
PollInterval = 1 * time.Second
)
func (r *RootCmd) syncWait() *serpent.Command {
var timeout time.Duration
cmd := &serpent.Command{
Use: "wait <unit>",
Short: "Wait for a unit's dependencies to be satisfied",
Long: "Poll until all dependencies for a unit are met. Exits when dependencies are satisfied or timeout is reached.",
Handler: func(i *serpent.Invocation) error {
ctx := context.Background()
if len(i.Args) != 1 {
return xerrors.New("exactly one unit name is required")
}
unitName := i.Args[0]
// Set up context with timeout if specified
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
// Show initial message
fmt.Printf("Waiting for dependencies of unit '%s' to be satisfied...\n", unitName)
// Connect to agent socket
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: "/tmp/coder.sock",
})
if err != nil {
return xerrors.Errorf("connect to agent socket: %w", err)
}
defer client.Close()
// Poll until dependencies are satisfied
ticker := time.NewTicker(PollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
if ctx.Err() == context.DeadlineExceeded {
return xerrors.Errorf("timeout waiting for dependencies of unit '%s'", unitName)
}
return ctx.Err()
case <-ticker.C:
// Check if dependencies are satisfied
err := client.SyncReady(ctx, unitName)
if err == nil {
// Dependencies are satisfied
fmt.Printf("Dependencies for unit '%s' are now satisfied\n", unitName)
return nil
}
// Check if it's a "not ready" error (expected while waiting)
if xerrors.Is(err, unit.ErrDependenciesNotSatisfied) {
// Still waiting, continue polling
continue
}
// Some other error occurred
return xerrors.Errorf("error checking dependencies: %w", err)
}
}
},
}
cmd.Options = append(cmd.Options, serpent.Option{
Flag: "timeout",
Description: "Maximum time to wait for dependencies (e.g., 30s, 5m). No timeout by default.",
Value: serpent.DurationOf(&timeout),
})
return cmd
}
+51
View File
@@ -0,0 +1,51 @@
package cli
import (
"context"
"fmt"
"golang.org/x/xerrors"
"github.com/coder/serpent"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
func (r *RootCmd) syncWant() *serpent.Command {
return &serpent.Command{
Use: "want <unit> <depends-on>",
Short: "Declare a dependency between units",
Long: "Declare that a unit depends on another unit reaching complete status.",
Handler: func(i *serpent.Invocation) error {
ctx := context.Background()
if len(i.Args) != 2 {
return xerrors.New("exactly two arguments are required: unit and depends-on")
}
unit := i.Args[0]
dependsOn := i.Args[1]
// Show initial message
fmt.Printf("Declaring dependency: '%s' depends on '%s'...\n", unit, dependsOn)
// Connect to agent socket
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: "/tmp/coder.sock",
})
if err != nil {
return xerrors.Errorf("connect to agent socket: %w", err)
}
defer client.Close()
// Declare the dependency
if err := client.SyncWant(ctx, unit, dependsOn); err != nil {
return xerrors.Errorf("declare dependency failed: %w", err)
}
// Display success message
fmt.Printf("Dependency declared: '%s' now depends on '%s'\n", unit, dependsOn)
return nil
},
}
}
+8
View File
@@ -0,0 +1,8 @@
package cli_test
import (
"testing"
)
func TestSyncWant(t *testing.T) {
}
+3
View File
@@ -67,6 +67,9 @@ OPTIONS:
--script-data-dir string, $CODER_AGENT_SCRIPT_DATA_DIR (default: /tmp)
Specify the location for storing script data.
--socket-path string, $CODER_AGENT_SOCKET_PATH
Specify the path for the agent socket.
--ssh-max-timeout duration, $CODER_AGENT_SSH_MAX_TIMEOUT (default: 72h)
Specify the max timeout for a SSH connection, it is advisable to set
it to a minimum of 60s, but no more than 72h.
+1 -2
View File
@@ -90,7 +90,6 @@
"allow_renames": false,
"favorite": false,
"next_start_at": "====[timestamp]=====",
"is_prebuild": false,
"task_id": null
"is_prebuild": false
}
]
-35
View File
@@ -80,41 +80,6 @@ OPTIONS:
Periodically check for new releases of Coder and inform the owner. The
check is performed once per day.
AIBRIDGE OPTIONS:
--aibridge-anthropic-base-url string, $CODER_AIBRIDGE_ANTHROPIC_BASE_URL (default: https://api.anthropic.com/)
The base URL of the Anthropic API.
--aibridge-anthropic-key string, $CODER_AIBRIDGE_ANTHROPIC_KEY
The key to authenticate against the Anthropic API.
--aibridge-bedrock-access-key string, $CODER_AIBRIDGE_BEDROCK_ACCESS_KEY
The access key to authenticate against the AWS Bedrock API.
--aibridge-bedrock-access-key-secret string, $CODER_AIBRIDGE_BEDROCK_ACCESS_KEY_SECRET
The access key secret to use with the access key to authenticate
against the AWS Bedrock API.
--aibridge-bedrock-model string, $CODER_AIBRIDGE_BEDROCK_MODEL (default: global.anthropic.claude-sonnet-4-5-20250929-v1:0)
The model to use when making requests to the AWS Bedrock API.
--aibridge-bedrock-region string, $CODER_AIBRIDGE_BEDROCK_REGION
The AWS Bedrock API region.
--aibridge-bedrock-small-fastmodel string, $CODER_AIBRIDGE_BEDROCK_SMALL_FAST_MODEL (default: global.anthropic.claude-haiku-4-5-20251001-v1:0)
The small fast model to use when making requests to the AWS Bedrock
API. Claude Code uses Haiku-class models to perform background tasks.
See
https://docs.claude.com/en/docs/claude-code/settings#environment-variables.
--aibridge-enabled bool, $CODER_AIBRIDGE_ENABLED (default: false)
Whether to start an in-memory aibridged instance.
--aibridge-openai-base-url string, $CODER_AIBRIDGE_OPENAI_BASE_URL (default: https://api.openai.com/v1/)
The base URL of the OpenAI API.
--aibridge-openai-key string, $CODER_AIBRIDGE_OPENAI_KEY
The key to authenticate against the OpenAI API.
CLIENT OPTIONS:
These options change the behavior of how clients interact with the Coder.
Clients include the Coder CLI, Coder Desktop, IDE extensions, and the web UI.
-4
View File
@@ -9,9 +9,5 @@ OPTIONS:
-b, --build int
Specify a workspace build to target by name. Defaults to latest.
-n, --no-build bool
Update the state without triggering a workspace build. Useful for
state-only migrations.
———
Run `coder --help` for a list of global options.
-5
View File
@@ -16,10 +16,6 @@ USAGE:
$ coder tokens ls
- Create a scoped token:
$ coder tokens create --scope workspace:read --allow workspace:<uuid>
- Remove a token by ID:
$ coder tokens rm WuoWs4ZsMX
@@ -28,7 +24,6 @@ SUBCOMMANDS:
create Create a token
list List tokens
remove Delete a token
view Display detailed information about a token
———
Run `coder --help` for a list of global options.
+1 -9
View File
@@ -6,20 +6,12 @@ USAGE:
Create a token
OPTIONS:
--allow allow-list
Repeatable allow-list entry (<type>:<uuid>, e.g. workspace:1234-...).
--lifetime string, $CODER_TOKEN_LIFETIME
Duration for the token lifetime. Supports standard Go duration units
(ns, us, ms, s, m, h) plus d (days) and y (years). Examples: 8h, 30d,
1y, 1d12h30m.
Specify a duration for the lifetime of the token.
-n, --name string, $CODER_TOKEN_NAME
Specify a human-readable name.
--scope string-array
Repeatable scope to attach to the token (e.g. workspace:read).
-u, --user string, $CODER_TOKEN_USER
Specify the user to create the token for (Only works if logged in user
is admin).
+1 -1
View File
@@ -12,7 +12,7 @@ OPTIONS:
Specifies whether all users' tokens will be listed or not (must have
Owner role to see all tokens).
-c, --column [id|name|scopes|allow list|last used|expires at|created at|owner] (default: id,name,scopes,allow list,last used,expires at,created at)
-c, --column [id|name|last used|expires at|created at|owner] (default: id,name,last used,expires at,created at)
Columns to display in table output.
-o, --output table|json (default: table)
-16
View File
@@ -1,16 +0,0 @@
coder v0.0.0-devel
USAGE:
coder tokens view [flags] <name|id>
Display detailed information about a token
OPTIONS:
-c, --column [id|name|scopes|allow list|last used|expires at|created at|owner] (default: id,name,scopes,allow list,last used,expires at,created at,owner)
Columns to display in table output.
-o, --output table|json (default: table)
Output format.
———
Run `coder --help` for a list of global options.
+4 -21
View File
@@ -714,7 +714,8 @@ workspace_prebuilds:
# (default: 3, type: int)
failure_hard_limit: 3
aibridge:
# Whether to start an in-memory aibridged instance.
# Whether to start an in-memory aibridged instance ("aibridge" experiment must be
# enabled, too).
# (default: false, type: bool)
enabled: false
# The base URL of the OpenAI API.
@@ -725,25 +726,7 @@ aibridge:
openai_key: ""
# The base URL of the Anthropic API.
# (default: https://api.anthropic.com/, type: string)
anthropic_base_url: https://api.anthropic.com/
base_url: https://api.anthropic.com/
# The key to authenticate against the Anthropic API.
# (default: <unset>, type: string)
anthropic_key: ""
# The AWS Bedrock API region.
# (default: <unset>, type: string)
bedrock_region: ""
# The access key to authenticate against the AWS Bedrock API.
# (default: <unset>, type: string)
bedrock_access_key: ""
# The access key secret to use with the access key to authenticate against the AWS
# Bedrock API.
# (default: <unset>, type: string)
bedrock_access_key_secret: ""
# The model to use when making requests to the AWS Bedrock API.
# (default: global.anthropic.claude-sonnet-4-5-20250929-v1:0, type: string)
bedrock_model: global.anthropic.claude-sonnet-4-5-20250929-v1:0
# The small fast model to use when making requests to the AWS Bedrock API. Claude
# Code uses Haiku-class models to perform background tasks. See
# https://docs.claude.com/en/docs/claude-code/settings#environment-variables.
# (default: global.anthropic.claude-haiku-4-5-20251001-v1:0, type: string)
bedrock_small_fast_model: global.anthropic.claude-haiku-4-5-20251001-v1:0
key: ""
+6 -104
View File
@@ -4,14 +4,12 @@ import (
"fmt"
"os"
"slices"
"sort"
"strings"
"time"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
@@ -29,10 +27,6 @@ func (r *RootCmd) tokens() *serpent.Command {
Description: "List your tokens",
Command: "coder tokens ls",
},
Example{
Description: "Create a scoped token",
Command: "coder tokens create --scope workspace:read --allow workspace:<uuid>",
},
Example{
Description: "Remove a token by ID",
Command: "coder tokens rm WuoWs4ZsMX",
@@ -45,7 +39,6 @@ func (r *RootCmd) tokens() *serpent.Command {
Children: []*serpent.Command{
r.createToken(),
r.listTokens(),
r.viewToken(),
r.removeToken(),
},
}
@@ -57,8 +50,6 @@ func (r *RootCmd) createToken() *serpent.Command {
tokenLifetime string
name string
user string
scopes []string
allowList []codersdk.APIAllowListTarget
)
cmd := &serpent.Command{
Use: "create",
@@ -97,18 +88,10 @@ func (r *RootCmd) createToken() *serpent.Command {
}
}
req := codersdk.CreateTokenRequest{
res, err := client.CreateToken(inv.Context(), userID, codersdk.CreateTokenRequest{
Lifetime: parsedLifetime,
TokenName: name,
}
if len(req.Scopes) == 0 {
req.Scopes = slice.StringEnums[codersdk.APIKeyScope](scopes)
}
if len(allowList) > 0 {
req.AllowList = append([]codersdk.APIAllowListTarget(nil), allowList...)
}
res, err := client.CreateToken(inv.Context(), userID, req)
})
if err != nil {
return xerrors.Errorf("create tokens: %w", err)
}
@@ -123,7 +106,7 @@ func (r *RootCmd) createToken() *serpent.Command {
{
Flag: "lifetime",
Env: "CODER_TOKEN_LIFETIME",
Description: "Duration for the token lifetime. Supports standard Go duration units (ns, us, ms, s, m, h) plus d (days) and y (years). Examples: 8h, 30d, 1y, 1d12h30m.",
Description: "Specify a duration for the lifetime of the token.",
Value: serpent.StringOf(&tokenLifetime),
},
{
@@ -140,16 +123,6 @@ func (r *RootCmd) createToken() *serpent.Command {
Description: "Specify the user to create the token for (Only works if logged in user is admin).",
Value: serpent.StringOf(&user),
},
{
Flag: "scope",
Description: "Repeatable scope to attach to the token (e.g. workspace:read).",
Value: serpent.StringArrayOf(&scopes),
},
{
Flag: "allow",
Description: "Repeatable allow-list entry (<type>:<uuid>, e.g. workspace:1234-...).",
Value: AllowListFlagOf(&allowList),
},
}
return cmd
@@ -163,8 +136,6 @@ type tokenListRow struct {
// For table format:
ID string `json:"-" table:"id,default_sort"`
TokenName string `json:"token_name" table:"name"`
Scopes string `json:"-" table:"scopes"`
Allow string `json:"-" table:"allow list"`
LastUsed time.Time `json:"-" table:"last used"`
ExpiresAt time.Time `json:"-" table:"expires at"`
CreatedAt time.Time `json:"-" table:"created at"`
@@ -172,47 +143,20 @@ type tokenListRow struct {
}
func tokenListRowFromToken(token codersdk.APIKeyWithOwner) tokenListRow {
return tokenListRowFromKey(token.APIKey, token.Username)
}
func tokenListRowFromKey(token codersdk.APIKey, owner string) tokenListRow {
return tokenListRow{
APIKey: token,
APIKey: token.APIKey,
ID: token.ID,
TokenName: token.TokenName,
Scopes: joinScopes(token.Scopes),
Allow: joinAllowList(token.AllowList),
LastUsed: token.LastUsed,
ExpiresAt: token.ExpiresAt,
CreatedAt: token.CreatedAt,
Owner: owner,
Owner: token.Username,
}
}
func joinScopes(scopes []codersdk.APIKeyScope) string {
if len(scopes) == 0 {
return ""
}
vals := slice.ToStrings(scopes)
sort.Strings(vals)
return strings.Join(vals, ", ")
}
func joinAllowList(entries []codersdk.APIAllowListTarget) string {
if len(entries) == 0 {
return ""
}
vals := make([]string, len(entries))
for i, entry := range entries {
vals[i] = entry.String()
}
sort.Strings(vals)
return strings.Join(vals, ", ")
}
func (r *RootCmd) listTokens() *serpent.Command {
// we only display the 'owner' column if the --all argument is passed in
defaultCols := []string{"id", "name", "scopes", "allow list", "last used", "expires at", "created at"}
defaultCols := []string{"id", "name", "last used", "expires at", "created at"}
if slices.Contains(os.Args, "-a") || slices.Contains(os.Args, "--all") {
defaultCols = append(defaultCols, "owner")
}
@@ -282,48 +226,6 @@ func (r *RootCmd) listTokens() *serpent.Command {
return cmd
}
func (r *RootCmd) viewToken() *serpent.Command {
formatter := cliui.NewOutputFormatter(
cliui.TableFormat([]tokenListRow{}, []string{"id", "name", "scopes", "allow list", "last used", "expires at", "created at", "owner"}),
cliui.JSONFormat(),
)
cmd := &serpent.Command{
Use: "view <name|id>",
Short: "Display detailed information about a token",
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
tokenName := inv.Args[0]
token, err := client.APIKeyByName(inv.Context(), codersdk.Me, tokenName)
if err != nil {
maybeID := strings.Split(tokenName, "-")[0]
token, err = client.APIKeyByID(inv.Context(), codersdk.Me, maybeID)
if err != nil {
return xerrors.Errorf("fetch api key by name or id: %w", err)
}
}
row := tokenListRowFromKey(*token, "")
out, err := formatter.Format(inv.Context(), []tokenListRow{row})
if err != nil {
return err
}
_, err = fmt.Fprintln(inv.Stdout, out)
return err
},
}
formatter.AttachOptions(&cmd.Options)
return cmd
}
func (r *RootCmd) removeToken() *serpent.Command {
cmd := &serpent.Command{
Use: "remove <name|id|token>",
+3 -56
View File
@@ -4,13 +4,10 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/google/uuid"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
@@ -49,18 +46,6 @@ func TestTokens(t *testing.T) {
require.NotEmpty(t, res)
id := res[:10]
allowWorkspaceID := uuid.New()
allowSpec := fmt.Sprintf("workspace:%s", allowWorkspaceID.String())
inv, root = clitest.New(t, "tokens", "create", "--name", "scoped-token", "--scope", string(codersdk.APIKeyScopeWorkspaceRead), "--allow", allowSpec)
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
res = buf.String()
require.NotEmpty(t, res)
scopedTokenID := res[:10]
// Test creating a token for second user from first user's (admin) session
inv, root = clitest.New(t, "tokens", "create", "--name", "token-two", "--user", secondUser.ID.String())
clitest.SetupConfig(t, client, root)
@@ -82,7 +67,7 @@ func TestTokens(t *testing.T) {
require.NoError(t, err)
res = buf.String()
require.NotEmpty(t, res)
// Result should only contain the tokens created for the admin user
// Result should only contain the token created for the admin user
require.Contains(t, res, "ID")
require.Contains(t, res, "EXPIRES AT")
require.Contains(t, res, "CREATED AT")
@@ -91,16 +76,6 @@ func TestTokens(t *testing.T) {
// Result should not contain the token created for the second user
require.NotContains(t, res, secondTokenID)
inv, root = clitest.New(t, "tokens", "view", "scoped-token")
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
res = buf.String()
require.Contains(t, res, string(codersdk.APIKeyScopeWorkspaceRead))
require.Contains(t, res, allowSpec)
// Test listing tokens from the second user's session
inv, root = clitest.New(t, "tokens", "ls")
clitest.SetupConfig(t, secondUserClient, root)
@@ -126,14 +101,6 @@ func TestTokens(t *testing.T) {
// User (non-admin) should not be able to create a token for another user
require.Error(t, err)
inv, root = clitest.New(t, "tokens", "create", "--name", "invalid-allow", "--allow", "badvalue")
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.Error(t, err)
require.Contains(t, err.Error(), "invalid allow_list entry")
inv, root = clitest.New(t, "tokens", "ls", "--output=json")
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
@@ -143,17 +110,8 @@ func TestTokens(t *testing.T) {
var tokens []codersdk.APIKey
require.NoError(t, json.Unmarshal(buf.Bytes(), &tokens))
require.Len(t, tokens, 2)
tokenByName := make(map[string]codersdk.APIKey, len(tokens))
for _, tk := range tokens {
tokenByName[tk.TokenName] = tk
}
require.Contains(t, tokenByName, "token-one")
require.Contains(t, tokenByName, "scoped-token")
scopedToken := tokenByName["scoped-token"]
require.Contains(t, scopedToken.Scopes, codersdk.APIKeyScopeWorkspaceRead)
require.Len(t, scopedToken.AllowList, 1)
require.Equal(t, allowSpec, scopedToken.AllowList[0].String())
require.Len(t, tokens, 1)
require.Equal(t, id, tokens[0].ID)
// Delete by name
inv, root = clitest.New(t, "tokens", "rm", "token-one")
@@ -177,17 +135,6 @@ func TestTokens(t *testing.T) {
require.NotEmpty(t, res)
require.Contains(t, res, "deleted")
// Delete scoped token by ID
inv, root = clitest.New(t, "tokens", "rm", scopedTokenID)
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
res = buf.String()
require.NotEmpty(t, res)
require.Contains(t, res, "deleted")
// Create third token
inv, root = clitest.New(t, "tokens", "create", "--name", "token-three")
clitest.SetupConfig(t, client, root)
-4
View File
@@ -239,10 +239,6 @@ func (a *API) Serve(ctx context.Context, l net.Listener) error {
return xerrors.Errorf("create agent API server: %w", err)
}
if err := a.ResourcesMonitoringAPI.InitMonitors(ctx); err != nil {
return xerrors.Errorf("initialize resource monitoring: %w", err)
}
return server.Serve(ctx, l)
}
+35 -52
View File
@@ -5,7 +5,6 @@ import (
"database/sql"
"errors"
"fmt"
"sync"
"time"
"golang.org/x/xerrors"
@@ -34,60 +33,42 @@ type ResourcesMonitoringAPI struct {
Debounce time.Duration
Config resourcesmonitor.Config
// Cache resource monitors on first call to avoid millions of DB queries per day.
memoryMonitor database.WorkspaceAgentMemoryResourceMonitor
volumeMonitors []database.WorkspaceAgentVolumeResourceMonitor
monitorsLock sync.RWMutex
}
// InitMonitors fetches resource monitors from the database and caches them.
// This must be called once after creating a ResourcesMonitoringAPI, the context should be
// the agent per-RPC connection context. If fetching fails with a real error (not sql.ErrNoRows), the
// connection should be torn down.
func (a *ResourcesMonitoringAPI) InitMonitors(ctx context.Context) error {
memMon, err := a.Database.FetchMemoryResourceMonitorsByAgentID(ctx, a.AgentID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("fetch memory resource monitor: %w", err)
}
// If sql.ErrNoRows, memoryMonitor stays as zero value (CreatedAt.IsZero() = true).
// Otherwise, store the fetched monitor.
if err == nil {
a.memoryMonitor = memMon
func (a *ResourcesMonitoringAPI) GetResourcesMonitoringConfiguration(ctx context.Context, _ *proto.GetResourcesMonitoringConfigurationRequest) (*proto.GetResourcesMonitoringConfigurationResponse, error) {
memoryMonitor, memoryErr := a.Database.FetchMemoryResourceMonitorsByAgentID(ctx, a.AgentID)
if memoryErr != nil && !errors.Is(memoryErr, sql.ErrNoRows) {
return nil, xerrors.Errorf("failed to fetch memory resource monitor: %w", memoryErr)
}
volMons, err := a.Database.FetchVolumesResourceMonitorsByAgentID(ctx, a.AgentID)
volumeMonitors, err := a.Database.FetchVolumesResourceMonitorsByAgentID(ctx, a.AgentID)
if err != nil {
return xerrors.Errorf("fetch volume resource monitors: %w", err)
return nil, xerrors.Errorf("failed to fetch volume resource monitors: %w", err)
}
// 0 length is valid, indicating none configured, since the volume monitors in the DB can be many.
a.volumeMonitors = volMons
return nil
}
func (a *ResourcesMonitoringAPI) GetResourcesMonitoringConfiguration(_ context.Context, _ *proto.GetResourcesMonitoringConfigurationRequest) (*proto.GetResourcesMonitoringConfigurationResponse, error) {
return &proto.GetResourcesMonitoringConfigurationResponse{
Config: &proto.GetResourcesMonitoringConfigurationResponse_Config{
CollectionIntervalSeconds: int32(a.Config.CollectionInterval.Seconds()),
NumDatapoints: a.Config.NumDatapoints,
},
Memory: func() *proto.GetResourcesMonitoringConfigurationResponse_Memory {
if a.memoryMonitor.CreatedAt.IsZero() {
if memoryErr != nil {
return nil
}
return &proto.GetResourcesMonitoringConfigurationResponse_Memory{
Enabled: a.memoryMonitor.Enabled,
Enabled: memoryMonitor.Enabled,
}
}(),
Volumes: func() []*proto.GetResourcesMonitoringConfigurationResponse_Volume {
volumes := make([]*proto.GetResourcesMonitoringConfigurationResponse_Volume, 0, len(a.volumeMonitors))
for _, monitor := range a.volumeMonitors {
volumes := make([]*proto.GetResourcesMonitoringConfigurationResponse_Volume, 0, len(volumeMonitors))
for _, monitor := range volumeMonitors {
volumes = append(volumes, &proto.GetResourcesMonitoringConfigurationResponse_Volume{
Enabled: monitor.Enabled,
Path: monitor.Path,
})
}
return volumes
}(),
}, nil
@@ -96,10 +77,6 @@ func (a *ResourcesMonitoringAPI) GetResourcesMonitoringConfiguration(_ context.C
func (a *ResourcesMonitoringAPI) PushResourcesMonitoringUsage(ctx context.Context, req *proto.PushResourcesMonitoringUsageRequest) (*proto.PushResourcesMonitoringUsageResponse, error) {
var err error
// Lock for the entire push operation since calls are sequential from the agent
a.monitorsLock.Lock()
defer a.monitorsLock.Unlock()
if memoryErr := a.monitorMemory(ctx, req.Datapoints); memoryErr != nil {
err = errors.Join(err, xerrors.Errorf("monitor memory: %w", memoryErr))
}
@@ -112,7 +89,18 @@ func (a *ResourcesMonitoringAPI) PushResourcesMonitoringUsage(ctx context.Contex
}
func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints []*proto.PushResourcesMonitoringUsageRequest_Datapoint) error {
if !a.memoryMonitor.Enabled {
monitor, err := a.Database.FetchMemoryResourceMonitorsByAgentID(ctx, a.AgentID)
if err != nil {
// It is valid for an agent to not have a memory monitor, so we
// do not want to treat it as an error.
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return xerrors.Errorf("fetch memory resource monitor: %w", err)
}
if !monitor.Enabled {
return nil
}
@@ -121,15 +109,15 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
usageDatapoints = append(usageDatapoints, datapoint.Memory)
}
usageStates := resourcesmonitor.CalculateMemoryUsageStates(a.memoryMonitor, usageDatapoints)
usageStates := resourcesmonitor.CalculateMemoryUsageStates(monitor, usageDatapoints)
oldState := a.memoryMonitor.State
oldState := monitor.State
newState := resourcesmonitor.NextState(a.Config, oldState, usageStates)
debouncedUntil, shouldNotify := a.memoryMonitor.Debounce(a.Debounce, a.Clock.Now(), oldState, newState)
debouncedUntil, shouldNotify := monitor.Debounce(a.Debounce, a.Clock.Now(), oldState, newState)
//nolint:gocritic // We need to be able to update the resource monitor here.
err := a.Database.UpdateMemoryResourceMonitor(dbauthz.AsResourceMonitor(ctx), database.UpdateMemoryResourceMonitorParams{
err = a.Database.UpdateMemoryResourceMonitor(dbauthz.AsResourceMonitor(ctx), database.UpdateMemoryResourceMonitorParams{
AgentID: a.AgentID,
State: newState,
UpdatedAt: dbtime.Time(a.Clock.Now()),
@@ -139,11 +127,6 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
return xerrors.Errorf("update workspace monitor: %w", err)
}
// Update cached state
a.memoryMonitor.State = newState
a.memoryMonitor.DebouncedUntil = dbtime.Time(debouncedUntil)
a.memoryMonitor.UpdatedAt = dbtime.Time(a.Clock.Now())
if !shouldNotify {
return nil
}
@@ -160,7 +143,7 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
notifications.TemplateWorkspaceOutOfMemory,
map[string]string{
"workspace": workspace.Name,
"threshold": fmt.Sprintf("%d%%", a.memoryMonitor.Threshold),
"threshold": fmt.Sprintf("%d%%", monitor.Threshold),
},
map[string]any{
// NOTE(DanielleMaywood):
@@ -186,9 +169,14 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
}
func (a *ResourcesMonitoringAPI) monitorVolumes(ctx context.Context, datapoints []*proto.PushResourcesMonitoringUsageRequest_Datapoint) error {
volumeMonitors, err := a.Database.FetchVolumesResourceMonitorsByAgentID(ctx, a.AgentID)
if err != nil {
return xerrors.Errorf("get or insert volume monitor: %w", err)
}
outOfDiskVolumes := make([]map[string]any, 0)
for i, monitor := range a.volumeMonitors {
for _, monitor := range volumeMonitors {
if !monitor.Enabled {
continue
}
@@ -231,11 +219,6 @@ func (a *ResourcesMonitoringAPI) monitorVolumes(ctx context.Context, datapoints
}); err != nil {
return xerrors.Errorf("update workspace monitor: %w", err)
}
// Update cached state
a.volumeMonitors[i].State = newState
a.volumeMonitors[i].DebouncedUntil = dbtime.Time(debouncedUntil)
a.volumeMonitors[i].UpdatedAt = dbtime.Time(a.Clock.Now())
}
if len(outOfDiskVolumes) == 0 {
@@ -101,9 +101,6 @@ func TestMemoryResourceMonitorDebounce(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: The monitor is given a state that will trigger NOK
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
@@ -307,9 +304,6 @@ func TestMemoryResourceMonitor(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
clock.Set(collectedAt)
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: datapoints,
@@ -343,8 +337,6 @@ func TestMemoryResourceMonitorMissingData(t *testing.T) {
State: database.WorkspaceAgentMonitorStateOK,
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: A datapoint is missing, surrounded by two NOK datapoints.
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
@@ -395,9 +387,6 @@ func TestMemoryResourceMonitorMissingData(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: A datapoint is missing, surrounded by two OK datapoints.
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
@@ -477,9 +466,6 @@ func TestVolumeResourceMonitorDebounce(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When:
// - First monitor is in a NOK state
// - Second monitor is in an OK state
@@ -756,9 +742,6 @@ func TestVolumeResourceMonitor(t *testing.T) {
Threshold: tt.thresholdPercent,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
clock.Set(collectedAt)
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: datapoints,
@@ -797,9 +780,6 @@ func TestVolumeResourceMonitorMultiple(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: both of them move to a NOK state
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
@@ -852,9 +832,6 @@ func TestVolumeResourceMonitorMissingData(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: A datapoint is missing, surrounded by two NOK datapoints.
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
@@ -914,9 +891,6 @@ func TestVolumeResourceMonitorMissingData(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: A datapoint is missing, surrounded by two OK datapoints.
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
+3 -4
View File
@@ -143,7 +143,7 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
if !templateVersion.HasAITask.Valid || !templateVersion.HasAITask.Bool {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: `Template does not have a valid "coder_ai_task" resource.`,
Message: fmt.Sprintf(`Template does not have required parameter %q`, codersdk.AITaskPromptParameterName),
})
return
}
@@ -241,7 +241,6 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
// Create task record in the database before creating the workspace so that
// we can request that the workspace be linked to it after creation.
dbTaskTable, err = tx.InsertTask(ctx, database.InsertTaskParams{
ID: uuid.New(),
OrganizationID: templateVersion.OrganizationID,
OwnerID: owner.ID,
Name: taskName,
@@ -339,8 +338,8 @@ func taskFromDBTaskAndWorkspace(dbTask database.Task, ws codersdk.Workspace) cod
ID: dbTask.ID,
OrganizationID: dbTask.OrganizationID,
OwnerID: dbTask.OwnerID,
OwnerName: dbTask.OwnerUsername,
OwnerAvatarURL: dbTask.OwnerAvatarUrl,
OwnerName: ws.OwnerName,
OwnerAvatarURL: ws.OwnerAvatarURL,
Name: dbTask.Name,
TemplateID: ws.TemplateID,
TemplateVersionID: dbTask.TemplateVersionID,
+10 -85
View File
@@ -1,7 +1,6 @@
package coderd_test
import (
"context"
"database/sql"
"encoding/json"
"io"
@@ -259,9 +258,6 @@ func TestTasks(t *testing.T) {
// Wait for the workspace to be built.
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
if assert.True(t, workspace.TaskID.Valid, "task id should be set on workspace") {
assert.Equal(t, task.ID, workspace.TaskID.UUID, "workspace task id should match")
}
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
// List tasks via experimental API and verify the prompt and status mapping.
@@ -300,9 +296,6 @@ func TestTasks(t *testing.T) {
// Get the workspace and wait for it to be ready.
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
if assert.True(t, ws.TaskID.Valid, "task id should be set on workspace") {
assert.Equal(t, task.ID, ws.TaskID.UUID, "workspace task id should match")
}
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
ws = coderdtest.MustWorkspace(t, client, task.WorkspaceID.UUID)
// Assert invariant: the workspace has exactly one resource with one agent with one app.
@@ -377,23 +370,24 @@ func TestTasks(t *testing.T) {
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
if assert.True(t, ws.TaskID.Valid, "task id should be set on workspace") {
assert.Equal(t, task.ID, ws.TaskID.UUID, "workspace task id should match")
}
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
err = exp.DeleteTask(ctx, "me", task.ID)
require.NoError(t, err, "delete task request should be accepted")
// Poll until the workspace is deleted.
testutil.Eventually(ctx, t, func(ctx context.Context) (done bool) {
for {
dws, derr := client.DeletedWorkspace(ctx, task.WorkspaceID.UUID)
if !assert.NoError(t, derr, "expected to fetch deleted workspace before deadline") {
return false
if derr == nil && dws.LatestBuild.Status == codersdk.WorkspaceStatusDeleted {
break
}
t.Logf("workspace latest_build status: %q", dws.LatestBuild.Status)
return dws.LatestBuild.Status == codersdk.WorkspaceStatusDeleted
}, testutil.IntervalMedium, "workspace should be deleted before deadline")
if ctx.Err() != nil {
require.NoError(t, derr, "expected to fetch deleted workspace before deadline")
require.Equal(t, codersdk.WorkspaceStatusDeleted, dws.LatestBuild.Status, "workspace should be deleted before deadline")
break
}
time.Sleep(testutil.IntervalMedium)
}
})
t.Run("NotFound", func(t *testing.T) {
@@ -426,9 +420,6 @@ func TestTasks(t *testing.T) {
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
ws := coderdtest.CreateWorkspace(t, client, template.ID)
if assert.False(t, ws.TaskID.Valid, "task id should not be set on non-task workspace") {
assert.Zero(t, ws.TaskID, "non-task workspace task id should be empty")
}
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
exp := codersdk.NewExperimentalClient(client)
@@ -477,72 +468,6 @@ func TestTasks(t *testing.T) {
t.Fatalf("unexpected status code: %d (expected 403 or 404)", authErr.StatusCode())
}
})
t.Run("DeletedWorkspace", func(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
template := createAITemplate(t, client, user)
ctx := testutil.Context(t, testutil.WaitLong)
exp := codersdk.NewExperimentalClient(client)
task, err := exp.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
TemplateVersionID: template.ActiveVersionID,
Input: "delete me",
})
require.NoError(t, err)
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
// Mark the workspace as deleted directly in the database, bypassing provisionerd.
require.NoError(t, db.UpdateWorkspaceDeletedByID(dbauthz.AsProvisionerd(ctx), database.UpdateWorkspaceDeletedByIDParams{
ID: ws.ID,
Deleted: true,
}))
// We should still be able to fetch the task if its workspace was deleted.
// Provisionerdserver will attempt delete the related task when deleting a workspace.
// This test ensures that we can still handle the case where, for some reason, the
// task has not been marked as deleted, but the workspace has.
task, err = exp.TaskByID(ctx, task.ID)
require.NoError(t, err, "fetching a task should still work if its related workspace is deleted")
err = exp.DeleteTask(ctx, task.OwnerID.String(), task.ID)
require.NoError(t, err, "should be possible to delete a task with no workspace")
})
t.Run("DeletingTaskWorkspaceDeletesTask", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
template := createAITemplate(t, client, user)
ctx := testutil.Context(t, testutil.WaitLong)
exp := codersdk.NewExperimentalClient(client)
task, err := exp.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
TemplateVersionID: template.ActiveVersionID,
Input: "delete me",
})
require.NoError(t, err)
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
if assert.True(t, ws.TaskID.Valid, "task id should be set on workspace") {
assert.Equal(t, task.ID, ws.TaskID.UUID, "workspace task id should match")
}
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
// When; the task workspace is deleted
coderdtest.MustTransitionWorkspace(t, client, ws.ID, codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionDelete)
// Then: the task associated with the workspace is also deleted
_, err = exp.TaskByID(ctx, task.ID)
require.Error(t, err, "expected an error fetching the task")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr, "expected a codersdk.Error")
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
})
})
t.Run("Send", func(t *testing.T) {
+11 -110
View File
@@ -85,7 +85,7 @@ const docTemplate = `{
}
}
},
"/aibridge/interceptions": {
"/api/experimental/aibridge/interceptions": {
"get": {
"security": [
{
@@ -10008,45 +10008,6 @@ const docTemplate = `{
}
}
}
},
"put": {
"security": [
{
"CoderSessionToken": []
}
],
"consumes": [
"application/json"
],
"tags": [
"Builds"
],
"summary": "Update workspace build state",
"operationId": "update-workspace-build-state",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Workspace build ID",
"name": "workspacebuild",
"in": "path",
"required": true
},
{
"description": "Request body",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/codersdk.UpdateWorkspaceBuildStateRequest"
}
}
],
"responses": {
"204": {
"description": "No Content"
}
}
}
},
"/workspacebuilds/{workspacebuild}/timings": {
@@ -11707,35 +11668,12 @@ const docTemplate = `{
}
}
},
"codersdk.AIBridgeBedrockConfig": {
"type": "object",
"properties": {
"access_key": {
"type": "string"
},
"access_key_secret": {
"type": "string"
},
"model": {
"type": "string"
},
"region": {
"type": "string"
},
"small_fast_model": {
"type": "string"
}
}
},
"codersdk.AIBridgeConfig": {
"type": "object",
"properties": {
"anthropic": {
"$ref": "#/definitions/codersdk.AIBridgeAnthropicConfig"
},
"bedrock": {
"$ref": "#/definitions/codersdk.AIBridgeBedrockConfig"
},
"enabled": {
"type": "boolean"
},
@@ -11747,10 +11685,6 @@ const docTemplate = `{
"codersdk.AIBridgeInterception": {
"type": "object",
"properties": {
"ended_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string",
"format": "uuid"
@@ -12562,13 +12496,6 @@ const docTemplate = `{
"type": "string",
"format": "uuid"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
@@ -13792,13 +13719,6 @@ const docTemplate = `{
"name": {
"type": "string"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific to the organization the role belongs to.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific to the organization the role belongs to.",
"type": "array",
@@ -14355,9 +14275,11 @@ const docTemplate = `{
"web-push",
"oauth2",
"mcp-server-http",
"workspace-sharing"
"workspace-sharing",
"aibridge"
],
"x-enum-comments": {
"ExperimentAIBridge": "Enables AI Bridge functionality.",
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
"ExperimentExample": "This isn't used for anything.",
"ExperimentMCPServerHTTP": "Enables the MCP HTTP server functionality.",
@@ -14375,7 +14297,8 @@ const docTemplate = `{
"ExperimentWebPush",
"ExperimentOAuth2",
"ExperimentMCPServerHTTP",
"ExperimentWorkspaceSharing"
"ExperimentWorkspaceSharing",
"ExperimentAIBridge"
]
},
"codersdk.ExternalAPIKeyScopes": {
@@ -17563,13 +17486,6 @@ const docTemplate = `{
"type": "string",
"format": "uuid"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
@@ -19198,17 +19114,6 @@ const docTemplate = `{
}
}
},
"codersdk.UpdateWorkspaceBuildStateRequest": {
"type": "object",
"properties": {
"state": {
"type": "array",
"items": {
"type": "integer"
}
}
}
},
"codersdk.UpdateWorkspaceDormancy": {
"type": "object",
"properties": {
@@ -19762,14 +19667,6 @@ const docTemplate = `{
"description": "OwnerName is the username of the owner of the workspace.",
"type": "string"
},
"task_id": {
"description": "TaskID, if set, indicates that the workspace is relevant to the given codersdk.Task.",
"allOf": [
{
"$ref": "#/definitions/uuid.NullUUID"
}
]
},
"template_active_version_id": {
"type": "string",
"format": "uuid"
@@ -20577,7 +20474,7 @@ const docTemplate = `{
"type": "object",
"properties": {
"ai_task_sidebar_app_id": {
"description": "Deprecated: This field has been replaced with ` + "`" + `Task.WorkspaceAppID` + "`" + `",
"description": "Deprecated: This field has been replaced with ` + "`" + `TaskAppID` + "`" + `",
"type": "string",
"format": "uuid"
},
@@ -20659,6 +20556,10 @@ const docTemplate = `{
}
]
},
"task_app_id": {
"type": "string",
"format": "uuid"
},
"template_version_id": {
"type": "string",
"format": "uuid"
+11 -106
View File
@@ -65,7 +65,7 @@
}
}
},
"/aibridge/interceptions": {
"/api/experimental/aibridge/interceptions": {
"get": {
"security": [
{
@@ -8870,41 +8870,6 @@
}
}
}
},
"put": {
"security": [
{
"CoderSessionToken": []
}
],
"consumes": ["application/json"],
"tags": ["Builds"],
"summary": "Update workspace build state",
"operationId": "update-workspace-build-state",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Workspace build ID",
"name": "workspacebuild",
"in": "path",
"required": true
},
{
"description": "Request body",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/codersdk.UpdateWorkspaceBuildStateRequest"
}
}
],
"responses": {
"204": {
"description": "No Content"
}
}
}
},
"/workspacebuilds/{workspacebuild}/timings": {
@@ -10399,35 +10364,12 @@
}
}
},
"codersdk.AIBridgeBedrockConfig": {
"type": "object",
"properties": {
"access_key": {
"type": "string"
},
"access_key_secret": {
"type": "string"
},
"model": {
"type": "string"
},
"region": {
"type": "string"
},
"small_fast_model": {
"type": "string"
}
}
},
"codersdk.AIBridgeConfig": {
"type": "object",
"properties": {
"anthropic": {
"$ref": "#/definitions/codersdk.AIBridgeAnthropicConfig"
},
"bedrock": {
"$ref": "#/definitions/codersdk.AIBridgeBedrockConfig"
},
"enabled": {
"type": "boolean"
},
@@ -10439,10 +10381,6 @@
"codersdk.AIBridgeInterception": {
"type": "object",
"properties": {
"ended_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string",
"format": "uuid"
@@ -11240,13 +11178,6 @@
"type": "string",
"format": "uuid"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
@@ -12402,13 +12333,6 @@
"name": {
"type": "string"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific to the organization the role belongs to.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific to the organization the role belongs to.",
"type": "array",
@@ -12958,9 +12882,11 @@
"web-push",
"oauth2",
"mcp-server-http",
"workspace-sharing"
"workspace-sharing",
"aibridge"
],
"x-enum-comments": {
"ExperimentAIBridge": "Enables AI Bridge functionality.",
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
"ExperimentExample": "This isn't used for anything.",
"ExperimentMCPServerHTTP": "Enables the MCP HTTP server functionality.",
@@ -12978,7 +12904,8 @@
"ExperimentWebPush",
"ExperimentOAuth2",
"ExperimentMCPServerHTTP",
"ExperimentWorkspaceSharing"
"ExperimentWorkspaceSharing",
"ExperimentAIBridge"
]
},
"codersdk.ExternalAPIKeyScopes": {
@@ -16051,13 +15978,6 @@
"type": "string",
"format": "uuid"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
@@ -17616,17 +17536,6 @@
}
}
},
"codersdk.UpdateWorkspaceBuildStateRequest": {
"type": "object",
"properties": {
"state": {
"type": "array",
"items": {
"type": "integer"
}
}
}
},
"codersdk.UpdateWorkspaceDormancy": {
"type": "object",
"properties": {
@@ -18144,14 +18053,6 @@
"description": "OwnerName is the username of the owner of the workspace.",
"type": "string"
},
"task_id": {
"description": "TaskID, if set, indicates that the workspace is relevant to the given codersdk.Task.",
"allOf": [
{
"$ref": "#/definitions/uuid.NullUUID"
}
]
},
"template_active_version_id": {
"type": "string",
"format": "uuid"
@@ -18907,7 +18808,7 @@
"type": "object",
"properties": {
"ai_task_sidebar_app_id": {
"description": "Deprecated: This field has been replaced with `Task.WorkspaceAppID`",
"description": "Deprecated: This field has been replaced with `TaskAppID`",
"type": "string",
"format": "uuid"
},
@@ -18985,6 +18886,10 @@
}
]
},
"task_app_id": {
"type": "string",
"format": "uuid"
},
"template_version_id": {
"type": "string",
"format": "uuid"
+2 -2
View File
@@ -509,11 +509,11 @@ func (api *API) auditLogResourceLink(ctx context.Context, alog database.GetAudit
if err != nil {
return ""
}
user, err := api.Database.GetUserByID(ctx, task.OwnerID)
workspace, err := api.Database.GetWorkspaceByID(ctx, task.WorkspaceID.UUID)
if err != nil {
return ""
}
return fmt.Sprintf("/tasks/%s/%s", user.Username, task.ID)
return fmt.Sprintf("/tasks/%s/%s", workspace.OwnerName, task.Name)
default:
return ""
+10 -11
View File
@@ -50,13 +50,6 @@ func TestCheckPermissions(t *testing.T) {
},
Action: "read",
},
readOrgWorkspaces: {
Object: codersdk.AuthorizationObject{
ResourceType: codersdk.ResourceWorkspace,
OrganizationID: adminUser.OrganizationID.String(),
},
Action: "read",
},
readMyself: {
Object: codersdk.AuthorizationObject{
ResourceType: codersdk.ResourceUser,
@@ -65,10 +58,16 @@ func TestCheckPermissions(t *testing.T) {
Action: "read",
},
readOwnWorkspaces: {
Object: codersdk.AuthorizationObject{
ResourceType: codersdk.ResourceWorkspace,
OwnerID: "me",
},
Action: "read",
},
readOrgWorkspaces: {
Object: codersdk.AuthorizationObject{
ResourceType: codersdk.ResourceWorkspace,
OrganizationID: adminUser.OrganizationID.String(),
OwnerID: "me",
},
Action: "read",
},
@@ -93,9 +92,9 @@ func TestCheckPermissions(t *testing.T) {
UserID: adminUser.UserID,
Check: map[string]bool{
readAllUsers: true,
readOrgWorkspaces: true,
readMyself: true,
readOwnWorkspaces: true,
readOrgWorkspaces: true,
updateSpecificTemplate: true,
},
},
@@ -105,9 +104,9 @@ func TestCheckPermissions(t *testing.T) {
UserID: orgAdminUser.ID,
Check: map[string]bool{
readAllUsers: true,
readOrgWorkspaces: true,
readMyself: true,
readOwnWorkspaces: true,
readOrgWorkspaces: true,
updateSpecificTemplate: true,
},
},
@@ -117,9 +116,9 @@ func TestCheckPermissions(t *testing.T) {
UserID: memberUser.ID,
Check: map[string]bool{
readAllUsers: false,
readOrgWorkspaces: false,
readMyself: true,
readOwnWorkspaces: true,
readOrgWorkspaces: false,
updateSpecificTemplate: false,
},
},
-172
View File
@@ -1764,175 +1764,3 @@ func TestExecutorAutostartSkipsWhenNoProvisionersAvailable(t *testing.T) {
assert.Len(t, stats.Transitions, 1, "should create builds when provisioners are available")
}
func TestExecutorTaskWorkspace(t *testing.T) {
t.Parallel()
createTaskTemplate := func(t *testing.T, client *codersdk.Client, orgID uuid.UUID, ctx context.Context, defaultTTL time.Duration) codersdk.Template {
t.Helper()
taskAppID := uuid.New()
version := coderdtest.CreateTemplateVersion(t, client, orgID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: []*proto.Response{
{
Type: &proto.Response_Plan{
Plan: &proto.PlanComplete{HasAiTasks: true},
},
},
},
ProvisionApply: []*proto.Response{
{
Type: &proto.Response_Apply{
Apply: &proto.ApplyComplete{
Resources: []*proto.Resource{
{
Agents: []*proto.Agent{
{
Id: uuid.NewString(),
Name: "dev",
Auth: &proto.Agent_Token{
Token: uuid.NewString(),
},
Apps: []*proto.App{
{
Id: taskAppID.String(),
Slug: "task-app",
},
},
},
},
},
},
AiTasks: []*proto.AITask{
{
AppId: taskAppID.String(),
},
},
},
},
},
},
})
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, orgID, version.ID)
if defaultTTL > 0 {
_, err := client.UpdateTemplateMeta(ctx, template.ID, codersdk.UpdateTemplateMeta{
DefaultTTLMillis: defaultTTL.Milliseconds(),
})
require.NoError(t, err)
}
return template
}
createTaskWorkspace := func(t *testing.T, client *codersdk.Client, template codersdk.Template, ctx context.Context, input string) codersdk.Workspace {
t.Helper()
exp := codersdk.NewExperimentalClient(client)
task, err := exp.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
TemplateVersionID: template.ActiveVersionID,
Input: input,
})
require.NoError(t, err)
require.True(t, task.WorkspaceID.Valid, "task should have a workspace")
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
return workspace
}
t.Run("Autostart", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
sched = mustSchedule(t, "CRON_TZ=UTC 0 * * * *")
tickCh = make(chan time.Time)
statsCh = make(chan autobuild.Stats)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
AutobuildTicker: tickCh,
IncludeProvisionerDaemon: true,
AutobuildStats: statsCh,
})
admin = coderdtest.CreateFirstUser(t, client)
)
// Given: A task workspace
template := createTaskTemplate(t, client, admin.OrganizationID, ctx, 0)
workspace := createTaskWorkspace(t, client, template, ctx, "test task for autostart")
// Given: The task workspace has an autostart schedule
err := client.UpdateWorkspaceAutostart(ctx, workspace.ID, codersdk.UpdateWorkspaceAutostartRequest{
Schedule: ptr.Ref(sched.String()),
})
require.NoError(t, err)
// Given: That the workspace is in a stopped state.
workspace = coderdtest.MustTransitionWorkspace(t, client, workspace.ID, codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionStop)
p, err := coderdtest.GetProvisionerForTags(db, time.Now(), workspace.OrganizationID, map[string]string{})
require.NoError(t, err)
// When: the autobuild executor ticks after the scheduled time
go func() {
tickTime := sched.Next(workspace.LatestBuild.CreatedAt)
coderdtest.UpdateProvisionerLastSeenAt(t, db, p.ID, tickTime)
tickCh <- tickTime
close(tickCh)
}()
// Then: We expect to see a start transition
stats := <-statsCh
require.Len(t, stats.Transitions, 1, "lifecycle executor should transition the task workspace")
assert.Contains(t, stats.Transitions, workspace.ID, "task workspace should be in transitions")
assert.Equal(t, database.WorkspaceTransitionStart, stats.Transitions[workspace.ID], "should autostart the workspace")
require.Empty(t, stats.Errors, "should have no errors when managing task workspaces")
})
t.Run("Autostop", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
tickCh = make(chan time.Time)
statsCh = make(chan autobuild.Stats)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
AutobuildTicker: tickCh,
IncludeProvisionerDaemon: true,
AutobuildStats: statsCh,
})
admin = coderdtest.CreateFirstUser(t, client)
)
// Given: A task workspace with an 8 hour deadline
template := createTaskTemplate(t, client, admin.OrganizationID, ctx, 8*time.Hour)
workspace := createTaskWorkspace(t, client, template, ctx, "test task for autostop")
// Given: The workspace is currently running
workspace = coderdtest.MustWorkspace(t, client, workspace.ID)
require.Equal(t, codersdk.WorkspaceTransitionStart, workspace.LatestBuild.Transition)
require.NotZero(t, workspace.LatestBuild.Deadline, "workspace should have a deadline for autostop")
p, err := coderdtest.GetProvisionerForTags(db, time.Now(), workspace.OrganizationID, map[string]string{})
require.NoError(t, err)
// When: the autobuild executor ticks after the deadline
go func() {
tickTime := workspace.LatestBuild.Deadline.Time.Add(time.Minute)
coderdtest.UpdateProvisionerLastSeenAt(t, db, p.ID, tickTime)
tickCh <- tickTime
close(tickCh)
}()
// Then: We expect to see a stop transition
stats := <-statsCh
require.Len(t, stats.Transitions, 1, "lifecycle executor should transition the task workspace")
assert.Contains(t, stats.Transitions, workspace.ID, "task workspace should be in transitions")
assert.Equal(t, database.WorkspaceTransitionStop, stats.Transitions[workspace.ID], "should autostop the workspace")
require.Empty(t, stats.Errors, "should have no errors when managing task workspaces")
})
}
-1
View File
@@ -1496,7 +1496,6 @@ func New(options *Options) *API {
r.Get("/parameters", api.workspaceBuildParameters)
r.Get("/resources", api.workspaceBuildResourcesDeprecated)
r.Get("/state", api.workspaceBuildState)
r.Put("/state", api.workspaceBuildUpdateState)
r.Get("/timings", api.workspaceBuildTimings)
})
r.Route("/authcheck", func(r chi.Router) {
-1
View File
@@ -14,7 +14,6 @@ const (
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
CheckWorkspaceBuildsAiTaskSidebarAppIDRequired CheckConstraint = "workspace_builds_ai_task_sidebar_app_id_required" // workspace_builds
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
)
+8 -13
View File
@@ -714,13 +714,12 @@ func RBACRole(role rbac.Role) codersdk.Role {
orgPerms := role.ByOrgID[slim.OrganizationID]
return codersdk.Role{
Name: slim.Name,
OrganizationID: slim.OrganizationID,
DisplayName: slim.DisplayName,
SitePermissions: List(role.Site, RBACPermission),
UserPermissions: List(role.User, RBACPermission),
OrganizationPermissions: List(orgPerms.Org, RBACPermission),
OrganizationMemberPermissions: List(orgPerms.Member, RBACPermission),
Name: slim.Name,
OrganizationID: slim.OrganizationID,
DisplayName: slim.DisplayName,
SitePermissions: List(role.Site, RBACPermission),
OrganizationPermissions: List(orgPerms.Org, RBACPermission),
UserPermissions: List(role.User, RBACPermission),
}
}
@@ -735,8 +734,8 @@ func Role(role database.CustomRole) codersdk.Role {
OrganizationID: orgID,
DisplayName: role.DisplayName,
SitePermissions: List(role.SitePermissions, Permission),
UserPermissions: List(role.UserPermissions, Permission),
OrganizationPermissions: List(role.OrgPermissions, Permission),
UserPermissions: List(role.UserPermissions, Permission),
}
}
@@ -963,7 +962,7 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
// created_at ASC
return sdkToolUsages[i].CreatedAt.Before(sdkToolUsages[j].CreatedAt)
})
intc := codersdk.AIBridgeInterception{
return codersdk.AIBridgeInterception{
ID: interception.ID,
Initiator: MinimalUserFromVisibleUser(initiator),
Provider: interception.Provider,
@@ -974,10 +973,6 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
UserPrompts: sdkUserPrompts,
ToolUsages: sdkToolUsages,
}
if interception.EndedAt.Valid {
intc.EndedAt = &interception.EndedAt.Time
}
return intc
}
func AIBridgeTokenUsage(usage database.AIBridgeTokenUsage) codersdk.AIBridgeTokenUsage {
+21 -80
View File
@@ -217,10 +217,10 @@ var (
rbac.ResourceTemplate.Type: {policy.ActionRead, policy.ActionUpdate},
// Unsure why provisionerd needs update and read personal
rbac.ResourceUser.Type: {policy.ActionRead, policy.ActionReadPersonal, policy.ActionUpdatePersonal},
rbac.ResourceWorkspaceDormant.Type: {policy.ActionDelete, policy.ActionRead, policy.ActionUpdate, policy.ActionWorkspaceStop, policy.ActionCreateAgent},
rbac.ResourceWorkspaceDormant.Type: {policy.ActionDelete, policy.ActionRead, policy.ActionUpdate, policy.ActionWorkspaceStop},
rbac.ResourceWorkspace.Type: {policy.ActionDelete, policy.ActionRead, policy.ActionUpdate, policy.ActionWorkspaceStart, policy.ActionWorkspaceStop, policy.ActionCreateAgent},
// Provisionerd needs to read, update, and delete tasks associated with workspaces.
rbac.ResourceTask.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
// Provisionerd needs to read and update tasks associated with workspaces.
rbac.ResourceTask.Type: {policy.ActionRead, policy.ActionUpdate},
rbac.ResourceApiKey.Type: {policy.WildcardSymbol},
// When org scoped provisioner credentials are implemented,
// this can be reduced to read a specific org.
@@ -254,7 +254,6 @@ var (
rbac.ResourceFile.Type: {policy.ActionRead}, // Required to read terraform files
rbac.ResourceNotificationMessage.Type: {policy.ActionCreate, policy.ActionRead},
rbac.ResourceSystem.Type: {policy.WildcardSymbol},
rbac.ResourceTask.Type: {policy.ActionRead, policy.ActionUpdate},
rbac.ResourceTemplate.Type: {policy.ActionRead, policy.ActionUpdate},
rbac.ResourceUser.Type: {policy.ActionRead},
rbac.ResourceWorkspace.Type: {policy.ActionDelete, policy.ActionRead, policy.ActionUpdate, policy.ActionWorkspaceStart, policy.ActionWorkspaceStop},
@@ -396,13 +395,11 @@ var (
Identifier: rbac.RoleIdentifier{Name: "subagentapi"},
DisplayName: "Sub Agent API",
Site: []rbac.Permission{},
User: []rbac.Permission{},
User: rbac.Permissions(map[string][]policy.Action{
rbac.ResourceWorkspace.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreateAgent, policy.ActionDeleteAgent},
}),
ByOrgID: map[string]rbac.OrgPermissions{
orgID.String(): {
Member: rbac.Permissions(map[string][]policy.Action{
rbac.ResourceWorkspace.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreateAgent, policy.ActionDeleteAgent},
}),
},
orgID.String(): {},
},
},
}),
@@ -1293,17 +1290,14 @@ func (q *querier) customRoleCheck(ctx context.Context, role database.CustomRole)
return xerrors.Errorf("invalid role: %w", err)
}
if len(rbacRole.ByOrgID) > 0 && (len(rbacRole.Site) > 0 || len(rbacRole.User) > 0) {
// This is a choice to keep roles simple. If we allow mixing site and org
// scoped perms, then knowing who can do what gets more complicated. Roles
// should either be entirely org-scoped or entirely unrelated to
// organizations.
return xerrors.Errorf("invalid custom role, cannot assign both org-scoped and site/user permissions at the same time")
if len(rbacRole.ByOrgID) > 0 && len(rbacRole.Site) > 0 {
// This is a choice to keep roles simple. If we allow mixing site and org scoped perms, then knowing who can
// do what gets more complicated.
return xerrors.Errorf("invalid custom role, cannot assign both org and site permissions at the same time")
}
if len(rbacRole.ByOrgID) > 1 {
// Again to avoid more complexity in our roles. Roles are limited to one
// organization.
// Again to avoid more complexity in our roles
return xerrors.Errorf("invalid custom role, cannot assign permissions to more than 1 org at a time")
}
@@ -1319,18 +1313,7 @@ func (q *querier) customRoleCheck(ctx context.Context, role database.CustomRole)
for _, orgPerm := range perms.Org {
err := q.customRoleEscalationCheck(ctx, act, orgPerm, rbac.Object{OrgID: orgID, Type: orgPerm.ResourceType})
if err != nil {
return xerrors.Errorf("org=%q: org: %w", orgID, err)
}
}
for _, memberPerm := range perms.Member {
// The person giving the permission should still be required to have
// the permissions throughout the org in order to give individuals the
// same permission among their own resources, since the role can be given
// to anyone. The `Owner` is intentionally omitted from the `Object` to
// enforce this.
err := q.customRoleEscalationCheck(ctx, act, memberPerm, rbac.Object{OrgID: orgID, Type: memberPerm.ResourceType})
if err != nil {
return xerrors.Errorf("org=%q: member: %w", orgID, err)
return xerrors.Errorf("org=%q: %w", orgID, err)
}
}
}
@@ -1348,8 +1331,8 @@ func (q *querier) customRoleCheck(ctx context.Context, role database.CustomRole)
func (q *querier) authorizeProvisionerJob(ctx context.Context, job database.ProvisionerJob) error {
switch job.Type {
case database.ProvisionerJobTypeWorkspaceBuild:
// Authorized call to get workspace build. If we can read the build, we can
// read the job.
// Authorized call to get workspace build. If we can read the build, we
// can read the job.
_, err := q.GetWorkspaceBuildByJobID(ctx, job.ID)
if err != nil {
return xerrors.Errorf("fetch related workspace build: %w", err)
@@ -1392,8 +1375,8 @@ func (q *querier) ActivityBumpWorkspace(ctx context.Context, arg database.Activi
}
func (q *querier) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error) {
// Although this technically only reads users, only system-related functions
// should be allowed to call this.
// Although this technically only reads users, only system-related functions should be
// allowed to call this.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
@@ -1412,8 +1395,8 @@ func (q *querier) ArchiveUnusedTemplateVersions(ctx context.Context, arg databas
}
func (q *querier) BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg database.BatchUpdateWorkspaceLastUsedAtParams) error {
// Could be any workspace and checking auth to each workspace is overkill for
// the purpose of this function.
// Could be any workspace and checking auth to each workspace is overkill for the purpose
// of this function.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceWorkspace.All()); err != nil {
return err
}
@@ -1441,13 +1424,6 @@ func (q *querier) BulkMarkNotificationMessagesSent(ctx context.Context, arg data
return q.db.BulkMarkNotificationMessagesSent(ctx, arg)
}
func (q *querier) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg database.CalculateAIBridgeInterceptionsTelemetrySummaryParams) (database.CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
return database.CalculateAIBridgeInterceptionsTelemetrySummaryRow{}, err
}
return q.db.CalculateAIBridgeInterceptionsTelemetrySummary(ctx, arg)
}
func (q *querier) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) {
empty := database.ClaimPrebuiltWorkspaceRow{}
@@ -1747,13 +1723,6 @@ func (q *querier) DeleteOldProvisionerDaemons(ctx context.Context) error {
return q.db.DeleteOldProvisionerDaemons(ctx)
}
func (q *querier) DeleteOldTelemetryLocks(ctx context.Context, beforeTime time.Time) error {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
return err
}
return q.db.DeleteOldTelemetryLocks(ctx, beforeTime)
}
func (q *querier) DeleteOldWorkspaceAgentLogs(ctx context.Context, threshold time.Time) error {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
return err
@@ -2649,13 +2618,6 @@ func (q *querier) GetOrganizationsByUserID(ctx context.Context, userID database.
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetOrganizationsByUserID)(ctx, userID)
}
func (q *querier) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg database.GetOrganizationsWithPrebuildStatusParams) ([]database.GetOrganizationsWithPrebuildStatusRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceOrganization.All()); err != nil {
return nil, err
}
return q.db.GetOrganizationsWithPrebuildStatus(ctx, arg)
}
func (q *querier) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
version, err := q.db.GetTemplateVersionByJobID(ctx, jobID)
if err != nil {
@@ -4250,13 +4212,6 @@ func (q *querier) InsertTelemetryItemIfNotExists(ctx context.Context, arg databa
return q.db.InsertTelemetryItemIfNotExists(ctx, arg)
}
func (q *querier) InsertTelemetryLock(ctx context.Context, arg database.InsertTelemetryLockParams) error {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.InsertTelemetryLock(ctx, arg)
}
func (q *querier) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) error {
obj := rbac.ResourceTemplate.InOrg(arg.OrganizationID)
if err := q.authorizeContext(ctx, policy.ActionCreate, obj); err != nil {
@@ -4568,13 +4523,6 @@ func (q *querier) ListAIBridgeInterceptions(ctx context.Context, arg database.Li
return q.db.ListAuthorizedAIBridgeInterceptions(ctx, arg, prep)
}
func (q *querier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg database.ListAIBridgeInterceptionsTelemetrySummariesParams) ([]database.ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
return nil, err
}
return q.db.ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg)
}
func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
// This function is a system function until we implement a join for aibridge interceptions.
// Matches the behavior of the workspaces listing endpoint.
@@ -4763,13 +4711,6 @@ func (q *querier) UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error {
return update(q.log, q.auth, fetch, q.db.UnfavoriteWorkspace)(ctx, id)
}
func (q *querier) UpdateAIBridgeInterceptionEnded(ctx context.Context, params database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, params.ID); err != nil {
return database.AIBridgeInterception{}, err
}
return q.db.UpdateAIBridgeInterceptionEnded(ctx, params)
}
func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
fetch := func(ctx context.Context, arg database.UpdateAPIKeyByIDParams) (database.APIKey, error) {
return q.db.GetAPIKeyByID(ctx, arg.ID)
@@ -4941,10 +4882,10 @@ func (q *querier) UpdateOrganizationDeletedByID(ctx context.Context, arg databas
return deleteQ(q.log, q.auth, q.db.GetOrganizationByID, deleteF)(ctx, arg.ID)
}
func (q *querier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
func (q *querier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
// Prebuild operation for canceling pending prebuild jobs from non-active template versions
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourcePrebuiltWorkspace); err != nil {
return []database.UpdatePrebuildProvisionerJobWithCancelRow{}, err
return []uuid.UUID{}, err
}
return q.db.UpdatePrebuildProvisionerJobWithCancel(ctx, arg)
}
+3 -45
View File
@@ -646,13 +646,10 @@ func (s *MethodTestSuite) TestProvisionerJob() {
PresetID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
Now: dbtime.Now(),
}
canceledJobs := []database.UpdatePrebuildProvisionerJobWithCancelRow{
{ID: uuid.New(), WorkspaceID: uuid.New(), TemplateID: uuid.New(), TemplateVersionPresetID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
{ID: uuid.New(), WorkspaceID: uuid.New(), TemplateID: uuid.New(), TemplateVersionPresetID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
}
jobIDs := []uuid.UUID{uuid.New(), uuid.New()}
dbm.EXPECT().UpdatePrebuildProvisionerJobWithCancel(gomock.Any(), arg).Return(canceledJobs, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourcePrebuiltWorkspace, policy.ActionUpdate).Returns(canceledJobs)
dbm.EXPECT().UpdatePrebuildProvisionerJobWithCancel(gomock.Any(), arg).Return(jobIDs, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourcePrebuiltWorkspace, policy.ActionUpdate).Returns(jobIDs)
}))
s.Run("GetProvisionerJobsByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
org := testutil.Fake(s.T(), faker, database.Organization{})
@@ -3759,14 +3756,6 @@ func (s *MethodTestSuite) TestPrebuilds() {
dbm.EXPECT().GetPrebuildMetrics(gomock.Any()).Return([]database.GetPrebuildMetricsRow{}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead)
}))
s.Run("GetOrganizationsWithPrebuildStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.GetOrganizationsWithPrebuildStatusParams{
UserID: uuid.New(),
GroupName: "test",
}
dbm.EXPECT().GetOrganizationsWithPrebuildStatus(gomock.Any(), arg).Return([]database.GetOrganizationsWithPrebuildStatusRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceOrganization.All(), policy.ActionRead)
}))
s.Run("GetPrebuildsSettings", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetPrebuildsSettings(gomock.Any()).Return("{}", nil).AnyTimes()
check.Args().Asserts()
@@ -4628,35 +4617,4 @@ func (s *MethodTestSuite) TestAIBridge() {
db.EXPECT().ListAIBridgeToolUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeToolUsage{}, nil).AnyTimes()
check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.AIBridgeToolUsage{})
}))
s.Run("UpdateAIBridgeInterceptionEnded", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intcID := uuid.UUID{1}
params := database.UpdateAIBridgeInterceptionEndedParams{ID: intcID}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intcID})
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intcID).Return(intc, nil).AnyTimes() // Validation.
db.EXPECT().UpdateAIBridgeInterceptionEnded(gomock.Any(), params).Return(intc, nil).AnyTimes()
check.Args(params).Asserts(intc, policy.ActionUpdate).Returns(intc)
}))
}
func (s *MethodTestSuite) TestTelemetry() {
s.Run("InsertTelemetryLock", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
db.EXPECT().InsertTelemetryLock(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
check.Args(database.InsertTelemetryLockParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate)
}))
s.Run("DeleteOldTelemetryLocks", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
db.EXPECT().DeleteOldTelemetryLocks(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
check.Args(time.Time{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
}))
s.Run("ListAIBridgeInterceptionsTelemetrySummaries", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
db.EXPECT().ListAIBridgeInterceptionsTelemetrySummaries(gomock.Any(), gomock.Any()).Return([]database.ListAIBridgeInterceptionsTelemetrySummariesRow{}, nil).AnyTimes()
check.Args(database.ListAIBridgeInterceptionsTelemetrySummariesParams{}).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead)
}))
s.Run("CalculateAIBridgeInterceptionsTelemetrySummary", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
db.EXPECT().CalculateAIBridgeInterceptionsTelemetrySummary(gomock.Any(), gomock.Any()).Return(database.CalculateAIBridgeInterceptionsTelemetrySummaryRow{}, nil).AnyTimes()
check.Args(database.CalculateAIBridgeInterceptionsTelemetrySummaryParams{}).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead)
}))
}
+7 -61
View File
@@ -41,7 +41,6 @@ type WorkspaceResponse struct {
Build database.WorkspaceBuild
AgentToken string
TemplateVersionResponse
Task database.Task
}
// WorkspaceBuildBuilder generates workspace builds and associated
@@ -58,7 +57,6 @@ type WorkspaceBuildBuilder struct {
agentToken string
jobStatus database.ProvisionerJobStatus
taskAppID uuid.UUID
taskSeed database.TaskTable
}
// WorkspaceBuild generates a workspace build for the provided workspace.
@@ -117,28 +115,25 @@ func (b WorkspaceBuildBuilder) WithAgent(mutations ...func([]*sdkproto.Agent) []
return b
}
func (b WorkspaceBuildBuilder) WithTask(taskSeed database.TaskTable, appSeed *sdkproto.App) WorkspaceBuildBuilder {
//nolint:revive // returns modified struct
b.taskSeed = taskSeed
if appSeed == nil {
appSeed = &sdkproto.App{}
func (b WorkspaceBuildBuilder) WithTask(seed *sdkproto.App) WorkspaceBuildBuilder {
if seed == nil {
seed = &sdkproto.App{}
}
var err error
//nolint: revive // returns modified struct
b.taskAppID, err = uuid.Parse(takeFirst(appSeed.Id, uuid.NewString()))
b.taskAppID, err = uuid.Parse(takeFirst(seed.Id, uuid.NewString()))
require.NoError(b.t, err)
return b.Params(database.WorkspaceBuildParameter{
Name: codersdk.AITaskPromptParameterName,
Value: b.taskSeed.Prompt,
Value: "list me",
}).WithAgent(func(a []*sdkproto.Agent) []*sdkproto.Agent {
a[0].Apps = []*sdkproto.App{
{
Id: b.taskAppID.String(),
Slug: takeFirst(appSeed.Slug, "task-app"),
Url: takeFirst(appSeed.Url, ""),
Slug: takeFirst(seed.Slug, "task-app"),
Url: takeFirst(seed.Url, ""),
},
}
return a
@@ -166,19 +161,6 @@ func (b WorkspaceBuildBuilder) Canceled() WorkspaceBuildBuilder {
// Workspace will be optionally populated if no ID is set on the provided
// workspace.
func (b WorkspaceBuildBuilder) Do() WorkspaceResponse {
var resp WorkspaceResponse
// Use transaction, like real wsbuilder.
err := b.db.InTx(func(tx database.Store) error {
//nolint:revive // calls do on modified struct
b.db = tx
resp = b.doInTX()
return nil
}, nil)
require.NoError(b.t, err)
return resp
}
func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
b.t.Helper()
jobID := uuid.New()
b.seed.ID = uuid.New()
@@ -230,37 +212,6 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
b.seed.WorkspaceID = b.ws.ID
b.seed.InitiatorID = takeFirst(b.seed.InitiatorID, b.ws.OwnerID)
// If a task was requested, ensure it exists and is associated with this
// workspace.
if b.taskAppID != uuid.Nil {
b.logger.Debug(context.Background(), "creating or updating task", "task_id", b.taskSeed.ID)
b.taskSeed.OrganizationID = takeFirst(b.taskSeed.OrganizationID, b.ws.OrganizationID)
b.taskSeed.OwnerID = takeFirst(b.taskSeed.OwnerID, b.ws.OwnerID)
b.taskSeed.Name = takeFirst(b.taskSeed.Name, b.ws.Name)
b.taskSeed.WorkspaceID = uuid.NullUUID{UUID: takeFirst(b.taskSeed.WorkspaceID.UUID, b.ws.ID), Valid: true}
b.taskSeed.TemplateVersionID = takeFirst(b.taskSeed.TemplateVersionID, b.seed.TemplateVersionID)
// Try to fetch existing task and update its workspace ID.
if task, err := b.db.GetTaskByID(ownerCtx, b.taskSeed.ID); err == nil {
if !task.WorkspaceID.Valid {
b.logger.Info(context.Background(), "updating task workspace id", "task_id", b.taskSeed.ID, "workspace_id", b.ws.ID)
_, err = b.db.UpdateTaskWorkspaceID(ownerCtx, database.UpdateTaskWorkspaceIDParams{
ID: b.taskSeed.ID,
WorkspaceID: uuid.NullUUID{UUID: b.ws.ID, Valid: true},
})
require.NoError(b.t, err, "update task workspace id")
} else if task.WorkspaceID.UUID != b.ws.ID {
require.Fail(b.t, "task already has a workspace id, mismatch", task.WorkspaceID.UUID, b.ws.ID)
}
} else if errors.Is(err, sql.ErrNoRows) {
task := dbgen.Task(b.t, b.db, b.taskSeed)
b.taskSeed.ID = task.ID
b.logger.Info(context.Background(), "created new task", "task_id", b.taskSeed.ID)
} else {
require.NoError(b.t, err, "get task by id")
}
}
// Create a provisioner job for the build!
payload, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: b.seed.ID,
@@ -373,11 +324,6 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
b.logger.Debug(context.Background(), "linked task to workspace build",
slog.F("task_id", task.ID),
slog.F("build_number", resp.Build.BuildNumber))
// Update task after linking.
task, err = b.db.GetTaskByID(ownerCtx, task.ID)
require.NoError(b.t, err, "get task by id")
resp.Task = task
}
for i := range b.params {
+1 -9
View File
@@ -1495,7 +1495,7 @@ func ClaimPrebuild(
return claimedWorkspace
}
func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertAIBridgeInterceptionParams, endedAt *time.Time) database.AIBridgeInterception {
func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertAIBridgeInterceptionParams) database.AIBridgeInterception {
interception, err := db.InsertAIBridgeInterception(genCtx, database.InsertAIBridgeInterceptionParams{
ID: takeFirst(seed.ID, uuid.New()),
InitiatorID: takeFirst(seed.InitiatorID, uuid.New()),
@@ -1504,13 +1504,6 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
StartedAt: takeFirst(seed.StartedAt, dbtime.Now()),
})
if endedAt != nil {
interception, err = db.UpdateAIBridgeInterceptionEnded(genCtx, database.UpdateAIBridgeInterceptionEndedParams{
ID: interception.ID,
EndedAt: *endedAt,
})
require.NoError(t, err, "insert aibridge interception")
}
require.NoError(t, err, "insert aibridge interception")
return interception
}
@@ -1576,7 +1569,6 @@ func Task(t testing.TB, db database.Store, orig database.TaskTable) database.Tas
}
task, err := db.InsertTask(genCtx, database.InsertTaskParams{
ID: takeFirst(orig.ID, uuid.New()),
OrganizationID: orig.OrganizationID,
OwnerID: orig.OwnerID,
Name: takeFirst(orig.Name, taskname.GenerateFallback()),
+1 -43
View File
@@ -158,13 +158,6 @@ func (m queryMetricsStore) BulkMarkNotificationMessagesSent(ctx context.Context,
return r0, r1
}
func (m queryMetricsStore) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg database.CalculateAIBridgeInterceptionsTelemetrySummaryParams) (database.CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
start := time.Now()
r0, r1 := m.s.CalculateAIBridgeInterceptionsTelemetrySummary(ctx, arg)
m.queryLatencies.WithLabelValues("CalculateAIBridgeInterceptionsTelemetrySummary").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) {
start := time.Now()
r0, r1 := m.s.ClaimPrebuiltWorkspace(ctx, arg)
@@ -410,13 +403,6 @@ func (m queryMetricsStore) DeleteOldProvisionerDaemons(ctx context.Context) erro
return r0
}
func (m queryMetricsStore) DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error {
start := time.Now()
r0 := m.s.DeleteOldTelemetryLocks(ctx, periodEndingAtBefore)
m.queryLatencies.WithLabelValues("DeleteOldTelemetryLocks").Observe(time.Since(start).Seconds())
return r0
}
func (m queryMetricsStore) DeleteOldWorkspaceAgentLogs(ctx context.Context, arg time.Time) error {
start := time.Now()
r0 := m.s.DeleteOldWorkspaceAgentLogs(ctx, arg)
@@ -1243,13 +1229,6 @@ func (m queryMetricsStore) GetOrganizationsByUserID(ctx context.Context, userID
return organizations, err
}
func (m queryMetricsStore) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg database.GetOrganizationsWithPrebuildStatusParams) ([]database.GetOrganizationsWithPrebuildStatusRow, error) {
start := time.Now()
r0, r1 := m.s.GetOrganizationsWithPrebuildStatus(ctx, arg)
m.queryLatencies.WithLabelValues("GetOrganizationsWithPrebuildStatus").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
start := time.Now()
schemas, err := m.s.GetParameterSchemasByJobID(ctx, jobID)
@@ -2538,13 +2517,6 @@ func (m queryMetricsStore) InsertTelemetryItemIfNotExists(ctx context.Context, a
return r0
}
func (m queryMetricsStore) InsertTelemetryLock(ctx context.Context, arg database.InsertTelemetryLockParams) error {
start := time.Now()
r0 := m.s.InsertTelemetryLock(ctx, arg)
m.queryLatencies.WithLabelValues("InsertTelemetryLock").Observe(time.Since(start).Seconds())
return r0
}
func (m queryMetricsStore) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) error {
start := time.Now()
err := m.s.InsertTemplate(ctx, arg)
@@ -2762,13 +2734,6 @@ func (m queryMetricsStore) ListAIBridgeInterceptions(ctx context.Context, arg da
return r0, r1
}
func (m queryMetricsStore) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg database.ListAIBridgeInterceptionsTelemetrySummariesParams) ([]database.ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
start := time.Now()
r0, r1 := m.s.ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg)
m.queryLatencies.WithLabelValues("ListAIBridgeInterceptionsTelemetrySummaries").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
start := time.Now()
r0, r1 := m.s.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIds)
@@ -2923,13 +2888,6 @@ func (m queryMetricsStore) UnfavoriteWorkspace(ctx context.Context, arg uuid.UUI
return r0
}
func (m queryMetricsStore) UpdateAIBridgeInterceptionEnded(ctx context.Context, id database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
start := time.Now()
r0, r1 := m.s.UpdateAIBridgeInterceptionEnded(ctx, id)
m.queryLatencies.WithLabelValues("UpdateAIBridgeInterceptionEnded").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
start := time.Now()
err := m.s.UpdateAPIKeyByID(ctx, arg)
@@ -3049,7 +3007,7 @@ func (m queryMetricsStore) UpdateOrganizationDeletedByID(ctx context.Context, ar
return r0
}
func (m queryMetricsStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
func (m queryMetricsStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
start := time.Now()
r0, r1 := m.s.UpdatePrebuildProvisionerJobWithCancel(ctx, arg)
m.queryLatencies.WithLabelValues("UpdatePrebuildProvisionerJobWithCancel").Observe(time.Since(start).Seconds())
+2 -90
View File
@@ -190,21 +190,6 @@ func (mr *MockStoreMockRecorder) BulkMarkNotificationMessagesSent(ctx, arg any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BulkMarkNotificationMessagesSent", reflect.TypeOf((*MockStore)(nil).BulkMarkNotificationMessagesSent), ctx, arg)
}
// CalculateAIBridgeInterceptionsTelemetrySummary mocks base method.
func (m *MockStore) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg database.CalculateAIBridgeInterceptionsTelemetrySummaryParams) (database.CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CalculateAIBridgeInterceptionsTelemetrySummary", ctx, arg)
ret0, _ := ret[0].(database.CalculateAIBridgeInterceptionsTelemetrySummaryRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CalculateAIBridgeInterceptionsTelemetrySummary indicates an expected call of CalculateAIBridgeInterceptionsTelemetrySummary.
func (mr *MockStoreMockRecorder) CalculateAIBridgeInterceptionsTelemetrySummary(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CalculateAIBridgeInterceptionsTelemetrySummary", reflect.TypeOf((*MockStore)(nil).CalculateAIBridgeInterceptionsTelemetrySummary), ctx, arg)
}
// ClaimPrebuiltWorkspace mocks base method.
func (m *MockStore) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) {
m.ctrl.T.Helper()
@@ -751,20 +736,6 @@ func (mr *MockStoreMockRecorder) DeleteOldProvisionerDaemons(ctx any) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldProvisionerDaemons", reflect.TypeOf((*MockStore)(nil).DeleteOldProvisionerDaemons), ctx)
}
// DeleteOldTelemetryLocks mocks base method.
func (m *MockStore) DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteOldTelemetryLocks", ctx, periodEndingAtBefore)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteOldTelemetryLocks indicates an expected call of DeleteOldTelemetryLocks.
func (mr *MockStoreMockRecorder) DeleteOldTelemetryLocks(ctx, periodEndingAtBefore any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldTelemetryLocks", reflect.TypeOf((*MockStore)(nil).DeleteOldTelemetryLocks), ctx, periodEndingAtBefore)
}
// DeleteOldWorkspaceAgentLogs mocks base method.
func (m *MockStore) DeleteOldWorkspaceAgentLogs(ctx context.Context, threshold time.Time) error {
m.ctrl.T.Helper()
@@ -2622,21 +2593,6 @@ func (mr *MockStoreMockRecorder) GetOrganizationsByUserID(ctx, arg any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationsByUserID", reflect.TypeOf((*MockStore)(nil).GetOrganizationsByUserID), ctx, arg)
}
// GetOrganizationsWithPrebuildStatus mocks base method.
func (m *MockStore) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg database.GetOrganizationsWithPrebuildStatusParams) ([]database.GetOrganizationsWithPrebuildStatusRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOrganizationsWithPrebuildStatus", ctx, arg)
ret0, _ := ret[0].([]database.GetOrganizationsWithPrebuildStatusRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrganizationsWithPrebuildStatus indicates an expected call of GetOrganizationsWithPrebuildStatus.
func (mr *MockStoreMockRecorder) GetOrganizationsWithPrebuildStatus(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationsWithPrebuildStatus", reflect.TypeOf((*MockStore)(nil).GetOrganizationsWithPrebuildStatus), ctx, arg)
}
// GetParameterSchemasByJobID mocks base method.
func (m *MockStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
m.ctrl.T.Helper()
@@ -5436,20 +5392,6 @@ func (mr *MockStoreMockRecorder) InsertTelemetryItemIfNotExists(ctx, arg any) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertTelemetryItemIfNotExists", reflect.TypeOf((*MockStore)(nil).InsertTelemetryItemIfNotExists), ctx, arg)
}
// InsertTelemetryLock mocks base method.
func (m *MockStore) InsertTelemetryLock(ctx context.Context, arg database.InsertTelemetryLockParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertTelemetryLock", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// InsertTelemetryLock indicates an expected call of InsertTelemetryLock.
func (mr *MockStoreMockRecorder) InsertTelemetryLock(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertTelemetryLock", reflect.TypeOf((*MockStore)(nil).InsertTelemetryLock), ctx, arg)
}
// InsertTemplate mocks base method.
func (m *MockStore) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) error {
m.ctrl.T.Helper()
@@ -5905,21 +5847,6 @@ func (mr *MockStoreMockRecorder) ListAIBridgeInterceptions(ctx, arg any) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).ListAIBridgeInterceptions), ctx, arg)
}
// ListAIBridgeInterceptionsTelemetrySummaries mocks base method.
func (m *MockStore) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg database.ListAIBridgeInterceptionsTelemetrySummariesParams) ([]database.ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListAIBridgeInterceptionsTelemetrySummaries", ctx, arg)
ret0, _ := ret[0].([]database.ListAIBridgeInterceptionsTelemetrySummariesRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListAIBridgeInterceptionsTelemetrySummaries indicates an expected call of ListAIBridgeInterceptionsTelemetrySummaries.
func (mr *MockStoreMockRecorder) ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeInterceptionsTelemetrySummaries", reflect.TypeOf((*MockStore)(nil).ListAIBridgeInterceptionsTelemetrySummaries), ctx, arg)
}
// ListAIBridgeTokenUsagesByInterceptionIDs mocks base method.
func (m *MockStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
m.ctrl.T.Helper()
@@ -6289,21 +6216,6 @@ func (mr *MockStoreMockRecorder) UnfavoriteWorkspace(ctx, id any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnfavoriteWorkspace", reflect.TypeOf((*MockStore)(nil).UnfavoriteWorkspace), ctx, id)
}
// UpdateAIBridgeInterceptionEnded mocks base method.
func (m *MockStore) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateAIBridgeInterceptionEnded", ctx, arg)
ret0, _ := ret[0].(database.AIBridgeInterception)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateAIBridgeInterceptionEnded indicates an expected call of UpdateAIBridgeInterceptionEnded.
func (mr *MockStoreMockRecorder) UpdateAIBridgeInterceptionEnded(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAIBridgeInterceptionEnded", reflect.TypeOf((*MockStore)(nil).UpdateAIBridgeInterceptionEnded), ctx, arg)
}
// UpdateAPIKeyByID mocks base method.
func (m *MockStore) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
m.ctrl.T.Helper()
@@ -6555,10 +6467,10 @@ func (mr *MockStoreMockRecorder) UpdateOrganizationDeletedByID(ctx, arg any) *go
}
// UpdatePrebuildProvisionerJobWithCancel mocks base method.
func (m *MockStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
func (m *MockStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdatePrebuildProvisionerJobWithCancel", ctx, arg)
ret0, _ := ret[0].([]database.UpdatePrebuildProvisionerJobWithCancelRow)
ret0, _ := ret[0].([]uuid.UUID)
ret1, _ := ret[1].(error)
return ret0, ret1
}
-10
View File
@@ -24,12 +24,6 @@ const (
// but we won't touch the `connection_logs` table.
maxAuditLogConnectionEventAge = 90 * 24 * time.Hour // 90 days
auditLogConnectionEventBatchSize = 1000
// Telemetry heartbeats are used to deduplicate events across replicas. We
// don't need to persist heartbeat rows for longer than 24 hours, as they
// are only used for deduplication across replicas. The time needs to be
// long enough to cover the maximum interval of a heartbeat event (currently
// 1 hour) plus some buffer.
maxTelemetryHeartbeatAge = 24 * time.Hour
)
// New creates a new periodically purging database instance.
@@ -77,10 +71,6 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, clk quartz.
if err := tx.ExpirePrebuildsAPIKeys(ctx, dbtime.Time(start)); err != nil {
return xerrors.Errorf("failed to expire prebuilds user api keys: %w", err)
}
deleteOldTelemetryLocksBefore := start.Add(-maxTelemetryHeartbeatAge)
if err := tx.DeleteOldTelemetryLocks(ctx, deleteOldTelemetryLocksBefore); err != nil {
return xerrors.Errorf("failed to delete old telemetry locks: %w", err)
}
deleteOldAuditLogConnectionEventsBefore := start.Add(-maxAuditLogConnectionEventAge)
if err := tx.DeleteOldAuditLogConnectionEvents(ctx, database.DeleteOldAuditLogConnectionEventsParams{
-53
View File
@@ -704,56 +704,3 @@ func TestExpireOldAPIKeys(t *testing.T) {
// Out of an abundance of caution, we do not expire explicitly named prebuilds API keys.
assertKeyActive(namedPrebuildsAPIKey.ID)
}
func TestDeleteOldTelemetryHeartbeats(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
clk := quartz.NewMock(t)
now := clk.Now().UTC()
// Insert telemetry heartbeats.
err := db.InsertTelemetryLock(ctx, database.InsertTelemetryLockParams{
EventType: "aibridge_interceptions_summary",
PeriodEndingAt: now.Add(-25 * time.Hour), // should be purged
})
require.NoError(t, err)
err = db.InsertTelemetryLock(ctx, database.InsertTelemetryLockParams{
EventType: "aibridge_interceptions_summary",
PeriodEndingAt: now.Add(-23 * time.Hour), // should be kept
})
require.NoError(t, err)
err = db.InsertTelemetryLock(ctx, database.InsertTelemetryLockParams{
EventType: "aibridge_interceptions_summary",
PeriodEndingAt: now, // should be kept
})
require.NoError(t, err)
done := awaitDoTick(ctx, t, clk)
closer := dbpurge.New(ctx, logger, db, clk)
defer closer.Close()
<-done // doTick() has now run.
require.Eventuallyf(t, func() bool {
// We use an SQL queries directly here because we don't expose queries
// for deleting heartbeats in the application code.
var totalCount int
err := sqlDB.QueryRowContext(ctx, `
SELECT COUNT(*) FROM telemetry_locks;
`).Scan(&totalCount)
assert.NoError(t, err)
var oldCount int
err = sqlDB.QueryRowContext(ctx, `
SELECT COUNT(*) FROM telemetry_locks WHERE period_ending_at < $1;
`, now.Add(-24*time.Hour)).Scan(&oldCount)
assert.NoError(t, err)
// Expect 2 heartbeats remaining and none older than 24 hours.
t.Logf("eventually: total count: %d, old count: %d", totalCount, oldCount)
return totalCount == 2 && oldCount == 0
}, testutil.WaitShort, testutil.IntervalFast, "it should delete old telemetry heartbeats")
}
+6 -51
View File
@@ -6,8 +6,6 @@ import (
_ "embed"
"fmt"
"os"
"runtime"
"strings"
"sync"
"time"
@@ -47,8 +45,6 @@ func (b *Broker) Create(t TBSubset, opts ...OpenOption) (ConnectionParams, error
host = defaultConnectionParams.Host
port = defaultConnectionParams.Port
)
packageName := getTestPackageName(t)
testName := t.Name()
// Use a time-based prefix to make it easier to find the database
// when debugging.
@@ -59,9 +55,9 @@ func (b *Broker) Create(t TBSubset, opts ...OpenOption) (ConnectionParams, error
}
dbName := now + "_" + dbSuffix
// TODO: add package and test name
_, err = b.coderTestingDB.Exec(
"INSERT INTO test_databases (name, process_uuid, test_package, test_name) VALUES ($1, $2, $3, $4)",
dbName, b.uuid, packageName, testName)
"INSERT INTO test_databases (name, process_uuid) VALUES ($1, $2)", dbName, b.uuid)
if err != nil {
return ConnectionParams{}, xerrors.Errorf("insert test_database row: %w", err)
}
@@ -108,10 +104,10 @@ func (b *Broker) clean(t TBSubset, dbName string) func() {
func (b *Broker) init(t TBSubset) error {
b.Lock()
defer b.Unlock()
b.refCount++
t.Cleanup(b.decRef)
if b.coderTestingDB != nil {
// already initialized
b.refCount++
t.Cleanup(b.decRef)
return nil
}
@@ -128,8 +124,8 @@ func (b *Broker) init(t TBSubset) error {
return xerrors.Errorf("open postgres connection: %w", err)
}
// coderTestingSQLInit is idempotent, so we can run it every time.
_, err = coderTestingDB.Exec(coderTestingSQLInit)
// creating the db can succeed even if the database doesn't exist. Ping it to find out.
err = coderTestingDB.Ping()
var pqErr *pq.Error
if xerrors.As(err, &pqErr) && pqErr.Code == "3D000" {
// database does not exist.
@@ -149,8 +145,6 @@ func (b *Broker) init(t TBSubset) error {
return xerrors.Errorf("ping '%s' database: %w", CoderTestingDBName, err)
}
b.coderTestingDB = coderTestingDB
b.refCount++
t.Cleanup(b.decRef)
if b.uuid == uuid.Nil {
b.uuid = uuid.New()
@@ -192,42 +186,3 @@ func (b *Broker) decRef() {
b.coderTestingDB = nil
}
}
// getTestPackageName returns the package name of the test that called it.
func getTestPackageName(t TBSubset) string {
packageName := "unknown"
// Ask runtime.Callers for up to 100 program counters, including runtime.Callers itself.
pc := make([]uintptr, 100)
n := runtime.Callers(0, pc)
if n == 0 {
// No PCs available. This can happen if the first argument to
// runtime.Callers is large.
//
// Return now to avoid processing the zero Frame that would
// otherwise be returned by frames.Next below.
t.Logf("could not determine test package name: no PCs available")
return packageName
}
pc = pc[:n] // pass only valid pcs to runtime.CallersFrames
frames := runtime.CallersFrames(pc)
// Loop to get frames.
// A fixed number of PCs can expand to an indefinite number of Frames.
for {
frame, more := frames.Next()
if strings.HasPrefix(frame.Function, "github.com/coder/coder/v2/") {
packageName = strings.SplitN(strings.TrimPrefix(frame.Function, "github.com/coder/coder/v2/"), ".", 2)[0]
}
if strings.HasPrefix(frame.Function, "testing") {
break
}
// Check whether there are more frames to process after this one.
if !more {
break
}
}
return packageName
}
@@ -1,13 +0,0 @@
package dbtestutil
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGetTestPackageName(t *testing.T) {
t.Parallel()
packageName := getTestPackageName(t)
require.Equal(t, "coderd/database/dbtestutil", packageName)
}
@@ -1,6 +1,3 @@
BEGIN TRANSACTION;
SELECT pg_advisory_xact_lock(7283699);
CREATE TABLE IF NOT EXISTS test_databases (
name text PRIMARY KEY,
created_at timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
@@ -9,10 +6,3 @@ CREATE TABLE IF NOT EXISTS test_databases (
);
CREATE INDEX IF NOT EXISTS test_databases_process_uuid ON test_databases (process_uuid, dropped_at);
ALTER TABLE test_databases ADD COLUMN IF NOT EXISTS test_name text;
COMMENT ON COLUMN test_databases.test_name IS 'Name of the test that created the database';
ALTER TABLE test_databases ADD COLUMN IF NOT EXISTS test_package text;
COMMENT ON COLUMN test_databases.test_package IS 'Package of the test that created the database';
COMMIT;
+14 -41
View File
@@ -1828,15 +1828,6 @@ CREATE TABLE tasks (
deleted_at timestamp with time zone
);
CREATE VIEW visible_users AS
SELECT users.id,
users.username,
users.name,
users.avatar_url
FROM users;
COMMENT ON VIEW visible_users IS 'Visible fields of users are allowed to be joined with other tables for including context of other resources.';
CREATE TABLE workspace_agents (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
@@ -1987,16 +1978,8 @@ CREATE VIEW tasks_with_status AS
END AS status,
task_app.workspace_build_number,
task_app.workspace_agent_id,
task_app.workspace_app_id,
task_owner.owner_username,
task_owner.owner_name,
task_owner.owner_avatar_url
FROM (((((tasks
CROSS JOIN LATERAL ( SELECT vu.username AS owner_username,
vu.name AS owner_name,
vu.avatar_url AS owner_avatar_url
FROM visible_users vu
WHERE (vu.id = tasks.owner_id)) task_owner)
task_app.workspace_app_id
FROM ((((tasks
LEFT JOIN LATERAL ( SELECT task_app_1.workspace_build_number,
task_app_1.workspace_agent_id,
task_app_1.workspace_app_id
@@ -2029,18 +2012,6 @@ CREATE TABLE telemetry_items (
updated_at timestamp with time zone DEFAULT now() NOT NULL
);
CREATE TABLE telemetry_locks (
event_type text NOT NULL,
period_ending_at timestamp with time zone NOT NULL,
CONSTRAINT telemetry_lock_event_type_constraint CHECK ((event_type = 'aibridge_interceptions_summary'::text))
);
COMMENT ON TABLE telemetry_locks IS 'Telemetry lock tracking table for deduplication of heartbeat events across replicas.';
COMMENT ON COLUMN telemetry_locks.event_type IS 'The type of event that was sent.';
COMMENT ON COLUMN telemetry_locks.period_ending_at IS 'The heartbeat period end timestamp.';
CREATE TABLE template_usage_stats (
start_time timestamp with time zone NOT NULL,
end_time timestamp with time zone NOT NULL,
@@ -2227,6 +2198,15 @@ COMMENT ON COLUMN template_versions.external_auth_providers IS 'IDs of External
COMMENT ON COLUMN template_versions.message IS 'Message describing the changes in this version of the template, similar to a Git commit message. Like a commit message, this should be a short, high-level description of the changes in this version of the template. This message is immutable and should not be updated after the fact.';
CREATE VIEW visible_users AS
SELECT users.id,
users.username,
users.name,
users.avatar_url
FROM users;
COMMENT ON VIEW visible_users IS 'Visible fields of users are allowed to be joined with other tables for including context of other resources.';
CREATE VIEW template_version_with_user AS
SELECT template_versions.id,
template_versions.template_id,
@@ -2922,13 +2902,11 @@ CREATE VIEW workspaces_expanded AS
templates.name AS template_name,
templates.display_name AS template_display_name,
templates.icon AS template_icon,
templates.description AS template_description,
tasks.id AS task_id
FROM ((((workspaces
templates.description AS template_description
FROM (((workspaces
JOIN visible_users ON ((workspaces.owner_id = visible_users.id)))
JOIN organizations ON ((workspaces.organization_id = organizations.id)))
JOIN templates ON ((workspaces.template_id = templates.id)))
LEFT JOIN tasks ON ((workspaces.id = tasks.workspace_id)));
JOIN templates ON ((workspaces.template_id = templates.id)));
COMMENT ON VIEW workspaces_expanded IS 'Joins in the display name information such as username, avatar, and organization name.';
@@ -3112,9 +3090,6 @@ ALTER TABLE ONLY tasks
ALTER TABLE ONLY telemetry_items
ADD CONSTRAINT telemetry_items_pkey PRIMARY KEY (key);
ALTER TABLE ONLY telemetry_locks
ADD CONSTRAINT telemetry_locks_pkey PRIMARY KEY (event_type, period_ending_at);
ALTER TABLE ONLY template_usage_stats
ADD CONSTRAINT template_usage_stats_pkey PRIMARY KEY (start_time, template_id, user_id);
@@ -3340,8 +3315,6 @@ CREATE INDEX idx_tailnet_tunnels_dst_id ON tailnet_tunnels USING hash (dst_id);
CREATE INDEX idx_tailnet_tunnels_src_id ON tailnet_tunnels USING hash (src_id);
CREATE INDEX idx_telemetry_locks_period_ending_at ON telemetry_locks USING btree (period_ending_at);
CREATE UNIQUE INDEX idx_template_version_presets_default ON template_version_presets USING btree (template_version_id) WHERE (is_default = true);
CREATE INDEX idx_template_versions_has_ai_task ON template_versions USING btree (has_ai_task);
@@ -141,19 +141,13 @@ ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'workspace_proxy:read';
ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'workspace_proxy:update';
-- End enum extensions
-- Purge old API keys to speed up the migration for large deployments.
-- Note: that problem should be solved in coderd once PR 20863 is released:
-- https://github.com/coder/coder/blob/main/coderd/database/dbpurge/dbpurge.go#L85
DELETE FROM api_keys WHERE expires_at < NOW() - INTERVAL '7 days';
-- Add new columns without defaults; backfill; then enforce NOT NULL
ALTER TABLE api_keys ADD COLUMN scopes api_key_scope[];
ALTER TABLE api_keys ADD COLUMN allow_list text[];
-- Backfill existing rows for compatibility
UPDATE api_keys SET
scopes = ARRAY[scope::api_key_scope],
allow_list = ARRAY['*:*'];
UPDATE api_keys SET scopes = ARRAY[scope::api_key_scope];
UPDATE api_keys SET allow_list = ARRAY['*:*'];
-- Enforce NOT NULL
ALTER TABLE api_keys ALTER COLUMN scopes SET NOT NULL;
@@ -1 +0,0 @@
DROP TABLE telemetry_locks;
@@ -1,12 +0,0 @@
CREATE TABLE telemetry_locks (
event_type TEXT NOT NULL CONSTRAINT telemetry_lock_event_type_constraint CHECK (event_type IN ('aibridge_interceptions_summary')),
period_ending_at TIMESTAMP WITH TIME ZONE NOT NULL,
PRIMARY KEY (event_type, period_ending_at)
);
COMMENT ON TABLE telemetry_locks IS 'Telemetry lock tracking table for deduplication of heartbeat events across replicas.';
COMMENT ON COLUMN telemetry_locks.event_type IS 'The type of event that was sent.';
COMMENT ON COLUMN telemetry_locks.period_ending_at IS 'The heartbeat period end timestamp.';
CREATE INDEX idx_telemetry_locks_period_ending_at ON telemetry_locks (period_ending_at);
@@ -1,74 +0,0 @@
-- Drop view from 000390_tasks_with_status_user_fields.up.sql.
DROP VIEW IF EXISTS tasks_with_status;
-- Restore from 000382_add_columns_to_tasks_with_status.up.sql.
CREATE VIEW
tasks_with_status
AS
SELECT
tasks.*,
CASE
WHEN tasks.workspace_id IS NULL OR latest_build.job_status IS NULL THEN 'pending'::task_status
WHEN latest_build.job_status = 'failed' THEN 'error'::task_status
WHEN latest_build.transition IN ('stop', 'delete')
AND latest_build.job_status = 'succeeded' THEN 'paused'::task_status
WHEN latest_build.transition = 'start'
AND latest_build.job_status = 'pending' THEN 'initializing'::task_status
WHEN latest_build.transition = 'start' AND latest_build.job_status IN ('running', 'succeeded') THEN
CASE
WHEN agent_status.none THEN 'initializing'::task_status
WHEN agent_status.connecting THEN 'initializing'::task_status
WHEN agent_status.connected THEN
CASE
WHEN app_status.any_unhealthy THEN 'error'::task_status
WHEN app_status.any_initializing THEN 'initializing'::task_status
WHEN app_status.all_healthy_or_disabled THEN 'active'::task_status
ELSE 'unknown'::task_status
END
ELSE 'unknown'::task_status
END
ELSE 'unknown'::task_status
END AS status,
task_app.*
FROM
tasks
LEFT JOIN LATERAL (
SELECT workspace_build_number, workspace_agent_id, workspace_app_id
FROM task_workspace_apps task_app
WHERE task_id = tasks.id
ORDER BY workspace_build_number DESC
LIMIT 1
) task_app ON TRUE
LEFT JOIN LATERAL (
SELECT
workspace_build.transition,
provisioner_job.job_status,
workspace_build.job_id
FROM workspace_builds workspace_build
JOIN provisioner_jobs provisioner_job ON provisioner_job.id = workspace_build.job_id
WHERE workspace_build.workspace_id = tasks.workspace_id
AND workspace_build.build_number = task_app.workspace_build_number
) latest_build ON TRUE
CROSS JOIN LATERAL (
SELECT
COUNT(*) = 0 AS none,
bool_or(workspace_agent.lifecycle_state IN ('created', 'starting')) AS connecting,
bool_and(workspace_agent.lifecycle_state = 'ready') AS connected
FROM workspace_agents workspace_agent
WHERE workspace_agent.id = task_app.workspace_agent_id
) agent_status
CROSS JOIN LATERAL (
SELECT
bool_or(workspace_app.health = 'unhealthy') AS any_unhealthy,
bool_or(workspace_app.health = 'initializing') AS any_initializing,
bool_and(workspace_app.health IN ('healthy', 'disabled')) AS all_healthy_or_disabled
FROM workspace_apps workspace_app
WHERE workspace_app.id = task_app.workspace_app_id
) app_status
WHERE
tasks.deleted_at IS NULL;
@@ -1,84 +0,0 @@
-- Drop view from 00037_add_columns_to_tasks_with_status.up.sql.
DROP VIEW IF EXISTS tasks_with_status;
-- Add owner_name, owner_avatar_url columns.
CREATE VIEW
tasks_with_status
AS
SELECT
tasks.*,
CASE
WHEN tasks.workspace_id IS NULL OR latest_build.job_status IS NULL THEN 'pending'::task_status
WHEN latest_build.job_status = 'failed' THEN 'error'::task_status
WHEN latest_build.transition IN ('stop', 'delete')
AND latest_build.job_status = 'succeeded' THEN 'paused'::task_status
WHEN latest_build.transition = 'start'
AND latest_build.job_status = 'pending' THEN 'initializing'::task_status
WHEN latest_build.transition = 'start' AND latest_build.job_status IN ('running', 'succeeded') THEN
CASE
WHEN agent_status.none THEN 'initializing'::task_status
WHEN agent_status.connecting THEN 'initializing'::task_status
WHEN agent_status.connected THEN
CASE
WHEN app_status.any_unhealthy THEN 'error'::task_status
WHEN app_status.any_initializing THEN 'initializing'::task_status
WHEN app_status.all_healthy_or_disabled THEN 'active'::task_status
ELSE 'unknown'::task_status
END
ELSE 'unknown'::task_status
END
ELSE 'unknown'::task_status
END AS status,
task_app.*,
task_owner.*
FROM
tasks
CROSS JOIN LATERAL (
SELECT
vu.username AS owner_username,
vu.name AS owner_name,
vu.avatar_url AS owner_avatar_url
FROM visible_users vu
WHERE vu.id = tasks.owner_id
) task_owner
LEFT JOIN LATERAL (
SELECT workspace_build_number, workspace_agent_id, workspace_app_id
FROM task_workspace_apps task_app
WHERE task_id = tasks.id
ORDER BY workspace_build_number DESC
LIMIT 1
) task_app ON TRUE
LEFT JOIN LATERAL (
SELECT
workspace_build.transition,
provisioner_job.job_status,
workspace_build.job_id
FROM workspace_builds workspace_build
JOIN provisioner_jobs provisioner_job ON provisioner_job.id = workspace_build.job_id
WHERE workspace_build.workspace_id = tasks.workspace_id
AND workspace_build.build_number = task_app.workspace_build_number
) latest_build ON TRUE
CROSS JOIN LATERAL (
SELECT
COUNT(*) = 0 AS none,
bool_or(workspace_agent.lifecycle_state IN ('created', 'starting')) AS connecting,
bool_and(workspace_agent.lifecycle_state = 'ready') AS connected
FROM workspace_agents workspace_agent
WHERE workspace_agent.id = task_app.workspace_agent_id
) agent_status
CROSS JOIN LATERAL (
SELECT
bool_or(workspace_app.health = 'unhealthy') AS any_unhealthy,
bool_or(workspace_app.health = 'initializing') AS any_initializing,
bool_and(workspace_app.health IN ('healthy', 'disabled')) AS all_healthy_or_disabled
FROM workspace_apps workspace_app
WHERE workspace_app.id = task_app.workspace_app_id
) app_status
WHERE
tasks.deleted_at IS NULL;
@@ -1,8 +0,0 @@
UPDATE notification_templates
SET enabled_by_default = true
WHERE id IN (
'8c5a4d12-9f7e-4b3a-a1c8-6e4f2d9b5a7c',
'3b7e8f1a-4c2d-49a6-b5e9-7f3a1c8d6b4e',
'bd4b7168-d05e-4e19-ad0f-3593b77aa90f',
'd4a6271c-cced-4ed0-84ad-afd02a9c7799'
);
@@ -1,8 +0,0 @@
UPDATE notification_templates
SET enabled_by_default = false
WHERE id IN (
'8c5a4d12-9f7e-4b3a-a1c8-6e4f2d9b5a7c',
'3b7e8f1a-4c2d-49a6-b5e9-7f3a1c8d6b4e',
'bd4b7168-d05e-4e19-ad0f-3593b77aa90f',
'd4a6271c-cced-4ed0-84ad-afd02a9c7799'
);
@@ -1,39 +0,0 @@
DROP VIEW workspaces_expanded;
-- Recreate the view from 000354_workspace_acl.up.sql
CREATE VIEW workspaces_expanded AS
SELECT workspaces.id,
workspaces.created_at,
workspaces.updated_at,
workspaces.owner_id,
workspaces.organization_id,
workspaces.template_id,
workspaces.deleted,
workspaces.name,
workspaces.autostart_schedule,
workspaces.ttl,
workspaces.last_used_at,
workspaces.dormant_at,
workspaces.deleting_at,
workspaces.automatic_updates,
workspaces.favorite,
workspaces.next_start_at,
workspaces.group_acl,
workspaces.user_acl,
visible_users.avatar_url AS owner_avatar_url,
visible_users.username AS owner_username,
visible_users.name AS owner_name,
organizations.name AS organization_name,
organizations.display_name AS organization_display_name,
organizations.icon AS organization_icon,
organizations.description AS organization_description,
templates.name AS template_name,
templates.display_name AS template_display_name,
templates.icon AS template_icon,
templates.description AS template_description
FROM (((workspaces
JOIN visible_users ON ((workspaces.owner_id = visible_users.id)))
JOIN organizations ON ((workspaces.organization_id = organizations.id)))
JOIN templates ON ((workspaces.template_id = templates.id)));
COMMENT ON VIEW workspaces_expanded IS 'Joins in the display name information such as username, avatar, and organization name.';
@@ -1,42 +0,0 @@
DROP VIEW workspaces_expanded;
-- Add nullable task_id to workspaces_expanded view
CREATE VIEW workspaces_expanded AS
SELECT workspaces.id,
workspaces.created_at,
workspaces.updated_at,
workspaces.owner_id,
workspaces.organization_id,
workspaces.template_id,
workspaces.deleted,
workspaces.name,
workspaces.autostart_schedule,
workspaces.ttl,
workspaces.last_used_at,
workspaces.dormant_at,
workspaces.deleting_at,
workspaces.automatic_updates,
workspaces.favorite,
workspaces.next_start_at,
workspaces.group_acl,
workspaces.user_acl,
visible_users.avatar_url AS owner_avatar_url,
visible_users.username AS owner_username,
visible_users.name AS owner_name,
organizations.name AS organization_name,
organizations.display_name AS organization_display_name,
organizations.icon AS organization_icon,
organizations.description AS organization_description,
templates.name AS template_name,
templates.display_name AS template_display_name,
templates.icon AS template_icon,
templates.description AS template_description,
tasks.id AS task_id
FROM ((((workspaces
JOIN visible_users ON ((workspaces.owner_id = visible_users.id)))
JOIN organizations ON ((workspaces.organization_id = organizations.id)))
JOIN templates ON ((workspaces.template_id = templates.id)))
LEFT JOIN tasks ON ((workspaces.id = tasks.workspace_id)));
COMMENT ON VIEW workspaces_expanded IS 'Joins in the display name information such as username, avatar, and organization name.';
@@ -1,57 +0,0 @@
-- Ensure api_keys and oauth2_provider_app_tokens have live data after
-- migration 000371 deletes expired rows.
INSERT INTO api_keys (
id,
hashed_secret,
user_id,
last_used,
expires_at,
created_at,
updated_at,
login_type,
lifetime_seconds,
ip_address,
token_name,
scopes,
allow_list
)
VALUES (
'fixture-api-key',
'\xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
'30095c71-380b-457a-8995-97b8ee6e5307',
NOW() - INTERVAL '1 hour',
NOW() + INTERVAL '30 days',
NOW() - INTERVAL '1 day',
NOW() - INTERVAL '1 day',
'password',
86400,
'0.0.0.0',
'fixture-api-key',
ARRAY['workspace:read']::api_key_scope[],
ARRAY['*:*']
)
ON CONFLICT (id) DO NOTHING;
INSERT INTO oauth2_provider_app_tokens (
id,
created_at,
expires_at,
hash_prefix,
refresh_hash,
app_secret_id,
api_key_id,
audience,
user_id
)
VALUES (
'9f92f3c9-811f-4f6f-9a1c-3f2eed1f9f15',
NOW() - INTERVAL '30 minutes',
NOW() + INTERVAL '30 days',
CAST('fixture-hash-prefix' AS bytea),
CAST('fixture-refresh-hash' AS bytea),
'b0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
'fixture-api-key',
'https://coder.example.com',
'30095c71-380b-457a-8995-97b8ee6e5307'
)
ON CONFLICT (id) DO NOTHING;
@@ -1,8 +0,0 @@
INSERT INTO telemetry_locks (
event_type,
period_ending_at
)
VALUES (
'aibridge_interceptions_summary',
'2025-01-01 00:00:00+00'::timestamptz
);
-2
View File
@@ -208,7 +208,6 @@ func (s APIKeyScopes) expandRBACScope() (rbac.Scope, error) {
for orgID, perms := range expanded.ByOrgID {
orgPerms := merged.ByOrgID[orgID]
orgPerms.Org = append(orgPerms.Org, perms.Org...)
orgPerms.Member = append(orgPerms.Member, perms.Member...)
merged.ByOrgID[orgID] = orgPerms
}
merged.User = append(merged.User, expanded.User...)
@@ -221,7 +220,6 @@ func (s APIKeyScopes) expandRBACScope() (rbac.Scope, error) {
merged.User = rbac.DeduplicatePermissions(merged.User)
for orgID, perms := range merged.ByOrgID {
perms.Org = rbac.DeduplicatePermissions(perms.Org)
perms.Member = rbac.DeduplicatePermissions(perms.Member)
merged.ByOrgID[orgID] = perms
}
-1
View File
@@ -321,7 +321,6 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
&i.TemplateDisplayName,
&i.TemplateIcon,
&i.TemplateDescription,
&i.TaskID,
&i.TemplateVersionID,
&i.TemplateVersionName,
&i.LatestBuildCompletedAt,
-12
View File
@@ -4221,9 +4221,6 @@ type Task struct {
WorkspaceBuildNumber sql.NullInt32 `db:"workspace_build_number" json:"workspace_build_number"`
WorkspaceAgentID uuid.NullUUID `db:"workspace_agent_id" json:"workspace_agent_id"`
WorkspaceAppID uuid.NullUUID `db:"workspace_app_id" json:"workspace_app_id"`
OwnerUsername string `db:"owner_username" json:"owner_username"`
OwnerName string `db:"owner_name" json:"owner_name"`
OwnerAvatarUrl string `db:"owner_avatar_url" json:"owner_avatar_url"`
}
type TaskTable struct {
@@ -4253,14 +4250,6 @@ type TelemetryItem struct {
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
// Telemetry lock tracking table for deduplication of heartbeat events across replicas.
type TelemetryLock struct {
// The type of event that was sent.
EventType string `db:"event_type" json:"event_type"`
// The heartbeat period end timestamp.
PeriodEndingAt time.Time `db:"period_ending_at" json:"period_ending_at"`
}
// Joins in the display name information such as username, avatar, and organization name.
type Template struct {
ID uuid.UUID `db:"id" json:"id"`
@@ -4663,7 +4652,6 @@ type Workspace struct {
TemplateDisplayName string `db:"template_display_name" json:"template_display_name"`
TemplateIcon string `db:"template_icon" json:"template_icon"`
TemplateDescription string `db:"template_description" json:"template_description"`
TaskID uuid.NullUUID `db:"task_id" json:"task_id"`
}
type WorkspaceAgent struct {
+1 -19
View File
@@ -60,9 +60,6 @@ type sqlcQuerier interface {
BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error
BulkMarkNotificationMessagesFailed(ctx context.Context, arg BulkMarkNotificationMessagesFailedParams) (int64, error)
BulkMarkNotificationMessagesSent(ctx context.Context, arg BulkMarkNotificationMessagesSentParams) (int64, error)
// Calculates the telemetry summary for a given provider, model, and client
// combination for telemetry reporting.
CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg CalculateAIBridgeInterceptionsTelemetrySummaryParams) (CalculateAIBridgeInterceptionsTelemetrySummaryRow, error)
ClaimPrebuiltWorkspace(ctx context.Context, arg ClaimPrebuiltWorkspaceParams) (ClaimPrebuiltWorkspaceRow, error)
CleanTailnetCoordinators(ctx context.Context) error
CleanTailnetLostPeers(ctx context.Context) error
@@ -110,8 +107,6 @@ type sqlcQuerier interface {
// A provisioner daemon with "zeroed" last_seen_at column indicates possible
// connectivity issues (no provisioner daemon activity since registration).
DeleteOldProvisionerDaemons(ctx context.Context) error
// Deletes old telemetry locks from the telemetry_locks table.
DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error
// If an agent hasn't connected in the last 7 days, we purge it's logs.
// Exception: if the logs are related to the latest build, we keep those around.
// Logs can take up a lot of space, so it's important we clean up frequently.
@@ -269,9 +264,6 @@ type sqlcQuerier interface {
GetOrganizationResourceCountByID(ctx context.Context, organizationID uuid.UUID) (GetOrganizationResourceCountByIDRow, error)
GetOrganizations(ctx context.Context, arg GetOrganizationsParams) ([]Organization, error)
GetOrganizationsByUserID(ctx context.Context, arg GetOrganizationsByUserIDParams) ([]Organization, error)
// GetOrganizationsWithPrebuildStatus returns organizations with prebuilds configured and their
// membership status for the prebuilds system user (org membership, group existence, group membership).
GetOrganizationsWithPrebuildStatus(ctx context.Context, arg GetOrganizationsWithPrebuildStatusParams) ([]GetOrganizationsWithPrebuildStatusRow, error)
GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error)
GetPrebuildMetrics(ctx context.Context) ([]GetPrebuildMetricsRow, error)
GetPrebuildsSettings(ctx context.Context) (string, error)
@@ -567,12 +559,6 @@ type sqlcQuerier interface {
InsertReplica(ctx context.Context, arg InsertReplicaParams) (Replica, error)
InsertTask(ctx context.Context, arg InsertTaskParams) (TaskTable, error)
InsertTelemetryItemIfNotExists(ctx context.Context, arg InsertTelemetryItemIfNotExistsParams) error
// Inserts a new lock row into the telemetry_locks table. Replicas should call
// this function prior to attempting to generate or publish a heartbeat event to
// the telemetry service.
// If the query returns a duplicate primary key error, the replica should not
// attempt to generate or publish the event to the telemetry service.
InsertTelemetryLock(ctx context.Context, arg InsertTelemetryLockParams) error
InsertTemplate(ctx context.Context, arg InsertTemplateParams) error
InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) error
InsertTemplateVersionParameter(ctx context.Context, arg InsertTemplateVersionParameterParams) (TemplateVersionParameter, error)
@@ -609,9 +595,6 @@ type sqlcQuerier interface {
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error)
ListAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams) ([]ListAIBridgeInterceptionsRow, error)
// Finds all unique AIBridge interception telemetry summaries combinations
// (provider, model, client) in the given timeframe for telemetry reporting.
ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error)
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
@@ -649,7 +632,6 @@ type sqlcQuerier interface {
// This will always work regardless of the current state of the template version.
UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error
UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error
UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error)
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error)
UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error)
@@ -670,7 +652,7 @@ type sqlcQuerier interface {
// Cancels all pending provisioner jobs for prebuilt workspaces on a specific preset from an
// inactive template version.
// This is an optimization to clean up stale pending jobs.
UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg UpdatePrebuildProvisionerJobWithCancelParams) ([]UpdatePrebuildProvisionerJobWithCancelRow, error)
UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error)
UpdatePresetPrebuildStatus(ctx context.Context, arg UpdatePresetPrebuildStatusParams) error
UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg UpdateProvisionerDaemonLastSeenAtParams) error
UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error
-68
View File
@@ -7248,9 +7248,7 @@ func TestTaskNameUniqueness(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort)
taskID := uuid.New()
task, err := db.InsertTask(ctx, database.InsertTaskParams{
ID: taskID,
OrganizationID: org.ID,
OwnerID: tt.ownerID,
Name: tt.taskName,
@@ -7265,7 +7263,6 @@ func TestTaskNameUniqueness(t *testing.T) {
require.NoError(t, err)
require.NotEqual(t, uuid.Nil, task.ID)
require.NotEqual(t, task1.ID, task.ID)
require.Equal(t, taskID, task.ID)
}
})
}
@@ -7727,68 +7724,3 @@ func TestUpdateTaskWorkspaceID(t *testing.T) {
})
}
}
func TestUpdateAIBridgeInterceptionEnded(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
t.Run("NonExistingInterception", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
got, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
ID: uuid.New(),
EndedAt: time.Now(),
})
require.ErrorContains(t, err, "no rows in result set")
require.EqualValues(t, database.AIBridgeInterception{}, got)
})
t.Run("OK", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
user := dbgen.User(t, db, database.User{})
interceptions := []database.AIBridgeInterception{}
for _, uid := range []uuid.UUID{{1}, {2}, {3}} {
insertParams := database.InsertAIBridgeInterceptionParams{
ID: uid,
InitiatorID: user.ID,
Metadata: json.RawMessage("{}"),
}
intc, err := db.InsertAIBridgeInterception(ctx, insertParams)
require.NoError(t, err)
require.Equal(t, uid, intc.ID)
require.False(t, intc.EndedAt.Valid)
interceptions = append(interceptions, intc)
}
intc0 := interceptions[0]
endedAt := time.Now()
// Mark first interception as done
updated, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
ID: intc0.ID,
EndedAt: endedAt,
})
require.NoError(t, err)
require.EqualValues(t, updated.ID, intc0.ID)
require.True(t, updated.EndedAt.Valid)
require.WithinDuration(t, endedAt, updated.EndedAt.Time, 5*time.Second)
// Updating first interception again should fail
updated, err = db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
ID: intc0.ID,
EndedAt: endedAt.Add(time.Hour),
})
require.ErrorIs(t, err, sql.ErrNoRows)
// Other interceptions should not have ended_at set
for _, intc := range interceptions[1:] {
got, err := db.GetAIBridgeInterceptionByID(ctx, intc.ID)
require.NoError(t, err)
require.False(t, got.EndedAt.Valid)
}
})
}
+25 -419
View File
@@ -111,164 +111,6 @@ func (q *sqlQuerier) ActivityBumpWorkspace(ctx context.Context, arg ActivityBump
return err
}
const calculateAIBridgeInterceptionsTelemetrySummary = `-- name: CalculateAIBridgeInterceptionsTelemetrySummary :one
WITH interceptions_in_range AS (
-- Get all matching interceptions in the given timeframe.
SELECT
id,
initiator_id,
(ended_at - started_at) AS duration
FROM
aibridge_interceptions
WHERE
provider = $1::text
AND model = $2::text
-- TODO: use the client value once we have it (see https://github.com/coder/aibridge/issues/31)
AND 'unknown' = $3::text
AND ended_at IS NOT NULL -- incomplete interceptions are not included in summaries
AND ended_at >= $4::timestamptz
AND ended_at < $5::timestamptz
),
interception_counts AS (
SELECT
COUNT(id) AS interception_count,
COUNT(DISTINCT initiator_id) AS unique_initiator_count
FROM
interceptions_in_range
),
duration_percentiles AS (
SELECT
(COALESCE(PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p50_millis,
(COALESCE(PERCENTILE_CONT(0.90) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p90_millis,
(COALESCE(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p95_millis,
(COALESCE(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p99_millis
FROM
interceptions_in_range
),
token_aggregates AS (
SELECT
COALESCE(SUM(tu.input_tokens), 0) AS token_count_input,
COALESCE(SUM(tu.output_tokens), 0) AS token_count_output,
-- Cached tokens are stored in metadata JSON, extract if available.
-- Read tokens may be stored in:
-- - cache_read_input (Anthropic)
-- - prompt_cached (OpenAI)
COALESCE(SUM(
COALESCE((tu.metadata->>'cache_read_input')::bigint, 0) +
COALESCE((tu.metadata->>'prompt_cached')::bigint, 0)
), 0) AS token_count_cached_read,
-- Written tokens may be stored in:
-- - cache_creation_input (Anthropic)
-- Note that cache_ephemeral_5m_input and cache_ephemeral_1h_input on
-- Anthropic are included in the cache_creation_input field.
COALESCE(SUM(
COALESCE((tu.metadata->>'cache_creation_input')::bigint, 0)
), 0) AS token_count_cached_written,
COUNT(tu.id) AS token_usages_count
FROM
interceptions_in_range i
LEFT JOIN
aibridge_token_usages tu ON i.id = tu.interception_id
),
prompt_aggregates AS (
SELECT
COUNT(up.id) AS user_prompts_count
FROM
interceptions_in_range i
LEFT JOIN
aibridge_user_prompts up ON i.id = up.interception_id
),
tool_aggregates AS (
SELECT
COUNT(tu.id) FILTER (WHERE tu.injected = true) AS tool_calls_count_injected,
COUNT(tu.id) FILTER (WHERE tu.injected = false) AS tool_calls_count_non_injected,
COUNT(tu.id) FILTER (WHERE tu.injected = true AND tu.invocation_error IS NOT NULL) AS injected_tool_call_error_count
FROM
interceptions_in_range i
LEFT JOIN
aibridge_tool_usages tu ON i.id = tu.interception_id
)
SELECT
ic.interception_count::bigint AS interception_count,
dp.interception_duration_p50_millis::bigint AS interception_duration_p50_millis,
dp.interception_duration_p90_millis::bigint AS interception_duration_p90_millis,
dp.interception_duration_p95_millis::bigint AS interception_duration_p95_millis,
dp.interception_duration_p99_millis::bigint AS interception_duration_p99_millis,
ic.unique_initiator_count::bigint AS unique_initiator_count,
pa.user_prompts_count::bigint AS user_prompts_count,
tok_agg.token_usages_count::bigint AS token_usages_count,
tok_agg.token_count_input::bigint AS token_count_input,
tok_agg.token_count_output::bigint AS token_count_output,
tok_agg.token_count_cached_read::bigint AS token_count_cached_read,
tok_agg.token_count_cached_written::bigint AS token_count_cached_written,
tool_agg.tool_calls_count_injected::bigint AS tool_calls_count_injected,
tool_agg.tool_calls_count_non_injected::bigint AS tool_calls_count_non_injected,
tool_agg.injected_tool_call_error_count::bigint AS injected_tool_call_error_count
FROM
interception_counts ic,
duration_percentiles dp,
token_aggregates tok_agg,
prompt_aggregates pa,
tool_aggregates tool_agg
`
type CalculateAIBridgeInterceptionsTelemetrySummaryParams struct {
Provider string `db:"provider" json:"provider"`
Model string `db:"model" json:"model"`
Client string `db:"client" json:"client"`
EndedAtAfter time.Time `db:"ended_at_after" json:"ended_at_after"`
EndedAtBefore time.Time `db:"ended_at_before" json:"ended_at_before"`
}
type CalculateAIBridgeInterceptionsTelemetrySummaryRow struct {
InterceptionCount int64 `db:"interception_count" json:"interception_count"`
InterceptionDurationP50Millis int64 `db:"interception_duration_p50_millis" json:"interception_duration_p50_millis"`
InterceptionDurationP90Millis int64 `db:"interception_duration_p90_millis" json:"interception_duration_p90_millis"`
InterceptionDurationP95Millis int64 `db:"interception_duration_p95_millis" json:"interception_duration_p95_millis"`
InterceptionDurationP99Millis int64 `db:"interception_duration_p99_millis" json:"interception_duration_p99_millis"`
UniqueInitiatorCount int64 `db:"unique_initiator_count" json:"unique_initiator_count"`
UserPromptsCount int64 `db:"user_prompts_count" json:"user_prompts_count"`
TokenUsagesCount int64 `db:"token_usages_count" json:"token_usages_count"`
TokenCountInput int64 `db:"token_count_input" json:"token_count_input"`
TokenCountOutput int64 `db:"token_count_output" json:"token_count_output"`
TokenCountCachedRead int64 `db:"token_count_cached_read" json:"token_count_cached_read"`
TokenCountCachedWritten int64 `db:"token_count_cached_written" json:"token_count_cached_written"`
ToolCallsCountInjected int64 `db:"tool_calls_count_injected" json:"tool_calls_count_injected"`
ToolCallsCountNonInjected int64 `db:"tool_calls_count_non_injected" json:"tool_calls_count_non_injected"`
InjectedToolCallErrorCount int64 `db:"injected_tool_call_error_count" json:"injected_tool_call_error_count"`
}
// Calculates the telemetry summary for a given provider, model, and client
// combination for telemetry reporting.
func (q *sqlQuerier) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg CalculateAIBridgeInterceptionsTelemetrySummaryParams) (CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
row := q.db.QueryRowContext(ctx, calculateAIBridgeInterceptionsTelemetrySummary,
arg.Provider,
arg.Model,
arg.Client,
arg.EndedAtAfter,
arg.EndedAtBefore,
)
var i CalculateAIBridgeInterceptionsTelemetrySummaryRow
err := row.Scan(
&i.InterceptionCount,
&i.InterceptionDurationP50Millis,
&i.InterceptionDurationP90Millis,
&i.InterceptionDurationP95Millis,
&i.InterceptionDurationP99Millis,
&i.UniqueInitiatorCount,
&i.UserPromptsCount,
&i.TokenUsagesCount,
&i.TokenCountInput,
&i.TokenCountOutput,
&i.TokenCountCachedRead,
&i.TokenCountCachedWritten,
&i.ToolCallsCountInjected,
&i.ToolCallsCountNonInjected,
&i.InjectedToolCallErrorCount,
)
return i, err
}
const countAIBridgeInterceptions = `-- name: CountAIBridgeInterceptions :one
SELECT
COUNT(*)
@@ -805,57 +647,6 @@ func (q *sqlQuerier) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBr
return items, nil
}
const listAIBridgeInterceptionsTelemetrySummaries = `-- name: ListAIBridgeInterceptionsTelemetrySummaries :many
SELECT
DISTINCT ON (provider, model, client)
provider,
model,
-- TODO: use the client value once we have it (see https://github.com/coder/aibridge/issues/31)
'unknown' AS client
FROM
aibridge_interceptions
WHERE
ended_at IS NOT NULL -- incomplete interceptions are not included in summaries
AND ended_at >= $1::timestamptz
AND ended_at < $2::timestamptz
`
type ListAIBridgeInterceptionsTelemetrySummariesParams struct {
EndedAtAfter time.Time `db:"ended_at_after" json:"ended_at_after"`
EndedAtBefore time.Time `db:"ended_at_before" json:"ended_at_before"`
}
type ListAIBridgeInterceptionsTelemetrySummariesRow struct {
Provider string `db:"provider" json:"provider"`
Model string `db:"model" json:"model"`
Client string `db:"client" json:"client"`
}
// Finds all unique AIBridge interception telemetry summaries combinations
// (provider, model, client) in the given timeframe for telemetry reporting.
func (q *sqlQuerier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
rows, err := q.db.QueryContext(ctx, listAIBridgeInterceptionsTelemetrySummaries, arg.EndedAtAfter, arg.EndedAtBefore)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ListAIBridgeInterceptionsTelemetrySummariesRow
for rows.Next() {
var i ListAIBridgeInterceptionsTelemetrySummariesRow
if err := rows.Scan(&i.Provider, &i.Model, &i.Client); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listAIBridgeTokenUsagesByInterceptionIDs = `-- name: ListAIBridgeTokenUsagesByInterceptionIDs :many
SELECT
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
@@ -987,35 +778,6 @@ func (q *sqlQuerier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Contex
return items, nil
}
const updateAIBridgeInterceptionEnded = `-- name: UpdateAIBridgeInterceptionEnded :one
UPDATE aibridge_interceptions
SET ended_at = $1::timestamptz
WHERE
id = $2::uuid
AND ended_at IS NULL
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at
`
type UpdateAIBridgeInterceptionEndedParams struct {
EndedAt time.Time `db:"ended_at" json:"ended_at"`
ID uuid.UUID `db:"id" json:"id"`
}
func (q *sqlQuerier) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error) {
row := q.db.QueryRowContext(ctx, updateAIBridgeInterceptionEnded, arg.EndedAt, arg.ID)
var i AIBridgeInterception
err := row.Scan(
&i.ID,
&i.InitiatorID,
&i.Provider,
&i.Model,
&i.StartedAt,
&i.Metadata,
&i.EndedAt,
)
return i, err
}
const deleteAPIKeyByID = `-- name: DeleteAPIKeyByID :exec
DELETE FROM
api_keys
@@ -8285,93 +8047,6 @@ func (q *sqlQuerier) FindMatchingPresetID(ctx context.Context, arg FindMatchingP
return template_version_preset_id, err
}
const getOrganizationsWithPrebuildStatus = `-- name: GetOrganizationsWithPrebuildStatus :many
WITH orgs_with_prebuilds AS (
-- Get unique organizations that have presets with prebuilds configured
SELECT DISTINCT o.id, o.name
FROM organizations o
INNER JOIN templates t ON t.organization_id = o.id
INNER JOIN template_versions tv ON tv.template_id = t.id
INNER JOIN template_version_presets tvp ON tvp.template_version_id = tv.id
WHERE tvp.desired_instances IS NOT NULL
),
prebuild_user_membership AS (
-- Check if the user is a member of the organizations
SELECT om.organization_id
FROM organization_members om
INNER JOIN orgs_with_prebuilds owp ON owp.id = om.organization_id
WHERE om.user_id = $1::uuid
),
prebuild_groups AS (
-- Check if the organizations have the prebuilds group
SELECT g.organization_id, g.id as group_id
FROM groups g
INNER JOIN orgs_with_prebuilds owp ON owp.id = g.organization_id
WHERE g.name = $2::text
),
prebuild_group_membership AS (
-- Check if the user is in the prebuilds group
SELECT pg.organization_id
FROM prebuild_groups pg
INNER JOIN group_members gm ON gm.group_id = pg.group_id
WHERE gm.user_id = $1::uuid
)
SELECT
owp.id AS organization_id,
owp.name AS organization_name,
(pum.organization_id IS NOT NULL)::boolean AS has_prebuild_user,
pg.group_id AS prebuilds_group_id,
(pgm.organization_id IS NOT NULL)::boolean AS has_prebuild_user_in_group
FROM orgs_with_prebuilds owp
LEFT JOIN prebuild_groups pg ON pg.organization_id = owp.id
LEFT JOIN prebuild_user_membership pum ON pum.organization_id = owp.id
LEFT JOIN prebuild_group_membership pgm ON pgm.organization_id = owp.id
`
type GetOrganizationsWithPrebuildStatusParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
GroupName string `db:"group_name" json:"group_name"`
}
type GetOrganizationsWithPrebuildStatusRow struct {
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
OrganizationName string `db:"organization_name" json:"organization_name"`
HasPrebuildUser bool `db:"has_prebuild_user" json:"has_prebuild_user"`
PrebuildsGroupID uuid.NullUUID `db:"prebuilds_group_id" json:"prebuilds_group_id"`
HasPrebuildUserInGroup bool `db:"has_prebuild_user_in_group" json:"has_prebuild_user_in_group"`
}
// GetOrganizationsWithPrebuildStatus returns organizations with prebuilds configured and their
// membership status for the prebuilds system user (org membership, group existence, group membership).
func (q *sqlQuerier) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg GetOrganizationsWithPrebuildStatusParams) ([]GetOrganizationsWithPrebuildStatusRow, error) {
rows, err := q.db.QueryContext(ctx, getOrganizationsWithPrebuildStatus, arg.UserID, arg.GroupName)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetOrganizationsWithPrebuildStatusRow
for rows.Next() {
var i GetOrganizationsWithPrebuildStatusRow
if err := rows.Scan(
&i.OrganizationID,
&i.OrganizationName,
&i.HasPrebuildUser,
&i.PrebuildsGroupID,
&i.HasPrebuildUserInGroup,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getPrebuildMetrics = `-- name: GetPrebuildMetrics :many
SELECT
t.name as template_name,
@@ -8774,8 +8449,12 @@ func (q *sqlQuerier) GetTemplatePresetsWithPrebuilds(ctx context.Context, templa
}
const updatePrebuildProvisionerJobWithCancel = `-- name: UpdatePrebuildProvisionerJobWithCancel :many
WITH jobs_to_cancel AS (
SELECT pj.id, w.id AS workspace_id, w.template_id, wpb.template_version_preset_id
UPDATE provisioner_jobs
SET
canceled_at = $1::timestamptz,
completed_at = $1::timestamptz
WHERE id IN (
SELECT pj.id
FROM provisioner_jobs pj
INNER JOIN workspace_prebuild_builds wpb ON wpb.job_id = pj.id
INNER JOIN workspaces w ON w.id = wpb.workspace_id
@@ -8794,13 +8473,7 @@ WITH jobs_to_cancel AS (
AND pj.canceled_at IS NULL
AND pj.completed_at IS NULL
)
UPDATE provisioner_jobs
SET
canceled_at = $1::timestamptz,
completed_at = $1::timestamptz
FROM jobs_to_cancel
WHERE provisioner_jobs.id = jobs_to_cancel.id
RETURNING jobs_to_cancel.id, jobs_to_cancel.workspace_id, jobs_to_cancel.template_id, jobs_to_cancel.template_version_preset_id
RETURNING id
`
type UpdatePrebuildProvisionerJobWithCancelParams struct {
@@ -8808,34 +8481,22 @@ type UpdatePrebuildProvisionerJobWithCancelParams struct {
PresetID uuid.NullUUID `db:"preset_id" json:"preset_id"`
}
type UpdatePrebuildProvisionerJobWithCancelRow struct {
ID uuid.UUID `db:"id" json:"id"`
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
TemplateID uuid.UUID `db:"template_id" json:"template_id"`
TemplateVersionPresetID uuid.NullUUID `db:"template_version_preset_id" json:"template_version_preset_id"`
}
// Cancels all pending provisioner jobs for prebuilt workspaces on a specific preset from an
// inactive template version.
// This is an optimization to clean up stale pending jobs.
func (q *sqlQuerier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg UpdatePrebuildProvisionerJobWithCancelParams) ([]UpdatePrebuildProvisionerJobWithCancelRow, error) {
func (q *sqlQuerier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
rows, err := q.db.QueryContext(ctx, updatePrebuildProvisionerJobWithCancel, arg.Now, arg.PresetID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []UpdatePrebuildProvisionerJobWithCancelRow
var items []uuid.UUID
for rows.Next() {
var i UpdatePrebuildProvisionerJobWithCancelRow
if err := rows.Scan(
&i.ID,
&i.WorkspaceID,
&i.TemplateID,
&i.TemplateVersionPresetID,
); err != nil {
var id uuid.UUID
if err := rows.Scan(&id); err != nil {
return nil, err
}
items = append(items, i)
items = append(items, id)
}
if err := rows.Close(); err != nil {
return nil, err
@@ -13065,7 +12726,7 @@ func (q *sqlQuerier) DeleteTask(ctx context.Context, arg DeleteTaskParams) (Task
}
const getTaskByID = `-- name: GetTaskByID :one
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE id = $1::uuid
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id FROM tasks_with_status WHERE id = $1::uuid
`
func (q *sqlQuerier) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error) {
@@ -13086,15 +12747,12 @@ func (q *sqlQuerier) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error
&i.WorkspaceBuildNumber,
&i.WorkspaceAgentID,
&i.WorkspaceAppID,
&i.OwnerUsername,
&i.OwnerName,
&i.OwnerAvatarUrl,
)
return i, err
}
const getTaskByWorkspaceID = `-- name: GetTaskByWorkspaceID :one
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE workspace_id = $1::uuid
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id FROM tasks_with_status WHERE workspace_id = $1::uuid
`
func (q *sqlQuerier) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (Task, error) {
@@ -13115,9 +12773,6 @@ func (q *sqlQuerier) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.
&i.WorkspaceBuildNumber,
&i.WorkspaceAgentID,
&i.WorkspaceAppID,
&i.OwnerUsername,
&i.OwnerName,
&i.OwnerAvatarUrl,
)
return i, err
}
@@ -13126,12 +12781,11 @@ const insertTask = `-- name: InsertTask :one
INSERT INTO tasks
(id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9)
(gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at
`
type InsertTaskParams struct {
ID uuid.UUID `db:"id" json:"id"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
Name string `db:"name" json:"name"`
@@ -13144,7 +12798,6 @@ type InsertTaskParams struct {
func (q *sqlQuerier) InsertTask(ctx context.Context, arg InsertTaskParams) (TaskTable, error) {
row := q.db.QueryRowContext(ctx, insertTask,
arg.ID,
arg.OrganizationID,
arg.OwnerID,
arg.Name,
@@ -13171,7 +12824,7 @@ func (q *sqlQuerier) InsertTask(ctx context.Context, arg InsertTaskParams) (Task
}
const listTasks = `-- name: ListTasks :many
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id, owner_username, owner_name, owner_avatar_url FROM tasks_with_status tws
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id FROM tasks_with_status tws
WHERE tws.deleted_at IS NULL
AND CASE WHEN $1::UUID != '00000000-0000-0000-0000-000000000000' THEN tws.owner_id = $1::UUID ELSE TRUE END
AND CASE WHEN $2::UUID != '00000000-0000-0000-0000-000000000000' THEN tws.organization_id = $2::UUID ELSE TRUE END
@@ -13209,9 +12862,6 @@ func (q *sqlQuerier) ListTasks(ctx context.Context, arg ListTasksParams) ([]Task
&i.WorkspaceBuildNumber,
&i.WorkspaceAgentID,
&i.WorkspaceAppID,
&i.OwnerUsername,
&i.OwnerName,
&i.OwnerAvatarUrl,
); err != nil {
return nil, err
}
@@ -13385,41 +13035,6 @@ func (q *sqlQuerier) UpsertTelemetryItem(ctx context.Context, arg UpsertTelemetr
return err
}
const deleteOldTelemetryLocks = `-- name: DeleteOldTelemetryLocks :exec
DELETE FROM
telemetry_locks
WHERE
period_ending_at < $1::timestamptz
`
// Deletes old telemetry locks from the telemetry_locks table.
func (q *sqlQuerier) DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error {
_, err := q.db.ExecContext(ctx, deleteOldTelemetryLocks, periodEndingAtBefore)
return err
}
const insertTelemetryLock = `-- name: InsertTelemetryLock :exec
INSERT INTO
telemetry_locks (event_type, period_ending_at)
VALUES
($1, $2)
`
type InsertTelemetryLockParams struct {
EventType string `db:"event_type" json:"event_type"`
PeriodEndingAt time.Time `db:"period_ending_at" json:"period_ending_at"`
}
// Inserts a new lock row into the telemetry_locks table. Replicas should call
// this function prior to attempting to generate or publish a heartbeat event to
// the telemetry service.
// If the query returns a duplicate primary key error, the replica should not
// attempt to generate or publish the event to the telemetry service.
func (q *sqlQuerier) InsertTelemetryLock(ctx context.Context, arg InsertTelemetryLockParams) error {
_, err := q.db.ExecContext(ctx, insertTelemetryLock, arg.EventType, arg.PeriodEndingAt)
return err
}
const getTemplateAverageBuildTime = `-- name: GetTemplateAverageBuildTime :one
WITH build_times AS (
SELECT
@@ -21927,7 +21542,7 @@ func (q *sqlQuerier) GetWorkspaceACLByID(ctx context.Context, id uuid.UUID) (Get
const getWorkspaceByAgentID = `-- name: GetWorkspaceByAgentID :one
SELECT
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
FROM
workspaces_expanded as workspaces
WHERE
@@ -21988,14 +21603,13 @@ func (q *sqlQuerier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUI
&i.TemplateDisplayName,
&i.TemplateIcon,
&i.TemplateDescription,
&i.TaskID,
)
return i, err
}
const getWorkspaceByID = `-- name: GetWorkspaceByID :one
SELECT
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
FROM
workspaces_expanded
WHERE
@@ -22037,14 +21651,13 @@ func (q *sqlQuerier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (Worksp
&i.TemplateDisplayName,
&i.TemplateIcon,
&i.TemplateDescription,
&i.TaskID,
)
return i, err
}
const getWorkspaceByOwnerIDAndName = `-- name: GetWorkspaceByOwnerIDAndName :one
SELECT
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
FROM
workspaces_expanded as workspaces
WHERE
@@ -22093,14 +21706,13 @@ func (q *sqlQuerier) GetWorkspaceByOwnerIDAndName(ctx context.Context, arg GetWo
&i.TemplateDisplayName,
&i.TemplateIcon,
&i.TemplateDescription,
&i.TaskID,
)
return i, err
}
const getWorkspaceByResourceID = `-- name: GetWorkspaceByResourceID :one
SELECT
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
FROM
workspaces_expanded as workspaces
WHERE
@@ -22156,14 +21768,13 @@ func (q *sqlQuerier) GetWorkspaceByResourceID(ctx context.Context, resourceID uu
&i.TemplateDisplayName,
&i.TemplateIcon,
&i.TemplateDescription,
&i.TaskID,
)
return i, err
}
const getWorkspaceByWorkspaceAppID = `-- name: GetWorkspaceByWorkspaceAppID :one
SELECT
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
FROM
workspaces_expanded as workspaces
WHERE
@@ -22231,7 +21842,6 @@ func (q *sqlQuerier) GetWorkspaceByWorkspaceAppID(ctx context.Context, workspace
&i.TemplateDisplayName,
&i.TemplateIcon,
&i.TemplateDescription,
&i.TaskID,
)
return i, err
}
@@ -22281,7 +21891,7 @@ SELECT
),
filtered_workspaces AS (
SELECT
workspaces.id, workspaces.created_at, workspaces.updated_at, workspaces.owner_id, workspaces.organization_id, workspaces.template_id, workspaces.deleted, workspaces.name, workspaces.autostart_schedule, workspaces.ttl, workspaces.last_used_at, workspaces.dormant_at, workspaces.deleting_at, workspaces.automatic_updates, workspaces.favorite, workspaces.next_start_at, workspaces.group_acl, workspaces.user_acl, workspaces.owner_avatar_url, workspaces.owner_username, workspaces.owner_name, workspaces.organization_name, workspaces.organization_display_name, workspaces.organization_icon, workspaces.organization_description, workspaces.template_name, workspaces.template_display_name, workspaces.template_icon, workspaces.template_description, workspaces.task_id,
workspaces.id, workspaces.created_at, workspaces.updated_at, workspaces.owner_id, workspaces.organization_id, workspaces.template_id, workspaces.deleted, workspaces.name, workspaces.autostart_schedule, workspaces.ttl, workspaces.last_used_at, workspaces.dormant_at, workspaces.deleting_at, workspaces.automatic_updates, workspaces.favorite, workspaces.next_start_at, workspaces.group_acl, workspaces.user_acl, workspaces.owner_avatar_url, workspaces.owner_username, workspaces.owner_name, workspaces.organization_name, workspaces.organization_display_name, workspaces.organization_icon, workspaces.organization_description, workspaces.template_name, workspaces.template_display_name, workspaces.template_icon, workspaces.template_description,
latest_build.template_version_id,
latest_build.template_version_name,
latest_build.completed_at as latest_build_completed_at,
@@ -22572,7 +22182,7 @@ WHERE
-- @authorize_filter
), filtered_workspaces_order AS (
SELECT
fw.id, fw.created_at, fw.updated_at, fw.owner_id, fw.organization_id, fw.template_id, fw.deleted, fw.name, fw.autostart_schedule, fw.ttl, fw.last_used_at, fw.dormant_at, fw.deleting_at, fw.automatic_updates, fw.favorite, fw.next_start_at, fw.group_acl, fw.user_acl, fw.owner_avatar_url, fw.owner_username, fw.owner_name, fw.organization_name, fw.organization_display_name, fw.organization_icon, fw.organization_description, fw.template_name, fw.template_display_name, fw.template_icon, fw.template_description, fw.task_id, fw.template_version_id, fw.template_version_name, fw.latest_build_completed_at, fw.latest_build_canceled_at, fw.latest_build_error, fw.latest_build_transition, fw.latest_build_status, fw.latest_build_has_ai_task, fw.latest_build_has_external_agent
fw.id, fw.created_at, fw.updated_at, fw.owner_id, fw.organization_id, fw.template_id, fw.deleted, fw.name, fw.autostart_schedule, fw.ttl, fw.last_used_at, fw.dormant_at, fw.deleting_at, fw.automatic_updates, fw.favorite, fw.next_start_at, fw.group_acl, fw.user_acl, fw.owner_avatar_url, fw.owner_username, fw.owner_name, fw.organization_name, fw.organization_display_name, fw.organization_icon, fw.organization_description, fw.template_name, fw.template_display_name, fw.template_icon, fw.template_description, fw.template_version_id, fw.template_version_name, fw.latest_build_completed_at, fw.latest_build_canceled_at, fw.latest_build_error, fw.latest_build_transition, fw.latest_build_status, fw.latest_build_has_ai_task, fw.latest_build_has_external_agent
FROM
filtered_workspaces fw
ORDER BY
@@ -22593,7 +22203,7 @@ WHERE
$25
), filtered_workspaces_order_with_summary AS (
SELECT
fwo.id, fwo.created_at, fwo.updated_at, fwo.owner_id, fwo.organization_id, fwo.template_id, fwo.deleted, fwo.name, fwo.autostart_schedule, fwo.ttl, fwo.last_used_at, fwo.dormant_at, fwo.deleting_at, fwo.automatic_updates, fwo.favorite, fwo.next_start_at, fwo.group_acl, fwo.user_acl, fwo.owner_avatar_url, fwo.owner_username, fwo.owner_name, fwo.organization_name, fwo.organization_display_name, fwo.organization_icon, fwo.organization_description, fwo.template_name, fwo.template_display_name, fwo.template_icon, fwo.template_description, fwo.task_id, fwo.template_version_id, fwo.template_version_name, fwo.latest_build_completed_at, fwo.latest_build_canceled_at, fwo.latest_build_error, fwo.latest_build_transition, fwo.latest_build_status, fwo.latest_build_has_ai_task, fwo.latest_build_has_external_agent
fwo.id, fwo.created_at, fwo.updated_at, fwo.owner_id, fwo.organization_id, fwo.template_id, fwo.deleted, fwo.name, fwo.autostart_schedule, fwo.ttl, fwo.last_used_at, fwo.dormant_at, fwo.deleting_at, fwo.automatic_updates, fwo.favorite, fwo.next_start_at, fwo.group_acl, fwo.user_acl, fwo.owner_avatar_url, fwo.owner_username, fwo.owner_name, fwo.organization_name, fwo.organization_display_name, fwo.organization_icon, fwo.organization_description, fwo.template_name, fwo.template_display_name, fwo.template_icon, fwo.template_description, fwo.template_version_id, fwo.template_version_name, fwo.latest_build_completed_at, fwo.latest_build_canceled_at, fwo.latest_build_error, fwo.latest_build_transition, fwo.latest_build_status, fwo.latest_build_has_ai_task, fwo.latest_build_has_external_agent
FROM
filtered_workspaces_order fwo
-- Return a technical summary row with total count of workspaces.
@@ -22629,7 +22239,6 @@ WHERE
'', -- template_display_name
'', -- template_icon
'', -- template_description
'00000000-0000-0000-0000-000000000000'::uuid, -- task_id
-- Extra columns added to ` + "`" + `filtered_workspaces` + "`" + `
'00000000-0000-0000-0000-000000000000'::uuid, -- template_version_id
'', -- template_version_name
@@ -22649,7 +22258,7 @@ WHERE
filtered_workspaces
)
SELECT
fwos.id, fwos.created_at, fwos.updated_at, fwos.owner_id, fwos.organization_id, fwos.template_id, fwos.deleted, fwos.name, fwos.autostart_schedule, fwos.ttl, fwos.last_used_at, fwos.dormant_at, fwos.deleting_at, fwos.automatic_updates, fwos.favorite, fwos.next_start_at, fwos.group_acl, fwos.user_acl, fwos.owner_avatar_url, fwos.owner_username, fwos.owner_name, fwos.organization_name, fwos.organization_display_name, fwos.organization_icon, fwos.organization_description, fwos.template_name, fwos.template_display_name, fwos.template_icon, fwos.template_description, fwos.task_id, fwos.template_version_id, fwos.template_version_name, fwos.latest_build_completed_at, fwos.latest_build_canceled_at, fwos.latest_build_error, fwos.latest_build_transition, fwos.latest_build_status, fwos.latest_build_has_ai_task, fwos.latest_build_has_external_agent,
fwos.id, fwos.created_at, fwos.updated_at, fwos.owner_id, fwos.organization_id, fwos.template_id, fwos.deleted, fwos.name, fwos.autostart_schedule, fwos.ttl, fwos.last_used_at, fwos.dormant_at, fwos.deleting_at, fwos.automatic_updates, fwos.favorite, fwos.next_start_at, fwos.group_acl, fwos.user_acl, fwos.owner_avatar_url, fwos.owner_username, fwos.owner_name, fwos.organization_name, fwos.organization_display_name, fwos.organization_icon, fwos.organization_description, fwos.template_name, fwos.template_display_name, fwos.template_icon, fwos.template_description, fwos.template_version_id, fwos.template_version_name, fwos.latest_build_completed_at, fwos.latest_build_canceled_at, fwos.latest_build_error, fwos.latest_build_transition, fwos.latest_build_status, fwos.latest_build_has_ai_task, fwos.latest_build_has_external_agent,
tc.count
FROM
filtered_workspaces_order_with_summary fwos
@@ -22717,7 +22326,6 @@ type GetWorkspacesRow struct {
TemplateDisplayName string `db:"template_display_name" json:"template_display_name"`
TemplateIcon string `db:"template_icon" json:"template_icon"`
TemplateDescription string `db:"template_description" json:"template_description"`
TaskID uuid.NullUUID `db:"task_id" json:"task_id"`
TemplateVersionID uuid.UUID `db:"template_version_id" json:"template_version_id"`
TemplateVersionName sql.NullString `db:"template_version_name" json:"template_version_name"`
LatestBuildCompletedAt sql.NullTime `db:"latest_build_completed_at" json:"latest_build_completed_at"`
@@ -22800,7 +22408,6 @@ func (q *sqlQuerier) GetWorkspaces(ctx context.Context, arg GetWorkspacesParams)
&i.TemplateDisplayName,
&i.TemplateIcon,
&i.TemplateDescription,
&i.TaskID,
&i.TemplateVersionID,
&i.TemplateVersionName,
&i.LatestBuildCompletedAt,
@@ -23531,7 +23138,6 @@ SET
WHERE
template_id = $3
AND dormant_at IS NOT NULL
AND deleted = false
-- Prebuilt workspaces (identified by having the prebuilds system user as owner_id)
-- should not have their dormant or deleting at set, as these are handled by the
-- prebuilds reconciliation loop.
-127
View File
@@ -6,14 +6,6 @@ INSERT INTO aibridge_interceptions (
)
RETURNING *;
-- name: UpdateAIBridgeInterceptionEnded :one
UPDATE aibridge_interceptions
SET ended_at = @ended_at::timestamptz
WHERE
id = @id::uuid
AND ended_at IS NULL
RETURNING *;
-- name: InsertAIBridgeTokenUsage :one
INSERT INTO aibridge_token_usages (
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
@@ -207,122 +199,3 @@ WHERE
ORDER BY
created_at ASC,
id ASC;
-- name: ListAIBridgeInterceptionsTelemetrySummaries :many
-- Finds all unique AIBridge interception telemetry summaries combinations
-- (provider, model, client) in the given timeframe for telemetry reporting.
SELECT
DISTINCT ON (provider, model, client)
provider,
model,
-- TODO: use the client value once we have it (see https://github.com/coder/aibridge/issues/31)
'unknown' AS client
FROM
aibridge_interceptions
WHERE
ended_at IS NOT NULL -- incomplete interceptions are not included in summaries
AND ended_at >= @ended_at_after::timestamptz
AND ended_at < @ended_at_before::timestamptz;
-- name: CalculateAIBridgeInterceptionsTelemetrySummary :one
-- Calculates the telemetry summary for a given provider, model, and client
-- combination for telemetry reporting.
WITH interceptions_in_range AS (
-- Get all matching interceptions in the given timeframe.
SELECT
id,
initiator_id,
(ended_at - started_at) AS duration
FROM
aibridge_interceptions
WHERE
provider = @provider::text
AND model = @model::text
-- TODO: use the client value once we have it (see https://github.com/coder/aibridge/issues/31)
AND 'unknown' = @client::text
AND ended_at IS NOT NULL -- incomplete interceptions are not included in summaries
AND ended_at >= @ended_at_after::timestamptz
AND ended_at < @ended_at_before::timestamptz
),
interception_counts AS (
SELECT
COUNT(id) AS interception_count,
COUNT(DISTINCT initiator_id) AS unique_initiator_count
FROM
interceptions_in_range
),
duration_percentiles AS (
SELECT
(COALESCE(PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p50_millis,
(COALESCE(PERCENTILE_CONT(0.90) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p90_millis,
(COALESCE(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p95_millis,
(COALESCE(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p99_millis
FROM
interceptions_in_range
),
token_aggregates AS (
SELECT
COALESCE(SUM(tu.input_tokens), 0) AS token_count_input,
COALESCE(SUM(tu.output_tokens), 0) AS token_count_output,
-- Cached tokens are stored in metadata JSON, extract if available.
-- Read tokens may be stored in:
-- - cache_read_input (Anthropic)
-- - prompt_cached (OpenAI)
COALESCE(SUM(
COALESCE((tu.metadata->>'cache_read_input')::bigint, 0) +
COALESCE((tu.metadata->>'prompt_cached')::bigint, 0)
), 0) AS token_count_cached_read,
-- Written tokens may be stored in:
-- - cache_creation_input (Anthropic)
-- Note that cache_ephemeral_5m_input and cache_ephemeral_1h_input on
-- Anthropic are included in the cache_creation_input field.
COALESCE(SUM(
COALESCE((tu.metadata->>'cache_creation_input')::bigint, 0)
), 0) AS token_count_cached_written,
COUNT(tu.id) AS token_usages_count
FROM
interceptions_in_range i
LEFT JOIN
aibridge_token_usages tu ON i.id = tu.interception_id
),
prompt_aggregates AS (
SELECT
COUNT(up.id) AS user_prompts_count
FROM
interceptions_in_range i
LEFT JOIN
aibridge_user_prompts up ON i.id = up.interception_id
),
tool_aggregates AS (
SELECT
COUNT(tu.id) FILTER (WHERE tu.injected = true) AS tool_calls_count_injected,
COUNT(tu.id) FILTER (WHERE tu.injected = false) AS tool_calls_count_non_injected,
COUNT(tu.id) FILTER (WHERE tu.injected = true AND tu.invocation_error IS NOT NULL) AS injected_tool_call_error_count
FROM
interceptions_in_range i
LEFT JOIN
aibridge_tool_usages tu ON i.id = tu.interception_id
)
SELECT
ic.interception_count::bigint AS interception_count,
dp.interception_duration_p50_millis::bigint AS interception_duration_p50_millis,
dp.interception_duration_p90_millis::bigint AS interception_duration_p90_millis,
dp.interception_duration_p95_millis::bigint AS interception_duration_p95_millis,
dp.interception_duration_p99_millis::bigint AS interception_duration_p99_millis,
ic.unique_initiator_count::bigint AS unique_initiator_count,
pa.user_prompts_count::bigint AS user_prompts_count,
tok_agg.token_usages_count::bigint AS token_usages_count,
tok_agg.token_count_input::bigint AS token_count_input,
tok_agg.token_count_output::bigint AS token_count_output,
tok_agg.token_count_cached_read::bigint AS token_count_cached_read,
tok_agg.token_count_cached_written::bigint AS token_count_cached_written,
tool_agg.tool_calls_count_injected::bigint AS tool_calls_count_injected,
tool_agg.tool_calls_count_non_injected::bigint AS tool_calls_count_non_injected,
tool_agg.injected_tool_call_error_count::bigint AS injected_tool_call_error_count
FROM
interception_counts ic,
duration_percentiles dp,
token_aggregates tok_agg,
prompt_aggregates pa,
tool_aggregates tool_agg
;
+7 -53
View File
@@ -300,8 +300,12 @@ GROUP BY wpb.template_version_preset_id;
-- Cancels all pending provisioner jobs for prebuilt workspaces on a specific preset from an
-- inactive template version.
-- This is an optimization to clean up stale pending jobs.
WITH jobs_to_cancel AS (
SELECT pj.id, w.id AS workspace_id, w.template_id, wpb.template_version_preset_id
UPDATE provisioner_jobs
SET
canceled_at = @now::timestamptz,
completed_at = @now::timestamptz
WHERE id IN (
SELECT pj.id
FROM provisioner_jobs pj
INNER JOIN workspace_prebuild_builds wpb ON wpb.job_id = pj.id
INNER JOIN workspaces w ON w.id = wpb.workspace_id
@@ -320,54 +324,4 @@ WITH jobs_to_cancel AS (
AND pj.canceled_at IS NULL
AND pj.completed_at IS NULL
)
UPDATE provisioner_jobs
SET
canceled_at = @now::timestamptz,
completed_at = @now::timestamptz
FROM jobs_to_cancel
WHERE provisioner_jobs.id = jobs_to_cancel.id
RETURNING jobs_to_cancel.id, jobs_to_cancel.workspace_id, jobs_to_cancel.template_id, jobs_to_cancel.template_version_preset_id;
-- name: GetOrganizationsWithPrebuildStatus :many
-- GetOrganizationsWithPrebuildStatus returns organizations with prebuilds configured and their
-- membership status for the prebuilds system user (org membership, group existence, group membership).
WITH orgs_with_prebuilds AS (
-- Get unique organizations that have presets with prebuilds configured
SELECT DISTINCT o.id, o.name
FROM organizations o
INNER JOIN templates t ON t.organization_id = o.id
INNER JOIN template_versions tv ON tv.template_id = t.id
INNER JOIN template_version_presets tvp ON tvp.template_version_id = tv.id
WHERE tvp.desired_instances IS NOT NULL
),
prebuild_user_membership AS (
-- Check if the user is a member of the organizations
SELECT om.organization_id
FROM organization_members om
INNER JOIN orgs_with_prebuilds owp ON owp.id = om.organization_id
WHERE om.user_id = @user_id::uuid
),
prebuild_groups AS (
-- Check if the organizations have the prebuilds group
SELECT g.organization_id, g.id as group_id
FROM groups g
INNER JOIN orgs_with_prebuilds owp ON owp.id = g.organization_id
WHERE g.name = @group_name::text
),
prebuild_group_membership AS (
-- Check if the user is in the prebuilds group
SELECT pg.organization_id
FROM prebuild_groups pg
INNER JOIN group_members gm ON gm.group_id = pg.group_id
WHERE gm.user_id = @user_id::uuid
)
SELECT
owp.id AS organization_id,
owp.name AS organization_name,
(pum.organization_id IS NOT NULL)::boolean AS has_prebuild_user,
pg.group_id AS prebuilds_group_id,
(pgm.organization_id IS NOT NULL)::boolean AS has_prebuild_user_in_group
FROM orgs_with_prebuilds owp
LEFT JOIN prebuild_groups pg ON pg.organization_id = owp.id
LEFT JOIN prebuild_user_membership pum ON pum.organization_id = owp.id
LEFT JOIN prebuild_group_membership pgm ON pgm.organization_id = owp.id;
RETURNING id;
+1 -1
View File
@@ -2,7 +2,7 @@
INSERT INTO tasks
(id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9)
(gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8)
RETURNING *;
-- name: UpdateTaskWorkspaceID :one
@@ -1,17 +0,0 @@
-- name: InsertTelemetryLock :exec
-- Inserts a new lock row into the telemetry_locks table. Replicas should call
-- this function prior to attempting to generate or publish a heartbeat event to
-- the telemetry service.
-- If the query returns a duplicate primary key error, the replica should not
-- attempt to generate or publish the event to the telemetry service.
INSERT INTO
telemetry_locks (event_type, period_ending_at)
VALUES
($1, $2);
-- name: DeleteOldTelemetryLocks :exec
-- Deletes old telemetry locks from the telemetry_locks table.
DELETE FROM
telemetry_locks
WHERE
period_ending_at < @period_ending_at_before::timestamptz;

Some files were not shown because too many files have changed in this diff Show More