Compare commits

..

3 Commits

Author SHA1 Message Date
Sas Swart 34c1370090 fix agent socket tests 2025-10-28 06:30:29 +00:00
Sas Swart 851c4f907c add a socket to the agent for local IPC 2025-10-28 06:26:49 +00:00
Sas Swart e3dfe45f35 LLM generated implementation of unit status change communication 2025-10-27 11:10:22 +00:00
227 changed files with 6355 additions and 7922 deletions
+3 -10
View File
@@ -5,13 +5,6 @@ runs:
using: "composite"
steps:
- name: Setup sqlc
# uses: sqlc-dev/setup-sqlc@c0209b9199cd1cce6a14fc27cabcec491b651761 # v4.0.0
# with:
# sqlc-version: "1.30.0"
# Switched to coder/sqlc fork to fix ambiguous column bug, see:
# - https://github.com/coder/sqlc/pull/1
# - https://github.com/sqlc-dev/sqlc/pull/4159
shell: bash
run: |
CGO_ENABLED=1 go install github.com/coder/sqlc/cmd/sqlc@aab4e865a51df0c43e1839f81a9d349b41d14f05
uses: sqlc-dev/setup-sqlc@c0209b9199cd1cce6a14fc27cabcec491b651761 # v4.0.0
with:
sqlc-version: "1.27.0"
+1 -1
View File
@@ -7,5 +7,5 @@ runs:
- name: Install Terraform
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # v3.1.2
with:
terraform_version: 1.13.4
terraform_version: 1.13.0
terraform_wrapper: false
+26 -20
View File
@@ -204,17 +204,9 @@ jobs:
# Needed for helm chart linting
- name: Install helm
# uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # v4.3.1
# with:
# version: v3.9.2
# The below is taken from https://helm.sh/docs/intro/install/#from-apt-debianubuntu
run: |
set -euo pipefail
sudo apt-get install curl gpg apt-transport-https --yes
curl -fsSL https://packages.buildkite.com/helm-linux/helm-debian/gpgkey | gpg --dearmor | sudo tee /usr/share/keyrings/helm.gpg > /dev/null
echo "deb [signed-by=/usr/share/keyrings/helm.gpg] https://packages.buildkite.com/helm-linux/helm-debian/any/ any main" | sudo tee /etc/apt/sources.list.d/helm-stable-debian.list
sudo apt-get update
sudo apt-get install helm
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # v4.3.1
with:
version: v3.9.2
- name: make lint
run: |
@@ -384,6 +376,13 @@ jobs:
id: go-paths
uses: ./.github/actions/setup-go-paths
- name: Download Go Build Cache
id: download-go-build-cache
uses: ./.github/actions/test-cache/download
with:
key-prefix: test-go-build-${{ runner.os }}-${{ runner.arch }}
cache-path: ${{ steps.go-paths.outputs.cached-dirs }}
- name: Setup Go
uses: ./.github/actions/setup-go
with:
@@ -391,7 +390,8 @@ jobs:
# download the toolchain configured in go.mod, so we don't
# need to reinstall it. It's faster on Windows runners.
use-preinstalled-go: ${{ runner.os == 'Windows' }}
use-cache: true
# Cache is already downloaded above
use-cache: false
- name: Setup Terraform
uses: ./.github/actions/setup-tf
@@ -500,11 +500,17 @@ jobs:
make test
- name: Upload failed test db dumps
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: failed-test-db-dump-${{matrix.os}}
path: "**/*.test.sql"
- name: Upload Go Build Cache
uses: ./.github/actions/test-cache/upload
with:
cache-key: ${{ steps.download-go-build-cache.outputs.cache-key }}
cache-path: ${{ steps.go-paths.outputs.cached-dirs }}
- name: Upload Test Cache
uses: ./.github/actions/test-cache/upload
with:
@@ -756,7 +762,7 @@ jobs:
- name: Upload Playwright Failed Tests
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: failed-test-videos${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/test-results/**/*.webm
@@ -764,7 +770,7 @@ jobs:
- name: Upload pprof dumps
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: debug-pprof-dumps${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/test-results/**/debug-pprof-*.txt
@@ -800,7 +806,7 @@ jobs:
# the check to pass. This is desired in PRs, but not in mainline.
- name: Publish to Chromatic (non-mainline)
if: github.ref != 'refs/heads/main' && github.repository_owner == 'coder'
uses: chromaui/action@bc2d84ad2b60813a67d995c5582d696104a19383 # v13.3.2
uses: chromaui/action@4ffe736a2a8262ea28067ff05a13b635ba31ec05 # v13.3.0
env:
NODE_OPTIONS: "--max_old_space_size=4096"
STORYBOOK: true
@@ -832,7 +838,7 @@ jobs:
# infinitely "in progress" in mainline unless we re-review each build.
- name: Publish to Chromatic (mainline)
if: github.ref == 'refs/heads/main' && github.repository_owner == 'coder'
uses: chromaui/action@bc2d84ad2b60813a67d995c5582d696104a19383 # v13.3.2
uses: chromaui/action@4ffe736a2a8262ea28067ff05a13b635ba31ec05 # v13.3.0
env:
NODE_OPTIONS: "--max_old_space_size=4096"
STORYBOOK: true
@@ -1030,7 +1036,7 @@ jobs:
- name: Upload build artifacts
if: ${{ github.repository_owner == 'coder' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: dylibs
path: |
@@ -1195,7 +1201,7 @@ jobs:
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
- name: Download dylibs
uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
with:
name: dylibs
path: ./build
@@ -1462,7 +1468,7 @@ jobs:
- name: Upload build artifacts
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: coder
path: |
+1 -1
View File
@@ -30,7 +30,7 @@ jobs:
- name: Setup Node
uses: ./.github/actions/setup-node
- uses: tj-actions/changed-files@dbf178ceecb9304128c8e0648591d71208c6e2c9 # v45.0.7
- uses: tj-actions/changed-files@d03a93c0dbfac6d6dd6a0d8a5e7daff992b07449 # v45.0.7
id: changed-files
with:
files: |
+2 -2
View File
@@ -36,11 +36,11 @@ jobs:
persist-credentials: false
- name: Setup Nix
uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34
uses: nixbuild/nix-quick-install-action@1f095fee853b33114486cfdeae62fa099cda35a9 # v33
with:
# Pinning to 2.28 here, as Nix gets a "error: [json.exception.type_error.302] type must be array, but is string"
# on version 2.29 and above.
nix_version: "2.28.5"
nix_version: "2.28.4"
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
with:
+4 -4
View File
@@ -131,7 +131,7 @@ jobs:
AC_CERTIFICATE_PASSWORD_FILE: /tmp/apple_cert_password.txt
- name: Upload build artifacts
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: dylibs
path: |
@@ -327,7 +327,7 @@ jobs:
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
- name: Download dylibs
uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
with:
name: dylibs
path: ./build
@@ -761,7 +761,7 @@ jobs:
- name: Upload artifacts to actions (if dry-run)
if: ${{ inputs.dry_run }}
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: release-artifacts
path: |
@@ -777,7 +777,7 @@ jobs:
- name: Upload latest sbom artifact to actions (if dry-run)
if: inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true'
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: latest-sbom-artifact
path: ./coder_latest_sbom.spdx.json
+2 -2
View File
@@ -39,7 +39,7 @@ jobs:
# Upload the results as artifacts.
- name: "Upload artifact"
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: SARIF file
path: results.sarif
@@ -47,6 +47,6 @@ jobs:
# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
uses: github/codeql-action/upload-sarif@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
with:
sarif_file: results.sarif
+4 -4
View File
@@ -40,7 +40,7 @@ jobs:
uses: ./.github/actions/setup-go
- name: Initialize CodeQL
uses: github/codeql-action/init@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
uses: github/codeql-action/init@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
with:
languages: go, javascript
@@ -50,7 +50,7 @@ jobs:
rm Makefile
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
uses: github/codeql-action/analyze@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
- name: Send Slack notification on failure
if: ${{ failure() }}
@@ -154,13 +154,13 @@ jobs:
severity: "CRITICAL,HIGH"
- name: Upload Trivy scan results to GitHub Security tab
uses: github/codeql-action/upload-sarif@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
uses: github/codeql-action/upload-sarif@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
with:
sarif_file: trivy-results.sarif
category: "Trivy"
- name: Upload Trivy scan results as an artifact
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: trivy
path: trivy-results.sarif
+2 -2
View File
@@ -125,7 +125,7 @@ jobs:
egress-policy: audit
- name: Delete PR Cleanup workflow runs
uses: Mattraks/delete-workflow-runs@ab482449ba468316e9a8801e092d0405715c5e6d # v2.1.0
uses: Mattraks/delete-workflow-runs@39f0bbed25d76b34de5594dceab824811479e5de # v2.0.6
with:
token: ${{ github.token }}
repository: ${{ github.repository }}
@@ -134,7 +134,7 @@ jobs:
delete_workflow_pattern: pr-cleanup.yaml
- name: Delete PR Deploy workflow skipped runs
uses: Mattraks/delete-workflow-runs@ab482449ba468316e9a8801e092d0405715c5e6d # v2.1.0
uses: Mattraks/delete-workflow-runs@39f0bbed25d76b34de5594dceab824811479e5de # v2.0.6
with:
token: ${{ github.token }}
repository: ${{ github.repository }}
-2
View File
@@ -89,5 +89,3 @@ result
__debug_bin*
**/.claude/settings.local.json
/.env
+12
View File
@@ -18,6 +18,18 @@ coderd/rbac/ @Emyrk
scripts/apitypings/ @Emyrk
scripts/gensite/ @aslilac
site/ @aslilac @Parkreiner
site/src/hooks/ @Parkreiner
# These rules intentionally do not specify any owners. More specific rules
# override less specific rules, so these files are "ignored" by the site/ rule.
site/e2e/google/protobuf/timestampGenerated.ts
site/e2e/provisionerGenerated.ts
site/src/api/countriesGenerated.ts
site/src/api/rbacresourcesGenerated.ts
site/src/api/typesGenerated.ts
site/src/testHelpers/entities.ts
site/CLAUDE.md
# The blood and guts of the autostop algorithm, which is quite complex and
# requires elite ball knowledge of most of the scheduling code to make changes
# without inadvertently affecting other parts of the codebase.
+8 -13
View File
@@ -636,8 +636,8 @@ TAILNETTEST_MOCKS := \
tailnet/tailnettest/subscriptionmock.go
AIBRIDGED_MOCKS := \
enterprise/aibridged/aibridgedmock/clientmock.go \
enterprise/aibridged/aibridgedmock/poolmock.go
enterprise/x/aibridged/aibridgedmock/clientmock.go \
enterprise/x/aibridged/aibridgedmock/poolmock.go
GEN_FILES := \
tailnet/proto/tailnet.pb.go \
@@ -645,7 +645,7 @@ GEN_FILES := \
provisionersdk/proto/provisioner.pb.go \
provisionerd/proto/provisionerd.pb.go \
vpn/vpn.pb.go \
enterprise/aibridged/proto/aibridged.pb.go \
enterprise/x/aibridged/proto/aibridged.pb.go \
$(DB_GEN_FILES) \
$(SITE_GEN_FILES) \
coderd/rbac/object_gen.go \
@@ -697,7 +697,7 @@ gen/mark-fresh:
provisionersdk/proto/provisioner.pb.go \
provisionerd/proto/provisionerd.pb.go \
vpn/vpn.pb.go \
enterprise/aibridged/proto/aibridged.pb.go \
enterprise/x/aibridged/proto/aibridged.pb.go \
coderd/database/dump.sql \
$(DB_GEN_FILES) \
site/src/api/typesGenerated.ts \
@@ -768,8 +768,8 @@ codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agen
go generate ./codersdk/workspacesdk/agentconnmock/
touch "$@"
$(AIBRIDGED_MOCKS): enterprise/aibridged/client.go enterprise/aibridged/pool.go
go generate ./enterprise/aibridged/aibridgedmock/
$(AIBRIDGED_MOCKS): enterprise/x/aibridged/client.go enterprise/x/aibridged/pool.go
go generate ./enterprise/x/aibridged/aibridgedmock/
touch "$@"
agent/agentcontainers/dcspec/dcspec_gen.go: \
@@ -822,13 +822,13 @@ vpn/vpn.pb.go: vpn/vpn.proto
--go_opt=paths=source_relative \
./vpn/vpn.proto
enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged.proto
enterprise/x/aibridged/proto/aibridged.pb.go: enterprise/x/aibridged/proto/aibridged.proto
protoc \
--go_out=. \
--go_opt=paths=source_relative \
--go-drpc_out=. \
--go-drpc_opt=paths=source_relative \
./enterprise/aibridged/proto/aibridged.proto
./enterprise/x/aibridged/proto/aibridged.proto
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
# -C sets the directory for the go run command
@@ -1182,8 +1182,3 @@ endif
dogfood/coder/nix.hash: flake.nix flake.lock
sha256sum flake.nix flake.lock >./dogfood/coder/nix.hash
# Count the number of test databases created per test package.
count-test-databases:
PGPASSWORD=postgres psql -h localhost -U postgres -d coder_testing -P pager=off -c 'SELECT test_package, count(*) as count from test_databases GROUP BY test_package ORDER BY count DESC'
.PHONY: count-test-databases
+74
View File
@@ -40,6 +40,7 @@ import (
"github.com/coder/coder/v2/agent/agentcontainers"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentscripts"
"github.com/coder/coder/v2/agent/agentsocket"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/agent/proto/resourcesmonitor"
@@ -91,6 +92,7 @@ type Options struct {
Devcontainers bool
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
Clock quartz.Clock
SocketPath string // Path for the agent socket server
}
type Client interface {
@@ -190,6 +192,7 @@ func New(options Options) Agent {
devcontainers: options.Devcontainers,
containerAPIOptions: options.DevcontainerAPIOptions,
socketPath: options.SocketPath,
}
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
@@ -271,6 +274,10 @@ type agent struct {
devcontainers bool
containerAPIOptions []agentcontainers.Option
containerAPI *agentcontainers.API
// Socket server for CLI communication
socketPath string
socketServer *agentsocket.Server
}
func (a *agent) TailnetConn() *tailnet.Conn {
@@ -350,9 +357,69 @@ func (a *agent) init() {
s.ExperimentalContainers = a.devcontainers
},
)
// Initialize socket server for CLI communication
a.initSocketServer()
go a.runLoop()
}
// initSocketServer initializes the socket server for CLI communication
func (a *agent) initSocketServer() {
// Get socket path from options or environment
socketPath := a.getSocketPath()
if socketPath == "" {
a.logger.Debug(a.hardCtx, "socket server disabled (no path configured)")
return
}
// Create socket server
server := agentsocket.NewServer(agentsocket.Config{
Path: socketPath,
Logger: a.logger.Named("socket"),
})
// Register default handlers
handlerCtx := agentsocket.CreateHandlerContext(
"", // Agent ID will be set when manifest is available
buildinfo.Version(),
"starting",
time.Now(),
a.logger,
)
agentsocket.RegisterDefaultHandlers(server, handlerCtx)
// Start the server
if err := server.Start(); err != nil {
a.logger.Warn(a.hardCtx, "failed to start socket server", slog.Error(err))
return
}
a.socketServer = server
a.logger.Info(a.hardCtx, "socket server started", slog.F("path", socketPath))
}
// getSocketPath returns the socket path from options or environment
func (a *agent) getSocketPath() string {
// Check if socket path is explicitly configured
if a.getSocketPathFromOptions() != "" {
return a.getSocketPathFromOptions()
}
// Check environment variable
if path := os.Getenv("CODER_AGENT_SOCKET_PATH"); path != "" {
return path
}
// Return empty to disable socket server
return ""
}
// getSocketPathFromOptions returns the socket path from agent options
func (a *agent) getSocketPathFromOptions() string {
return a.socketPath
}
// runLoop attempts to start the agent in a retry loop.
// Coder may be offline temporarily, a connection issue
// may be happening, but regardless after the intermittent
@@ -1931,6 +1998,13 @@ func (a *agent) Close() error {
a.logger.Error(a.hardCtx, "container API close", slog.Error(err))
}
// Close socket server
if a.socketServer != nil {
if err := a.socketServer.Stop(); err != nil {
a.logger.Error(a.hardCtx, "socket server close", slog.Error(err))
}
}
// Wait for the graceful shutdown to complete, but don't wait forever so
// that we don't break user expectations.
go func() {
+2
View File
@@ -682,6 +682,8 @@ func (api *API) updaterLoop() {
} else {
prevErr = nil
}
default:
api.logger.Debug(api.ctx, "updater loop ticker skipped, update in progress")
}
return nil // Always nil to keep the ticker going.
+214
View File
@@ -0,0 +1,214 @@
# Agent Socket API
The Agent Socket API provides a local communication channel between CLI commands running within a workspace and the Coder agent process. This enables new CLI commands to interact directly with the agent without going through the control plane.
## Overview
The socket server runs within the agent process and listens on a Unix domain socket (or named pipe on Windows). CLI commands can connect to this socket to query agent information, check health status, and perform other operations.
## Architecture
### Socket Server
- **Location**: `agent/agentsocket/`
- **Protocol**: JSON-RPC 2.0 over Unix domain socket
- **Platform Support**: Linux, macOS, Windows 10+ (build 17063+)
- **Authentication**: Pluggable middleware (no-auth by default)
### Client Library
- **Location**: `codersdk/agentsdk/socket_client.go`
- **Auto-discovery**: Automatically finds socket path
- **Type-safe**: Go client with proper error handling
## Socket Path Discovery
The socket path is determined in the following order:
1. **Environment Variable**: `CODER_AGENT_SOCKET_PATH`
2. **XDG Runtime Directory**: `$XDG_RUNTIME_DIR/coder-agent.sock`
3. **User Temp Directory**: `/tmp/coder-agent-{uid}.sock`
4. **Fallback**: `/tmp/coder-agent.sock`
## Protocol
### Request Format
```json
{
"version": "1.0",
"method": "ping",
"id": "request-123",
"params": {}
}
```
### Response Format
```json
{
"version": "1.0",
"id": "request-123",
"result": {
"message": "pong",
"timestamp": "2024-01-01T00:00:00Z"
}
}
```
### Error Format
```json
{
"version": "1.0",
"id": "request-123",
"error": {
"code": -32601,
"message": "Method not found",
"data": "nonexistent"
}
}
```
## Available Methods
### Core Methods
- `ping` - Health check with timestamp
- `health` - Agent status and uptime
- `agent.info` - Detailed agent information
- `methods.list` - List available methods
### Example Usage
```go
// Create client
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{})
if err != nil {
log.Fatal(err)
}
defer client.Close()
// Ping the agent
pingResp, err := client.Ping(ctx)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Agent responded: %s\n", pingResp.Message)
// Get agent info
info, err := client.AgentInfo(ctx)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Agent ID: %s, Version: %s\n", info.ID, info.Version)
```
## Adding New Handlers
### Server Side
```go
// Register a new handler
server.RegisterHandler("custom.method", func(ctx Context, req *Request) (*Response, error) {
// Handle the request
result := map[string]string{"status": "ok"}
return NewResponse(req.ID, result)
})
```
### Client Side
```go
// Add method to client
func (c *SocketClient) CustomMethod(ctx context.Context) (*CustomResponse, error) {
req := &Request{
Version: "1.0",
Method: "custom.method",
ID: generateRequestID(),
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, fmt.Errorf("custom method error: %s", resp.Error.Message)
}
var result CustomResponse
if err := json.Unmarshal(resp.Result, &result); err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
return &result, nil
}
```
## Authentication
The socket server supports pluggable authentication middleware. By default, no authentication is performed (suitable for local-only communication).
### Custom Authentication
```go
type CustomAuthMiddleware struct {
// Add auth fields
}
func (m *CustomAuthMiddleware) Authenticate(ctx context.Context, conn net.Conn) (context.Context, error) {
// Implement authentication logic
// Return context with auth info or error
return ctx, nil
}
// Use in server config
server := agentsocket.NewServer(agentsocket.Config{
Path: socketPath,
Logger: logger,
AuthMiddleware: &CustomAuthMiddleware{},
})
```
## Configuration
### Agent Options
```go
options := agent.Options{
// ... other options
SocketPath: "/custom/path/agent.sock", // Optional, uses auto-discovery if empty
}
```
### Environment Variables
- `CODER_AGENT_SOCKET_PATH` - Override socket path
- `XDG_RUNTIME_DIR` - Used for socket path discovery
## Error Codes
| Code | Description |
|------|-------------|
| -32700 | Parse error |
| -32600 | Invalid request |
| -32601 | Method not found |
| -32602 | Invalid params |
| -32603 | Internal error |
## Platform Support
### Unix-like Systems (Linux, macOS)
- Uses Unix domain sockets
- Socket file permissions: 600 (owner read/write only)
- Auto-cleanup on shutdown
### Windows
- Uses Unix domain sockets (Windows 10 build 17063+)
- Falls back to named pipes if needed
- Simplified permission handling
## Security Considerations
1. **Local Only**: Socket is only accessible from within the workspace
2. **File Permissions**: Socket file is restricted to owner only
3. **No Network Access**: Unix domain sockets don't traverse network
4. **Authentication Ready**: Middleware pattern allows future auth implementation
## Future Extensibility
The design supports:
- **Protocol Versioning**: Request includes version field
- **Multiple Transports**: Interface-based design allows TCP/WebSocket later
- **Auth Plugins**: Middleware pattern for various auth methods
- **Custom Handlers**: Simple registration pattern for new commands
+23
View File
@@ -0,0 +1,23 @@
package agentsocket
import (
"context"
"net"
)
// AuthMiddleware defines the interface for authentication middleware
type AuthMiddleware interface {
// Authenticate authenticates a connection and returns a context with auth info
Authenticate(ctx context.Context, conn net.Conn) (context.Context, error)
}
// NoAuthMiddleware is a no-op authentication middleware
type NoAuthMiddleware struct{}
// Authenticate implements AuthMiddleware but performs no authentication
func (*NoAuthMiddleware) Authenticate(ctx context.Context, conn net.Conn) (context.Context, error) {
return ctx, nil
}
// Ensure NoAuthMiddleware implements AuthMiddleware
var _ AuthMiddleware = (*NoAuthMiddleware)(nil)
+108
View File
@@ -0,0 +1,108 @@
package agentsocket
import (
"time"
"cdr.dev/slog"
)
// AgentInfo represents information about the agent
type AgentInfo struct {
ID string `json:"id"`
Version string `json:"version"`
Status string `json:"status"`
StartedAt time.Time `json:"started_at"`
Uptime string `json:"uptime"`
}
// PingResponse represents a ping response
type PingResponse struct {
Message string `json:"message"`
Timestamp time.Time `json:"timestamp"`
}
// HealthResponse represents a health check response
type HealthResponse struct {
Status string `json:"status"`
Timestamp time.Time `json:"timestamp"`
Uptime string `json:"uptime"`
}
// HandlerContext provides context for handlers
type HandlerContext struct {
AgentID string
Version string
Status string
StartedAt time.Time
Logger slog.Logger
}
// NewHandlers creates the default set of handlers
func NewHandlers(handlerCtx HandlerContext) map[string]Handler {
handlers := make(map[string]Handler)
// Ping handler
handlers["ping"] = func(_ Context, req *Request) (*Response, error) {
resp := PingResponse{
Message: "pong",
Timestamp: time.Now(),
}
return NewResponse(req.ID, resp)
}
// Health check handler
handlers["health"] = func(_ Context, req *Request) (*Response, error) {
uptime := time.Since(handlerCtx.StartedAt)
resp := HealthResponse{
Status: handlerCtx.Status,
Timestamp: time.Now(),
Uptime: uptime.String(),
}
return NewResponse(req.ID, resp)
}
// Agent info handler
handlers["agent.info"] = func(_ Context, req *Request) (*Response, error) {
uptime := time.Since(handlerCtx.StartedAt)
resp := AgentInfo{
ID: handlerCtx.AgentID,
Version: handlerCtx.Version,
Status: handlerCtx.Status,
StartedAt: handlerCtx.StartedAt,
Uptime: uptime.String(),
}
return NewResponse(req.ID, resp)
}
// List methods handler
handlers["methods.list"] = func(_ Context, req *Request) (*Response, error) {
methods := []string{
"ping",
"health",
"agent.info",
"methods.list",
}
return NewResponse(req.ID, methods)
}
return handlers
}
// RegisterDefaultHandlers registers the default set of handlers with a server
func RegisterDefaultHandlers(server *Server, ctx HandlerContext) {
handlers := NewHandlers(ctx)
for method, handler := range handlers {
server.RegisterHandler(method, handler)
}
}
// CreateHandlerContext creates a handler context from agent information
func CreateHandlerContext(agentID, version, status string, startedAt time.Time, logger slog.Logger) HandlerContext {
return HandlerContext{
AgentID: agentID,
Version: version,
Status: status,
StartedAt: startedAt,
Logger: logger,
}
}
+83
View File
@@ -0,0 +1,83 @@
package agentsocket
import (
"encoding/json"
"golang.org/x/xerrors"
)
// Protocol version for the agent socket API
const ProtocolVersion = "1.0"
// Request represents an incoming request to the agent socket
type Request struct {
Version string `json:"version"`
Method string `json:"method"`
ID string `json:"id,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
}
// Response represents a response from the agent socket
type Response struct {
Version string `json:"version"`
ID string `json:"id,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *Error `json:"error,omitempty"`
}
// Error represents an error in the response
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data,omitempty"`
}
// Standard error codes
const (
ErrCodeParseError = -32700
ErrCodeInvalidRequest = -32600
ErrCodeMethodNotFound = -32601
ErrCodeInvalidParams = -32602
ErrCodeInternalError = -32603
)
// NewError creates a new error response
func NewError(code int, message string, data any) *Error {
return &Error{
Code: code,
Message: message,
Data: data,
}
}
// NewResponse creates a successful response
func NewResponse(id string, result any) (*Response, error) {
resultBytes, err := json.Marshal(result)
if err != nil {
return nil, xerrors.Errorf("marshal result: %w", err)
}
return &Response{
Version: ProtocolVersion,
ID: id,
Result: resultBytes,
}, nil
}
// NewErrorResponse creates an error response
func NewErrorResponse(id string, err *Error) *Response {
return &Response{
Version: ProtocolVersion,
ID: id,
Error: err,
}
}
// Handler represents a function that can handle a request
type Handler func(ctx Context, req *Request) (*Response, error)
// Context provides context for request handling
type Context struct {
// Additional context can be added here in the future
// For now, this is a placeholder for future auth context, etc.
}
+266
View File
@@ -0,0 +1,266 @@
package agentsocket
import (
"context"
"encoding/json"
"io"
"net"
"sync"
"time"
"golang.org/x/xerrors"
"cdr.dev/slog"
)
// Server represents the agent socket server
type Server struct {
logger slog.Logger
path string
listener net.Listener
handlers map[string]Handler
authMiddleware AuthMiddleware
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// Config holds configuration for the socket server
type Config struct {
Path string
Logger slog.Logger
AuthMiddleware AuthMiddleware
}
// NewServer creates a new agent socket server
func NewServer(config Config) *Server {
ctx, cancel := context.WithCancel(context.Background())
server := &Server{
logger: config.Logger.Named("agentsocket"),
path: config.Path,
handlers: make(map[string]Handler),
authMiddleware: config.AuthMiddleware,
ctx: ctx,
cancel: cancel,
}
// Set default auth middleware if none provided
if server.authMiddleware == nil {
server.authMiddleware = &NoAuthMiddleware{}
}
return server
}
// Start starts the socket server
func (s *Server) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.listener != nil {
return xerrors.New("server already started")
}
// Get socket path
path := s.path
if path == "" {
var err error
path, err = getDefaultSocketPath()
if err != nil {
return xerrors.Errorf("get default socket path: %w", err)
}
}
// Check if socket is available
if !isSocketAvailable(path) {
return xerrors.Errorf("socket path %s is not available", path)
}
// Create socket listener
listener, err := createSocket(s.ctx, path)
if err != nil {
return xerrors.Errorf("create socket: %w", err)
}
s.listener = listener
s.path = path
s.logger.Info(s.ctx, "agent socket server started", slog.F("path", path))
// Start accepting connections
s.wg.Add(1)
go s.acceptConnections()
return nil
}
// Stop stops the socket server
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.listener == nil {
return nil
}
s.logger.Info(s.ctx, "stopping agent socket server")
// Cancel context to stop accepting new connections
s.cancel()
// Close listener
if err := s.listener.Close(); err != nil {
s.logger.Warn(s.ctx, "error closing socket listener", slog.Error(err))
}
// Wait for all connections to finish
s.wg.Wait()
// Clean up socket file
if err := cleanupSocket(s.path); err != nil {
s.logger.Warn(s.ctx, "error cleaning up socket file", slog.Error(err))
}
s.listener = nil
s.logger.Info(s.ctx, "agent socket server stopped")
return nil
}
// RegisterHandler registers a handler for a method
func (s *Server) RegisterHandler(method string, handler Handler) {
s.mu.Lock()
defer s.mu.Unlock()
s.handlers[method] = handler
}
// GetPath returns the socket path
func (s *Server) GetPath() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.path
}
// acceptConnections accepts incoming connections
func (s *Server) acceptConnections() {
defer s.wg.Done()
for {
select {
case <-s.ctx.Done():
return
default:
}
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.ctx.Done():
return
default:
s.logger.Warn(s.ctx, "error accepting connection", slog.Error(err))
continue
}
}
// Handle connection in a goroutine
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handleConnection(conn)
}()
}
}
// handleConnection handles a single connection
func (s *Server) handleConnection(conn net.Conn) {
defer conn.Close()
// Authenticate connection first to get context
ctx, err := s.authMiddleware.Authenticate(s.ctx, conn)
if err != nil {
s.logger.Warn(s.ctx, "authentication failed", slog.Error(err))
return
}
// Set connection deadline
if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
s.logger.Warn(ctx, "failed to set connection deadline", slog.Error(err))
}
s.logger.Debug(ctx, "new connection accepted", slog.F("remote_addr", conn.RemoteAddr()))
// Handle requests
decoder := json.NewDecoder(conn)
encoder := json.NewEncoder(conn)
for {
select {
case <-ctx.Done():
return
default:
}
// Set read deadline
if err := conn.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil {
s.logger.Warn(ctx, "failed to set read deadline", slog.Error(err))
}
var req Request
if err := decoder.Decode(&req); err != nil {
if err == io.EOF {
s.logger.Debug(ctx, "connection closed by client")
return
}
s.logger.Warn(ctx, "error decoding request", slog.Error(err))
// Send error response
resp := NewErrorResponse("", NewError(ErrCodeParseError, "Parse error", err.Error()))
encoder.Encode(resp)
return
}
// Handle request
resp := s.handleRequest(ctx, &req)
// Send response
if err := encoder.Encode(resp); err != nil {
s.logger.Warn(ctx, "error sending response", slog.Error(err))
return
}
}
}
// handleRequest handles a single request
func (s *Server) handleRequest(ctx context.Context, req *Request) *Response {
// Validate request
if req.Version != ProtocolVersion {
return NewErrorResponse(req.ID, NewError(ErrCodeInvalidRequest, "Unsupported version", req.Version))
}
if req.Method == "" {
return NewErrorResponse(req.ID, NewError(ErrCodeInvalidRequest, "Missing method", nil))
}
// Get handler
s.mu.RLock()
handler, exists := s.handlers[req.Method]
s.mu.RUnlock()
if !exists {
return NewErrorResponse(req.ID, NewError(ErrCodeMethodNotFound, "Method not found", req.Method))
}
// Call handler
type requestIDKey struct{}
ctx = context.WithValue(ctx, requestIDKey{}, req.ID)
resp, err := handler(Context{}, req)
if err != nil {
s.logger.Warn(ctx, "handler execution failed", slog.Error(err))
return NewErrorResponse(req.ID, NewError(ErrCodeInternalError, "Internal error", err.Error()))
}
return resp
}
+250
View File
@@ -0,0 +1,250 @@
package agentsocket
import (
"encoding/json"
"fmt"
"net"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog"
)
func TestServer_StartStop(t *testing.T) {
t.Parallel()
// Create temporary socket path
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test.sock")
// Create server
server := NewServer(Config{
Path: socketPath,
Logger: slog.Make().Leveled(slog.LevelDebug),
})
// Register a test handler
server.RegisterHandler("test", func(ctx Context, req *Request) (*Response, error) {
return NewResponse(req.ID, map[string]string{"message": "test response"})
})
// Start server
err := server.Start()
require.NoError(t, err)
defer server.Stop()
// Verify socket file exists
_, err = os.Stat(socketPath)
require.NoError(t, err)
// Test connection
conn, err := net.Dial("unix", socketPath)
require.NoError(t, err)
defer conn.Close()
// Send test request
req := Request{
Version: "1.0",
Method: "test",
ID: "test-1",
}
err = json.NewEncoder(conn).Encode(req)
require.NoError(t, err)
// Read response
var resp Response
err = json.NewDecoder(conn).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "test-1", resp.ID)
assert.Nil(t, resp.Error)
assert.NotNil(t, resp.Result)
// Verify response content
var result map[string]string
err = json.Unmarshal(resp.Result, &result)
require.NoError(t, err)
assert.Equal(t, "test response", result["message"])
}
func TestServer_ErrorHandling(t *testing.T) {
t.Parallel()
// Create temporary socket path
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test.sock")
// Create server
server := NewServer(Config{
Path: socketPath,
Logger: slog.Make().Leveled(slog.LevelDebug),
})
// Start server
err := server.Start()
require.NoError(t, err)
defer server.Stop()
// Test connection
conn, err := net.Dial("unix", socketPath)
require.NoError(t, err)
defer conn.Close()
// Send request for non-existent method
req := Request{
Version: "1.0",
Method: "nonexistent",
ID: "test-1",
}
err = json.NewEncoder(conn).Encode(req)
require.NoError(t, err)
// Read response
var resp Response
err = json.NewDecoder(conn).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "test-1", resp.ID)
assert.NotNil(t, resp.Error)
assert.Equal(t, ErrCodeMethodNotFound, resp.Error.Code)
assert.Equal(t, "Method not found", resp.Error.Message)
}
func TestServer_DefaultHandlers(t *testing.T) {
t.Parallel()
// Create temporary socket path
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test.sock")
// Create server
server := NewServer(Config{
Path: socketPath,
Logger: slog.Make().Leveled(slog.LevelDebug),
})
// Register default handlers
handlerCtx := CreateHandlerContext(
"test-agent-id",
"1.0.0",
"ready",
time.Now().Add(-time.Hour),
slog.Make(),
)
RegisterDefaultHandlers(server, handlerCtx)
// Start server
err := server.Start()
require.NoError(t, err)
defer server.Stop()
// Test ping
conn, err := net.Dial("unix", socketPath)
require.NoError(t, err)
defer conn.Close()
req := Request{
Version: "1.0",
Method: "ping",
ID: "ping-1",
}
err = json.NewEncoder(conn).Encode(req)
require.NoError(t, err)
var resp Response
err = json.NewDecoder(conn).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "ping-1", resp.ID)
assert.Nil(t, resp.Error)
var pingResp PingResponse
err = json.Unmarshal(resp.Result, &pingResp)
require.NoError(t, err)
assert.Equal(t, "pong", pingResp.Message)
}
func TestServer_ConcurrentConnections(t *testing.T) {
t.Parallel()
// Create temporary socket path
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test.sock")
// Create server
server := NewServer(Config{
Path: socketPath,
Logger: slog.Make().Leveled(slog.LevelDebug),
})
// Register a test handler
server.RegisterHandler("test", func(ctx Context, req *Request) (*Response, error) {
time.Sleep(10 * time.Millisecond) // Simulate some work
return NewResponse(req.ID, map[string]string{"message": "test response"})
})
// Start server
err := server.Start()
require.NoError(t, err)
defer server.Stop()
// Test multiple concurrent connections
const numConnections = 5
results := make(chan error, numConnections)
for i := 0; i < numConnections; i++ {
go func(i int) {
conn, err := net.Dial("unix", socketPath)
if err != nil {
results <- err
return
}
defer conn.Close()
req := Request{
Version: "1.0",
Method: "test",
ID: fmt.Sprintf("test-%d", i),
}
err = json.NewEncoder(conn).Encode(req)
if err != nil {
results <- err
return
}
var resp Response
err = json.NewDecoder(conn).Decode(&resp)
if err != nil {
results <- err
return
}
if resp.Error != nil {
results <- xerrors.Errorf("server error: %s", resp.Error.Message)
return
}
results <- nil
}(i)
}
// Wait for all connections to complete
for i := 0; i < numConnections; i++ {
select {
case err := <-results:
require.NoError(t, err)
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for concurrent connections")
}
}
}
+106
View File
@@ -0,0 +1,106 @@
//go:build !windows
package agentsocket
import (
"context"
"fmt"
"net"
"os"
"path/filepath"
"syscall"
"time"
"golang.org/x/xerrors"
)
// createSocket creates a Unix domain socket listener
func createSocket(ctx context.Context, path string) (net.Listener, error) {
// Remove existing socket file if it exists
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return nil, xerrors.Errorf("remove existing socket: %w", err)
}
// Create parent directory if it doesn't exist
parentDir := filepath.Dir(path)
if err := os.MkdirAll(parentDir, 0o700); err != nil {
return nil, xerrors.Errorf("create socket directory: %w", err)
}
// Create Unix domain socket listener
listener, err := net.Listen("unix", path)
if err != nil {
return nil, xerrors.Errorf("listen on unix socket: %w", err)
}
// Set socket permissions to be accessible only by the current user
if err := os.Chmod(path, 0o600); err != nil {
listener.Close()
return nil, xerrors.Errorf("set socket permissions: %w", err)
}
return listener, nil
}
// getDefaultSocketPath returns the default socket path for Unix-like systems
func getDefaultSocketPath() (string, error) {
// Try XDG_RUNTIME_DIR first
if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" {
return filepath.Join(runtimeDir, "coder-agent.sock"), nil
}
// Fall back to /tmp with user-specific path
uid := os.Getuid()
return filepath.Join("/tmp", fmt.Sprintf("coder-agent-%d.sock", uid)), nil
}
// cleanupSocket removes the socket file
func cleanupSocket(path string) error {
return os.Remove(path)
}
// isSocketAvailable checks if a socket path is available for use
func isSocketAvailable(path string) bool {
// Check if file exists
if _, err := os.Stat(path); os.IsNotExist(err) {
return true
}
// Try to connect to see if it's actually listening
conn, err := net.Dial("unix", path)
if err != nil {
// If we can't connect, the socket is not in use
return true
}
conn.Close()
return false
}
// getSocketInfo returns information about the socket file
func getSocketInfo(path string) (*SocketInfo, error) {
stat, err := os.Stat(path)
if err != nil {
return nil, err
}
sys, ok := stat.Sys().(*syscall.Stat_t)
if !ok {
return nil, xerrors.New("unable to get stat_t from file info")
}
return &SocketInfo{
Path: path,
UID: int(sys.Uid),
GID: int(sys.Gid),
Mode: stat.Mode(),
ModTime: stat.ModTime(),
}, nil
}
// SocketInfo contains information about a socket file
type SocketInfo struct {
Path string
UID int
GID int
Mode os.FileMode
ModTime time.Time
}
+99
View File
@@ -0,0 +1,99 @@
//go:build windows
package agentsocket
import (
"context"
"fmt"
"net"
"os"
"path/filepath"
"strconv"
"time"
)
// createSocket creates a Unix domain socket listener on Windows
// Falls back to named pipe if Unix sockets are not supported
func createSocket(ctx context.Context, path string) (net.Listener, error) {
// Try Unix domain socket first (Windows 10 build 17063+)
listener, err := net.Listen("unix", path)
if err == nil {
return listener, nil
}
// Fall back to named pipe
pipePath := `\\.\pipe\coder-agent`
return net.Listen("tcp", pipePath)
}
// getDefaultSocketPath returns the default socket path for Windows
func getDefaultSocketPath() (string, error) {
// Try to use a temporary directory
tempDir := os.TempDir()
if tempDir == "" {
tempDir = "C:\\temp"
}
// Create a user-specific subdirectory
uid := os.Getuid()
userDir := filepath.Join(tempDir, "coder-agent", strconv.Itoa(uid))
if err := os.MkdirAll(userDir, 0o700); err != nil {
return "", fmt.Errorf("create user directory: %w", err)
}
return filepath.Join(userDir, "agent.sock"), nil
}
// cleanupSocket removes the socket file
func cleanupSocket(path string) error {
return os.Remove(path)
}
// isSocketAvailable checks if a socket path is available for use
func isSocketAvailable(path string) bool {
// Check if file exists
if _, err := os.Stat(path); os.IsNotExist(err) {
return true
}
// Try to connect to see if it's actually listening
conn, err := net.Dial("unix", path)
if err != nil {
// If we can't connect, the socket is not in use
return true
}
conn.Close()
return false
}
// getSocketInfo returns information about the socket file
func getSocketInfo(path string) (*SocketInfo, error) {
stat, err := os.Stat(path)
if err != nil {
return nil, err
}
// On Windows, we'll use a simplified approach for now
// In a real implementation, you'd get the security descriptor
return &SocketInfo{
Path: path,
UID: 0, // Simplified for now
GID: 0, // Simplified for now
Mode: stat.Mode(),
ModTime: stat.ModTime(),
Owner: "unknown",
Group: "unknown",
}, nil
}
// SocketInfo contains information about a socket file
type SocketInfo struct {
Path string
UID int
GID int
Mode os.FileMode
ModTime time.Time
Owner string // Windows SID string
Group string // Windows SID string
}
+227
View File
@@ -0,0 +1,227 @@
package unit
import (
"sync"
"golang.org/x/xerrors"
)
// ErrConsumerNotFound is returned when a consumer ID is not registered.
var ErrConsumerNotFound = xerrors.New("consumer not found")
// ErrCannotUpdateOtherConsumer is returned when attempting to update another consumer's status.
var ErrCannotUpdateOtherConsumer = xerrors.New("cannot update other consumer's status")
// dependencyVertex represents a vertex in the dependency graph that is associated with a consumer.
type dependencyVertex[ConsumerID comparable] struct {
ID ConsumerID
}
// Dependency represents a dependency relationship between consumers.
type Dependency[StatusType, ConsumerID comparable] struct {
Consumer ConsumerID
DependsOn ConsumerID
RequiredStatus StatusType
CurrentStatus StatusType
IsSatisfied bool
}
// DependencyTracker provides reactive dependency tracking over a Graph.
// It manages consumer registration, dependency relationships, and status updates
// with automatic recalculation of readiness when dependencies are satisfied.
type DependencyTracker[StatusType, ConsumerID comparable] struct {
mu sync.RWMutex
// The underlying graph that stores dependency relationships
graph *Graph[StatusType, *dependencyVertex[ConsumerID]]
// Track current status of each consumer
consumerStatus map[ConsumerID]StatusType
// Track readiness state (cached to avoid repeated graph traversal)
consumerReadiness map[ConsumerID]bool
// Track which consumers are registered
registeredConsumers map[ConsumerID]bool
// Store vertex instances for each consumer to ensure consistent references
consumerVertices map[ConsumerID]*dependencyVertex[ConsumerID]
}
// NewDependencyTracker creates a new DependencyTracker instance.
func NewDependencyTracker[StatusType, ConsumerID comparable]() *DependencyTracker[StatusType, ConsumerID] {
return &DependencyTracker[StatusType, ConsumerID]{
graph: &Graph[StatusType, *dependencyVertex[ConsumerID]]{},
consumerStatus: make(map[ConsumerID]StatusType),
consumerReadiness: make(map[ConsumerID]bool),
registeredConsumers: make(map[ConsumerID]bool),
consumerVertices: make(map[ConsumerID]*dependencyVertex[ConsumerID]),
}
}
// Register registers a new consumer as a vertex in the dependency graph.
func (dt *DependencyTracker[StatusType, ConsumerID]) Register(id ConsumerID) error {
dt.mu.Lock()
defer dt.mu.Unlock()
if dt.registeredConsumers[id] {
return xerrors.Errorf("consumer %v is already registered", id)
}
// Create and store the vertex for this consumer
vertex := &dependencyVertex[ConsumerID]{ID: id}
dt.consumerVertices[id] = vertex
dt.registeredConsumers[id] = true
dt.consumerReadiness[id] = true // New consumers start as ready (no dependencies)
return nil
}
// AddDependency adds a dependency relationship between consumers.
// The consumer depends on the dependsOn consumer reaching the requiredStatus.
func (dt *DependencyTracker[StatusType, ConsumerID]) AddDependency(consumer ConsumerID, dependsOn ConsumerID, requiredStatus StatusType) error {
dt.mu.Lock()
defer dt.mu.Unlock()
if !dt.registeredConsumers[consumer] {
return xerrors.Errorf("consumer %v is not registered", consumer)
}
if !dt.registeredConsumers[dependsOn] {
return xerrors.Errorf("consumer %v is not registered", dependsOn)
}
// Get the stored vertices for both consumers
consumerVertex := dt.consumerVertices[consumer]
dependsOnVertex := dt.consumerVertices[dependsOn]
// Add the dependency edge to the graph
// The edge goes from consumer to dependsOn, representing the dependency
err := dt.graph.AddEdge(consumerVertex, dependsOnVertex, requiredStatus)
if err != nil {
return xerrors.Errorf("failed to add dependency: %w", err)
}
// Recalculate readiness for the consumer since it now has a dependency
dt.recalculateReadinessUnsafe(consumer)
return nil
}
// UpdateStatus updates a consumer's status and recalculates readiness for affected dependents.
func (dt *DependencyTracker[StatusType, ConsumerID]) UpdateStatus(consumer ConsumerID, newStatus StatusType) error {
dt.mu.Lock()
defer dt.mu.Unlock()
if !dt.registeredConsumers[consumer] {
return ErrConsumerNotFound
}
// Update the consumer's status
dt.consumerStatus[consumer] = newStatus
// Get all consumers that depend on this one (reverse adjacent vertices)
consumerVertex := dt.consumerVertices[consumer]
dependentEdges := dt.graph.GetReverseAdjacentVertices(consumerVertex)
// Recalculate readiness for all dependents
for _, edge := range dependentEdges {
dt.recalculateReadinessUnsafe(edge.From.ID)
}
return nil
}
// IsReady checks if all dependencies for a consumer are satisfied.
func (dt *DependencyTracker[StatusType, ConsumerID]) IsReady(consumer ConsumerID) (bool, error) {
dt.mu.RLock()
defer dt.mu.RUnlock()
if !dt.registeredConsumers[consumer] {
return false, ErrConsumerNotFound
}
return dt.consumerReadiness[consumer], nil
}
// GetUnmetDependencies returns a list of unsatisfied dependencies for a consumer.
func (dt *DependencyTracker[StatusType, ConsumerID]) GetUnmetDependencies(consumer ConsumerID) ([]Dependency[StatusType, ConsumerID], error) {
dt.mu.RLock()
defer dt.mu.RUnlock()
if !dt.registeredConsumers[consumer] {
return nil, ErrConsumerNotFound
}
consumerVertex := dt.consumerVertices[consumer]
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
var unmetDependencies []Dependency[StatusType, ConsumerID]
for _, edge := range forwardEdges {
dependsOnConsumer := edge.To.ID
requiredStatus := edge.Edge
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
if !exists {
// If the dependency consumer has no status, it's not satisfied
var zeroStatus StatusType
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
Consumer: consumer,
DependsOn: dependsOnConsumer,
RequiredStatus: requiredStatus,
CurrentStatus: zeroStatus, // Zero value
IsSatisfied: false,
})
} else {
isSatisfied := currentStatus == requiredStatus
if !isSatisfied {
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
Consumer: consumer,
DependsOn: dependsOnConsumer,
RequiredStatus: requiredStatus,
CurrentStatus: currentStatus,
IsSatisfied: false,
})
}
}
}
return unmetDependencies, nil
}
// recalculateReadinessUnsafe recalculates the readiness state for a consumer.
// This method assumes the caller holds the write lock.
func (dt *DependencyTracker[StatusType, ConsumerID]) recalculateReadinessUnsafe(consumer ConsumerID) {
consumerVertex := dt.consumerVertices[consumer]
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
// If there are no dependencies, the consumer is ready
if len(forwardEdges) == 0 {
dt.consumerReadiness[consumer] = true
return
}
// Check if all dependencies are satisfied
allSatisfied := true
for _, edge := range forwardEdges {
dependsOnConsumer := edge.To.ID
requiredStatus := edge.Edge
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
if !exists || currentStatus != requiredStatus {
allSatisfied = false
break
}
}
dt.consumerReadiness[consumer] = allSatisfied
}
// GetGraph returns the underlying graph for visualization and debugging.
// This should be used carefully as it exposes the internal graph structure.
func (dt *DependencyTracker[StatusType, ConsumerID]) GetGraph() *Graph[StatusType, *dependencyVertex[ConsumerID]] {
return dt.graph
}
// ExportDOT exports the dependency graph to DOT format for visualization.
func (dt *DependencyTracker[StatusType, ConsumerID]) ExportDOT(name string) (string, error) {
return dt.graph.ToDOT(name)
}
+692
View File
@@ -0,0 +1,692 @@
package unit_test
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/unit"
)
type testStatus string
const (
statusInitialized testStatus = "initialized"
statusStarted testStatus = "started"
statusRunning testStatus = "running"
statusCompleted testStatus = "completed"
)
type testConsumerID string
const (
consumerA testConsumerID = "serviceA"
consumerB testConsumerID = "serviceB"
consumerC testConsumerID = "serviceC"
consumerD testConsumerID = "serviceD"
)
func TestDependencyTracker_Register(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
t.Run("RegisterNewConsumer", func(t *testing.T) {
t.Parallel()
err := tracker.Register(consumerA)
require.NoError(t, err)
// Consumer should be ready initially (no dependencies)
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("RegisterDuplicateConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerA)
require.Error(t, err)
assert.Contains(t, err.Error(), "already registered")
})
t.Run("RegisterMultipleConsumers", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
consumers := []testConsumerID{consumerA, consumerB, consumerC}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// All should be ready initially
for _, consumer := range consumers {
ready, err := tracker.IsReady(consumer)
require.NoError(t, err)
assert.True(t, ready)
}
})
}
func TestDependencyTracker_AddDependency(t *testing.T) {
t.Parallel()
t.Run("AddDependencyBetweenRegisteredConsumers", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
// A should no longer be ready (depends on B)
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// B should still be ready (no dependencies)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("AddDependencyWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
// Try to add dependency to unregistered consumer
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.Error(t, err)
assert.Contains(t, err.Error(), "not registered")
})
t.Run("AddDependencyFromUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.Register(consumerB)
require.NoError(t, err)
// Try to add dependency from unregistered consumer
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.Error(t, err)
assert.Contains(t, err.Error(), "not registered")
})
}
func TestDependencyTracker_UpdateStatus(t *testing.T) {
t.Parallel()
t.Run("UpdateStatusTriggersReadinessRecalculation", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
// Initially A is not ready
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update B to "running" - A should become ready
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.UpdateStatus(consumerA, statusRunning)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
})
t.Run("LinearChainDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register all consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// Create chain: A depends on B being "started", B depends on C being "completed"
err := tracker.AddDependency(consumerA, consumerB, statusStarted)
require.NoError(t, err)
err = tracker.AddDependency(consumerB, consumerC, statusCompleted)
require.NoError(t, err)
// Initially only C is ready
ready, err := tracker.IsReady(consumerC)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.False(t, ready)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update C to "completed" - B should become ready
err = tracker.UpdateStatus(consumerC, statusCompleted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update B to "started" - A should become ready
err = tracker.UpdateStatus(consumerB, statusStarted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
}
func TestDependencyTracker_GetUnmetDependencies(t *testing.T) {
t.Parallel()
t.Run("GetUnmetDependenciesForConsumerWithNoDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.NoError(t, err)
assert.Empty(t, unmet)
})
t.Run("GetUnmetDependenciesForConsumerWithUnsatisfiedDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.NoError(t, err)
require.Len(t, unmet, 1)
assert.Equal(t, consumerA, unmet[0].Consumer)
assert.Equal(t, consumerB, unmet[0].DependsOn)
assert.Equal(t, statusRunning, unmet[0].RequiredStatus)
assert.False(t, unmet[0].IsSatisfied)
})
t.Run("GetUnmetDependenciesForConsumerWithSatisfiedDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
// Update B to "running"
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.NoError(t, err)
assert.Empty(t, unmet)
})
t.Run("GetUnmetDependenciesForUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
assert.Nil(t, unmet)
})
}
func TestDependencyTracker_ConcurrentOperations(t *testing.T) {
t.Parallel()
t.Run("ConcurrentStatusUpdates", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// Create dependencies: A depends on B, B depends on C, C depends on D
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerB, consumerC, statusStarted)
require.NoError(t, err)
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
require.NoError(t, err)
var wg sync.WaitGroup
const numGoroutines = 10
// Launch goroutines that update statuses
errors := make([]error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
// Update D to completed (should make C ready)
err := tracker.UpdateStatus(consumerD, statusCompleted)
if err != nil {
errors[goroutineID] = err
return
}
// Update C to started (should make B ready)
err = tracker.UpdateStatus(consumerC, statusStarted)
if err != nil {
errors[goroutineID] = err
return
}
// Update B to running (should make A ready)
err = tracker.UpdateStatus(consumerB, statusRunning)
if err != nil {
errors[goroutineID] = err
return
}
}(i)
}
wg.Wait()
// Check for any errors in goroutines
for i, err := range errors {
require.NoError(t, err, "goroutine %d had error", i)
}
// All consumers should be ready after the updates
for _, consumer := range consumers {
ready, err := tracker.IsReady(consumer)
require.NoError(t, err)
assert.True(t, ready)
}
})
t.Run("ConcurrentReadinessChecks", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register consumers
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
var wg sync.WaitGroup
const numGoroutines = 20
// Launch goroutines that check readiness
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
// Check readiness multiple times
for j := 0; j < 10; j++ {
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
// Initially should be false, then true after B is updated
_ = ready
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
// B should always be ready (no dependencies)
assert.True(t, ready)
}
}(i)
}
// Update B to "running" in the middle of readiness checks
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
wg.Wait()
})
}
func TestDependencyTracker_MultipleDependencies(t *testing.T) {
t.Parallel()
t.Run("ConsumerWithMultipleDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register all consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// A depends on B being "running" AND C being "started"
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
require.NoError(t, err)
// A should not be ready (depends on both B and C)
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update B to "running" - A should still not be ready (needs C too)
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update C to "started" - A should now be ready
err = tracker.UpdateStatus(consumerC, statusStarted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("ComplexDependencyChain", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register all consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// Create complex dependency graph:
// A depends on B being "running" AND C being "started"
// B depends on D being "completed"
// C depends on D being "completed"
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
require.NoError(t, err)
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
require.NoError(t, err)
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
require.NoError(t, err)
// Initially only D is ready
ready, err := tracker.IsReady(consumerD)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.False(t, ready)
ready, err = tracker.IsReady(consumerC)
require.NoError(t, err)
assert.False(t, ready)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update D to "completed" - B and C should become ready
err = tracker.UpdateStatus(consumerD, statusCompleted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerC)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update B to "running" - A should still not be ready (needs C)
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update C to "started" - A should now be ready
err = tracker.UpdateStatus(consumerC, statusStarted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("DifferentStatusTypes", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register consumers
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
err = tracker.Register(consumerC)
require.NoError(t, err)
// A depends on B being "running" AND C being "completed"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerA, consumerC, statusCompleted)
require.NoError(t, err)
// Update B to "running" but not C - A should not be ready
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update C to "completed" - A should now be ready
err = tracker.UpdateStatus(consumerC, statusCompleted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
}
func TestDependencyTracker_ErrorCases(t *testing.T) {
t.Parallel()
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.UpdateStatus(consumerA, statusRunning)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
})
t.Run("IsReadyWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
ready, err := tracker.IsReady(consumerA)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
assert.False(t, ready)
})
t.Run("GetUnmetDependenciesWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
assert.Nil(t, unmet)
})
t.Run("AddDependencyWithUnregisteredConsumers", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Try to add dependency with unregistered consumers
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.Error(t, err)
assert.Contains(t, err.Error(), "not registered")
})
t.Run("CyclicDependencyDetection", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register consumers
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
// Try to make B depend on A (creates cycle)
err = tracker.AddDependency(consumerB, consumerA, statusStarted)
require.Error(t, err)
assert.Contains(t, err.Error(), "would create a cycle")
})
}
func TestDependencyTracker_ToDOT(t *testing.T) {
t.Parallel()
t.Run("ExportSimpleGraph", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register consumers
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// Add dependency
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
dot, err := tracker.ExportDOT("test")
require.NoError(t, err)
assert.NotEmpty(t, dot)
assert.Contains(t, dot, "digraph")
})
t.Run("ExportComplexGraph", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register all consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// Create complex dependency graph
// A depends on B and C, B depends on D, C depends on D
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
require.NoError(t, err)
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
require.NoError(t, err)
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
require.NoError(t, err)
dot, err := tracker.ExportDOT("complex")
require.NoError(t, err)
assert.NotEmpty(t, dot)
assert.Contains(t, dot, "digraph")
})
}
-78
View File
@@ -1,78 +0,0 @@
package cli
import (
"encoding/csv"
"strings"
"github.com/spf13/pflag"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
)
var (
_ pflag.SliceValue = &AllowListFlag{}
_ pflag.Value = &AllowListFlag{}
)
// AllowListFlag implements pflag.SliceValue for codersdk.APIAllowListTarget entries.
type AllowListFlag []codersdk.APIAllowListTarget
func AllowListFlagOf(al *[]codersdk.APIAllowListTarget) *AllowListFlag {
return (*AllowListFlag)(al)
}
func (a AllowListFlag) String() string {
return strings.Join(a.GetSlice(), ",")
}
func (a AllowListFlag) Value() []codersdk.APIAllowListTarget {
return []codersdk.APIAllowListTarget(a)
}
func (AllowListFlag) Type() string { return "allow-list" }
func (a *AllowListFlag) Set(set string) error {
values, err := csv.NewReader(strings.NewReader(set)).Read()
if err != nil {
return xerrors.Errorf("parse allow list entries as csv: %w", err)
}
for _, v := range values {
if err := a.Append(v); err != nil {
return err
}
}
return nil
}
func (a *AllowListFlag) Append(value string) error {
value = strings.TrimSpace(value)
if value == "" {
return xerrors.New("allow list entry cannot be empty")
}
var target codersdk.APIAllowListTarget
if err := target.UnmarshalText([]byte(value)); err != nil {
return err
}
*a = append(*a, target)
return nil
}
func (a *AllowListFlag) Replace(items []string) error {
*a = []codersdk.APIAllowListTarget{}
for _, item := range items {
if err := a.Append(item); err != nil {
return err
}
}
return nil
}
func (a *AllowListFlag) GetSlice() []string {
out := make([]string, len(*a))
for i, entry := range *a {
out[i] = entry.String()
}
return out
}
+67 -9
View File
@@ -2,6 +2,7 @@ package cli_test
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"io"
@@ -18,7 +19,10 @@ import (
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
@@ -39,22 +43,76 @@ func makeAITask(t *testing.T, db database.Store, orgID, adminID, ownerID uuid.UU
},
}).Do()
build := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
ws := database.WorkspaceTable{
OrganizationID: orgID,
OwnerID: ownerID,
TemplateID: tv.Template.ID,
}).
}
build := dbfake.WorkspaceBuild(t, db, ws).
Seed(database.WorkspaceBuild{
TemplateVersionID: tv.TemplateVersion.ID,
Transition: transition,
}).
WithAgent().
WithTask(database.TaskTable{
Prompt: prompt,
}, nil).
Do()
}).WithAgent().Do()
dbgen.WorkspaceBuildParameters(t, db, []database.WorkspaceBuildParameter{
{
WorkspaceBuildID: build.Build.ID,
Name: codersdk.AITaskPromptParameterName,
Value: prompt,
},
})
agents, err := db.GetWorkspaceAgentsByWorkspaceAndBuildNumber(
dbauthz.AsSystemRestricted(context.Background()),
database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{
WorkspaceID: build.Workspace.ID,
BuildNumber: build.Build.BuildNumber,
},
)
require.NoError(t, err)
require.NotEmpty(t, agents)
agentID := agents[0].ID
return build.Task
// Create a workspace app and set it as the sidebar app.
app := dbgen.WorkspaceApp(t, db, database.WorkspaceApp{
AgentID: agentID,
Slug: "task-sidebar",
DisplayName: "Task Sidebar",
External: false,
})
// Update build flags to reference the sidebar app and HasAITask=true.
err = db.UpdateWorkspaceBuildFlagsByID(
dbauthz.AsSystemRestricted(context.Background()),
database.UpdateWorkspaceBuildFlagsByIDParams{
ID: build.Build.ID,
HasAITask: sql.NullBool{Bool: true, Valid: true},
HasExternalAgent: sql.NullBool{Bool: false, Valid: false},
SidebarAppID: uuid.NullUUID{UUID: app.ID, Valid: true},
UpdatedAt: build.Build.UpdatedAt,
},
)
require.NoError(t, err)
// Create a task record in the tasks table for the new data model.
task := dbgen.Task(t, db, database.TaskTable{
OrganizationID: orgID,
OwnerID: ownerID,
Name: build.Workspace.Name,
WorkspaceID: uuid.NullUUID{UUID: build.Workspace.ID, Valid: true},
TemplateVersionID: tv.TemplateVersion.ID,
TemplateParameters: []byte("{}"),
Prompt: prompt,
CreatedAt: dbtime.Now(),
})
// Link the task to the workspace app.
dbgen.TaskWorkspaceApp(t, db, database.TaskWorkspaceApp{
TaskID: task.ID,
WorkspaceBuildNumber: build.Build.BuildNumber,
WorkspaceAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
WorkspaceAppID: uuid.NullUUID{UUID: app.ID, Valid: true},
})
return task
}
func TestExpTaskList(t *testing.T) {
+4 -1
View File
@@ -293,6 +293,7 @@ func createAITaskTemplate(t *testing.T, client *codersdk.Client, orgID uuid.UUID
{
Type: &proto.Response_Plan{
Plan: &proto.PlanComplete{
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
HasAiTasks: true,
},
},
@@ -327,7 +328,9 @@ func createAITaskTemplate(t *testing.T, client *codersdk.Client, orgID uuid.UUID
},
AiTasks: []*proto.AITask{
{
AppId: taskAppID.String(),
SidebarApp: &proto.AITaskSidebarApp{
Id: taskAppID.String(),
},
},
},
},
-47
View File
@@ -109,51 +109,6 @@ func (r *RootCmd) ssh() *serpent.Command {
}
},
),
CompletionHandler: func(inv *serpent.Invocation) []string {
client, err := r.InitClient(inv)
if err != nil {
return []string{}
}
res, err := client.Workspaces(inv.Context(), codersdk.WorkspaceFilter{
Owner: codersdk.Me,
})
if err != nil {
return []string{}
}
var mu sync.Mutex
var completions []string
var wg sync.WaitGroup
for _, ws := range res.Workspaces {
wg.Add(1)
go func() {
defer wg.Done()
resources, err := client.TemplateVersionResources(inv.Context(), ws.LatestBuild.TemplateVersionID)
if err != nil {
return
}
var agents []codersdk.WorkspaceAgent
for _, resource := range resources {
agents = append(agents, resource.Agents...)
}
mu.Lock()
defer mu.Unlock()
if len(agents) == 1 {
completions = append(completions, ws.Name)
} else {
for _, agent := range agents {
completions = append(completions, fmt.Sprintf("%s.%s", ws.Name, agent.Name))
}
}
}()
}
wg.Wait()
slices.Sort(completions)
return completions
},
Handler: func(inv *serpent.Invocation) (retErr error) {
client, err := r.InitClient(inv)
if err != nil {
@@ -951,8 +906,6 @@ func GetWorkspaceAndAgent(ctx context.Context, inv *serpent.Invocation, client *
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("start workspace with active template version: %w", err)
}
_, _ = fmt.Fprintln(inv.Stdout, "Unable to start the workspace with template version from last build. Your workspace has been updated to the current active template version.")
default:
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("start workspace with current template version: %w", err)
}
} else if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("start workspace with current template version: %w", err)
-96
View File
@@ -2447,99 +2447,3 @@ func tempDirUnixSocket(t *testing.T) string {
return t.TempDir()
}
func TestSSH_Completion(t *testing.T) {
t.Parallel()
t.Run("SingleAgent", func(t *testing.T) {
t.Parallel()
client, workspace, agentToken := setupWorkspaceForAgent(t)
_ = agenttest.New(t, client.URL, agentToken)
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
var stdout bytes.Buffer
inv, root := clitest.New(t, "ssh", "")
inv.Stdout = &stdout
inv.Environ.Set("COMPLETION_MODE", "1")
clitest.SetupConfig(t, client, root)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
// For single-agent workspaces, the only completion should be the
// bare workspace name.
output := stdout.String()
t.Logf("Completion output: %q", output)
require.Contains(t, output, workspace.Name)
})
t.Run("MultiAgent", func(t *testing.T) {
t.Parallel()
client, store := coderdtest.NewWithDatabase(t, nil)
first := coderdtest.CreateFirstUser(t, client)
userClient, user := coderdtest.CreateAnotherUserMutators(t, client, first.OrganizationID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
r.Username = "multiuser"
})
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
Name: "multiworkspace",
OrganizationID: first.OrganizationID,
OwnerID: user.ID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
return []*proto.Agent{
{
Name: "agent1",
Auth: &proto.Agent_Token{},
},
{
Name: "agent2",
Auth: &proto.Agent_Token{},
},
}
}).Do()
var stdout bytes.Buffer
inv, root := clitest.New(t, "ssh", "")
inv.Stdout = &stdout
inv.Environ.Set("COMPLETION_MODE", "1")
clitest.SetupConfig(t, userClient, root)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
// For multi-agent workspaces, completions should include the
// workspace.agent format but NOT the bare workspace name.
output := stdout.String()
t.Logf("Completion output: %q", output)
lines := strings.Split(strings.TrimSpace(output), "\n")
require.NotContains(t, lines, r.Workspace.Name)
require.Contains(t, output, r.Workspace.Name+".agent1")
require.Contains(t, output, r.Workspace.Name+".agent2")
})
t.Run("NetworkError", func(t *testing.T) {
t.Parallel()
var stdout bytes.Buffer
inv, _ := clitest.New(t, "ssh", "")
inv.Stdout = &stdout
inv.Environ.Set("COMPLETION_MODE", "1")
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
output := stdout.String()
require.Empty(t, output)
})
}
+1 -2
View File
@@ -90,7 +90,6 @@
"allow_renames": false,
"favorite": false,
"next_start_at": "====[timestamp]=====",
"is_prebuild": false,
"task_id": null
"is_prebuild": false
}
]
-35
View File
@@ -80,41 +80,6 @@ OPTIONS:
Periodically check for new releases of Coder and inform the owner. The
check is performed once per day.
AIBRIDGE OPTIONS:
--aibridge-anthropic-base-url string, $CODER_AIBRIDGE_ANTHROPIC_BASE_URL (default: https://api.anthropic.com/)
The base URL of the Anthropic API.
--aibridge-anthropic-key string, $CODER_AIBRIDGE_ANTHROPIC_KEY
The key to authenticate against the Anthropic API.
--aibridge-bedrock-access-key string, $CODER_AIBRIDGE_BEDROCK_ACCESS_KEY
The access key to authenticate against the AWS Bedrock API.
--aibridge-bedrock-access-key-secret string, $CODER_AIBRIDGE_BEDROCK_ACCESS_KEY_SECRET
The access key secret to use with the access key to authenticate
against the AWS Bedrock API.
--aibridge-bedrock-model string, $CODER_AIBRIDGE_BEDROCK_MODEL (default: global.anthropic.claude-sonnet-4-5-20250929-v1:0)
The model to use when making requests to the AWS Bedrock API.
--aibridge-bedrock-region string, $CODER_AIBRIDGE_BEDROCK_REGION
The AWS Bedrock API region.
--aibridge-bedrock-small-fastmodel string, $CODER_AIBRIDGE_BEDROCK_SMALL_FAST_MODEL (default: global.anthropic.claude-haiku-4-5-20251001-v1:0)
The small fast model to use when making requests to the AWS Bedrock
API. Claude Code uses Haiku-class models to perform background tasks.
See
https://docs.claude.com/en/docs/claude-code/settings#environment-variables.
--aibridge-enabled bool, $CODER_AIBRIDGE_ENABLED (default: false)
Whether to start an in-memory aibridged instance.
--aibridge-openai-base-url string, $CODER_AIBRIDGE_OPENAI_BASE_URL (default: https://api.openai.com/v1/)
The base URL of the OpenAI API.
--aibridge-openai-key string, $CODER_AIBRIDGE_OPENAI_KEY
The key to authenticate against the OpenAI API.
CLIENT OPTIONS:
These options change the behavior of how clients interact with the Coder.
Clients include the Coder CLI, Coder Desktop, IDE extensions, and the web UI.
-5
View File
@@ -16,10 +16,6 @@ USAGE:
$ coder tokens ls
- Create a scoped token:
$ coder tokens create --scope workspace:read --allow workspace:<uuid>
- Remove a token by ID:
$ coder tokens rm WuoWs4ZsMX
@@ -28,7 +24,6 @@ SUBCOMMANDS:
create Create a token
list List tokens
remove Delete a token
view Display detailed information about a token
———
Run `coder --help` for a list of global options.
+1 -9
View File
@@ -6,20 +6,12 @@ USAGE:
Create a token
OPTIONS:
--allow allow-list
Repeatable allow-list entry (<type>:<uuid>, e.g. workspace:1234-...).
--lifetime string, $CODER_TOKEN_LIFETIME
Duration for the token lifetime. Supports standard Go duration units
(ns, us, ms, s, m, h) plus d (days) and y (years). Examples: 8h, 30d,
1y, 1d12h30m.
Specify a duration for the lifetime of the token.
-n, --name string, $CODER_TOKEN_NAME
Specify a human-readable name.
--scope string-array
Repeatable scope to attach to the token (e.g. workspace:read).
-u, --user string, $CODER_TOKEN_USER
Specify the user to create the token for (Only works if logged in user
is admin).
+1 -1
View File
@@ -12,7 +12,7 @@ OPTIONS:
Specifies whether all users' tokens will be listed or not (must have
Owner role to see all tokens).
-c, --column [id|name|scopes|allow list|last used|expires at|created at|owner] (default: id,name,scopes,allow list,last used,expires at,created at)
-c, --column [id|name|last used|expires at|created at|owner] (default: id,name,last used,expires at,created at)
Columns to display in table output.
-o, --output table|json (default: table)
-16
View File
@@ -1,16 +0,0 @@
coder v0.0.0-devel
USAGE:
coder tokens view [flags] <name|id>
Display detailed information about a token
OPTIONS:
-c, --column [id|name|scopes|allow list|last used|expires at|created at|owner] (default: id,name,scopes,allow list,last used,expires at,created at,owner)
Columns to display in table output.
-o, --output table|json (default: table)
Output format.
———
Run `coder --help` for a list of global options.
+4 -21
View File
@@ -714,7 +714,8 @@ workspace_prebuilds:
# (default: 3, type: int)
failure_hard_limit: 3
aibridge:
# Whether to start an in-memory aibridged instance.
# Whether to start an in-memory aibridged instance ("aibridge" experiment must be
# enabled, too).
# (default: false, type: bool)
enabled: false
# The base URL of the OpenAI API.
@@ -725,25 +726,7 @@ aibridge:
openai_key: ""
# The base URL of the Anthropic API.
# (default: https://api.anthropic.com/, type: string)
anthropic_base_url: https://api.anthropic.com/
base_url: https://api.anthropic.com/
# The key to authenticate against the Anthropic API.
# (default: <unset>, type: string)
anthropic_key: ""
# The AWS Bedrock API region.
# (default: <unset>, type: string)
bedrock_region: ""
# The access key to authenticate against the AWS Bedrock API.
# (default: <unset>, type: string)
bedrock_access_key: ""
# The access key secret to use with the access key to authenticate against the AWS
# Bedrock API.
# (default: <unset>, type: string)
bedrock_access_key_secret: ""
# The model to use when making requests to the AWS Bedrock API.
# (default: global.anthropic.claude-sonnet-4-5-20250929-v1:0, type: string)
bedrock_model: global.anthropic.claude-sonnet-4-5-20250929-v1:0
# The small fast model to use when making requests to the AWS Bedrock API. Claude
# Code uses Haiku-class models to perform background tasks. See
# https://docs.claude.com/en/docs/claude-code/settings#environment-variables.
# (default: global.anthropic.claude-haiku-4-5-20251001-v1:0, type: string)
bedrock_small_fast_model: global.anthropic.claude-haiku-4-5-20251001-v1:0
key: ""
+6 -104
View File
@@ -4,14 +4,12 @@ import (
"fmt"
"os"
"slices"
"sort"
"strings"
"time"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
@@ -29,10 +27,6 @@ func (r *RootCmd) tokens() *serpent.Command {
Description: "List your tokens",
Command: "coder tokens ls",
},
Example{
Description: "Create a scoped token",
Command: "coder tokens create --scope workspace:read --allow workspace:<uuid>",
},
Example{
Description: "Remove a token by ID",
Command: "coder tokens rm WuoWs4ZsMX",
@@ -45,7 +39,6 @@ func (r *RootCmd) tokens() *serpent.Command {
Children: []*serpent.Command{
r.createToken(),
r.listTokens(),
r.viewToken(),
r.removeToken(),
},
}
@@ -57,8 +50,6 @@ func (r *RootCmd) createToken() *serpent.Command {
tokenLifetime string
name string
user string
scopes []string
allowList []codersdk.APIAllowListTarget
)
cmd := &serpent.Command{
Use: "create",
@@ -97,18 +88,10 @@ func (r *RootCmd) createToken() *serpent.Command {
}
}
req := codersdk.CreateTokenRequest{
res, err := client.CreateToken(inv.Context(), userID, codersdk.CreateTokenRequest{
Lifetime: parsedLifetime,
TokenName: name,
}
if len(req.Scopes) == 0 {
req.Scopes = slice.StringEnums[codersdk.APIKeyScope](scopes)
}
if len(allowList) > 0 {
req.AllowList = append([]codersdk.APIAllowListTarget(nil), allowList...)
}
res, err := client.CreateToken(inv.Context(), userID, req)
})
if err != nil {
return xerrors.Errorf("create tokens: %w", err)
}
@@ -123,7 +106,7 @@ func (r *RootCmd) createToken() *serpent.Command {
{
Flag: "lifetime",
Env: "CODER_TOKEN_LIFETIME",
Description: "Duration for the token lifetime. Supports standard Go duration units (ns, us, ms, s, m, h) plus d (days) and y (years). Examples: 8h, 30d, 1y, 1d12h30m.",
Description: "Specify a duration for the lifetime of the token.",
Value: serpent.StringOf(&tokenLifetime),
},
{
@@ -140,16 +123,6 @@ func (r *RootCmd) createToken() *serpent.Command {
Description: "Specify the user to create the token for (Only works if logged in user is admin).",
Value: serpent.StringOf(&user),
},
{
Flag: "scope",
Description: "Repeatable scope to attach to the token (e.g. workspace:read).",
Value: serpent.StringArrayOf(&scopes),
},
{
Flag: "allow",
Description: "Repeatable allow-list entry (<type>:<uuid>, e.g. workspace:1234-...).",
Value: AllowListFlagOf(&allowList),
},
}
return cmd
@@ -163,8 +136,6 @@ type tokenListRow struct {
// For table format:
ID string `json:"-" table:"id,default_sort"`
TokenName string `json:"token_name" table:"name"`
Scopes string `json:"-" table:"scopes"`
Allow string `json:"-" table:"allow list"`
LastUsed time.Time `json:"-" table:"last used"`
ExpiresAt time.Time `json:"-" table:"expires at"`
CreatedAt time.Time `json:"-" table:"created at"`
@@ -172,47 +143,20 @@ type tokenListRow struct {
}
func tokenListRowFromToken(token codersdk.APIKeyWithOwner) tokenListRow {
return tokenListRowFromKey(token.APIKey, token.Username)
}
func tokenListRowFromKey(token codersdk.APIKey, owner string) tokenListRow {
return tokenListRow{
APIKey: token,
APIKey: token.APIKey,
ID: token.ID,
TokenName: token.TokenName,
Scopes: joinScopes(token.Scopes),
Allow: joinAllowList(token.AllowList),
LastUsed: token.LastUsed,
ExpiresAt: token.ExpiresAt,
CreatedAt: token.CreatedAt,
Owner: owner,
Owner: token.Username,
}
}
func joinScopes(scopes []codersdk.APIKeyScope) string {
if len(scopes) == 0 {
return ""
}
vals := slice.ToStrings(scopes)
sort.Strings(vals)
return strings.Join(vals, ", ")
}
func joinAllowList(entries []codersdk.APIAllowListTarget) string {
if len(entries) == 0 {
return ""
}
vals := make([]string, len(entries))
for i, entry := range entries {
vals[i] = entry.String()
}
sort.Strings(vals)
return strings.Join(vals, ", ")
}
func (r *RootCmd) listTokens() *serpent.Command {
// we only display the 'owner' column if the --all argument is passed in
defaultCols := []string{"id", "name", "scopes", "allow list", "last used", "expires at", "created at"}
defaultCols := []string{"id", "name", "last used", "expires at", "created at"}
if slices.Contains(os.Args, "-a") || slices.Contains(os.Args, "--all") {
defaultCols = append(defaultCols, "owner")
}
@@ -282,48 +226,6 @@ func (r *RootCmd) listTokens() *serpent.Command {
return cmd
}
func (r *RootCmd) viewToken() *serpent.Command {
formatter := cliui.NewOutputFormatter(
cliui.TableFormat([]tokenListRow{}, []string{"id", "name", "scopes", "allow list", "last used", "expires at", "created at", "owner"}),
cliui.JSONFormat(),
)
cmd := &serpent.Command{
Use: "view <name|id>",
Short: "Display detailed information about a token",
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
tokenName := inv.Args[0]
token, err := client.APIKeyByName(inv.Context(), codersdk.Me, tokenName)
if err != nil {
maybeID := strings.Split(tokenName, "-")[0]
token, err = client.APIKeyByID(inv.Context(), codersdk.Me, maybeID)
if err != nil {
return xerrors.Errorf("fetch api key by name or id: %w", err)
}
}
row := tokenListRowFromKey(*token, "")
out, err := formatter.Format(inv.Context(), []tokenListRow{row})
if err != nil {
return err
}
_, err = fmt.Fprintln(inv.Stdout, out)
return err
},
}
formatter.AttachOptions(&cmd.Options)
return cmd
}
func (r *RootCmd) removeToken() *serpent.Command {
cmd := &serpent.Command{
Use: "remove <name|id|token>",
+3 -56
View File
@@ -4,13 +4,10 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/google/uuid"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
@@ -49,18 +46,6 @@ func TestTokens(t *testing.T) {
require.NotEmpty(t, res)
id := res[:10]
allowWorkspaceID := uuid.New()
allowSpec := fmt.Sprintf("workspace:%s", allowWorkspaceID.String())
inv, root = clitest.New(t, "tokens", "create", "--name", "scoped-token", "--scope", string(codersdk.APIKeyScopeWorkspaceRead), "--allow", allowSpec)
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
res = buf.String()
require.NotEmpty(t, res)
scopedTokenID := res[:10]
// Test creating a token for second user from first user's (admin) session
inv, root = clitest.New(t, "tokens", "create", "--name", "token-two", "--user", secondUser.ID.String())
clitest.SetupConfig(t, client, root)
@@ -82,7 +67,7 @@ func TestTokens(t *testing.T) {
require.NoError(t, err)
res = buf.String()
require.NotEmpty(t, res)
// Result should only contain the tokens created for the admin user
// Result should only contain the token created for the admin user
require.Contains(t, res, "ID")
require.Contains(t, res, "EXPIRES AT")
require.Contains(t, res, "CREATED AT")
@@ -91,16 +76,6 @@ func TestTokens(t *testing.T) {
// Result should not contain the token created for the second user
require.NotContains(t, res, secondTokenID)
inv, root = clitest.New(t, "tokens", "view", "scoped-token")
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
res = buf.String()
require.Contains(t, res, string(codersdk.APIKeyScopeWorkspaceRead))
require.Contains(t, res, allowSpec)
// Test listing tokens from the second user's session
inv, root = clitest.New(t, "tokens", "ls")
clitest.SetupConfig(t, secondUserClient, root)
@@ -126,14 +101,6 @@ func TestTokens(t *testing.T) {
// User (non-admin) should not be able to create a token for another user
require.Error(t, err)
inv, root = clitest.New(t, "tokens", "create", "--name", "invalid-allow", "--allow", "badvalue")
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.Error(t, err)
require.Contains(t, err.Error(), "invalid allow_list entry")
inv, root = clitest.New(t, "tokens", "ls", "--output=json")
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
@@ -143,17 +110,8 @@ func TestTokens(t *testing.T) {
var tokens []codersdk.APIKey
require.NoError(t, json.Unmarshal(buf.Bytes(), &tokens))
require.Len(t, tokens, 2)
tokenByName := make(map[string]codersdk.APIKey, len(tokens))
for _, tk := range tokens {
tokenByName[tk.TokenName] = tk
}
require.Contains(t, tokenByName, "token-one")
require.Contains(t, tokenByName, "scoped-token")
scopedToken := tokenByName["scoped-token"]
require.Contains(t, scopedToken.Scopes, codersdk.APIKeyScopeWorkspaceRead)
require.Len(t, scopedToken.AllowList, 1)
require.Equal(t, allowSpec, scopedToken.AllowList[0].String())
require.Len(t, tokens, 1)
require.Equal(t, id, tokens[0].ID)
// Delete by name
inv, root = clitest.New(t, "tokens", "rm", "token-one")
@@ -177,17 +135,6 @@ func TestTokens(t *testing.T) {
require.NotEmpty(t, res)
require.Contains(t, res, "deleted")
// Delete scoped token by ID
inv, root = clitest.New(t, "tokens", "rm", scopedTokenID)
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
res = buf.String()
require.NotEmpty(t, res)
require.Contains(t, res, "deleted")
// Create third token
inv, root = clitest.New(t, "tokens", "create", "--name", "token-three")
clitest.SetupConfig(t, client, root)
-4
View File
@@ -239,10 +239,6 @@ func (a *API) Serve(ctx context.Context, l net.Listener) error {
return xerrors.Errorf("create agent API server: %w", err)
}
if err := a.ResourcesMonitoringAPI.InitMonitors(ctx); err != nil {
return xerrors.Errorf("initialize resource monitoring: %w", err)
}
return server.Serve(ctx, l)
}
+35 -52
View File
@@ -5,7 +5,6 @@ import (
"database/sql"
"errors"
"fmt"
"sync"
"time"
"golang.org/x/xerrors"
@@ -34,60 +33,42 @@ type ResourcesMonitoringAPI struct {
Debounce time.Duration
Config resourcesmonitor.Config
// Cache resource monitors on first call to avoid millions of DB queries per day.
memoryMonitor database.WorkspaceAgentMemoryResourceMonitor
volumeMonitors []database.WorkspaceAgentVolumeResourceMonitor
monitorsLock sync.RWMutex
}
// InitMonitors fetches resource monitors from the database and caches them.
// This must be called once after creating a ResourcesMonitoringAPI, the context should be
// the agent per-RPC connection context. If fetching fails with a real error (not sql.ErrNoRows), the
// connection should be torn down.
func (a *ResourcesMonitoringAPI) InitMonitors(ctx context.Context) error {
memMon, err := a.Database.FetchMemoryResourceMonitorsByAgentID(ctx, a.AgentID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("fetch memory resource monitor: %w", err)
}
// If sql.ErrNoRows, memoryMonitor stays as zero value (CreatedAt.IsZero() = true).
// Otherwise, store the fetched monitor.
if err == nil {
a.memoryMonitor = memMon
func (a *ResourcesMonitoringAPI) GetResourcesMonitoringConfiguration(ctx context.Context, _ *proto.GetResourcesMonitoringConfigurationRequest) (*proto.GetResourcesMonitoringConfigurationResponse, error) {
memoryMonitor, memoryErr := a.Database.FetchMemoryResourceMonitorsByAgentID(ctx, a.AgentID)
if memoryErr != nil && !errors.Is(memoryErr, sql.ErrNoRows) {
return nil, xerrors.Errorf("failed to fetch memory resource monitor: %w", memoryErr)
}
volMons, err := a.Database.FetchVolumesResourceMonitorsByAgentID(ctx, a.AgentID)
volumeMonitors, err := a.Database.FetchVolumesResourceMonitorsByAgentID(ctx, a.AgentID)
if err != nil {
return xerrors.Errorf("fetch volume resource monitors: %w", err)
return nil, xerrors.Errorf("failed to fetch volume resource monitors: %w", err)
}
// 0 length is valid, indicating none configured, since the volume monitors in the DB can be many.
a.volumeMonitors = volMons
return nil
}
func (a *ResourcesMonitoringAPI) GetResourcesMonitoringConfiguration(_ context.Context, _ *proto.GetResourcesMonitoringConfigurationRequest) (*proto.GetResourcesMonitoringConfigurationResponse, error) {
return &proto.GetResourcesMonitoringConfigurationResponse{
Config: &proto.GetResourcesMonitoringConfigurationResponse_Config{
CollectionIntervalSeconds: int32(a.Config.CollectionInterval.Seconds()),
NumDatapoints: a.Config.NumDatapoints,
},
Memory: func() *proto.GetResourcesMonitoringConfigurationResponse_Memory {
if a.memoryMonitor.CreatedAt.IsZero() {
if memoryErr != nil {
return nil
}
return &proto.GetResourcesMonitoringConfigurationResponse_Memory{
Enabled: a.memoryMonitor.Enabled,
Enabled: memoryMonitor.Enabled,
}
}(),
Volumes: func() []*proto.GetResourcesMonitoringConfigurationResponse_Volume {
volumes := make([]*proto.GetResourcesMonitoringConfigurationResponse_Volume, 0, len(a.volumeMonitors))
for _, monitor := range a.volumeMonitors {
volumes := make([]*proto.GetResourcesMonitoringConfigurationResponse_Volume, 0, len(volumeMonitors))
for _, monitor := range volumeMonitors {
volumes = append(volumes, &proto.GetResourcesMonitoringConfigurationResponse_Volume{
Enabled: monitor.Enabled,
Path: monitor.Path,
})
}
return volumes
}(),
}, nil
@@ -96,10 +77,6 @@ func (a *ResourcesMonitoringAPI) GetResourcesMonitoringConfiguration(_ context.C
func (a *ResourcesMonitoringAPI) PushResourcesMonitoringUsage(ctx context.Context, req *proto.PushResourcesMonitoringUsageRequest) (*proto.PushResourcesMonitoringUsageResponse, error) {
var err error
// Lock for the entire push operation since calls are sequential from the agent
a.monitorsLock.Lock()
defer a.monitorsLock.Unlock()
if memoryErr := a.monitorMemory(ctx, req.Datapoints); memoryErr != nil {
err = errors.Join(err, xerrors.Errorf("monitor memory: %w", memoryErr))
}
@@ -112,7 +89,18 @@ func (a *ResourcesMonitoringAPI) PushResourcesMonitoringUsage(ctx context.Contex
}
func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints []*proto.PushResourcesMonitoringUsageRequest_Datapoint) error {
if !a.memoryMonitor.Enabled {
monitor, err := a.Database.FetchMemoryResourceMonitorsByAgentID(ctx, a.AgentID)
if err != nil {
// It is valid for an agent to not have a memory monitor, so we
// do not want to treat it as an error.
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return xerrors.Errorf("fetch memory resource monitor: %w", err)
}
if !monitor.Enabled {
return nil
}
@@ -121,15 +109,15 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
usageDatapoints = append(usageDatapoints, datapoint.Memory)
}
usageStates := resourcesmonitor.CalculateMemoryUsageStates(a.memoryMonitor, usageDatapoints)
usageStates := resourcesmonitor.CalculateMemoryUsageStates(monitor, usageDatapoints)
oldState := a.memoryMonitor.State
oldState := monitor.State
newState := resourcesmonitor.NextState(a.Config, oldState, usageStates)
debouncedUntil, shouldNotify := a.memoryMonitor.Debounce(a.Debounce, a.Clock.Now(), oldState, newState)
debouncedUntil, shouldNotify := monitor.Debounce(a.Debounce, a.Clock.Now(), oldState, newState)
//nolint:gocritic // We need to be able to update the resource monitor here.
err := a.Database.UpdateMemoryResourceMonitor(dbauthz.AsResourceMonitor(ctx), database.UpdateMemoryResourceMonitorParams{
err = a.Database.UpdateMemoryResourceMonitor(dbauthz.AsResourceMonitor(ctx), database.UpdateMemoryResourceMonitorParams{
AgentID: a.AgentID,
State: newState,
UpdatedAt: dbtime.Time(a.Clock.Now()),
@@ -139,11 +127,6 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
return xerrors.Errorf("update workspace monitor: %w", err)
}
// Update cached state
a.memoryMonitor.State = newState
a.memoryMonitor.DebouncedUntil = dbtime.Time(debouncedUntil)
a.memoryMonitor.UpdatedAt = dbtime.Time(a.Clock.Now())
if !shouldNotify {
return nil
}
@@ -160,7 +143,7 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
notifications.TemplateWorkspaceOutOfMemory,
map[string]string{
"workspace": workspace.Name,
"threshold": fmt.Sprintf("%d%%", a.memoryMonitor.Threshold),
"threshold": fmt.Sprintf("%d%%", monitor.Threshold),
},
map[string]any{
// NOTE(DanielleMaywood):
@@ -186,9 +169,14 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
}
func (a *ResourcesMonitoringAPI) monitorVolumes(ctx context.Context, datapoints []*proto.PushResourcesMonitoringUsageRequest_Datapoint) error {
volumeMonitors, err := a.Database.FetchVolumesResourceMonitorsByAgentID(ctx, a.AgentID)
if err != nil {
return xerrors.Errorf("get or insert volume monitor: %w", err)
}
outOfDiskVolumes := make([]map[string]any, 0)
for i, monitor := range a.volumeMonitors {
for _, monitor := range volumeMonitors {
if !monitor.Enabled {
continue
}
@@ -231,11 +219,6 @@ func (a *ResourcesMonitoringAPI) monitorVolumes(ctx context.Context, datapoints
}); err != nil {
return xerrors.Errorf("update workspace monitor: %w", err)
}
// Update cached state
a.volumeMonitors[i].State = newState
a.volumeMonitors[i].DebouncedUntil = dbtime.Time(debouncedUntil)
a.volumeMonitors[i].UpdatedAt = dbtime.Time(a.Clock.Now())
}
if len(outOfDiskVolumes) == 0 {
@@ -101,9 +101,6 @@ func TestMemoryResourceMonitorDebounce(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: The monitor is given a state that will trigger NOK
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
@@ -307,9 +304,6 @@ func TestMemoryResourceMonitor(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
clock.Set(collectedAt)
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: datapoints,
@@ -343,8 +337,6 @@ func TestMemoryResourceMonitorMissingData(t *testing.T) {
State: database.WorkspaceAgentMonitorStateOK,
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: A datapoint is missing, surrounded by two NOK datapoints.
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
@@ -395,9 +387,6 @@ func TestMemoryResourceMonitorMissingData(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: A datapoint is missing, surrounded by two OK datapoints.
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
@@ -477,9 +466,6 @@ func TestVolumeResourceMonitorDebounce(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When:
// - First monitor is in a NOK state
// - Second monitor is in an OK state
@@ -756,9 +742,6 @@ func TestVolumeResourceMonitor(t *testing.T) {
Threshold: tt.thresholdPercent,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
clock.Set(collectedAt)
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: datapoints,
@@ -797,9 +780,6 @@ func TestVolumeResourceMonitorMultiple(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: both of them move to a NOK state
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
@@ -852,9 +832,6 @@ func TestVolumeResourceMonitorMissingData(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: A datapoint is missing, surrounded by two NOK datapoints.
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
@@ -914,9 +891,6 @@ func TestVolumeResourceMonitorMissingData(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: A datapoint is missing, surrounded by two OK datapoints.
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
+57 -25
View File
@@ -7,6 +7,7 @@ import (
"net/http"
"net/url"
"slices"
"strings"
"time"
"github.com/google/uuid"
@@ -23,12 +24,62 @@ import (
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/searchquery"
"github.com/coder/coder/v2/coderd/taskname"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk"
aiagentapi "github.com/coder/agentapi-sdk-go"
)
// This endpoint is experimental and not guaranteed to be stable, so we're not
// generating public-facing documentation for it.
func (api *API) aiTasksPrompts(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
buildIDsParam := r.URL.Query().Get("build_ids")
if buildIDsParam == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "build_ids query parameter is required",
})
return
}
// Parse build IDs
buildIDStrings := strings.Split(buildIDsParam, ",")
buildIDs := make([]uuid.UUID, 0, len(buildIDStrings))
for _, idStr := range buildIDStrings {
id, err := uuid.Parse(strings.TrimSpace(idStr))
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Invalid build ID format: %s", idStr),
Detail: err.Error(),
})
return
}
buildIDs = append(buildIDs, id)
}
parameters, err := api.Database.GetWorkspaceBuildParametersByBuildIDs(ctx, buildIDs)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace build parameters.",
Detail: err.Error(),
})
return
}
promptsByBuildID := make(map[string]string, len(parameters))
for _, param := range parameters {
if param.Name != codersdk.AITaskPromptParameterName {
continue
}
buildID := param.WorkspaceBuildID.String()
promptsByBuildID[buildID] = param.Value
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AITasksPromptsResponse{
Prompts: promptsByBuildID,
})
}
// @Summary Create a new AI task
// @Description: EXPERIMENTAL: this endpoint is experimental and not guaranteed to be stable.
// @ID create-task
@@ -123,31 +174,13 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
}
}
// Check if the template defines the AI Prompt parameter.
templateParams, err := api.Database.GetTemplateVersionParameters(ctx, req.TemplateVersionID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching template parameters.",
Detail: err.Error(),
})
return
}
var richParams []codersdk.WorkspaceBuildParameter
if _, hasAIPromptParam := slice.Find(templateParams, func(param database.TemplateVersionParameter) bool {
return param.Name == codersdk.AITaskPromptParameterName
}); hasAIPromptParam {
// Only add the AI Prompt parameter if the template defines it.
richParams = []codersdk.WorkspaceBuildParameter{
{Name: codersdk.AITaskPromptParameterName, Value: req.Input},
}
}
createReq := codersdk.CreateWorkspaceRequest{
Name: taskName,
TemplateVersionID: req.TemplateVersionID,
TemplateVersionPresetID: req.TemplateVersionPresetID,
RichParameterValues: richParams,
RichParameterValues: []codersdk.WorkspaceBuildParameter{
{Name: codersdk.AITaskPromptParameterName, Value: req.Input},
},
}
var owner workspaceOwner
@@ -208,7 +241,6 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
// Create task record in the database before creating the workspace so that
// we can request that the workspace be linked to it after creation.
dbTaskTable, err = tx.InsertTask(ctx, database.InsertTaskParams{
ID: uuid.New(),
OrganizationID: templateVersion.OrganizationID,
OwnerID: owner.ID,
Name: taskName,
@@ -306,8 +338,8 @@ func taskFromDBTaskAndWorkspace(dbTask database.Task, ws codersdk.Workspace) cod
ID: dbTask.ID,
OrganizationID: dbTask.OrganizationID,
OwnerID: dbTask.OwnerID,
OwnerName: dbTask.OwnerUsername,
OwnerAvatarURL: dbTask.OwnerAvatarUrl,
OwnerName: ws.OwnerName,
OwnerAvatarURL: ws.OwnerAvatarURL,
Name: dbTask.Name,
TemplateID: ws.TemplateID,
TemplateVersionID: dbTask.TemplateVersionID,
+139 -90
View File
@@ -1,7 +1,6 @@
package coderd_test
import (
"context"
"database/sql"
"encoding/json"
"io"
@@ -35,6 +34,128 @@ import (
"github.com/coder/coder/v2/testutil"
)
func TestAITasksPrompts(t *testing.T) {
t.Parallel()
t.Run("EmptyBuildIDs", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{})
_ = coderdtest.CreateFirstUser(t, client)
experimentalClient := codersdk.NewExperimentalClient(client)
ctx := testutil.Context(t, testutil.WaitShort)
// Test with empty build IDs
prompts, err := experimentalClient.AITaskPrompts(ctx, []uuid.UUID{})
require.NoError(t, err)
require.Empty(t, prompts.Prompts)
})
t.Run("MultipleBuilds", func(t *testing.T) {
t.Parallel()
adminClient := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
first := coderdtest.CreateFirstUser(t, adminClient)
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, first.OrganizationID)
ctx := testutil.Context(t, testutil.WaitLong)
// Create a template with parameters
version := coderdtest.CreateTemplateVersion(t, adminClient, first.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: []*proto.Response{{
Type: &proto.Response_Plan{
Plan: &proto.PlanComplete{
Parameters: []*proto.RichParameter{
{
Name: "param1",
Type: "string",
DefaultValue: "default1",
},
{
Name: codersdk.AITaskPromptParameterName,
Type: "string",
DefaultValue: "default2",
},
},
},
},
}},
ProvisionApply: echo.ApplyComplete,
})
template := coderdtest.CreateTemplate(t, adminClient, first.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJobCompleted(t, adminClient, version.ID)
// Create two workspaces with different parameters
workspace1 := coderdtest.CreateWorkspace(t, memberClient, template.ID, func(request *codersdk.CreateWorkspaceRequest) {
request.RichParameterValues = []codersdk.WorkspaceBuildParameter{
{Name: "param1", Value: "value1a"},
{Name: codersdk.AITaskPromptParameterName, Value: "value2a"},
}
})
coderdtest.AwaitWorkspaceBuildJobCompleted(t, memberClient, workspace1.LatestBuild.ID)
workspace2 := coderdtest.CreateWorkspace(t, memberClient, template.ID, func(request *codersdk.CreateWorkspaceRequest) {
request.RichParameterValues = []codersdk.WorkspaceBuildParameter{
{Name: "param1", Value: "value1b"},
{Name: codersdk.AITaskPromptParameterName, Value: "value2b"},
}
})
coderdtest.AwaitWorkspaceBuildJobCompleted(t, memberClient, workspace2.LatestBuild.ID)
workspace3 := coderdtest.CreateWorkspace(t, adminClient, template.ID, func(request *codersdk.CreateWorkspaceRequest) {
request.RichParameterValues = []codersdk.WorkspaceBuildParameter{
{Name: "param1", Value: "value1c"},
{Name: codersdk.AITaskPromptParameterName, Value: "value2c"},
}
})
coderdtest.AwaitWorkspaceBuildJobCompleted(t, adminClient, workspace3.LatestBuild.ID)
allBuildIDs := []uuid.UUID{workspace1.LatestBuild.ID, workspace2.LatestBuild.ID, workspace3.LatestBuild.ID}
experimentalMemberClient := codersdk.NewExperimentalClient(memberClient)
// Test parameters endpoint as member
prompts, err := experimentalMemberClient.AITaskPrompts(ctx, allBuildIDs)
require.NoError(t, err)
// we expect 2 prompts because the member client does not have access to workspace3
// since it was created by the admin client
require.Len(t, prompts.Prompts, 2)
// Check workspace1 parameters
build1Prompt := prompts.Prompts[workspace1.LatestBuild.ID.String()]
require.Equal(t, "value2a", build1Prompt)
// Check workspace2 parameters
build2Prompt := prompts.Prompts[workspace2.LatestBuild.ID.String()]
require.Equal(t, "value2b", build2Prompt)
experimentalAdminClient := codersdk.NewExperimentalClient(adminClient)
// Test parameters endpoint as admin
// we expect 3 prompts because the admin client has access to all workspaces
prompts, err = experimentalAdminClient.AITaskPrompts(ctx, allBuildIDs)
require.NoError(t, err)
require.Len(t, prompts.Prompts, 3)
// Check workspace3 parameters
build3Prompt := prompts.Prompts[workspace3.LatestBuild.ID.String()]
require.Equal(t, "value2c", build3Prompt)
})
t.Run("NonExistentBuildIDs", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{})
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitShort)
// Test with non-existent build IDs
nonExistentID := uuid.New()
experimentalClient := codersdk.NewExperimentalClient(client)
prompts, err := experimentalClient.AITaskPrompts(ctx, []uuid.UUID{nonExistentID})
require.NoError(t, err)
require.Empty(t, prompts.Prompts)
})
}
func TestTasks(t *testing.T) {
t.Parallel()
@@ -66,6 +187,7 @@ func TestTasks(t *testing.T) {
{
Type: &proto.Response_Plan{
Plan: &proto.PlanComplete{
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
HasAiTasks: true,
},
},
@@ -136,9 +258,6 @@ func TestTasks(t *testing.T) {
// Wait for the workspace to be built.
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
if assert.True(t, workspace.TaskID.Valid, "task id should be set on workspace") {
assert.Equal(t, task.ID, workspace.TaskID.UUID, "workspace task id should match")
}
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
// List tasks via experimental API and verify the prompt and status mapping.
@@ -177,9 +296,6 @@ func TestTasks(t *testing.T) {
// Get the workspace and wait for it to be ready.
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
if assert.True(t, ws.TaskID.Valid, "task id should be set on workspace") {
assert.Equal(t, task.ID, ws.TaskID.UUID, "workspace task id should match")
}
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
ws = coderdtest.MustWorkspace(t, client, task.WorkspaceID.UUID)
// Assert invariant: the workspace has exactly one resource with one agent with one app.
@@ -254,23 +370,24 @@ func TestTasks(t *testing.T) {
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
if assert.True(t, ws.TaskID.Valid, "task id should be set on workspace") {
assert.Equal(t, task.ID, ws.TaskID.UUID, "workspace task id should match")
}
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
err = exp.DeleteTask(ctx, "me", task.ID)
require.NoError(t, err, "delete task request should be accepted")
// Poll until the workspace is deleted.
testutil.Eventually(ctx, t, func(ctx context.Context) (done bool) {
for {
dws, derr := client.DeletedWorkspace(ctx, task.WorkspaceID.UUID)
if !assert.NoError(t, derr, "expected to fetch deleted workspace before deadline") {
return false
if derr == nil && dws.LatestBuild.Status == codersdk.WorkspaceStatusDeleted {
break
}
t.Logf("workspace latest_build status: %q", dws.LatestBuild.Status)
return dws.LatestBuild.Status == codersdk.WorkspaceStatusDeleted
}, testutil.IntervalMedium, "workspace should be deleted before deadline")
if ctx.Err() != nil {
require.NoError(t, derr, "expected to fetch deleted workspace before deadline")
require.Equal(t, codersdk.WorkspaceStatusDeleted, dws.LatestBuild.Status, "workspace should be deleted before deadline")
break
}
time.Sleep(testutil.IntervalMedium)
}
})
t.Run("NotFound", func(t *testing.T) {
@@ -303,9 +420,6 @@ func TestTasks(t *testing.T) {
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
ws := coderdtest.CreateWorkspace(t, client, template.ID)
if assert.False(t, ws.TaskID.Valid, "task id should not be set on non-task workspace") {
assert.Zero(t, ws.TaskID, "non-task workspace task id should be empty")
}
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
exp := codersdk.NewExperimentalClient(client)
@@ -354,32 +468,6 @@ func TestTasks(t *testing.T) {
t.Fatalf("unexpected status code: %d (expected 403 or 404)", authErr.StatusCode())
}
})
t.Run("NoWorkspace", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
template := createAITemplate(t, client, user)
ctx := testutil.Context(t, testutil.WaitLong)
exp := codersdk.NewExperimentalClient(client)
task, err := exp.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
TemplateVersionID: template.ActiveVersionID,
Input: "delete me",
})
require.NoError(t, err)
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
// Delete the task workspace
coderdtest.MustTransitionWorkspace(t, client, ws.ID, codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionDelete)
// We should still be able to fetch the task after deleting its workspace
task, err = exp.TaskByID(ctx, task.ID)
require.NoError(t, err, "fetching a task should still work after deleting its related workspace")
err = exp.DeleteTask(ctx, task.OwnerID.String(), task.ID)
require.NoError(t, err, "should be possible to delete a task with no workspace")
})
})
t.Run("Send", func(t *testing.T) {
@@ -694,51 +782,6 @@ func TestTasksCreate(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionApply: echo.ApplyComplete,
ProvisionPlan: []*proto.Response{
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
HasAiTasks: true,
}}},
},
})
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
expClient := codersdk.NewExperimentalClient(client)
task, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
TemplateVersionID: template.ActiveVersionID,
Input: taskPrompt,
})
require.NoError(t, err)
require.True(t, task.WorkspaceID.Valid)
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
assert.NotEmpty(t, task.Name)
assert.Equal(t, template.ID, task.TemplateID)
parameters, err := client.WorkspaceBuildParameters(ctx, ws.LatestBuild.ID)
require.NoError(t, err)
require.Len(t, parameters, 0)
})
t.Run("OK AIPromptBackCompat", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
taskPrompt = "Some task prompt"
)
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
// Given: A template with an "AI Prompt" parameter
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
@@ -818,6 +861,7 @@ func TestTasksCreate(t *testing.T) {
ProvisionApply: echo.ApplyComplete,
ProvisionPlan: []*proto.Response{
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
HasAiTasks: true,
}}},
},
@@ -933,6 +977,7 @@ func TestTasksCreate(t *testing.T) {
ProvisionApply: echo.ApplyComplete,
ProvisionPlan: []*proto.Response{
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
HasAiTasks: true,
}}},
},
@@ -992,6 +1037,7 @@ func TestTasksCreate(t *testing.T) {
ProvisionApply: echo.ApplyComplete,
ProvisionPlan: []*proto.Response{
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
HasAiTasks: true,
}}},
},
@@ -1028,6 +1074,7 @@ func TestTasksCreate(t *testing.T) {
ProvisionApply: echo.ApplyComplete,
ProvisionPlan: []*proto.Response{
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
HasAiTasks: true,
}}},
},
@@ -1080,6 +1127,7 @@ func TestTasksCreate(t *testing.T) {
ProvisionApply: echo.ApplyComplete,
ProvisionPlan: []*proto.Response{
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
HasAiTasks: true,
}}},
},
@@ -1092,6 +1140,7 @@ func TestTasksCreate(t *testing.T) {
ProvisionApply: echo.ApplyComplete,
ProvisionPlan: []*proto.Response{
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
HasAiTasks: true,
}}},
},
+6 -59
View File
@@ -85,7 +85,7 @@ const docTemplate = `{
}
}
},
"/aibridge/interceptions": {
"/api/experimental/aibridge/interceptions": {
"get": {
"security": [
{
@@ -11668,35 +11668,12 @@ const docTemplate = `{
}
}
},
"codersdk.AIBridgeBedrockConfig": {
"type": "object",
"properties": {
"access_key": {
"type": "string"
},
"access_key_secret": {
"type": "string"
},
"model": {
"type": "string"
},
"region": {
"type": "string"
},
"small_fast_model": {
"type": "string"
}
}
},
"codersdk.AIBridgeConfig": {
"type": "object",
"properties": {
"anthropic": {
"$ref": "#/definitions/codersdk.AIBridgeAnthropicConfig"
},
"bedrock": {
"$ref": "#/definitions/codersdk.AIBridgeBedrockConfig"
},
"enabled": {
"type": "boolean"
},
@@ -11708,10 +11685,6 @@ const docTemplate = `{
"codersdk.AIBridgeInterception": {
"type": "object",
"properties": {
"ended_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string",
"format": "uuid"
@@ -12523,13 +12496,6 @@ const docTemplate = `{
"type": "string",
"format": "uuid"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
@@ -13753,13 +13719,6 @@ const docTemplate = `{
"name": {
"type": "string"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific to the organization the role belongs to.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific to the organization the role belongs to.",
"type": "array",
@@ -14316,9 +14275,11 @@ const docTemplate = `{
"web-push",
"oauth2",
"mcp-server-http",
"workspace-sharing"
"workspace-sharing",
"aibridge"
],
"x-enum-comments": {
"ExperimentAIBridge": "Enables AI Bridge functionality.",
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
"ExperimentExample": "This isn't used for anything.",
"ExperimentMCPServerHTTP": "Enables the MCP HTTP server functionality.",
@@ -14336,7 +14297,8 @@ const docTemplate = `{
"ExperimentWebPush",
"ExperimentOAuth2",
"ExperimentMCPServerHTTP",
"ExperimentWorkspaceSharing"
"ExperimentWorkspaceSharing",
"ExperimentAIBridge"
]
},
"codersdk.ExternalAPIKeyScopes": {
@@ -17524,13 +17486,6 @@ const docTemplate = `{
"type": "string",
"format": "uuid"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
@@ -19712,14 +19667,6 @@ const docTemplate = `{
"description": "OwnerName is the username of the owner of the workspace.",
"type": "string"
},
"task_id": {
"description": "TaskID, if set, indicates that the workspace is relevant to the given codersdk.Task.",
"allOf": [
{
"$ref": "#/definitions/uuid.NullUUID"
}
]
},
"template_active_version_id": {
"type": "string",
"format": "uuid"
+6 -59
View File
@@ -65,7 +65,7 @@
}
}
},
"/aibridge/interceptions": {
"/api/experimental/aibridge/interceptions": {
"get": {
"security": [
{
@@ -10364,35 +10364,12 @@
}
}
},
"codersdk.AIBridgeBedrockConfig": {
"type": "object",
"properties": {
"access_key": {
"type": "string"
},
"access_key_secret": {
"type": "string"
},
"model": {
"type": "string"
},
"region": {
"type": "string"
},
"small_fast_model": {
"type": "string"
}
}
},
"codersdk.AIBridgeConfig": {
"type": "object",
"properties": {
"anthropic": {
"$ref": "#/definitions/codersdk.AIBridgeAnthropicConfig"
},
"bedrock": {
"$ref": "#/definitions/codersdk.AIBridgeBedrockConfig"
},
"enabled": {
"type": "boolean"
},
@@ -10404,10 +10381,6 @@
"codersdk.AIBridgeInterception": {
"type": "object",
"properties": {
"ended_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string",
"format": "uuid"
@@ -11205,13 +11178,6 @@
"type": "string",
"format": "uuid"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
@@ -12367,13 +12333,6 @@
"name": {
"type": "string"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific to the organization the role belongs to.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific to the organization the role belongs to.",
"type": "array",
@@ -12923,9 +12882,11 @@
"web-push",
"oauth2",
"mcp-server-http",
"workspace-sharing"
"workspace-sharing",
"aibridge"
],
"x-enum-comments": {
"ExperimentAIBridge": "Enables AI Bridge functionality.",
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
"ExperimentExample": "This isn't used for anything.",
"ExperimentMCPServerHTTP": "Enables the MCP HTTP server functionality.",
@@ -12943,7 +12904,8 @@
"ExperimentWebPush",
"ExperimentOAuth2",
"ExperimentMCPServerHTTP",
"ExperimentWorkspaceSharing"
"ExperimentWorkspaceSharing",
"ExperimentAIBridge"
]
},
"codersdk.ExternalAPIKeyScopes": {
@@ -16016,13 +15978,6 @@
"type": "string",
"format": "uuid"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
@@ -18098,14 +18053,6 @@
"description": "OwnerName is the username of the owner of the workspace.",
"type": "string"
},
"task_id": {
"description": "TaskID, if set, indicates that the workspace is relevant to the given codersdk.Task.",
"allOf": [
{
"$ref": "#/definitions/uuid.NullUUID"
}
]
},
"template_active_version_id": {
"type": "string",
"format": "uuid"
+2 -2
View File
@@ -509,11 +509,11 @@ func (api *API) auditLogResourceLink(ctx context.Context, alog database.GetAudit
if err != nil {
return ""
}
user, err := api.Database.GetUserByID(ctx, task.OwnerID)
workspace, err := api.Database.GetWorkspaceByID(ctx, task.WorkspaceID.UUID)
if err != nil {
return ""
}
return fmt.Sprintf("/tasks/%s/%s", user.Username, task.ID)
return fmt.Sprintf("/tasks/%s/%s", workspace.OwnerName, task.Name)
default:
return ""
+10 -11
View File
@@ -50,13 +50,6 @@ func TestCheckPermissions(t *testing.T) {
},
Action: "read",
},
readOrgWorkspaces: {
Object: codersdk.AuthorizationObject{
ResourceType: codersdk.ResourceWorkspace,
OrganizationID: adminUser.OrganizationID.String(),
},
Action: "read",
},
readMyself: {
Object: codersdk.AuthorizationObject{
ResourceType: codersdk.ResourceUser,
@@ -65,10 +58,16 @@ func TestCheckPermissions(t *testing.T) {
Action: "read",
},
readOwnWorkspaces: {
Object: codersdk.AuthorizationObject{
ResourceType: codersdk.ResourceWorkspace,
OwnerID: "me",
},
Action: "read",
},
readOrgWorkspaces: {
Object: codersdk.AuthorizationObject{
ResourceType: codersdk.ResourceWorkspace,
OrganizationID: adminUser.OrganizationID.String(),
OwnerID: "me",
},
Action: "read",
},
@@ -93,9 +92,9 @@ func TestCheckPermissions(t *testing.T) {
UserID: adminUser.UserID,
Check: map[string]bool{
readAllUsers: true,
readOrgWorkspaces: true,
readMyself: true,
readOwnWorkspaces: true,
readOrgWorkspaces: true,
updateSpecificTemplate: true,
},
},
@@ -105,9 +104,9 @@ func TestCheckPermissions(t *testing.T) {
UserID: orgAdminUser.ID,
Check: map[string]bool{
readAllUsers: true,
readOrgWorkspaces: true,
readMyself: true,
readOwnWorkspaces: true,
readOrgWorkspaces: true,
updateSpecificTemplate: true,
},
},
@@ -117,9 +116,9 @@ func TestCheckPermissions(t *testing.T) {
UserID: memberUser.ID,
Check: map[string]bool{
readAllUsers: false,
readOrgWorkspaces: false,
readMyself: true,
readOwnWorkspaces: true,
readOrgWorkspaces: false,
updateSpecificTemplate: false,
},
},
-172
View File
@@ -1764,175 +1764,3 @@ func TestExecutorAutostartSkipsWhenNoProvisionersAvailable(t *testing.T) {
assert.Len(t, stats.Transitions, 1, "should create builds when provisioners are available")
}
func TestExecutorTaskWorkspace(t *testing.T) {
t.Parallel()
createTaskTemplate := func(t *testing.T, client *codersdk.Client, orgID uuid.UUID, ctx context.Context, defaultTTL time.Duration) codersdk.Template {
t.Helper()
taskAppID := uuid.New()
version := coderdtest.CreateTemplateVersion(t, client, orgID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: []*proto.Response{
{
Type: &proto.Response_Plan{
Plan: &proto.PlanComplete{HasAiTasks: true},
},
},
},
ProvisionApply: []*proto.Response{
{
Type: &proto.Response_Apply{
Apply: &proto.ApplyComplete{
Resources: []*proto.Resource{
{
Agents: []*proto.Agent{
{
Id: uuid.NewString(),
Name: "dev",
Auth: &proto.Agent_Token{
Token: uuid.NewString(),
},
Apps: []*proto.App{
{
Id: taskAppID.String(),
Slug: "task-app",
},
},
},
},
},
},
AiTasks: []*proto.AITask{
{
AppId: taskAppID.String(),
},
},
},
},
},
},
})
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, orgID, version.ID)
if defaultTTL > 0 {
_, err := client.UpdateTemplateMeta(ctx, template.ID, codersdk.UpdateTemplateMeta{
DefaultTTLMillis: defaultTTL.Milliseconds(),
})
require.NoError(t, err)
}
return template
}
createTaskWorkspace := func(t *testing.T, client *codersdk.Client, template codersdk.Template, ctx context.Context, input string) codersdk.Workspace {
t.Helper()
exp := codersdk.NewExperimentalClient(client)
task, err := exp.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
TemplateVersionID: template.ActiveVersionID,
Input: input,
})
require.NoError(t, err)
require.True(t, task.WorkspaceID.Valid, "task should have a workspace")
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
return workspace
}
t.Run("Autostart", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
sched = mustSchedule(t, "CRON_TZ=UTC 0 * * * *")
tickCh = make(chan time.Time)
statsCh = make(chan autobuild.Stats)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
AutobuildTicker: tickCh,
IncludeProvisionerDaemon: true,
AutobuildStats: statsCh,
})
admin = coderdtest.CreateFirstUser(t, client)
)
// Given: A task workspace
template := createTaskTemplate(t, client, admin.OrganizationID, ctx, 0)
workspace := createTaskWorkspace(t, client, template, ctx, "test task for autostart")
// Given: The task workspace has an autostart schedule
err := client.UpdateWorkspaceAutostart(ctx, workspace.ID, codersdk.UpdateWorkspaceAutostartRequest{
Schedule: ptr.Ref(sched.String()),
})
require.NoError(t, err)
// Given: That the workspace is in a stopped state.
workspace = coderdtest.MustTransitionWorkspace(t, client, workspace.ID, codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionStop)
p, err := coderdtest.GetProvisionerForTags(db, time.Now(), workspace.OrganizationID, map[string]string{})
require.NoError(t, err)
// When: the autobuild executor ticks after the scheduled time
go func() {
tickTime := sched.Next(workspace.LatestBuild.CreatedAt)
coderdtest.UpdateProvisionerLastSeenAt(t, db, p.ID, tickTime)
tickCh <- tickTime
close(tickCh)
}()
// Then: We expect to see a start transition
stats := <-statsCh
require.Len(t, stats.Transitions, 1, "lifecycle executor should transition the task workspace")
assert.Contains(t, stats.Transitions, workspace.ID, "task workspace should be in transitions")
assert.Equal(t, database.WorkspaceTransitionStart, stats.Transitions[workspace.ID], "should autostart the workspace")
require.Empty(t, stats.Errors, "should have no errors when managing task workspaces")
})
t.Run("Autostop", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
tickCh = make(chan time.Time)
statsCh = make(chan autobuild.Stats)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
AutobuildTicker: tickCh,
IncludeProvisionerDaemon: true,
AutobuildStats: statsCh,
})
admin = coderdtest.CreateFirstUser(t, client)
)
// Given: A task workspace with an 8 hour deadline
template := createTaskTemplate(t, client, admin.OrganizationID, ctx, 8*time.Hour)
workspace := createTaskWorkspace(t, client, template, ctx, "test task for autostop")
// Given: The workspace is currently running
workspace = coderdtest.MustWorkspace(t, client, workspace.ID)
require.Equal(t, codersdk.WorkspaceTransitionStart, workspace.LatestBuild.Transition)
require.NotZero(t, workspace.LatestBuild.Deadline, "workspace should have a deadline for autostop")
p, err := coderdtest.GetProvisionerForTags(db, time.Now(), workspace.OrganizationID, map[string]string{})
require.NoError(t, err)
// When: the autobuild executor ticks after the deadline
go func() {
tickTime := workspace.LatestBuild.Deadline.Time.Add(time.Minute)
coderdtest.UpdateProvisionerLastSeenAt(t, db, p.ID, tickTime)
tickCh <- tickTime
close(tickCh)
}()
// Then: We expect to see a stop transition
stats := <-statsCh
require.Len(t, stats.Transitions, 1, "lifecycle executor should transition the task workspace")
assert.Contains(t, stats.Transitions, workspace.ID, "task workspace should be in transitions")
assert.Equal(t, database.WorkspaceTransitionStop, stats.Transitions[workspace.ID], "should autostop the workspace")
require.Empty(t, stats.Errors, "should have no errors when managing task workspaces")
})
}
+4 -1
View File
@@ -1021,7 +1021,10 @@ func New(options *Options) *API {
apiRateLimiter,
httpmw.ReportCLITelemetry(api.Logger, options.Telemetry),
)
r.Route("/aitasks", func(r chi.Router) {
r.Use(apiKeyMiddleware)
r.Get("/prompts", api.aiTasksPrompts)
})
r.Route("/tasks", func(r chi.Router) {
r.Use(apiKeyMiddleware)
-1
View File
@@ -14,7 +14,6 @@ const (
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
CheckWorkspaceBuildsAiTaskSidebarAppIDRequired CheckConstraint = "workspace_builds_ai_task_sidebar_app_id_required" // workspace_builds
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
)
+8 -13
View File
@@ -714,13 +714,12 @@ func RBACRole(role rbac.Role) codersdk.Role {
orgPerms := role.ByOrgID[slim.OrganizationID]
return codersdk.Role{
Name: slim.Name,
OrganizationID: slim.OrganizationID,
DisplayName: slim.DisplayName,
SitePermissions: List(role.Site, RBACPermission),
UserPermissions: List(role.User, RBACPermission),
OrganizationPermissions: List(orgPerms.Org, RBACPermission),
OrganizationMemberPermissions: List(orgPerms.Member, RBACPermission),
Name: slim.Name,
OrganizationID: slim.OrganizationID,
DisplayName: slim.DisplayName,
SitePermissions: List(role.Site, RBACPermission),
OrganizationPermissions: List(orgPerms.Org, RBACPermission),
UserPermissions: List(role.User, RBACPermission),
}
}
@@ -735,8 +734,8 @@ func Role(role database.CustomRole) codersdk.Role {
OrganizationID: orgID,
DisplayName: role.DisplayName,
SitePermissions: List(role.SitePermissions, Permission),
UserPermissions: List(role.UserPermissions, Permission),
OrganizationPermissions: List(role.OrgPermissions, Permission),
UserPermissions: List(role.UserPermissions, Permission),
}
}
@@ -963,7 +962,7 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
// created_at ASC
return sdkToolUsages[i].CreatedAt.Before(sdkToolUsages[j].CreatedAt)
})
intc := codersdk.AIBridgeInterception{
return codersdk.AIBridgeInterception{
ID: interception.ID,
Initiator: MinimalUserFromVisibleUser(initiator),
Provider: interception.Provider,
@@ -974,10 +973,6 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
UserPrompts: sdkUserPrompts,
ToolUsages: sdkToolUsages,
}
if interception.EndedAt.Valid {
intc.EndedAt = &interception.EndedAt.Time
}
return intc
}
func AIBridgeTokenUsage(usage database.AIBridgeTokenUsage) codersdk.AIBridgeTokenUsage {
+18 -77
View File
@@ -254,7 +254,6 @@ var (
rbac.ResourceFile.Type: {policy.ActionRead}, // Required to read terraform files
rbac.ResourceNotificationMessage.Type: {policy.ActionCreate, policy.ActionRead},
rbac.ResourceSystem.Type: {policy.WildcardSymbol},
rbac.ResourceTask.Type: {policy.ActionRead, policy.ActionUpdate},
rbac.ResourceTemplate.Type: {policy.ActionRead, policy.ActionUpdate},
rbac.ResourceUser.Type: {policy.ActionRead},
rbac.ResourceWorkspace.Type: {policy.ActionDelete, policy.ActionRead, policy.ActionUpdate, policy.ActionWorkspaceStart, policy.ActionWorkspaceStop},
@@ -396,13 +395,11 @@ var (
Identifier: rbac.RoleIdentifier{Name: "subagentapi"},
DisplayName: "Sub Agent API",
Site: []rbac.Permission{},
User: []rbac.Permission{},
User: rbac.Permissions(map[string][]policy.Action{
rbac.ResourceWorkspace.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreateAgent, policy.ActionDeleteAgent},
}),
ByOrgID: map[string]rbac.OrgPermissions{
orgID.String(): {
Member: rbac.Permissions(map[string][]policy.Action{
rbac.ResourceWorkspace.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreateAgent, policy.ActionDeleteAgent},
}),
},
orgID.String(): {},
},
},
}),
@@ -1293,17 +1290,14 @@ func (q *querier) customRoleCheck(ctx context.Context, role database.CustomRole)
return xerrors.Errorf("invalid role: %w", err)
}
if len(rbacRole.ByOrgID) > 0 && (len(rbacRole.Site) > 0 || len(rbacRole.User) > 0) {
// This is a choice to keep roles simple. If we allow mixing site and org
// scoped perms, then knowing who can do what gets more complicated. Roles
// should either be entirely org-scoped or entirely unrelated to
// organizations.
return xerrors.Errorf("invalid custom role, cannot assign both org-scoped and site/user permissions at the same time")
if len(rbacRole.ByOrgID) > 0 && len(rbacRole.Site) > 0 {
// This is a choice to keep roles simple. If we allow mixing site and org scoped perms, then knowing who can
// do what gets more complicated.
return xerrors.Errorf("invalid custom role, cannot assign both org and site permissions at the same time")
}
if len(rbacRole.ByOrgID) > 1 {
// Again to avoid more complexity in our roles. Roles are limited to one
// organization.
// Again to avoid more complexity in our roles
return xerrors.Errorf("invalid custom role, cannot assign permissions to more than 1 org at a time")
}
@@ -1319,18 +1313,7 @@ func (q *querier) customRoleCheck(ctx context.Context, role database.CustomRole)
for _, orgPerm := range perms.Org {
err := q.customRoleEscalationCheck(ctx, act, orgPerm, rbac.Object{OrgID: orgID, Type: orgPerm.ResourceType})
if err != nil {
return xerrors.Errorf("org=%q: org: %w", orgID, err)
}
}
for _, memberPerm := range perms.Member {
// The person giving the permission should still be required to have
// the permissions throughout the org in order to give individuals the
// same permission among their own resources, since the role can be given
// to anyone. The `Owner` is intentionally omitted from the `Object` to
// enforce this.
err := q.customRoleEscalationCheck(ctx, act, memberPerm, rbac.Object{OrgID: orgID, Type: memberPerm.ResourceType})
if err != nil {
return xerrors.Errorf("org=%q: member: %w", orgID, err)
return xerrors.Errorf("org=%q: %w", orgID, err)
}
}
}
@@ -1348,8 +1331,8 @@ func (q *querier) customRoleCheck(ctx context.Context, role database.CustomRole)
func (q *querier) authorizeProvisionerJob(ctx context.Context, job database.ProvisionerJob) error {
switch job.Type {
case database.ProvisionerJobTypeWorkspaceBuild:
// Authorized call to get workspace build. If we can read the build, we can
// read the job.
// Authorized call to get workspace build. If we can read the build, we
// can read the job.
_, err := q.GetWorkspaceBuildByJobID(ctx, job.ID)
if err != nil {
return xerrors.Errorf("fetch related workspace build: %w", err)
@@ -1392,8 +1375,8 @@ func (q *querier) ActivityBumpWorkspace(ctx context.Context, arg database.Activi
}
func (q *querier) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error) {
// Although this technically only reads users, only system-related functions
// should be allowed to call this.
// Although this technically only reads users, only system-related functions should be
// allowed to call this.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
@@ -1412,8 +1395,8 @@ func (q *querier) ArchiveUnusedTemplateVersions(ctx context.Context, arg databas
}
func (q *querier) BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg database.BatchUpdateWorkspaceLastUsedAtParams) error {
// Could be any workspace and checking auth to each workspace is overkill for
// the purpose of this function.
// Could be any workspace and checking auth to each workspace is overkill for the purpose
// of this function.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceWorkspace.All()); err != nil {
return err
}
@@ -1441,13 +1424,6 @@ func (q *querier) BulkMarkNotificationMessagesSent(ctx context.Context, arg data
return q.db.BulkMarkNotificationMessagesSent(ctx, arg)
}
func (q *querier) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg database.CalculateAIBridgeInterceptionsTelemetrySummaryParams) (database.CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
return database.CalculateAIBridgeInterceptionsTelemetrySummaryRow{}, err
}
return q.db.CalculateAIBridgeInterceptionsTelemetrySummary(ctx, arg)
}
func (q *querier) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) {
empty := database.ClaimPrebuiltWorkspaceRow{}
@@ -1747,13 +1723,6 @@ func (q *querier) DeleteOldProvisionerDaemons(ctx context.Context) error {
return q.db.DeleteOldProvisionerDaemons(ctx)
}
func (q *querier) DeleteOldTelemetryLocks(ctx context.Context, beforeTime time.Time) error {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
return err
}
return q.db.DeleteOldTelemetryLocks(ctx, beforeTime)
}
func (q *querier) DeleteOldWorkspaceAgentLogs(ctx context.Context, threshold time.Time) error {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
return err
@@ -2649,13 +2618,6 @@ func (q *querier) GetOrganizationsByUserID(ctx context.Context, userID database.
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetOrganizationsByUserID)(ctx, userID)
}
func (q *querier) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg database.GetOrganizationsWithPrebuildStatusParams) ([]database.GetOrganizationsWithPrebuildStatusRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceOrganization.All()); err != nil {
return nil, err
}
return q.db.GetOrganizationsWithPrebuildStatus(ctx, arg)
}
func (q *querier) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
version, err := q.db.GetTemplateVersionByJobID(ctx, jobID)
if err != nil {
@@ -4250,13 +4212,6 @@ func (q *querier) InsertTelemetryItemIfNotExists(ctx context.Context, arg databa
return q.db.InsertTelemetryItemIfNotExists(ctx, arg)
}
func (q *querier) InsertTelemetryLock(ctx context.Context, arg database.InsertTelemetryLockParams) error {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.InsertTelemetryLock(ctx, arg)
}
func (q *querier) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) error {
obj := rbac.ResourceTemplate.InOrg(arg.OrganizationID)
if err := q.authorizeContext(ctx, policy.ActionCreate, obj); err != nil {
@@ -4568,13 +4523,6 @@ func (q *querier) ListAIBridgeInterceptions(ctx context.Context, arg database.Li
return q.db.ListAuthorizedAIBridgeInterceptions(ctx, arg, prep)
}
func (q *querier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg database.ListAIBridgeInterceptionsTelemetrySummariesParams) ([]database.ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
return nil, err
}
return q.db.ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg)
}
func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
// This function is a system function until we implement a join for aibridge interceptions.
// Matches the behavior of the workspaces listing endpoint.
@@ -4763,13 +4711,6 @@ func (q *querier) UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error {
return update(q.log, q.auth, fetch, q.db.UnfavoriteWorkspace)(ctx, id)
}
func (q *querier) UpdateAIBridgeInterceptionEnded(ctx context.Context, params database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, params.ID); err != nil {
return database.AIBridgeInterception{}, err
}
return q.db.UpdateAIBridgeInterceptionEnded(ctx, params)
}
func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
fetch := func(ctx context.Context, arg database.UpdateAPIKeyByIDParams) (database.APIKey, error) {
return q.db.GetAPIKeyByID(ctx, arg.ID)
@@ -4941,10 +4882,10 @@ func (q *querier) UpdateOrganizationDeletedByID(ctx context.Context, arg databas
return deleteQ(q.log, q.auth, q.db.GetOrganizationByID, deleteF)(ctx, arg.ID)
}
func (q *querier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
func (q *querier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
// Prebuild operation for canceling pending prebuild jobs from non-active template versions
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourcePrebuiltWorkspace); err != nil {
return []database.UpdatePrebuildProvisionerJobWithCancelRow{}, err
return []uuid.UUID{}, err
}
return q.db.UpdatePrebuildProvisionerJobWithCancel(ctx, arg)
}
+3 -45
View File
@@ -646,13 +646,10 @@ func (s *MethodTestSuite) TestProvisionerJob() {
PresetID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
Now: dbtime.Now(),
}
canceledJobs := []database.UpdatePrebuildProvisionerJobWithCancelRow{
{ID: uuid.New(), WorkspaceID: uuid.New(), TemplateID: uuid.New(), TemplateVersionPresetID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
{ID: uuid.New(), WorkspaceID: uuid.New(), TemplateID: uuid.New(), TemplateVersionPresetID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
}
jobIDs := []uuid.UUID{uuid.New(), uuid.New()}
dbm.EXPECT().UpdatePrebuildProvisionerJobWithCancel(gomock.Any(), arg).Return(canceledJobs, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourcePrebuiltWorkspace, policy.ActionUpdate).Returns(canceledJobs)
dbm.EXPECT().UpdatePrebuildProvisionerJobWithCancel(gomock.Any(), arg).Return(jobIDs, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourcePrebuiltWorkspace, policy.ActionUpdate).Returns(jobIDs)
}))
s.Run("GetProvisionerJobsByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
org := testutil.Fake(s.T(), faker, database.Organization{})
@@ -3759,14 +3756,6 @@ func (s *MethodTestSuite) TestPrebuilds() {
dbm.EXPECT().GetPrebuildMetrics(gomock.Any()).Return([]database.GetPrebuildMetricsRow{}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead)
}))
s.Run("GetOrganizationsWithPrebuildStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.GetOrganizationsWithPrebuildStatusParams{
UserID: uuid.New(),
GroupName: "test",
}
dbm.EXPECT().GetOrganizationsWithPrebuildStatus(gomock.Any(), arg).Return([]database.GetOrganizationsWithPrebuildStatusRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceOrganization.All(), policy.ActionRead)
}))
s.Run("GetPrebuildsSettings", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetPrebuildsSettings(gomock.Any()).Return("{}", nil).AnyTimes()
check.Args().Asserts()
@@ -4628,35 +4617,4 @@ func (s *MethodTestSuite) TestAIBridge() {
db.EXPECT().ListAIBridgeToolUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeToolUsage{}, nil).AnyTimes()
check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.AIBridgeToolUsage{})
}))
s.Run("UpdateAIBridgeInterceptionEnded", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intcID := uuid.UUID{1}
params := database.UpdateAIBridgeInterceptionEndedParams{ID: intcID}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intcID})
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intcID).Return(intc, nil).AnyTimes() // Validation.
db.EXPECT().UpdateAIBridgeInterceptionEnded(gomock.Any(), params).Return(intc, nil).AnyTimes()
check.Args(params).Asserts(intc, policy.ActionUpdate).Returns(intc)
}))
}
func (s *MethodTestSuite) TestTelemetry() {
s.Run("InsertTelemetryLock", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
db.EXPECT().InsertTelemetryLock(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
check.Args(database.InsertTelemetryLockParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate)
}))
s.Run("DeleteOldTelemetryLocks", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
db.EXPECT().DeleteOldTelemetryLocks(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
check.Args(time.Time{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
}))
s.Run("ListAIBridgeInterceptionsTelemetrySummaries", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
db.EXPECT().ListAIBridgeInterceptionsTelemetrySummaries(gomock.Any(), gomock.Any()).Return([]database.ListAIBridgeInterceptionsTelemetrySummariesRow{}, nil).AnyTimes()
check.Args(database.ListAIBridgeInterceptionsTelemetrySummariesParams{}).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead)
}))
s.Run("CalculateAIBridgeInterceptionsTelemetrySummary", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
db.EXPECT().CalculateAIBridgeInterceptionsTelemetrySummary(gomock.Any(), gomock.Any()).Return(database.CalculateAIBridgeInterceptionsTelemetrySummaryRow{}, nil).AnyTimes()
check.Args(database.CalculateAIBridgeInterceptionsTelemetrySummaryParams{}).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead)
}))
}
+7 -61
View File
@@ -41,7 +41,6 @@ type WorkspaceResponse struct {
Build database.WorkspaceBuild
AgentToken string
TemplateVersionResponse
Task database.Task
}
// WorkspaceBuildBuilder generates workspace builds and associated
@@ -58,7 +57,6 @@ type WorkspaceBuildBuilder struct {
agentToken string
jobStatus database.ProvisionerJobStatus
taskAppID uuid.UUID
taskSeed database.TaskTable
}
// WorkspaceBuild generates a workspace build for the provided workspace.
@@ -117,28 +115,25 @@ func (b WorkspaceBuildBuilder) WithAgent(mutations ...func([]*sdkproto.Agent) []
return b
}
func (b WorkspaceBuildBuilder) WithTask(taskSeed database.TaskTable, appSeed *sdkproto.App) WorkspaceBuildBuilder {
//nolint:revive // returns modified struct
b.taskSeed = taskSeed
if appSeed == nil {
appSeed = &sdkproto.App{}
func (b WorkspaceBuildBuilder) WithTask(seed *sdkproto.App) WorkspaceBuildBuilder {
if seed == nil {
seed = &sdkproto.App{}
}
var err error
//nolint: revive // returns modified struct
b.taskAppID, err = uuid.Parse(takeFirst(appSeed.Id, uuid.NewString()))
b.taskAppID, err = uuid.Parse(takeFirst(seed.Id, uuid.NewString()))
require.NoError(b.t, err)
return b.Params(database.WorkspaceBuildParameter{
Name: codersdk.AITaskPromptParameterName,
Value: b.taskSeed.Prompt,
Value: "list me",
}).WithAgent(func(a []*sdkproto.Agent) []*sdkproto.Agent {
a[0].Apps = []*sdkproto.App{
{
Id: b.taskAppID.String(),
Slug: takeFirst(appSeed.Slug, "task-app"),
Url: takeFirst(appSeed.Url, ""),
Slug: takeFirst(seed.Slug, "task-app"),
Url: takeFirst(seed.Url, ""),
},
}
return a
@@ -166,19 +161,6 @@ func (b WorkspaceBuildBuilder) Canceled() WorkspaceBuildBuilder {
// Workspace will be optionally populated if no ID is set on the provided
// workspace.
func (b WorkspaceBuildBuilder) Do() WorkspaceResponse {
var resp WorkspaceResponse
// Use transaction, like real wsbuilder.
err := b.db.InTx(func(tx database.Store) error {
//nolint:revive // calls do on modified struct
b.db = tx
resp = b.doInTX()
return nil
}, nil)
require.NoError(b.t, err)
return resp
}
func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
b.t.Helper()
jobID := uuid.New()
b.seed.ID = uuid.New()
@@ -230,37 +212,6 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
b.seed.WorkspaceID = b.ws.ID
b.seed.InitiatorID = takeFirst(b.seed.InitiatorID, b.ws.OwnerID)
// If a task was requested, ensure it exists and is associated with this
// workspace.
if b.taskAppID != uuid.Nil {
b.logger.Debug(context.Background(), "creating or updating task", "task_id", b.taskSeed.ID)
b.taskSeed.OrganizationID = takeFirst(b.taskSeed.OrganizationID, b.ws.OrganizationID)
b.taskSeed.OwnerID = takeFirst(b.taskSeed.OwnerID, b.ws.OwnerID)
b.taskSeed.Name = takeFirst(b.taskSeed.Name, b.ws.Name)
b.taskSeed.WorkspaceID = uuid.NullUUID{UUID: takeFirst(b.taskSeed.WorkspaceID.UUID, b.ws.ID), Valid: true}
b.taskSeed.TemplateVersionID = takeFirst(b.taskSeed.TemplateVersionID, b.seed.TemplateVersionID)
// Try to fetch existing task and update its workspace ID.
if task, err := b.db.GetTaskByID(ownerCtx, b.taskSeed.ID); err == nil {
if !task.WorkspaceID.Valid {
b.logger.Info(context.Background(), "updating task workspace id", "task_id", b.taskSeed.ID, "workspace_id", b.ws.ID)
_, err = b.db.UpdateTaskWorkspaceID(ownerCtx, database.UpdateTaskWorkspaceIDParams{
ID: b.taskSeed.ID,
WorkspaceID: uuid.NullUUID{UUID: b.ws.ID, Valid: true},
})
require.NoError(b.t, err, "update task workspace id")
} else if task.WorkspaceID.UUID != b.ws.ID {
require.Fail(b.t, "task already has a workspace id, mismatch", task.WorkspaceID.UUID, b.ws.ID)
}
} else if errors.Is(err, sql.ErrNoRows) {
task := dbgen.Task(b.t, b.db, b.taskSeed)
b.taskSeed.ID = task.ID
b.logger.Info(context.Background(), "created new task", "task_id", b.taskSeed.ID)
} else {
require.NoError(b.t, err, "get task by id")
}
}
// Create a provisioner job for the build!
payload, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: b.seed.ID,
@@ -373,11 +324,6 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
b.logger.Debug(context.Background(), "linked task to workspace build",
slog.F("task_id", task.ID),
slog.F("build_number", resp.Build.BuildNumber))
// Update task after linking.
task, err = b.db.GetTaskByID(ownerCtx, task.ID)
require.NoError(b.t, err, "get task by id")
resp.Task = task
}
for i := range b.params {
+1 -9
View File
@@ -1495,7 +1495,7 @@ func ClaimPrebuild(
return claimedWorkspace
}
func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertAIBridgeInterceptionParams, endedAt *time.Time) database.AIBridgeInterception {
func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertAIBridgeInterceptionParams) database.AIBridgeInterception {
interception, err := db.InsertAIBridgeInterception(genCtx, database.InsertAIBridgeInterceptionParams{
ID: takeFirst(seed.ID, uuid.New()),
InitiatorID: takeFirst(seed.InitiatorID, uuid.New()),
@@ -1504,13 +1504,6 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
StartedAt: takeFirst(seed.StartedAt, dbtime.Now()),
})
if endedAt != nil {
interception, err = db.UpdateAIBridgeInterceptionEnded(genCtx, database.UpdateAIBridgeInterceptionEndedParams{
ID: interception.ID,
EndedAt: *endedAt,
})
require.NoError(t, err, "insert aibridge interception")
}
require.NoError(t, err, "insert aibridge interception")
return interception
}
@@ -1576,7 +1569,6 @@ func Task(t testing.TB, db database.Store, orig database.TaskTable) database.Tas
}
task, err := db.InsertTask(genCtx, database.InsertTaskParams{
ID: takeFirst(orig.ID, uuid.New()),
OrganizationID: orig.OrganizationID,
OwnerID: orig.OwnerID,
Name: takeFirst(orig.Name, taskname.GenerateFallback()),
+1 -43
View File
@@ -158,13 +158,6 @@ func (m queryMetricsStore) BulkMarkNotificationMessagesSent(ctx context.Context,
return r0, r1
}
func (m queryMetricsStore) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg database.CalculateAIBridgeInterceptionsTelemetrySummaryParams) (database.CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
start := time.Now()
r0, r1 := m.s.CalculateAIBridgeInterceptionsTelemetrySummary(ctx, arg)
m.queryLatencies.WithLabelValues("CalculateAIBridgeInterceptionsTelemetrySummary").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) {
start := time.Now()
r0, r1 := m.s.ClaimPrebuiltWorkspace(ctx, arg)
@@ -410,13 +403,6 @@ func (m queryMetricsStore) DeleteOldProvisionerDaemons(ctx context.Context) erro
return r0
}
func (m queryMetricsStore) DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error {
start := time.Now()
r0 := m.s.DeleteOldTelemetryLocks(ctx, periodEndingAtBefore)
m.queryLatencies.WithLabelValues("DeleteOldTelemetryLocks").Observe(time.Since(start).Seconds())
return r0
}
func (m queryMetricsStore) DeleteOldWorkspaceAgentLogs(ctx context.Context, arg time.Time) error {
start := time.Now()
r0 := m.s.DeleteOldWorkspaceAgentLogs(ctx, arg)
@@ -1243,13 +1229,6 @@ func (m queryMetricsStore) GetOrganizationsByUserID(ctx context.Context, userID
return organizations, err
}
func (m queryMetricsStore) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg database.GetOrganizationsWithPrebuildStatusParams) ([]database.GetOrganizationsWithPrebuildStatusRow, error) {
start := time.Now()
r0, r1 := m.s.GetOrganizationsWithPrebuildStatus(ctx, arg)
m.queryLatencies.WithLabelValues("GetOrganizationsWithPrebuildStatus").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
start := time.Now()
schemas, err := m.s.GetParameterSchemasByJobID(ctx, jobID)
@@ -2538,13 +2517,6 @@ func (m queryMetricsStore) InsertTelemetryItemIfNotExists(ctx context.Context, a
return r0
}
func (m queryMetricsStore) InsertTelemetryLock(ctx context.Context, arg database.InsertTelemetryLockParams) error {
start := time.Now()
r0 := m.s.InsertTelemetryLock(ctx, arg)
m.queryLatencies.WithLabelValues("InsertTelemetryLock").Observe(time.Since(start).Seconds())
return r0
}
func (m queryMetricsStore) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) error {
start := time.Now()
err := m.s.InsertTemplate(ctx, arg)
@@ -2762,13 +2734,6 @@ func (m queryMetricsStore) ListAIBridgeInterceptions(ctx context.Context, arg da
return r0, r1
}
func (m queryMetricsStore) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg database.ListAIBridgeInterceptionsTelemetrySummariesParams) ([]database.ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
start := time.Now()
r0, r1 := m.s.ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg)
m.queryLatencies.WithLabelValues("ListAIBridgeInterceptionsTelemetrySummaries").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
start := time.Now()
r0, r1 := m.s.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIds)
@@ -2923,13 +2888,6 @@ func (m queryMetricsStore) UnfavoriteWorkspace(ctx context.Context, arg uuid.UUI
return r0
}
func (m queryMetricsStore) UpdateAIBridgeInterceptionEnded(ctx context.Context, id database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
start := time.Now()
r0, r1 := m.s.UpdateAIBridgeInterceptionEnded(ctx, id)
m.queryLatencies.WithLabelValues("UpdateAIBridgeInterceptionEnded").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
start := time.Now()
err := m.s.UpdateAPIKeyByID(ctx, arg)
@@ -3049,7 +3007,7 @@ func (m queryMetricsStore) UpdateOrganizationDeletedByID(ctx context.Context, ar
return r0
}
func (m queryMetricsStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
func (m queryMetricsStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
start := time.Now()
r0, r1 := m.s.UpdatePrebuildProvisionerJobWithCancel(ctx, arg)
m.queryLatencies.WithLabelValues("UpdatePrebuildProvisionerJobWithCancel").Observe(time.Since(start).Seconds())
+2 -90
View File
@@ -190,21 +190,6 @@ func (mr *MockStoreMockRecorder) BulkMarkNotificationMessagesSent(ctx, arg any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BulkMarkNotificationMessagesSent", reflect.TypeOf((*MockStore)(nil).BulkMarkNotificationMessagesSent), ctx, arg)
}
// CalculateAIBridgeInterceptionsTelemetrySummary mocks base method.
func (m *MockStore) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg database.CalculateAIBridgeInterceptionsTelemetrySummaryParams) (database.CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CalculateAIBridgeInterceptionsTelemetrySummary", ctx, arg)
ret0, _ := ret[0].(database.CalculateAIBridgeInterceptionsTelemetrySummaryRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CalculateAIBridgeInterceptionsTelemetrySummary indicates an expected call of CalculateAIBridgeInterceptionsTelemetrySummary.
func (mr *MockStoreMockRecorder) CalculateAIBridgeInterceptionsTelemetrySummary(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CalculateAIBridgeInterceptionsTelemetrySummary", reflect.TypeOf((*MockStore)(nil).CalculateAIBridgeInterceptionsTelemetrySummary), ctx, arg)
}
// ClaimPrebuiltWorkspace mocks base method.
func (m *MockStore) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) {
m.ctrl.T.Helper()
@@ -751,20 +736,6 @@ func (mr *MockStoreMockRecorder) DeleteOldProvisionerDaemons(ctx any) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldProvisionerDaemons", reflect.TypeOf((*MockStore)(nil).DeleteOldProvisionerDaemons), ctx)
}
// DeleteOldTelemetryLocks mocks base method.
func (m *MockStore) DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteOldTelemetryLocks", ctx, periodEndingAtBefore)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteOldTelemetryLocks indicates an expected call of DeleteOldTelemetryLocks.
func (mr *MockStoreMockRecorder) DeleteOldTelemetryLocks(ctx, periodEndingAtBefore any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldTelemetryLocks", reflect.TypeOf((*MockStore)(nil).DeleteOldTelemetryLocks), ctx, periodEndingAtBefore)
}
// DeleteOldWorkspaceAgentLogs mocks base method.
func (m *MockStore) DeleteOldWorkspaceAgentLogs(ctx context.Context, threshold time.Time) error {
m.ctrl.T.Helper()
@@ -2622,21 +2593,6 @@ func (mr *MockStoreMockRecorder) GetOrganizationsByUserID(ctx, arg any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationsByUserID", reflect.TypeOf((*MockStore)(nil).GetOrganizationsByUserID), ctx, arg)
}
// GetOrganizationsWithPrebuildStatus mocks base method.
func (m *MockStore) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg database.GetOrganizationsWithPrebuildStatusParams) ([]database.GetOrganizationsWithPrebuildStatusRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOrganizationsWithPrebuildStatus", ctx, arg)
ret0, _ := ret[0].([]database.GetOrganizationsWithPrebuildStatusRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrganizationsWithPrebuildStatus indicates an expected call of GetOrganizationsWithPrebuildStatus.
func (mr *MockStoreMockRecorder) GetOrganizationsWithPrebuildStatus(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationsWithPrebuildStatus", reflect.TypeOf((*MockStore)(nil).GetOrganizationsWithPrebuildStatus), ctx, arg)
}
// GetParameterSchemasByJobID mocks base method.
func (m *MockStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
m.ctrl.T.Helper()
@@ -5436,20 +5392,6 @@ func (mr *MockStoreMockRecorder) InsertTelemetryItemIfNotExists(ctx, arg any) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertTelemetryItemIfNotExists", reflect.TypeOf((*MockStore)(nil).InsertTelemetryItemIfNotExists), ctx, arg)
}
// InsertTelemetryLock mocks base method.
func (m *MockStore) InsertTelemetryLock(ctx context.Context, arg database.InsertTelemetryLockParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertTelemetryLock", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// InsertTelemetryLock indicates an expected call of InsertTelemetryLock.
func (mr *MockStoreMockRecorder) InsertTelemetryLock(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertTelemetryLock", reflect.TypeOf((*MockStore)(nil).InsertTelemetryLock), ctx, arg)
}
// InsertTemplate mocks base method.
func (m *MockStore) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) error {
m.ctrl.T.Helper()
@@ -5905,21 +5847,6 @@ func (mr *MockStoreMockRecorder) ListAIBridgeInterceptions(ctx, arg any) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).ListAIBridgeInterceptions), ctx, arg)
}
// ListAIBridgeInterceptionsTelemetrySummaries mocks base method.
func (m *MockStore) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg database.ListAIBridgeInterceptionsTelemetrySummariesParams) ([]database.ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListAIBridgeInterceptionsTelemetrySummaries", ctx, arg)
ret0, _ := ret[0].([]database.ListAIBridgeInterceptionsTelemetrySummariesRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListAIBridgeInterceptionsTelemetrySummaries indicates an expected call of ListAIBridgeInterceptionsTelemetrySummaries.
func (mr *MockStoreMockRecorder) ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeInterceptionsTelemetrySummaries", reflect.TypeOf((*MockStore)(nil).ListAIBridgeInterceptionsTelemetrySummaries), ctx, arg)
}
// ListAIBridgeTokenUsagesByInterceptionIDs mocks base method.
func (m *MockStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
m.ctrl.T.Helper()
@@ -6289,21 +6216,6 @@ func (mr *MockStoreMockRecorder) UnfavoriteWorkspace(ctx, id any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnfavoriteWorkspace", reflect.TypeOf((*MockStore)(nil).UnfavoriteWorkspace), ctx, id)
}
// UpdateAIBridgeInterceptionEnded mocks base method.
func (m *MockStore) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateAIBridgeInterceptionEnded", ctx, arg)
ret0, _ := ret[0].(database.AIBridgeInterception)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateAIBridgeInterceptionEnded indicates an expected call of UpdateAIBridgeInterceptionEnded.
func (mr *MockStoreMockRecorder) UpdateAIBridgeInterceptionEnded(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAIBridgeInterceptionEnded", reflect.TypeOf((*MockStore)(nil).UpdateAIBridgeInterceptionEnded), ctx, arg)
}
// UpdateAPIKeyByID mocks base method.
func (m *MockStore) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
m.ctrl.T.Helper()
@@ -6555,10 +6467,10 @@ func (mr *MockStoreMockRecorder) UpdateOrganizationDeletedByID(ctx, arg any) *go
}
// UpdatePrebuildProvisionerJobWithCancel mocks base method.
func (m *MockStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
func (m *MockStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdatePrebuildProvisionerJobWithCancel", ctx, arg)
ret0, _ := ret[0].([]database.UpdatePrebuildProvisionerJobWithCancelRow)
ret0, _ := ret[0].([]uuid.UUID)
ret1, _ := ret[1].(error)
return ret0, ret1
}
-10
View File
@@ -24,12 +24,6 @@ const (
// but we won't touch the `connection_logs` table.
maxAuditLogConnectionEventAge = 90 * 24 * time.Hour // 90 days
auditLogConnectionEventBatchSize = 1000
// Telemetry heartbeats are used to deduplicate events across replicas. We
// don't need to persist heartbeat rows for longer than 24 hours, as they
// are only used for deduplication across replicas. The time needs to be
// long enough to cover the maximum interval of a heartbeat event (currently
// 1 hour) plus some buffer.
maxTelemetryHeartbeatAge = 24 * time.Hour
)
// New creates a new periodically purging database instance.
@@ -77,10 +71,6 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, clk quartz.
if err := tx.ExpirePrebuildsAPIKeys(ctx, dbtime.Time(start)); err != nil {
return xerrors.Errorf("failed to expire prebuilds user api keys: %w", err)
}
deleteOldTelemetryLocksBefore := start.Add(-maxTelemetryHeartbeatAge)
if err := tx.DeleteOldTelemetryLocks(ctx, deleteOldTelemetryLocksBefore); err != nil {
return xerrors.Errorf("failed to delete old telemetry locks: %w", err)
}
deleteOldAuditLogConnectionEventsBefore := start.Add(-maxAuditLogConnectionEventAge)
if err := tx.DeleteOldAuditLogConnectionEvents(ctx, database.DeleteOldAuditLogConnectionEventsParams{
-53
View File
@@ -704,56 +704,3 @@ func TestExpireOldAPIKeys(t *testing.T) {
// Out of an abundance of caution, we do not expire explicitly named prebuilds API keys.
assertKeyActive(namedPrebuildsAPIKey.ID)
}
func TestDeleteOldTelemetryHeartbeats(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
clk := quartz.NewMock(t)
now := clk.Now().UTC()
// Insert telemetry heartbeats.
err := db.InsertTelemetryLock(ctx, database.InsertTelemetryLockParams{
EventType: "aibridge_interceptions_summary",
PeriodEndingAt: now.Add(-25 * time.Hour), // should be purged
})
require.NoError(t, err)
err = db.InsertTelemetryLock(ctx, database.InsertTelemetryLockParams{
EventType: "aibridge_interceptions_summary",
PeriodEndingAt: now.Add(-23 * time.Hour), // should be kept
})
require.NoError(t, err)
err = db.InsertTelemetryLock(ctx, database.InsertTelemetryLockParams{
EventType: "aibridge_interceptions_summary",
PeriodEndingAt: now, // should be kept
})
require.NoError(t, err)
done := awaitDoTick(ctx, t, clk)
closer := dbpurge.New(ctx, logger, db, clk)
defer closer.Close()
<-done // doTick() has now run.
require.Eventuallyf(t, func() bool {
// We use an SQL queries directly here because we don't expose queries
// for deleting heartbeats in the application code.
var totalCount int
err := sqlDB.QueryRowContext(ctx, `
SELECT COUNT(*) FROM telemetry_locks;
`).Scan(&totalCount)
assert.NoError(t, err)
var oldCount int
err = sqlDB.QueryRowContext(ctx, `
SELECT COUNT(*) FROM telemetry_locks WHERE period_ending_at < $1;
`, now.Add(-24*time.Hour)).Scan(&oldCount)
assert.NoError(t, err)
// Expect 2 heartbeats remaining and none older than 24 hours.
t.Logf("eventually: total count: %d, old count: %d", totalCount, oldCount)
return totalCount == 2 && oldCount == 0
}, testutil.WaitShort, testutil.IntervalFast, "it should delete old telemetry heartbeats")
}
+6 -51
View File
@@ -6,8 +6,6 @@ import (
_ "embed"
"fmt"
"os"
"runtime"
"strings"
"sync"
"time"
@@ -47,8 +45,6 @@ func (b *Broker) Create(t TBSubset, opts ...OpenOption) (ConnectionParams, error
host = defaultConnectionParams.Host
port = defaultConnectionParams.Port
)
packageName := getTestPackageName(t)
testName := t.Name()
// Use a time-based prefix to make it easier to find the database
// when debugging.
@@ -59,9 +55,9 @@ func (b *Broker) Create(t TBSubset, opts ...OpenOption) (ConnectionParams, error
}
dbName := now + "_" + dbSuffix
// TODO: add package and test name
_, err = b.coderTestingDB.Exec(
"INSERT INTO test_databases (name, process_uuid, test_package, test_name) VALUES ($1, $2, $3, $4)",
dbName, b.uuid, packageName, testName)
"INSERT INTO test_databases (name, process_uuid) VALUES ($1, $2)", dbName, b.uuid)
if err != nil {
return ConnectionParams{}, xerrors.Errorf("insert test_database row: %w", err)
}
@@ -108,10 +104,10 @@ func (b *Broker) clean(t TBSubset, dbName string) func() {
func (b *Broker) init(t TBSubset) error {
b.Lock()
defer b.Unlock()
b.refCount++
t.Cleanup(b.decRef)
if b.coderTestingDB != nil {
// already initialized
b.refCount++
t.Cleanup(b.decRef)
return nil
}
@@ -128,8 +124,8 @@ func (b *Broker) init(t TBSubset) error {
return xerrors.Errorf("open postgres connection: %w", err)
}
// coderTestingSQLInit is idempotent, so we can run it every time.
_, err = coderTestingDB.Exec(coderTestingSQLInit)
// creating the db can succeed even if the database doesn't exist. Ping it to find out.
err = coderTestingDB.Ping()
var pqErr *pq.Error
if xerrors.As(err, &pqErr) && pqErr.Code == "3D000" {
// database does not exist.
@@ -149,8 +145,6 @@ func (b *Broker) init(t TBSubset) error {
return xerrors.Errorf("ping '%s' database: %w", CoderTestingDBName, err)
}
b.coderTestingDB = coderTestingDB
b.refCount++
t.Cleanup(b.decRef)
if b.uuid == uuid.Nil {
b.uuid = uuid.New()
@@ -192,42 +186,3 @@ func (b *Broker) decRef() {
b.coderTestingDB = nil
}
}
// getTestPackageName returns the package name of the test that called it.
func getTestPackageName(t TBSubset) string {
packageName := "unknown"
// Ask runtime.Callers for up to 100 program counters, including runtime.Callers itself.
pc := make([]uintptr, 100)
n := runtime.Callers(0, pc)
if n == 0 {
// No PCs available. This can happen if the first argument to
// runtime.Callers is large.
//
// Return now to avoid processing the zero Frame that would
// otherwise be returned by frames.Next below.
t.Logf("could not determine test package name: no PCs available")
return packageName
}
pc = pc[:n] // pass only valid pcs to runtime.CallersFrames
frames := runtime.CallersFrames(pc)
// Loop to get frames.
// A fixed number of PCs can expand to an indefinite number of Frames.
for {
frame, more := frames.Next()
if strings.HasPrefix(frame.Function, "github.com/coder/coder/v2/") {
packageName = strings.SplitN(strings.TrimPrefix(frame.Function, "github.com/coder/coder/v2/"), ".", 2)[0]
}
if strings.HasPrefix(frame.Function, "testing") {
break
}
// Check whether there are more frames to process after this one.
if !more {
break
}
}
return packageName
}
@@ -1,13 +0,0 @@
package dbtestutil
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGetTestPackageName(t *testing.T) {
t.Parallel()
packageName := getTestPackageName(t)
require.Equal(t, "coderd/database/dbtestutil", packageName)
}
@@ -1,6 +1,3 @@
BEGIN TRANSACTION;
SELECT pg_advisory_xact_lock(7283699);
CREATE TABLE IF NOT EXISTS test_databases (
name text PRIMARY KEY,
created_at timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
@@ -9,10 +6,3 @@ CREATE TABLE IF NOT EXISTS test_databases (
);
CREATE INDEX IF NOT EXISTS test_databases_process_uuid ON test_databases (process_uuid, dropped_at);
ALTER TABLE test_databases ADD COLUMN IF NOT EXISTS test_name text;
COMMENT ON COLUMN test_databases.test_name IS 'Name of the test that created the database';
ALTER TABLE test_databases ADD COLUMN IF NOT EXISTS test_package text;
COMMENT ON COLUMN test_databases.test_package IS 'Package of the test that created the database';
COMMIT;
+14 -41
View File
@@ -1828,15 +1828,6 @@ CREATE TABLE tasks (
deleted_at timestamp with time zone
);
CREATE VIEW visible_users AS
SELECT users.id,
users.username,
users.name,
users.avatar_url
FROM users;
COMMENT ON VIEW visible_users IS 'Visible fields of users are allowed to be joined with other tables for including context of other resources.';
CREATE TABLE workspace_agents (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
@@ -1987,16 +1978,8 @@ CREATE VIEW tasks_with_status AS
END AS status,
task_app.workspace_build_number,
task_app.workspace_agent_id,
task_app.workspace_app_id,
task_owner.owner_username,
task_owner.owner_name,
task_owner.owner_avatar_url
FROM (((((tasks
CROSS JOIN LATERAL ( SELECT vu.username AS owner_username,
vu.name AS owner_name,
vu.avatar_url AS owner_avatar_url
FROM visible_users vu
WHERE (vu.id = tasks.owner_id)) task_owner)
task_app.workspace_app_id
FROM ((((tasks
LEFT JOIN LATERAL ( SELECT task_app_1.workspace_build_number,
task_app_1.workspace_agent_id,
task_app_1.workspace_app_id
@@ -2029,18 +2012,6 @@ CREATE TABLE telemetry_items (
updated_at timestamp with time zone DEFAULT now() NOT NULL
);
CREATE TABLE telemetry_locks (
event_type text NOT NULL,
period_ending_at timestamp with time zone NOT NULL,
CONSTRAINT telemetry_lock_event_type_constraint CHECK ((event_type = 'aibridge_interceptions_summary'::text))
);
COMMENT ON TABLE telemetry_locks IS 'Telemetry lock tracking table for deduplication of heartbeat events across replicas.';
COMMENT ON COLUMN telemetry_locks.event_type IS 'The type of event that was sent.';
COMMENT ON COLUMN telemetry_locks.period_ending_at IS 'The heartbeat period end timestamp.';
CREATE TABLE template_usage_stats (
start_time timestamp with time zone NOT NULL,
end_time timestamp with time zone NOT NULL,
@@ -2227,6 +2198,15 @@ COMMENT ON COLUMN template_versions.external_auth_providers IS 'IDs of External
COMMENT ON COLUMN template_versions.message IS 'Message describing the changes in this version of the template, similar to a Git commit message. Like a commit message, this should be a short, high-level description of the changes in this version of the template. This message is immutable and should not be updated after the fact.';
CREATE VIEW visible_users AS
SELECT users.id,
users.username,
users.name,
users.avatar_url
FROM users;
COMMENT ON VIEW visible_users IS 'Visible fields of users are allowed to be joined with other tables for including context of other resources.';
CREATE VIEW template_version_with_user AS
SELECT template_versions.id,
template_versions.template_id,
@@ -2922,13 +2902,11 @@ CREATE VIEW workspaces_expanded AS
templates.name AS template_name,
templates.display_name AS template_display_name,
templates.icon AS template_icon,
templates.description AS template_description,
tasks.id AS task_id
FROM ((((workspaces
templates.description AS template_description
FROM (((workspaces
JOIN visible_users ON ((workspaces.owner_id = visible_users.id)))
JOIN organizations ON ((workspaces.organization_id = organizations.id)))
JOIN templates ON ((workspaces.template_id = templates.id)))
LEFT JOIN tasks ON ((workspaces.id = tasks.workspace_id)));
JOIN templates ON ((workspaces.template_id = templates.id)));
COMMENT ON VIEW workspaces_expanded IS 'Joins in the display name information such as username, avatar, and organization name.';
@@ -3112,9 +3090,6 @@ ALTER TABLE ONLY tasks
ALTER TABLE ONLY telemetry_items
ADD CONSTRAINT telemetry_items_pkey PRIMARY KEY (key);
ALTER TABLE ONLY telemetry_locks
ADD CONSTRAINT telemetry_locks_pkey PRIMARY KEY (event_type, period_ending_at);
ALTER TABLE ONLY template_usage_stats
ADD CONSTRAINT template_usage_stats_pkey PRIMARY KEY (start_time, template_id, user_id);
@@ -3340,8 +3315,6 @@ CREATE INDEX idx_tailnet_tunnels_dst_id ON tailnet_tunnels USING hash (dst_id);
CREATE INDEX idx_tailnet_tunnels_src_id ON tailnet_tunnels USING hash (src_id);
CREATE INDEX idx_telemetry_locks_period_ending_at ON telemetry_locks USING btree (period_ending_at);
CREATE UNIQUE INDEX idx_template_version_presets_default ON template_version_presets USING btree (template_version_id) WHERE (is_default = true);
CREATE INDEX idx_template_versions_has_ai_task ON template_versions USING btree (has_ai_task);
@@ -1 +0,0 @@
DROP TABLE telemetry_locks;
@@ -1,12 +0,0 @@
CREATE TABLE telemetry_locks (
event_type TEXT NOT NULL CONSTRAINT telemetry_lock_event_type_constraint CHECK (event_type IN ('aibridge_interceptions_summary')),
period_ending_at TIMESTAMP WITH TIME ZONE NOT NULL,
PRIMARY KEY (event_type, period_ending_at)
);
COMMENT ON TABLE telemetry_locks IS 'Telemetry lock tracking table for deduplication of heartbeat events across replicas.';
COMMENT ON COLUMN telemetry_locks.event_type IS 'The type of event that was sent.';
COMMENT ON COLUMN telemetry_locks.period_ending_at IS 'The heartbeat period end timestamp.';
CREATE INDEX idx_telemetry_locks_period_ending_at ON telemetry_locks (period_ending_at);
@@ -1,74 +0,0 @@
-- Drop view from 000390_tasks_with_status_user_fields.up.sql.
DROP VIEW IF EXISTS tasks_with_status;
-- Restore from 000382_add_columns_to_tasks_with_status.up.sql.
CREATE VIEW
tasks_with_status
AS
SELECT
tasks.*,
CASE
WHEN tasks.workspace_id IS NULL OR latest_build.job_status IS NULL THEN 'pending'::task_status
WHEN latest_build.job_status = 'failed' THEN 'error'::task_status
WHEN latest_build.transition IN ('stop', 'delete')
AND latest_build.job_status = 'succeeded' THEN 'paused'::task_status
WHEN latest_build.transition = 'start'
AND latest_build.job_status = 'pending' THEN 'initializing'::task_status
WHEN latest_build.transition = 'start' AND latest_build.job_status IN ('running', 'succeeded') THEN
CASE
WHEN agent_status.none THEN 'initializing'::task_status
WHEN agent_status.connecting THEN 'initializing'::task_status
WHEN agent_status.connected THEN
CASE
WHEN app_status.any_unhealthy THEN 'error'::task_status
WHEN app_status.any_initializing THEN 'initializing'::task_status
WHEN app_status.all_healthy_or_disabled THEN 'active'::task_status
ELSE 'unknown'::task_status
END
ELSE 'unknown'::task_status
END
ELSE 'unknown'::task_status
END AS status,
task_app.*
FROM
tasks
LEFT JOIN LATERAL (
SELECT workspace_build_number, workspace_agent_id, workspace_app_id
FROM task_workspace_apps task_app
WHERE task_id = tasks.id
ORDER BY workspace_build_number DESC
LIMIT 1
) task_app ON TRUE
LEFT JOIN LATERAL (
SELECT
workspace_build.transition,
provisioner_job.job_status,
workspace_build.job_id
FROM workspace_builds workspace_build
JOIN provisioner_jobs provisioner_job ON provisioner_job.id = workspace_build.job_id
WHERE workspace_build.workspace_id = tasks.workspace_id
AND workspace_build.build_number = task_app.workspace_build_number
) latest_build ON TRUE
CROSS JOIN LATERAL (
SELECT
COUNT(*) = 0 AS none,
bool_or(workspace_agent.lifecycle_state IN ('created', 'starting')) AS connecting,
bool_and(workspace_agent.lifecycle_state = 'ready') AS connected
FROM workspace_agents workspace_agent
WHERE workspace_agent.id = task_app.workspace_agent_id
) agent_status
CROSS JOIN LATERAL (
SELECT
bool_or(workspace_app.health = 'unhealthy') AS any_unhealthy,
bool_or(workspace_app.health = 'initializing') AS any_initializing,
bool_and(workspace_app.health IN ('healthy', 'disabled')) AS all_healthy_or_disabled
FROM workspace_apps workspace_app
WHERE workspace_app.id = task_app.workspace_app_id
) app_status
WHERE
tasks.deleted_at IS NULL;
@@ -1,84 +0,0 @@
-- Drop view from 00037_add_columns_to_tasks_with_status.up.sql.
DROP VIEW IF EXISTS tasks_with_status;
-- Add owner_name, owner_avatar_url columns.
CREATE VIEW
tasks_with_status
AS
SELECT
tasks.*,
CASE
WHEN tasks.workspace_id IS NULL OR latest_build.job_status IS NULL THEN 'pending'::task_status
WHEN latest_build.job_status = 'failed' THEN 'error'::task_status
WHEN latest_build.transition IN ('stop', 'delete')
AND latest_build.job_status = 'succeeded' THEN 'paused'::task_status
WHEN latest_build.transition = 'start'
AND latest_build.job_status = 'pending' THEN 'initializing'::task_status
WHEN latest_build.transition = 'start' AND latest_build.job_status IN ('running', 'succeeded') THEN
CASE
WHEN agent_status.none THEN 'initializing'::task_status
WHEN agent_status.connecting THEN 'initializing'::task_status
WHEN agent_status.connected THEN
CASE
WHEN app_status.any_unhealthy THEN 'error'::task_status
WHEN app_status.any_initializing THEN 'initializing'::task_status
WHEN app_status.all_healthy_or_disabled THEN 'active'::task_status
ELSE 'unknown'::task_status
END
ELSE 'unknown'::task_status
END
ELSE 'unknown'::task_status
END AS status,
task_app.*,
task_owner.*
FROM
tasks
CROSS JOIN LATERAL (
SELECT
vu.username AS owner_username,
vu.name AS owner_name,
vu.avatar_url AS owner_avatar_url
FROM visible_users vu
WHERE vu.id = tasks.owner_id
) task_owner
LEFT JOIN LATERAL (
SELECT workspace_build_number, workspace_agent_id, workspace_app_id
FROM task_workspace_apps task_app
WHERE task_id = tasks.id
ORDER BY workspace_build_number DESC
LIMIT 1
) task_app ON TRUE
LEFT JOIN LATERAL (
SELECT
workspace_build.transition,
provisioner_job.job_status,
workspace_build.job_id
FROM workspace_builds workspace_build
JOIN provisioner_jobs provisioner_job ON provisioner_job.id = workspace_build.job_id
WHERE workspace_build.workspace_id = tasks.workspace_id
AND workspace_build.build_number = task_app.workspace_build_number
) latest_build ON TRUE
CROSS JOIN LATERAL (
SELECT
COUNT(*) = 0 AS none,
bool_or(workspace_agent.lifecycle_state IN ('created', 'starting')) AS connecting,
bool_and(workspace_agent.lifecycle_state = 'ready') AS connected
FROM workspace_agents workspace_agent
WHERE workspace_agent.id = task_app.workspace_agent_id
) agent_status
CROSS JOIN LATERAL (
SELECT
bool_or(workspace_app.health = 'unhealthy') AS any_unhealthy,
bool_or(workspace_app.health = 'initializing') AS any_initializing,
bool_and(workspace_app.health IN ('healthy', 'disabled')) AS all_healthy_or_disabled
FROM workspace_apps workspace_app
WHERE workspace_app.id = task_app.workspace_app_id
) app_status
WHERE
tasks.deleted_at IS NULL;
@@ -1,8 +0,0 @@
UPDATE notification_templates
SET enabled_by_default = true
WHERE id IN (
'8c5a4d12-9f7e-4b3a-a1c8-6e4f2d9b5a7c',
'3b7e8f1a-4c2d-49a6-b5e9-7f3a1c8d6b4e',
'bd4b7168-d05e-4e19-ad0f-3593b77aa90f',
'd4a6271c-cced-4ed0-84ad-afd02a9c7799'
);
@@ -1,8 +0,0 @@
UPDATE notification_templates
SET enabled_by_default = false
WHERE id IN (
'8c5a4d12-9f7e-4b3a-a1c8-6e4f2d9b5a7c',
'3b7e8f1a-4c2d-49a6-b5e9-7f3a1c8d6b4e',
'bd4b7168-d05e-4e19-ad0f-3593b77aa90f',
'd4a6271c-cced-4ed0-84ad-afd02a9c7799'
);
@@ -1,39 +0,0 @@
DROP VIEW workspaces_expanded;
-- Recreate the view from 000354_workspace_acl.up.sql
CREATE VIEW workspaces_expanded AS
SELECT workspaces.id,
workspaces.created_at,
workspaces.updated_at,
workspaces.owner_id,
workspaces.organization_id,
workspaces.template_id,
workspaces.deleted,
workspaces.name,
workspaces.autostart_schedule,
workspaces.ttl,
workspaces.last_used_at,
workspaces.dormant_at,
workspaces.deleting_at,
workspaces.automatic_updates,
workspaces.favorite,
workspaces.next_start_at,
workspaces.group_acl,
workspaces.user_acl,
visible_users.avatar_url AS owner_avatar_url,
visible_users.username AS owner_username,
visible_users.name AS owner_name,
organizations.name AS organization_name,
organizations.display_name AS organization_display_name,
organizations.icon AS organization_icon,
organizations.description AS organization_description,
templates.name AS template_name,
templates.display_name AS template_display_name,
templates.icon AS template_icon,
templates.description AS template_description
FROM (((workspaces
JOIN visible_users ON ((workspaces.owner_id = visible_users.id)))
JOIN organizations ON ((workspaces.organization_id = organizations.id)))
JOIN templates ON ((workspaces.template_id = templates.id)));
COMMENT ON VIEW workspaces_expanded IS 'Joins in the display name information such as username, avatar, and organization name.';
@@ -1,42 +0,0 @@
DROP VIEW workspaces_expanded;
-- Add nullable task_id to workspaces_expanded view
CREATE VIEW workspaces_expanded AS
SELECT workspaces.id,
workspaces.created_at,
workspaces.updated_at,
workspaces.owner_id,
workspaces.organization_id,
workspaces.template_id,
workspaces.deleted,
workspaces.name,
workspaces.autostart_schedule,
workspaces.ttl,
workspaces.last_used_at,
workspaces.dormant_at,
workspaces.deleting_at,
workspaces.automatic_updates,
workspaces.favorite,
workspaces.next_start_at,
workspaces.group_acl,
workspaces.user_acl,
visible_users.avatar_url AS owner_avatar_url,
visible_users.username AS owner_username,
visible_users.name AS owner_name,
organizations.name AS organization_name,
organizations.display_name AS organization_display_name,
organizations.icon AS organization_icon,
organizations.description AS organization_description,
templates.name AS template_name,
templates.display_name AS template_display_name,
templates.icon AS template_icon,
templates.description AS template_description,
tasks.id AS task_id
FROM ((((workspaces
JOIN visible_users ON ((workspaces.owner_id = visible_users.id)))
JOIN organizations ON ((workspaces.organization_id = organizations.id)))
JOIN templates ON ((workspaces.template_id = templates.id)))
LEFT JOIN tasks ON ((workspaces.id = tasks.workspace_id)));
COMMENT ON VIEW workspaces_expanded IS 'Joins in the display name information such as username, avatar, and organization name.';
@@ -1,8 +0,0 @@
INSERT INTO telemetry_locks (
event_type,
period_ending_at
)
VALUES (
'aibridge_interceptions_summary',
'2025-01-01 00:00:00+00'::timestamptz
);
-2
View File
@@ -208,7 +208,6 @@ func (s APIKeyScopes) expandRBACScope() (rbac.Scope, error) {
for orgID, perms := range expanded.ByOrgID {
orgPerms := merged.ByOrgID[orgID]
orgPerms.Org = append(orgPerms.Org, perms.Org...)
orgPerms.Member = append(orgPerms.Member, perms.Member...)
merged.ByOrgID[orgID] = orgPerms
}
merged.User = append(merged.User, expanded.User...)
@@ -221,7 +220,6 @@ func (s APIKeyScopes) expandRBACScope() (rbac.Scope, error) {
merged.User = rbac.DeduplicatePermissions(merged.User)
for orgID, perms := range merged.ByOrgID {
perms.Org = rbac.DeduplicatePermissions(perms.Org)
perms.Member = rbac.DeduplicatePermissions(perms.Member)
merged.ByOrgID[orgID] = perms
}
+1 -1
View File
@@ -321,7 +321,6 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
&i.TemplateDisplayName,
&i.TemplateIcon,
&i.TemplateDescription,
&i.TaskID,
&i.TemplateVersionID,
&i.TemplateVersionName,
&i.LatestBuildCompletedAt,
@@ -329,6 +328,7 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
&i.LatestBuildError,
&i.LatestBuildTransition,
&i.LatestBuildStatus,
&i.LatestBuildHasAITask,
&i.LatestBuildHasExternalAgent,
&i.Count,
); err != nil {
+1 -13
View File
@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// sqlc v1.27.0
package database
@@ -4221,9 +4221,6 @@ type Task struct {
WorkspaceBuildNumber sql.NullInt32 `db:"workspace_build_number" json:"workspace_build_number"`
WorkspaceAgentID uuid.NullUUID `db:"workspace_agent_id" json:"workspace_agent_id"`
WorkspaceAppID uuid.NullUUID `db:"workspace_app_id" json:"workspace_app_id"`
OwnerUsername string `db:"owner_username" json:"owner_username"`
OwnerName string `db:"owner_name" json:"owner_name"`
OwnerAvatarUrl string `db:"owner_avatar_url" json:"owner_avatar_url"`
}
type TaskTable struct {
@@ -4253,14 +4250,6 @@ type TelemetryItem struct {
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
// Telemetry lock tracking table for deduplication of heartbeat events across replicas.
type TelemetryLock struct {
// The type of event that was sent.
EventType string `db:"event_type" json:"event_type"`
// The heartbeat period end timestamp.
PeriodEndingAt time.Time `db:"period_ending_at" json:"period_ending_at"`
}
// Joins in the display name information such as username, avatar, and organization name.
type Template struct {
ID uuid.UUID `db:"id" json:"id"`
@@ -4663,7 +4652,6 @@ type Workspace struct {
TemplateDisplayName string `db:"template_display_name" json:"template_display_name"`
TemplateIcon string `db:"template_icon" json:"template_icon"`
TemplateDescription string `db:"template_description" json:"template_description"`
TaskID uuid.NullUUID `db:"task_id" json:"task_id"`
}
type WorkspaceAgent struct {
+2 -20
View File
@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// sqlc v1.27.0
package database
@@ -60,9 +60,6 @@ type sqlcQuerier interface {
BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error
BulkMarkNotificationMessagesFailed(ctx context.Context, arg BulkMarkNotificationMessagesFailedParams) (int64, error)
BulkMarkNotificationMessagesSent(ctx context.Context, arg BulkMarkNotificationMessagesSentParams) (int64, error)
// Calculates the telemetry summary for a given provider, model, and client
// combination for telemetry reporting.
CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg CalculateAIBridgeInterceptionsTelemetrySummaryParams) (CalculateAIBridgeInterceptionsTelemetrySummaryRow, error)
ClaimPrebuiltWorkspace(ctx context.Context, arg ClaimPrebuiltWorkspaceParams) (ClaimPrebuiltWorkspaceRow, error)
CleanTailnetCoordinators(ctx context.Context) error
CleanTailnetLostPeers(ctx context.Context) error
@@ -110,8 +107,6 @@ type sqlcQuerier interface {
// A provisioner daemon with "zeroed" last_seen_at column indicates possible
// connectivity issues (no provisioner daemon activity since registration).
DeleteOldProvisionerDaemons(ctx context.Context) error
// Deletes old telemetry locks from the telemetry_locks table.
DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error
// If an agent hasn't connected in the last 7 days, we purge it's logs.
// Exception: if the logs are related to the latest build, we keep those around.
// Logs can take up a lot of space, so it's important we clean up frequently.
@@ -269,9 +264,6 @@ type sqlcQuerier interface {
GetOrganizationResourceCountByID(ctx context.Context, organizationID uuid.UUID) (GetOrganizationResourceCountByIDRow, error)
GetOrganizations(ctx context.Context, arg GetOrganizationsParams) ([]Organization, error)
GetOrganizationsByUserID(ctx context.Context, arg GetOrganizationsByUserIDParams) ([]Organization, error)
// GetOrganizationsWithPrebuildStatus returns organizations with prebuilds configured and their
// membership status for the prebuilds system user (org membership, group existence, group membership).
GetOrganizationsWithPrebuildStatus(ctx context.Context, arg GetOrganizationsWithPrebuildStatusParams) ([]GetOrganizationsWithPrebuildStatusRow, error)
GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error)
GetPrebuildMetrics(ctx context.Context) ([]GetPrebuildMetricsRow, error)
GetPrebuildsSettings(ctx context.Context) (string, error)
@@ -567,12 +559,6 @@ type sqlcQuerier interface {
InsertReplica(ctx context.Context, arg InsertReplicaParams) (Replica, error)
InsertTask(ctx context.Context, arg InsertTaskParams) (TaskTable, error)
InsertTelemetryItemIfNotExists(ctx context.Context, arg InsertTelemetryItemIfNotExistsParams) error
// Inserts a new lock row into the telemetry_locks table. Replicas should call
// this function prior to attempting to generate or publish a heartbeat event to
// the telemetry service.
// If the query returns a duplicate primary key error, the replica should not
// attempt to generate or publish the event to the telemetry service.
InsertTelemetryLock(ctx context.Context, arg InsertTelemetryLockParams) error
InsertTemplate(ctx context.Context, arg InsertTemplateParams) error
InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) error
InsertTemplateVersionParameter(ctx context.Context, arg InsertTemplateVersionParameterParams) (TemplateVersionParameter, error)
@@ -609,9 +595,6 @@ type sqlcQuerier interface {
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error)
ListAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams) ([]ListAIBridgeInterceptionsRow, error)
// Finds all unique AIBridge interception telemetry summaries combinations
// (provider, model, client) in the given timeframe for telemetry reporting.
ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error)
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
@@ -649,7 +632,6 @@ type sqlcQuerier interface {
// This will always work regardless of the current state of the template version.
UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error
UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error
UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error)
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error)
UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error)
@@ -670,7 +652,7 @@ type sqlcQuerier interface {
// Cancels all pending provisioner jobs for prebuilt workspaces on a specific preset from an
// inactive template version.
// This is an optimization to clean up stale pending jobs.
UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg UpdatePrebuildProvisionerJobWithCancelParams) ([]UpdatePrebuildProvisionerJobWithCancelRow, error)
UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error)
UpdatePresetPrebuildStatus(ctx context.Context, arg UpdatePresetPrebuildStatusParams) error
UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg UpdateProvisionerDaemonLastSeenAtParams) error
UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error
-68
View File
@@ -7248,9 +7248,7 @@ func TestTaskNameUniqueness(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort)
taskID := uuid.New()
task, err := db.InsertTask(ctx, database.InsertTaskParams{
ID: taskID,
OrganizationID: org.ID,
OwnerID: tt.ownerID,
Name: tt.taskName,
@@ -7265,7 +7263,6 @@ func TestTaskNameUniqueness(t *testing.T) {
require.NoError(t, err)
require.NotEqual(t, uuid.Nil, task.ID)
require.NotEqual(t, task1.ID, task.ID)
require.Equal(t, taskID, task.ID)
}
})
}
@@ -7727,68 +7724,3 @@ func TestUpdateTaskWorkspaceID(t *testing.T) {
})
}
}
func TestUpdateAIBridgeInterceptionEnded(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
t.Run("NonExistingInterception", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
got, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
ID: uuid.New(),
EndedAt: time.Now(),
})
require.ErrorContains(t, err, "no rows in result set")
require.EqualValues(t, database.AIBridgeInterception{}, got)
})
t.Run("OK", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
user := dbgen.User(t, db, database.User{})
interceptions := []database.AIBridgeInterception{}
for _, uid := range []uuid.UUID{{1}, {2}, {3}} {
insertParams := database.InsertAIBridgeInterceptionParams{
ID: uid,
InitiatorID: user.ID,
Metadata: json.RawMessage("{}"),
}
intc, err := db.InsertAIBridgeInterception(ctx, insertParams)
require.NoError(t, err)
require.Equal(t, uid, intc.ID)
require.False(t, intc.EndedAt.Valid)
interceptions = append(interceptions, intc)
}
intc0 := interceptions[0]
endedAt := time.Now()
// Mark first interception as done
updated, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
ID: intc0.ID,
EndedAt: endedAt,
})
require.NoError(t, err)
require.EqualValues(t, updated.ID, intc0.ID)
require.True(t, updated.EndedAt.Valid)
require.WithinDuration(t, endedAt, updated.EndedAt.Time, 5*time.Second)
// Updating first interception again should fail
updated, err = db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
ID: intc0.ID,
EndedAt: endedAt.Add(time.Hour),
})
require.ErrorIs(t, err, sql.ErrNoRows)
// Other interceptions should not have ended_at set
for _, intc := range interceptions[1:] {
got, err := db.GetAIBridgeInterceptionByID(ctx, intc.ID)
require.NoError(t, err)
require.False(t, got.EndedAt.Valid)
}
})
}
+48 -431
View File
@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// sqlc v1.27.0
package database
@@ -111,164 +111,6 @@ func (q *sqlQuerier) ActivityBumpWorkspace(ctx context.Context, arg ActivityBump
return err
}
const calculateAIBridgeInterceptionsTelemetrySummary = `-- name: CalculateAIBridgeInterceptionsTelemetrySummary :one
WITH interceptions_in_range AS (
-- Get all matching interceptions in the given timeframe.
SELECT
id,
initiator_id,
(ended_at - started_at) AS duration
FROM
aibridge_interceptions
WHERE
provider = $1::text
AND model = $2::text
-- TODO: use the client value once we have it (see https://github.com/coder/aibridge/issues/31)
AND 'unknown' = $3::text
AND ended_at IS NOT NULL -- incomplete interceptions are not included in summaries
AND ended_at >= $4::timestamptz
AND ended_at < $5::timestamptz
),
interception_counts AS (
SELECT
COUNT(id) AS interception_count,
COUNT(DISTINCT initiator_id) AS unique_initiator_count
FROM
interceptions_in_range
),
duration_percentiles AS (
SELECT
(COALESCE(PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p50_millis,
(COALESCE(PERCENTILE_CONT(0.90) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p90_millis,
(COALESCE(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p95_millis,
(COALESCE(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p99_millis
FROM
interceptions_in_range
),
token_aggregates AS (
SELECT
COALESCE(SUM(tu.input_tokens), 0) AS token_count_input,
COALESCE(SUM(tu.output_tokens), 0) AS token_count_output,
-- Cached tokens are stored in metadata JSON, extract if available.
-- Read tokens may be stored in:
-- - cache_read_input (Anthropic)
-- - prompt_cached (OpenAI)
COALESCE(SUM(
COALESCE((tu.metadata->>'cache_read_input')::bigint, 0) +
COALESCE((tu.metadata->>'prompt_cached')::bigint, 0)
), 0) AS token_count_cached_read,
-- Written tokens may be stored in:
-- - cache_creation_input (Anthropic)
-- Note that cache_ephemeral_5m_input and cache_ephemeral_1h_input on
-- Anthropic are included in the cache_creation_input field.
COALESCE(SUM(
COALESCE((tu.metadata->>'cache_creation_input')::bigint, 0)
), 0) AS token_count_cached_written,
COUNT(tu.id) AS token_usages_count
FROM
interceptions_in_range i
LEFT JOIN
aibridge_token_usages tu ON i.id = tu.interception_id
),
prompt_aggregates AS (
SELECT
COUNT(up.id) AS user_prompts_count
FROM
interceptions_in_range i
LEFT JOIN
aibridge_user_prompts up ON i.id = up.interception_id
),
tool_aggregates AS (
SELECT
COUNT(tu.id) FILTER (WHERE tu.injected = true) AS tool_calls_count_injected,
COUNT(tu.id) FILTER (WHERE tu.injected = false) AS tool_calls_count_non_injected,
COUNT(tu.id) FILTER (WHERE tu.injected = true AND tu.invocation_error IS NOT NULL) AS injected_tool_call_error_count
FROM
interceptions_in_range i
LEFT JOIN
aibridge_tool_usages tu ON i.id = tu.interception_id
)
SELECT
ic.interception_count::bigint AS interception_count,
dp.interception_duration_p50_millis::bigint AS interception_duration_p50_millis,
dp.interception_duration_p90_millis::bigint AS interception_duration_p90_millis,
dp.interception_duration_p95_millis::bigint AS interception_duration_p95_millis,
dp.interception_duration_p99_millis::bigint AS interception_duration_p99_millis,
ic.unique_initiator_count::bigint AS unique_initiator_count,
pa.user_prompts_count::bigint AS user_prompts_count,
tok_agg.token_usages_count::bigint AS token_usages_count,
tok_agg.token_count_input::bigint AS token_count_input,
tok_agg.token_count_output::bigint AS token_count_output,
tok_agg.token_count_cached_read::bigint AS token_count_cached_read,
tok_agg.token_count_cached_written::bigint AS token_count_cached_written,
tool_agg.tool_calls_count_injected::bigint AS tool_calls_count_injected,
tool_agg.tool_calls_count_non_injected::bigint AS tool_calls_count_non_injected,
tool_agg.injected_tool_call_error_count::bigint AS injected_tool_call_error_count
FROM
interception_counts ic,
duration_percentiles dp,
token_aggregates tok_agg,
prompt_aggregates pa,
tool_aggregates tool_agg
`
type CalculateAIBridgeInterceptionsTelemetrySummaryParams struct {
Provider string `db:"provider" json:"provider"`
Model string `db:"model" json:"model"`
Client string `db:"client" json:"client"`
EndedAtAfter time.Time `db:"ended_at_after" json:"ended_at_after"`
EndedAtBefore time.Time `db:"ended_at_before" json:"ended_at_before"`
}
type CalculateAIBridgeInterceptionsTelemetrySummaryRow struct {
InterceptionCount int64 `db:"interception_count" json:"interception_count"`
InterceptionDurationP50Millis int64 `db:"interception_duration_p50_millis" json:"interception_duration_p50_millis"`
InterceptionDurationP90Millis int64 `db:"interception_duration_p90_millis" json:"interception_duration_p90_millis"`
InterceptionDurationP95Millis int64 `db:"interception_duration_p95_millis" json:"interception_duration_p95_millis"`
InterceptionDurationP99Millis int64 `db:"interception_duration_p99_millis" json:"interception_duration_p99_millis"`
UniqueInitiatorCount int64 `db:"unique_initiator_count" json:"unique_initiator_count"`
UserPromptsCount int64 `db:"user_prompts_count" json:"user_prompts_count"`
TokenUsagesCount int64 `db:"token_usages_count" json:"token_usages_count"`
TokenCountInput int64 `db:"token_count_input" json:"token_count_input"`
TokenCountOutput int64 `db:"token_count_output" json:"token_count_output"`
TokenCountCachedRead int64 `db:"token_count_cached_read" json:"token_count_cached_read"`
TokenCountCachedWritten int64 `db:"token_count_cached_written" json:"token_count_cached_written"`
ToolCallsCountInjected int64 `db:"tool_calls_count_injected" json:"tool_calls_count_injected"`
ToolCallsCountNonInjected int64 `db:"tool_calls_count_non_injected" json:"tool_calls_count_non_injected"`
InjectedToolCallErrorCount int64 `db:"injected_tool_call_error_count" json:"injected_tool_call_error_count"`
}
// Calculates the telemetry summary for a given provider, model, and client
// combination for telemetry reporting.
func (q *sqlQuerier) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg CalculateAIBridgeInterceptionsTelemetrySummaryParams) (CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
row := q.db.QueryRowContext(ctx, calculateAIBridgeInterceptionsTelemetrySummary,
arg.Provider,
arg.Model,
arg.Client,
arg.EndedAtAfter,
arg.EndedAtBefore,
)
var i CalculateAIBridgeInterceptionsTelemetrySummaryRow
err := row.Scan(
&i.InterceptionCount,
&i.InterceptionDurationP50Millis,
&i.InterceptionDurationP90Millis,
&i.InterceptionDurationP95Millis,
&i.InterceptionDurationP99Millis,
&i.UniqueInitiatorCount,
&i.UserPromptsCount,
&i.TokenUsagesCount,
&i.TokenCountInput,
&i.TokenCountOutput,
&i.TokenCountCachedRead,
&i.TokenCountCachedWritten,
&i.ToolCallsCountInjected,
&i.ToolCallsCountNonInjected,
&i.InjectedToolCallErrorCount,
)
return i, err
}
const countAIBridgeInterceptions = `-- name: CountAIBridgeInterceptions :one
SELECT
COUNT(*)
@@ -805,57 +647,6 @@ func (q *sqlQuerier) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBr
return items, nil
}
const listAIBridgeInterceptionsTelemetrySummaries = `-- name: ListAIBridgeInterceptionsTelemetrySummaries :many
SELECT
DISTINCT ON (provider, model, client)
provider,
model,
-- TODO: use the client value once we have it (see https://github.com/coder/aibridge/issues/31)
'unknown' AS client
FROM
aibridge_interceptions
WHERE
ended_at IS NOT NULL -- incomplete interceptions are not included in summaries
AND ended_at >= $1::timestamptz
AND ended_at < $2::timestamptz
`
type ListAIBridgeInterceptionsTelemetrySummariesParams struct {
EndedAtAfter time.Time `db:"ended_at_after" json:"ended_at_after"`
EndedAtBefore time.Time `db:"ended_at_before" json:"ended_at_before"`
}
type ListAIBridgeInterceptionsTelemetrySummariesRow struct {
Provider string `db:"provider" json:"provider"`
Model string `db:"model" json:"model"`
Client string `db:"client" json:"client"`
}
// Finds all unique AIBridge interception telemetry summaries combinations
// (provider, model, client) in the given timeframe for telemetry reporting.
func (q *sqlQuerier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
rows, err := q.db.QueryContext(ctx, listAIBridgeInterceptionsTelemetrySummaries, arg.EndedAtAfter, arg.EndedAtBefore)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ListAIBridgeInterceptionsTelemetrySummariesRow
for rows.Next() {
var i ListAIBridgeInterceptionsTelemetrySummariesRow
if err := rows.Scan(&i.Provider, &i.Model, &i.Client); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listAIBridgeTokenUsagesByInterceptionIDs = `-- name: ListAIBridgeTokenUsagesByInterceptionIDs :many
SELECT
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
@@ -987,35 +778,6 @@ func (q *sqlQuerier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Contex
return items, nil
}
const updateAIBridgeInterceptionEnded = `-- name: UpdateAIBridgeInterceptionEnded :one
UPDATE aibridge_interceptions
SET ended_at = $1::timestamptz
WHERE
id = $2::uuid
AND ended_at IS NULL
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at
`
type UpdateAIBridgeInterceptionEndedParams struct {
EndedAt time.Time `db:"ended_at" json:"ended_at"`
ID uuid.UUID `db:"id" json:"id"`
}
func (q *sqlQuerier) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error) {
row := q.db.QueryRowContext(ctx, updateAIBridgeInterceptionEnded, arg.EndedAt, arg.ID)
var i AIBridgeInterception
err := row.Scan(
&i.ID,
&i.InitiatorID,
&i.Provider,
&i.Model,
&i.StartedAt,
&i.Metadata,
&i.EndedAt,
)
return i, err
}
const deleteAPIKeyByID = `-- name: DeleteAPIKeyByID :exec
DELETE FROM
api_keys
@@ -8285,93 +8047,6 @@ func (q *sqlQuerier) FindMatchingPresetID(ctx context.Context, arg FindMatchingP
return template_version_preset_id, err
}
const getOrganizationsWithPrebuildStatus = `-- name: GetOrganizationsWithPrebuildStatus :many
WITH orgs_with_prebuilds AS (
-- Get unique organizations that have presets with prebuilds configured
SELECT DISTINCT o.id, o.name
FROM organizations o
INNER JOIN templates t ON t.organization_id = o.id
INNER JOIN template_versions tv ON tv.template_id = t.id
INNER JOIN template_version_presets tvp ON tvp.template_version_id = tv.id
WHERE tvp.desired_instances IS NOT NULL
),
prebuild_user_membership AS (
-- Check if the user is a member of the organizations
SELECT om.organization_id
FROM organization_members om
INNER JOIN orgs_with_prebuilds owp ON owp.id = om.organization_id
WHERE om.user_id = $1::uuid
),
prebuild_groups AS (
-- Check if the organizations have the prebuilds group
SELECT g.organization_id, g.id as group_id
FROM groups g
INNER JOIN orgs_with_prebuilds owp ON owp.id = g.organization_id
WHERE g.name = $2::text
),
prebuild_group_membership AS (
-- Check if the user is in the prebuilds group
SELECT pg.organization_id
FROM prebuild_groups pg
INNER JOIN group_members gm ON gm.group_id = pg.group_id
WHERE gm.user_id = $1::uuid
)
SELECT
owp.id AS organization_id,
owp.name AS organization_name,
(pum.organization_id IS NOT NULL)::boolean AS has_prebuild_user,
pg.group_id AS prebuilds_group_id,
(pgm.organization_id IS NOT NULL)::boolean AS has_prebuild_user_in_group
FROM orgs_with_prebuilds owp
LEFT JOIN prebuild_groups pg ON pg.organization_id = owp.id
LEFT JOIN prebuild_user_membership pum ON pum.organization_id = owp.id
LEFT JOIN prebuild_group_membership pgm ON pgm.organization_id = owp.id
`
type GetOrganizationsWithPrebuildStatusParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
GroupName string `db:"group_name" json:"group_name"`
}
type GetOrganizationsWithPrebuildStatusRow struct {
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
OrganizationName string `db:"organization_name" json:"organization_name"`
HasPrebuildUser bool `db:"has_prebuild_user" json:"has_prebuild_user"`
PrebuildsGroupID uuid.NullUUID `db:"prebuilds_group_id" json:"prebuilds_group_id"`
HasPrebuildUserInGroup bool `db:"has_prebuild_user_in_group" json:"has_prebuild_user_in_group"`
}
// GetOrganizationsWithPrebuildStatus returns organizations with prebuilds configured and their
// membership status for the prebuilds system user (org membership, group existence, group membership).
func (q *sqlQuerier) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg GetOrganizationsWithPrebuildStatusParams) ([]GetOrganizationsWithPrebuildStatusRow, error) {
rows, err := q.db.QueryContext(ctx, getOrganizationsWithPrebuildStatus, arg.UserID, arg.GroupName)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetOrganizationsWithPrebuildStatusRow
for rows.Next() {
var i GetOrganizationsWithPrebuildStatusRow
if err := rows.Scan(
&i.OrganizationID,
&i.OrganizationName,
&i.HasPrebuildUser,
&i.PrebuildsGroupID,
&i.HasPrebuildUserInGroup,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getPrebuildMetrics = `-- name: GetPrebuildMetrics :many
SELECT
t.name as template_name,
@@ -8774,8 +8449,12 @@ func (q *sqlQuerier) GetTemplatePresetsWithPrebuilds(ctx context.Context, templa
}
const updatePrebuildProvisionerJobWithCancel = `-- name: UpdatePrebuildProvisionerJobWithCancel :many
WITH jobs_to_cancel AS (
SELECT pj.id, w.id AS workspace_id, w.template_id, wpb.template_version_preset_id
UPDATE provisioner_jobs
SET
canceled_at = $1::timestamptz,
completed_at = $1::timestamptz
WHERE id IN (
SELECT pj.id
FROM provisioner_jobs pj
INNER JOIN workspace_prebuild_builds wpb ON wpb.job_id = pj.id
INNER JOIN workspaces w ON w.id = wpb.workspace_id
@@ -8794,13 +8473,7 @@ WITH jobs_to_cancel AS (
AND pj.canceled_at IS NULL
AND pj.completed_at IS NULL
)
UPDATE provisioner_jobs
SET
canceled_at = $1::timestamptz,
completed_at = $1::timestamptz
FROM jobs_to_cancel
WHERE provisioner_jobs.id = jobs_to_cancel.id
RETURNING jobs_to_cancel.id, jobs_to_cancel.workspace_id, jobs_to_cancel.template_id, jobs_to_cancel.template_version_preset_id
RETURNING id
`
type UpdatePrebuildProvisionerJobWithCancelParams struct {
@@ -8808,34 +8481,22 @@ type UpdatePrebuildProvisionerJobWithCancelParams struct {
PresetID uuid.NullUUID `db:"preset_id" json:"preset_id"`
}
type UpdatePrebuildProvisionerJobWithCancelRow struct {
ID uuid.UUID `db:"id" json:"id"`
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
TemplateID uuid.UUID `db:"template_id" json:"template_id"`
TemplateVersionPresetID uuid.NullUUID `db:"template_version_preset_id" json:"template_version_preset_id"`
}
// Cancels all pending provisioner jobs for prebuilt workspaces on a specific preset from an
// inactive template version.
// This is an optimization to clean up stale pending jobs.
func (q *sqlQuerier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg UpdatePrebuildProvisionerJobWithCancelParams) ([]UpdatePrebuildProvisionerJobWithCancelRow, error) {
func (q *sqlQuerier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
rows, err := q.db.QueryContext(ctx, updatePrebuildProvisionerJobWithCancel, arg.Now, arg.PresetID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []UpdatePrebuildProvisionerJobWithCancelRow
var items []uuid.UUID
for rows.Next() {
var i UpdatePrebuildProvisionerJobWithCancelRow
if err := rows.Scan(
&i.ID,
&i.WorkspaceID,
&i.TemplateID,
&i.TemplateVersionPresetID,
); err != nil {
var id uuid.UUID
if err := rows.Scan(&id); err != nil {
return nil, err
}
items = append(items, i)
items = append(items, id)
}
if err := rows.Close(); err != nil {
return nil, err
@@ -13065,7 +12726,7 @@ func (q *sqlQuerier) DeleteTask(ctx context.Context, arg DeleteTaskParams) (Task
}
const getTaskByID = `-- name: GetTaskByID :one
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE id = $1::uuid
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id FROM tasks_with_status WHERE id = $1::uuid
`
func (q *sqlQuerier) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error) {
@@ -13086,15 +12747,12 @@ func (q *sqlQuerier) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error
&i.WorkspaceBuildNumber,
&i.WorkspaceAgentID,
&i.WorkspaceAppID,
&i.OwnerUsername,
&i.OwnerName,
&i.OwnerAvatarUrl,
)
return i, err
}
const getTaskByWorkspaceID = `-- name: GetTaskByWorkspaceID :one
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE workspace_id = $1::uuid
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id FROM tasks_with_status WHERE workspace_id = $1::uuid
`
func (q *sqlQuerier) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (Task, error) {
@@ -13115,9 +12773,6 @@ func (q *sqlQuerier) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.
&i.WorkspaceBuildNumber,
&i.WorkspaceAgentID,
&i.WorkspaceAppID,
&i.OwnerUsername,
&i.OwnerName,
&i.OwnerAvatarUrl,
)
return i, err
}
@@ -13126,12 +12781,11 @@ const insertTask = `-- name: InsertTask :one
INSERT INTO tasks
(id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9)
(gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at
`
type InsertTaskParams struct {
ID uuid.UUID `db:"id" json:"id"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
Name string `db:"name" json:"name"`
@@ -13144,7 +12798,6 @@ type InsertTaskParams struct {
func (q *sqlQuerier) InsertTask(ctx context.Context, arg InsertTaskParams) (TaskTable, error) {
row := q.db.QueryRowContext(ctx, insertTask,
arg.ID,
arg.OrganizationID,
arg.OwnerID,
arg.Name,
@@ -13171,7 +12824,7 @@ func (q *sqlQuerier) InsertTask(ctx context.Context, arg InsertTaskParams) (Task
}
const listTasks = `-- name: ListTasks :many
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id, owner_username, owner_name, owner_avatar_url FROM tasks_with_status tws
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id FROM tasks_with_status tws
WHERE tws.deleted_at IS NULL
AND CASE WHEN $1::UUID != '00000000-0000-0000-0000-000000000000' THEN tws.owner_id = $1::UUID ELSE TRUE END
AND CASE WHEN $2::UUID != '00000000-0000-0000-0000-000000000000' THEN tws.organization_id = $2::UUID ELSE TRUE END
@@ -13209,9 +12862,6 @@ func (q *sqlQuerier) ListTasks(ctx context.Context, arg ListTasksParams) ([]Task
&i.WorkspaceBuildNumber,
&i.WorkspaceAgentID,
&i.WorkspaceAppID,
&i.OwnerUsername,
&i.OwnerName,
&i.OwnerAvatarUrl,
); err != nil {
return nil, err
}
@@ -13385,41 +13035,6 @@ func (q *sqlQuerier) UpsertTelemetryItem(ctx context.Context, arg UpsertTelemetr
return err
}
const deleteOldTelemetryLocks = `-- name: DeleteOldTelemetryLocks :exec
DELETE FROM
telemetry_locks
WHERE
period_ending_at < $1::timestamptz
`
// Deletes old telemetry locks from the telemetry_locks table.
func (q *sqlQuerier) DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error {
_, err := q.db.ExecContext(ctx, deleteOldTelemetryLocks, periodEndingAtBefore)
return err
}
const insertTelemetryLock = `-- name: InsertTelemetryLock :exec
INSERT INTO
telemetry_locks (event_type, period_ending_at)
VALUES
($1, $2)
`
type InsertTelemetryLockParams struct {
EventType string `db:"event_type" json:"event_type"`
PeriodEndingAt time.Time `db:"period_ending_at" json:"period_ending_at"`
}
// Inserts a new lock row into the telemetry_locks table. Replicas should call
// this function prior to attempting to generate or publish a heartbeat event to
// the telemetry service.
// If the query returns a duplicate primary key error, the replica should not
// attempt to generate or publish the event to the telemetry service.
func (q *sqlQuerier) InsertTelemetryLock(ctx context.Context, arg InsertTelemetryLockParams) error {
_, err := q.db.ExecContext(ctx, insertTelemetryLock, arg.EventType, arg.PeriodEndingAt)
return err
}
const getTemplateAverageBuildTime = `-- name: GetTemplateAverageBuildTime :one
WITH build_times AS (
SELECT
@@ -21927,7 +21542,7 @@ func (q *sqlQuerier) GetWorkspaceACLByID(ctx context.Context, id uuid.UUID) (Get
const getWorkspaceByAgentID = `-- name: GetWorkspaceByAgentID :one
SELECT
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
FROM
workspaces_expanded as workspaces
WHERE
@@ -21988,14 +21603,13 @@ func (q *sqlQuerier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUI
&i.TemplateDisplayName,
&i.TemplateIcon,
&i.TemplateDescription,
&i.TaskID,
)
return i, err
}
const getWorkspaceByID = `-- name: GetWorkspaceByID :one
SELECT
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
FROM
workspaces_expanded
WHERE
@@ -22037,14 +21651,13 @@ func (q *sqlQuerier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (Worksp
&i.TemplateDisplayName,
&i.TemplateIcon,
&i.TemplateDescription,
&i.TaskID,
)
return i, err
}
const getWorkspaceByOwnerIDAndName = `-- name: GetWorkspaceByOwnerIDAndName :one
SELECT
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
FROM
workspaces_expanded as workspaces
WHERE
@@ -22093,14 +21706,13 @@ func (q *sqlQuerier) GetWorkspaceByOwnerIDAndName(ctx context.Context, arg GetWo
&i.TemplateDisplayName,
&i.TemplateIcon,
&i.TemplateDescription,
&i.TaskID,
)
return i, err
}
const getWorkspaceByResourceID = `-- name: GetWorkspaceByResourceID :one
SELECT
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
FROM
workspaces_expanded as workspaces
WHERE
@@ -22156,14 +21768,13 @@ func (q *sqlQuerier) GetWorkspaceByResourceID(ctx context.Context, resourceID uu
&i.TemplateDisplayName,
&i.TemplateIcon,
&i.TemplateDescription,
&i.TaskID,
)
return i, err
}
const getWorkspaceByWorkspaceAppID = `-- name: GetWorkspaceByWorkspaceAppID :one
SELECT
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
FROM
workspaces_expanded as workspaces
WHERE
@@ -22231,7 +21842,6 @@ func (q *sqlQuerier) GetWorkspaceByWorkspaceAppID(ctx context.Context, workspace
&i.TemplateDisplayName,
&i.TemplateIcon,
&i.TemplateDescription,
&i.TaskID,
)
return i, err
}
@@ -22281,7 +21891,7 @@ SELECT
),
filtered_workspaces AS (
SELECT
workspaces.id, workspaces.created_at, workspaces.updated_at, workspaces.owner_id, workspaces.organization_id, workspaces.template_id, workspaces.deleted, workspaces.name, workspaces.autostart_schedule, workspaces.ttl, workspaces.last_used_at, workspaces.dormant_at, workspaces.deleting_at, workspaces.automatic_updates, workspaces.favorite, workspaces.next_start_at, workspaces.group_acl, workspaces.user_acl, workspaces.owner_avatar_url, workspaces.owner_username, workspaces.owner_name, workspaces.organization_name, workspaces.organization_display_name, workspaces.organization_icon, workspaces.organization_description, workspaces.template_name, workspaces.template_display_name, workspaces.template_icon, workspaces.template_description, workspaces.task_id,
workspaces.id, workspaces.created_at, workspaces.updated_at, workspaces.owner_id, workspaces.organization_id, workspaces.template_id, workspaces.deleted, workspaces.name, workspaces.autostart_schedule, workspaces.ttl, workspaces.last_used_at, workspaces.dormant_at, workspaces.deleting_at, workspaces.automatic_updates, workspaces.favorite, workspaces.next_start_at, workspaces.group_acl, workspaces.user_acl, workspaces.owner_avatar_url, workspaces.owner_username, workspaces.owner_name, workspaces.organization_name, workspaces.organization_display_name, workspaces.organization_icon, workspaces.organization_description, workspaces.template_name, workspaces.template_display_name, workspaces.template_icon, workspaces.template_description,
latest_build.template_version_id,
latest_build.template_version_name,
latest_build.completed_at as latest_build_completed_at,
@@ -22289,6 +21899,7 @@ SELECT
latest_build.error as latest_build_error,
latest_build.transition as latest_build_transition,
latest_build.job_status as latest_build_status,
latest_build.has_ai_task as latest_build_has_ai_task,
latest_build.has_external_agent as latest_build_has_external_agent
FROM
workspaces_expanded as workspaces
@@ -22522,19 +22133,25 @@ WHERE
(latest_build.template_version_id = template.active_version_id) = $18 :: boolean
ELSE true
END
-- Filter by has_ai_task, checks if this is a task workspace.
-- Filter by has_ai_task in latest build
AND CASE
WHEN $19::boolean IS NOT NULL
THEN $19::boolean = EXISTS (
SELECT
1
FROM
tasks
WHERE
-- Consider all tasks, deleting a task does not turn the
-- workspace into a non-task workspace.
tasks.workspace_id = workspaces.id
)
WHEN $19 :: boolean IS NOT NULL THEN
(COALESCE(latest_build.has_ai_task, false) OR (
-- If the build has no AI task, it means that the provisioner job is in progress
-- and we don't know if it has an AI task yet. In this case, we optimistically
-- assume that it has an AI task if the AI Prompt parameter is not empty. This
-- lets the AI Task frontend spawn a task and see it immediately after instead of
-- having to wait for the build to complete.
latest_build.has_ai_task IS NULL AND
latest_build.completed_at IS NULL AND
EXISTS (
SELECT 1
FROM workspace_build_parameters
WHERE workspace_build_parameters.workspace_build_id = latest_build.id
AND workspace_build_parameters.name = 'AI Prompt'
AND workspace_build_parameters.value != ''
)
)) = ($19 :: boolean)
ELSE true
END
-- Filter by has_external_agent in latest build
@@ -22565,7 +22182,7 @@ WHERE
-- @authorize_filter
), filtered_workspaces_order AS (
SELECT
fw.id, fw.created_at, fw.updated_at, fw.owner_id, fw.organization_id, fw.template_id, fw.deleted, fw.name, fw.autostart_schedule, fw.ttl, fw.last_used_at, fw.dormant_at, fw.deleting_at, fw.automatic_updates, fw.favorite, fw.next_start_at, fw.group_acl, fw.user_acl, fw.owner_avatar_url, fw.owner_username, fw.owner_name, fw.organization_name, fw.organization_display_name, fw.organization_icon, fw.organization_description, fw.template_name, fw.template_display_name, fw.template_icon, fw.template_description, fw.task_id, fw.template_version_id, fw.template_version_name, fw.latest_build_completed_at, fw.latest_build_canceled_at, fw.latest_build_error, fw.latest_build_transition, fw.latest_build_status, fw.latest_build_has_external_agent
fw.id, fw.created_at, fw.updated_at, fw.owner_id, fw.organization_id, fw.template_id, fw.deleted, fw.name, fw.autostart_schedule, fw.ttl, fw.last_used_at, fw.dormant_at, fw.deleting_at, fw.automatic_updates, fw.favorite, fw.next_start_at, fw.group_acl, fw.user_acl, fw.owner_avatar_url, fw.owner_username, fw.owner_name, fw.organization_name, fw.organization_display_name, fw.organization_icon, fw.organization_description, fw.template_name, fw.template_display_name, fw.template_icon, fw.template_description, fw.template_version_id, fw.template_version_name, fw.latest_build_completed_at, fw.latest_build_canceled_at, fw.latest_build_error, fw.latest_build_transition, fw.latest_build_status, fw.latest_build_has_ai_task, fw.latest_build_has_external_agent
FROM
filtered_workspaces fw
ORDER BY
@@ -22586,7 +22203,7 @@ WHERE
$25
), filtered_workspaces_order_with_summary AS (
SELECT
fwo.id, fwo.created_at, fwo.updated_at, fwo.owner_id, fwo.organization_id, fwo.template_id, fwo.deleted, fwo.name, fwo.autostart_schedule, fwo.ttl, fwo.last_used_at, fwo.dormant_at, fwo.deleting_at, fwo.automatic_updates, fwo.favorite, fwo.next_start_at, fwo.group_acl, fwo.user_acl, fwo.owner_avatar_url, fwo.owner_username, fwo.owner_name, fwo.organization_name, fwo.organization_display_name, fwo.organization_icon, fwo.organization_description, fwo.template_name, fwo.template_display_name, fwo.template_icon, fwo.template_description, fwo.task_id, fwo.template_version_id, fwo.template_version_name, fwo.latest_build_completed_at, fwo.latest_build_canceled_at, fwo.latest_build_error, fwo.latest_build_transition, fwo.latest_build_status, fwo.latest_build_has_external_agent
fwo.id, fwo.created_at, fwo.updated_at, fwo.owner_id, fwo.organization_id, fwo.template_id, fwo.deleted, fwo.name, fwo.autostart_schedule, fwo.ttl, fwo.last_used_at, fwo.dormant_at, fwo.deleting_at, fwo.automatic_updates, fwo.favorite, fwo.next_start_at, fwo.group_acl, fwo.user_acl, fwo.owner_avatar_url, fwo.owner_username, fwo.owner_name, fwo.organization_name, fwo.organization_display_name, fwo.organization_icon, fwo.organization_description, fwo.template_name, fwo.template_display_name, fwo.template_icon, fwo.template_description, fwo.template_version_id, fwo.template_version_name, fwo.latest_build_completed_at, fwo.latest_build_canceled_at, fwo.latest_build_error, fwo.latest_build_transition, fwo.latest_build_status, fwo.latest_build_has_ai_task, fwo.latest_build_has_external_agent
FROM
filtered_workspaces_order fwo
-- Return a technical summary row with total count of workspaces.
@@ -22622,7 +22239,6 @@ WHERE
'', -- template_display_name
'', -- template_icon
'', -- template_description
'00000000-0000-0000-0000-000000000000'::uuid, -- task_id
-- Extra columns added to ` + "`" + `filtered_workspaces` + "`" + `
'00000000-0000-0000-0000-000000000000'::uuid, -- template_version_id
'', -- template_version_name
@@ -22631,6 +22247,7 @@ WHERE
'', -- latest_build_error
'start'::workspace_transition, -- latest_build_transition
'unknown'::provisioner_job_status, -- latest_build_status
false, -- latest_build_has_ai_task
false -- latest_build_has_external_agent
WHERE
$27 :: boolean = true
@@ -22641,7 +22258,7 @@ WHERE
filtered_workspaces
)
SELECT
fwos.id, fwos.created_at, fwos.updated_at, fwos.owner_id, fwos.organization_id, fwos.template_id, fwos.deleted, fwos.name, fwos.autostart_schedule, fwos.ttl, fwos.last_used_at, fwos.dormant_at, fwos.deleting_at, fwos.automatic_updates, fwos.favorite, fwos.next_start_at, fwos.group_acl, fwos.user_acl, fwos.owner_avatar_url, fwos.owner_username, fwos.owner_name, fwos.organization_name, fwos.organization_display_name, fwos.organization_icon, fwos.organization_description, fwos.template_name, fwos.template_display_name, fwos.template_icon, fwos.template_description, fwos.task_id, fwos.template_version_id, fwos.template_version_name, fwos.latest_build_completed_at, fwos.latest_build_canceled_at, fwos.latest_build_error, fwos.latest_build_transition, fwos.latest_build_status, fwos.latest_build_has_external_agent,
fwos.id, fwos.created_at, fwos.updated_at, fwos.owner_id, fwos.organization_id, fwos.template_id, fwos.deleted, fwos.name, fwos.autostart_schedule, fwos.ttl, fwos.last_used_at, fwos.dormant_at, fwos.deleting_at, fwos.automatic_updates, fwos.favorite, fwos.next_start_at, fwos.group_acl, fwos.user_acl, fwos.owner_avatar_url, fwos.owner_username, fwos.owner_name, fwos.organization_name, fwos.organization_display_name, fwos.organization_icon, fwos.organization_description, fwos.template_name, fwos.template_display_name, fwos.template_icon, fwos.template_description, fwos.template_version_id, fwos.template_version_name, fwos.latest_build_completed_at, fwos.latest_build_canceled_at, fwos.latest_build_error, fwos.latest_build_transition, fwos.latest_build_status, fwos.latest_build_has_ai_task, fwos.latest_build_has_external_agent,
tc.count
FROM
filtered_workspaces_order_with_summary fwos
@@ -22709,7 +22326,6 @@ type GetWorkspacesRow struct {
TemplateDisplayName string `db:"template_display_name" json:"template_display_name"`
TemplateIcon string `db:"template_icon" json:"template_icon"`
TemplateDescription string `db:"template_description" json:"template_description"`
TaskID uuid.NullUUID `db:"task_id" json:"task_id"`
TemplateVersionID uuid.UUID `db:"template_version_id" json:"template_version_id"`
TemplateVersionName sql.NullString `db:"template_version_name" json:"template_version_name"`
LatestBuildCompletedAt sql.NullTime `db:"latest_build_completed_at" json:"latest_build_completed_at"`
@@ -22717,6 +22333,7 @@ type GetWorkspacesRow struct {
LatestBuildError sql.NullString `db:"latest_build_error" json:"latest_build_error"`
LatestBuildTransition WorkspaceTransition `db:"latest_build_transition" json:"latest_build_transition"`
LatestBuildStatus ProvisionerJobStatus `db:"latest_build_status" json:"latest_build_status"`
LatestBuildHasAITask sql.NullBool `db:"latest_build_has_ai_task" json:"latest_build_has_ai_task"`
LatestBuildHasExternalAgent sql.NullBool `db:"latest_build_has_external_agent" json:"latest_build_has_external_agent"`
Count int64 `db:"count" json:"count"`
}
@@ -22791,7 +22408,6 @@ func (q *sqlQuerier) GetWorkspaces(ctx context.Context, arg GetWorkspacesParams)
&i.TemplateDisplayName,
&i.TemplateIcon,
&i.TemplateDescription,
&i.TaskID,
&i.TemplateVersionID,
&i.TemplateVersionName,
&i.LatestBuildCompletedAt,
@@ -22799,6 +22415,7 @@ func (q *sqlQuerier) GetWorkspaces(ctx context.Context, arg GetWorkspacesParams)
&i.LatestBuildError,
&i.LatestBuildTransition,
&i.LatestBuildStatus,
&i.LatestBuildHasAITask,
&i.LatestBuildHasExternalAgent,
&i.Count,
); err != nil {
-127
View File
@@ -6,14 +6,6 @@ INSERT INTO aibridge_interceptions (
)
RETURNING *;
-- name: UpdateAIBridgeInterceptionEnded :one
UPDATE aibridge_interceptions
SET ended_at = @ended_at::timestamptz
WHERE
id = @id::uuid
AND ended_at IS NULL
RETURNING *;
-- name: InsertAIBridgeTokenUsage :one
INSERT INTO aibridge_token_usages (
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
@@ -207,122 +199,3 @@ WHERE
ORDER BY
created_at ASC,
id ASC;
-- name: ListAIBridgeInterceptionsTelemetrySummaries :many
-- Finds all unique AIBridge interception telemetry summaries combinations
-- (provider, model, client) in the given timeframe for telemetry reporting.
SELECT
DISTINCT ON (provider, model, client)
provider,
model,
-- TODO: use the client value once we have it (see https://github.com/coder/aibridge/issues/31)
'unknown' AS client
FROM
aibridge_interceptions
WHERE
ended_at IS NOT NULL -- incomplete interceptions are not included in summaries
AND ended_at >= @ended_at_after::timestamptz
AND ended_at < @ended_at_before::timestamptz;
-- name: CalculateAIBridgeInterceptionsTelemetrySummary :one
-- Calculates the telemetry summary for a given provider, model, and client
-- combination for telemetry reporting.
WITH interceptions_in_range AS (
-- Get all matching interceptions in the given timeframe.
SELECT
id,
initiator_id,
(ended_at - started_at) AS duration
FROM
aibridge_interceptions
WHERE
provider = @provider::text
AND model = @model::text
-- TODO: use the client value once we have it (see https://github.com/coder/aibridge/issues/31)
AND 'unknown' = @client::text
AND ended_at IS NOT NULL -- incomplete interceptions are not included in summaries
AND ended_at >= @ended_at_after::timestamptz
AND ended_at < @ended_at_before::timestamptz
),
interception_counts AS (
SELECT
COUNT(id) AS interception_count,
COUNT(DISTINCT initiator_id) AS unique_initiator_count
FROM
interceptions_in_range
),
duration_percentiles AS (
SELECT
(COALESCE(PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p50_millis,
(COALESCE(PERCENTILE_CONT(0.90) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p90_millis,
(COALESCE(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p95_millis,
(COALESCE(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p99_millis
FROM
interceptions_in_range
),
token_aggregates AS (
SELECT
COALESCE(SUM(tu.input_tokens), 0) AS token_count_input,
COALESCE(SUM(tu.output_tokens), 0) AS token_count_output,
-- Cached tokens are stored in metadata JSON, extract if available.
-- Read tokens may be stored in:
-- - cache_read_input (Anthropic)
-- - prompt_cached (OpenAI)
COALESCE(SUM(
COALESCE((tu.metadata->>'cache_read_input')::bigint, 0) +
COALESCE((tu.metadata->>'prompt_cached')::bigint, 0)
), 0) AS token_count_cached_read,
-- Written tokens may be stored in:
-- - cache_creation_input (Anthropic)
-- Note that cache_ephemeral_5m_input and cache_ephemeral_1h_input on
-- Anthropic are included in the cache_creation_input field.
COALESCE(SUM(
COALESCE((tu.metadata->>'cache_creation_input')::bigint, 0)
), 0) AS token_count_cached_written,
COUNT(tu.id) AS token_usages_count
FROM
interceptions_in_range i
LEFT JOIN
aibridge_token_usages tu ON i.id = tu.interception_id
),
prompt_aggregates AS (
SELECT
COUNT(up.id) AS user_prompts_count
FROM
interceptions_in_range i
LEFT JOIN
aibridge_user_prompts up ON i.id = up.interception_id
),
tool_aggregates AS (
SELECT
COUNT(tu.id) FILTER (WHERE tu.injected = true) AS tool_calls_count_injected,
COUNT(tu.id) FILTER (WHERE tu.injected = false) AS tool_calls_count_non_injected,
COUNT(tu.id) FILTER (WHERE tu.injected = true AND tu.invocation_error IS NOT NULL) AS injected_tool_call_error_count
FROM
interceptions_in_range i
LEFT JOIN
aibridge_tool_usages tu ON i.id = tu.interception_id
)
SELECT
ic.interception_count::bigint AS interception_count,
dp.interception_duration_p50_millis::bigint AS interception_duration_p50_millis,
dp.interception_duration_p90_millis::bigint AS interception_duration_p90_millis,
dp.interception_duration_p95_millis::bigint AS interception_duration_p95_millis,
dp.interception_duration_p99_millis::bigint AS interception_duration_p99_millis,
ic.unique_initiator_count::bigint AS unique_initiator_count,
pa.user_prompts_count::bigint AS user_prompts_count,
tok_agg.token_usages_count::bigint AS token_usages_count,
tok_agg.token_count_input::bigint AS token_count_input,
tok_agg.token_count_output::bigint AS token_count_output,
tok_agg.token_count_cached_read::bigint AS token_count_cached_read,
tok_agg.token_count_cached_written::bigint AS token_count_cached_written,
tool_agg.tool_calls_count_injected::bigint AS tool_calls_count_injected,
tool_agg.tool_calls_count_non_injected::bigint AS tool_calls_count_non_injected,
tool_agg.injected_tool_call_error_count::bigint AS injected_tool_call_error_count
FROM
interception_counts ic,
duration_percentiles dp,
token_aggregates tok_agg,
prompt_aggregates pa,
tool_aggregates tool_agg
;
+7 -53
View File
@@ -300,8 +300,12 @@ GROUP BY wpb.template_version_preset_id;
-- Cancels all pending provisioner jobs for prebuilt workspaces on a specific preset from an
-- inactive template version.
-- This is an optimization to clean up stale pending jobs.
WITH jobs_to_cancel AS (
SELECT pj.id, w.id AS workspace_id, w.template_id, wpb.template_version_preset_id
UPDATE provisioner_jobs
SET
canceled_at = @now::timestamptz,
completed_at = @now::timestamptz
WHERE id IN (
SELECT pj.id
FROM provisioner_jobs pj
INNER JOIN workspace_prebuild_builds wpb ON wpb.job_id = pj.id
INNER JOIN workspaces w ON w.id = wpb.workspace_id
@@ -320,54 +324,4 @@ WITH jobs_to_cancel AS (
AND pj.canceled_at IS NULL
AND pj.completed_at IS NULL
)
UPDATE provisioner_jobs
SET
canceled_at = @now::timestamptz,
completed_at = @now::timestamptz
FROM jobs_to_cancel
WHERE provisioner_jobs.id = jobs_to_cancel.id
RETURNING jobs_to_cancel.id, jobs_to_cancel.workspace_id, jobs_to_cancel.template_id, jobs_to_cancel.template_version_preset_id;
-- name: GetOrganizationsWithPrebuildStatus :many
-- GetOrganizationsWithPrebuildStatus returns organizations with prebuilds configured and their
-- membership status for the prebuilds system user (org membership, group existence, group membership).
WITH orgs_with_prebuilds AS (
-- Get unique organizations that have presets with prebuilds configured
SELECT DISTINCT o.id, o.name
FROM organizations o
INNER JOIN templates t ON t.organization_id = o.id
INNER JOIN template_versions tv ON tv.template_id = t.id
INNER JOIN template_version_presets tvp ON tvp.template_version_id = tv.id
WHERE tvp.desired_instances IS NOT NULL
),
prebuild_user_membership AS (
-- Check if the user is a member of the organizations
SELECT om.organization_id
FROM organization_members om
INNER JOIN orgs_with_prebuilds owp ON owp.id = om.organization_id
WHERE om.user_id = @user_id::uuid
),
prebuild_groups AS (
-- Check if the organizations have the prebuilds group
SELECT g.organization_id, g.id as group_id
FROM groups g
INNER JOIN orgs_with_prebuilds owp ON owp.id = g.organization_id
WHERE g.name = @group_name::text
),
prebuild_group_membership AS (
-- Check if the user is in the prebuilds group
SELECT pg.organization_id
FROM prebuild_groups pg
INNER JOIN group_members gm ON gm.group_id = pg.group_id
WHERE gm.user_id = @user_id::uuid
)
SELECT
owp.id AS organization_id,
owp.name AS organization_name,
(pum.organization_id IS NOT NULL)::boolean AS has_prebuild_user,
pg.group_id AS prebuilds_group_id,
(pgm.organization_id IS NOT NULL)::boolean AS has_prebuild_user_in_group
FROM orgs_with_prebuilds owp
LEFT JOIN prebuild_groups pg ON pg.organization_id = owp.id
LEFT JOIN prebuild_user_membership pum ON pum.organization_id = owp.id
LEFT JOIN prebuild_group_membership pgm ON pgm.organization_id = owp.id;
RETURNING id;
+1 -1
View File
@@ -2,7 +2,7 @@
INSERT INTO tasks
(id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9)
(gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8)
RETURNING *;
-- name: UpdateTaskWorkspaceID :one
@@ -1,17 +0,0 @@
-- name: InsertTelemetryLock :exec
-- Inserts a new lock row into the telemetry_locks table. Replicas should call
-- this function prior to attempting to generate or publish a heartbeat event to
-- the telemetry service.
-- If the query returns a duplicate primary key error, the replica should not
-- attempt to generate or publish the event to the telemetry service.
INSERT INTO
telemetry_locks (event_type, period_ending_at)
VALUES
($1, $2);
-- name: DeleteOldTelemetryLocks :exec
-- Deletes old telemetry locks from the telemetry_locks table.
DELETE FROM
telemetry_locks
WHERE
period_ending_at < @period_ending_at_before::timestamptz;
+20 -13
View File
@@ -117,6 +117,7 @@ SELECT
latest_build.error as latest_build_error,
latest_build.transition as latest_build_transition,
latest_build.job_status as latest_build_status,
latest_build.has_ai_task as latest_build_has_ai_task,
latest_build.has_external_agent as latest_build_has_external_agent
FROM
workspaces_expanded as workspaces
@@ -350,19 +351,25 @@ WHERE
(latest_build.template_version_id = template.active_version_id) = sqlc.narg('using_active') :: boolean
ELSE true
END
-- Filter by has_ai_task, checks if this is a task workspace.
-- Filter by has_ai_task in latest build
AND CASE
WHEN sqlc.narg('has_ai_task')::boolean IS NOT NULL
THEN sqlc.narg('has_ai_task')::boolean = EXISTS (
SELECT
1
FROM
tasks
WHERE
-- Consider all tasks, deleting a task does not turn the
-- workspace into a non-task workspace.
tasks.workspace_id = workspaces.id
)
WHEN sqlc.narg('has_ai_task') :: boolean IS NOT NULL THEN
(COALESCE(latest_build.has_ai_task, false) OR (
-- If the build has no AI task, it means that the provisioner job is in progress
-- and we don't know if it has an AI task yet. In this case, we optimistically
-- assume that it has an AI task if the AI Prompt parameter is not empty. This
-- lets the AI Task frontend spawn a task and see it immediately after instead of
-- having to wait for the build to complete.
latest_build.has_ai_task IS NULL AND
latest_build.completed_at IS NULL AND
EXISTS (
SELECT 1
FROM workspace_build_parameters
WHERE workspace_build_parameters.workspace_build_id = latest_build.id
AND workspace_build_parameters.name = 'AI Prompt'
AND workspace_build_parameters.value != ''
)
)) = (sqlc.narg('has_ai_task') :: boolean)
ELSE true
END
-- Filter by has_external_agent in latest build
@@ -450,7 +457,6 @@ WHERE
'', -- template_display_name
'', -- template_icon
'', -- template_description
'00000000-0000-0000-0000-000000000000'::uuid, -- task_id
-- Extra columns added to `filtered_workspaces`
'00000000-0000-0000-0000-000000000000'::uuid, -- template_version_id
'', -- template_version_name
@@ -459,6 +465,7 @@ WHERE
'', -- latest_build_error
'start'::workspace_transition, -- latest_build_transition
'unknown'::provisioner_job_status, -- latest_build_status
false, -- latest_build_has_ai_task
false -- latest_build_has_external_agent
WHERE
@with_summary :: boolean = true
-1
View File
@@ -62,7 +62,6 @@ const (
UniqueTaskWorkspaceAppsPkey UniqueConstraint = "task_workspace_apps_pkey" // ALTER TABLE ONLY task_workspace_apps ADD CONSTRAINT task_workspace_apps_pkey PRIMARY KEY (task_id, workspace_build_number);
UniqueTasksPkey UniqueConstraint = "tasks_pkey" // ALTER TABLE ONLY tasks ADD CONSTRAINT tasks_pkey PRIMARY KEY (id);
UniqueTelemetryItemsPkey UniqueConstraint = "telemetry_items_pkey" // ALTER TABLE ONLY telemetry_items ADD CONSTRAINT telemetry_items_pkey PRIMARY KEY (key);
UniqueTelemetryLocksPkey UniqueConstraint = "telemetry_locks_pkey" // ALTER TABLE ONLY telemetry_locks ADD CONSTRAINT telemetry_locks_pkey PRIMARY KEY (event_type, period_ending_at);
UniqueTemplateUsageStatsPkey UniqueConstraint = "template_usage_stats_pkey" // ALTER TABLE ONLY template_usage_stats ADD CONSTRAINT template_usage_stats_pkey PRIMARY KEY (start_time, template_id, user_id);
UniqueTemplateVersionParametersTemplateVersionIDNameKey UniqueConstraint = "template_version_parameters_template_version_id_name_key" // ALTER TABLE ONLY template_version_parameters ADD CONSTRAINT template_version_parameters_template_version_id_name_key UNIQUE (template_version_id, name);
UniqueTemplateVersionPresetParametersPkey UniqueConstraint = "template_version_preset_parameters_pkey" // ALTER TABLE ONLY template_version_preset_parameters ADD CONSTRAINT template_version_preset_parameters_pkey PRIMARY KEY (id);
+1 -6
View File
@@ -37,18 +37,13 @@ type ReconciliationOrchestrator interface {
TrackResourceReplacement(ctx context.Context, workspaceID, buildID uuid.UUID, replacements []*sdkproto.ResourceReplacement)
}
// ReconcileStats contains statistics about a reconciliation cycle.
type ReconcileStats struct {
Elapsed time.Duration
}
type Reconciler interface {
StateSnapshotter
// ReconcileAll orchestrates the reconciliation of all prebuilds across all templates.
// It takes a global snapshot of the system state and then reconciles each preset
// in parallel, creating or deleting prebuilds as needed to reach their desired states.
ReconcileAll(ctx context.Context) (ReconcileStats, error)
ReconcileAll(ctx context.Context) error
}
// StateSnapshotter defines the operations necessary to capture workspace prebuilds state.
+1 -5
View File
@@ -17,11 +17,7 @@ func (NoopReconciler) Run(context.Context) {}
func (NoopReconciler) Stop(context.Context, error) {}
func (NoopReconciler) TrackResourceReplacement(context.Context, uuid.UUID, uuid.UUID, []*sdkproto.ResourceReplacement) {
}
func (NoopReconciler) ReconcileAll(context.Context) (ReconcileStats, error) {
return ReconcileStats{}, nil
}
func (NoopReconciler) ReconcileAll(context.Context) error { return nil }
func (NoopReconciler) SnapshotState(context.Context, database.Store) (*GlobalSnapshot, error) {
return &GlobalSnapshot{}, nil
}
-104
View File
@@ -1,104 +0,0 @@
# Rego authorization policy
## Code style
It's a good idea to consult the [Rego style guide](https://docs.styra.com/opa/rego-style-guide). The "Variables and Data Types" section in particular has some helpful and non-obvious advice in it.
## Debugging
Open Policy Agent provides a CLI and a playground that can be used for evaluating, formatting, testing, and linting policies.
### CLI
Below are some helpful commands you can use for debugging.
For full evaluation, run:
```sh
opa eval --format=pretty 'data.authz.allow' -d policy.rego -i input.json
```
For partial evaluation, run:
```sh
opa eval --partial --format=pretty 'data.authz.allow' -d policy.rego \
--unknowns input.object.owner --unknowns input.object.org_owner \
--unknowns input.object.acl_user_list --unknowns input.object.acl_group_list \
-i input.json
```
### Playground
Use the [Open Policy Agent Playground](https://play.openpolicyagent.org/) while editing to getting linting, code formatting, and help debugging!
You can use the contents of input.json as a starting point for your own testing input. Paste the contents of policy.rego into the left-hand side of the playground, and the contents of input.json into the "Input" section. Click "Evaluate" and you should see something like the following in the output.
```json
{
"allow": true,
"check_scope_allow_list": true,
"org": 0,
"org_member": 0,
"org_memberships": [],
"permission_allow": true,
"role_allow": true,
"scope_allow": true,
"scope_org": 0,
"scope_org_member": 0,
"scope_site": 1,
"scope_user": 0,
"site": 1,
"user": 0
}
```
## Levels
Permissions are evaluated at four levels: site, user, org, org_member.
For each level, two checks are performed:
- Do the subject's permissions allow them to perform this action?
- Does the subject's scope allow them to perform this action?
Each of these checks gets a "vote", which must one of three values:
- -1 to deny (usually because of a negative permission)
- 0 to abstain (no matching permission)
- 1 to allow
If a level abstains, then the decision gets deferred to the next level. When
there is no "next" level to defer to it is equivalent to being denied.
### Scope
Additionally, each input has a "scope" that can be thought of as a second set of permissions, where each permission belongs to one of the four levelsexactly the same as role permissions. An action is only allowed if it is allowed by both the subject's permissions _and_ their current scope. This is to allow issuing tokens for a subject that have a subset of the full subjects permissions.
For example, you may have a scope like...
```json
{
"by_org_id": {
"<org_id>": {
"member": [{ "resource_type": "workspace", "action": "*" }]
}
}
}
```
...to limit the token to only accessing workspaces owned by the user within a specific org. This provides some assurances for an admin user, that the token can only access intended resources, rather than having full access to everything.
The final policy decision is determined by evaluating each of these checks in their proper precedence order from the `allow` rule.
## Unknown values
This policy is specifically constructed to compress to a set of queries if 'input.object.owner' and 'input.object.org_owner' are unknown. There is no specific set of rules that will guarantee that this policy has this property, however, there are some tricks. We have tests that enforce this property, so any changes that pass the tests will be okay.
Some general rules to follow:
1. Do not use unknown values in any [comprehensions](https://www.openpolicyagent.org/docs/latest/policy-language/#comprehensions) or iterations.
2. Use the unknown values as minimally as possible.
3. Avoid making code branches based on the value of the unknown field.
Unknown values are like a "set" of possible values (which is why rule 1 usually breaks things).
For example, in the org level rules, we calculate the "vote" for all orgs, rather than just the `input.object.org_owner`. This way, if the `org_owner` changes, then we don't need to recompute any votes; we already have it for the changed value. This means we don't need branching, because the end result is just a lookup table.
+35 -83
View File
@@ -58,68 +58,22 @@ This can be represented by the following truth table, where Y represents _positi
- `+site.app.*.read`: allowed to perform the `read` action against all objects of type `app` in a given Coder deployment.
- `-user.workspace.*.create`: user is not allowed to create workspaces.
## Levels
A user can be given (or deprived) a permission at several levels. Currently,
those levels are:
- Site-wide level
- Organization level
- User level
- Organization member level
The site-wide level is the most authoritative. Any permission granted or denied at the side-wide level is absolute. After checking the site-wide level, depending of if the resource is owned by an organization or not, it will check the other levels.
- If the resource is owned by an organization, the next most authoritative level is the organization level. It acts like the site-wide level, but only for resources within the corresponding organization. The user can use that permission on any resource within that organization.
- After the organization level is the member level. This level only applies to resources that are owned by both the organization _and_ the user.
- If the resource is not owned by an organization, the next level to check is the user level. This level only applies to resources owned by the user and that are not owned by any organization.
```
┌──────────┐
│ Site │
└─────┬────┘
┌──────────┴───────────┐
┌──┤ Owned by an org? ├──┐
│ └──────────────────────┘ │
┌──┴──┐ ┌──┴─┐
│ Yes │ │ No │
└──┬──┘ └──┬─┘
┌────────┴─────────┐ ┌─────┴────┐
│ Organization │ │ User │
└────────┬─────────┘ └──────────┘
┌─────┴──────┐
│ Member │
└────────────┘
```
## Roles
A _role_ is a set of permissions. When evaluating a role's permission to form an action, all the relevant permissions for the role are combined at each level. Permissions at a higher level override permissions at a lower level.
The following tables show the per-level role evaluation. Y indicates that the role provides positive permissions, N indicates the role provides negative permissions, and _indicates the role does not provide positive or negative permissions. YN_ indicates that the value in the cell does not matter for the access result. The table varies depending on if the resource belongs to an organization or not.
The following table shows the per-level role evaluation.
Y indicates that the role provides positive permissions, N indicates the role provides negative permissions, and _indicates the role does not provide positive or negative permissions. YN_ indicates that the value in the cell does not matter for the access result.
If the resource is owned by an organization, such as a template or a workspace:
| Role (example) | Site | Org | OrgMember | Result |
|--------------------------|------|------|-----------|--------|
| site-admin | Y | YN\_ | YN\_ | Y |
| negative-site-permission | N | YN\_ | YN\_ | N |
| org-admin | \_ | Y | YN\_ | Y |
| non-org-member | \_ | N | YN\_ | N |
| member-owned | \_ | \_ | Y | Y |
| not-member-owned | \_ | \_ | N | N |
| unauthenticated | \_ | \_ | \_ | N |
If the resource is not owned by an organization:
| Role (example) | Site | User | Result |
|--------------------------|------|------|--------|
| site-admin | Y | YN\_ | Y |
| negative-site-permission | N | YN\_ | N |
| user-owned | \_ | Y | Y |
| not-user-owned | \_ | N | N |
| unauthenticated | \_ | \_ | N |
| Role (example) | Site | Org | User | Result |
|-----------------|------|------|------|--------|
| site-admin | Y | YN\_ | YN\_ | Y |
| no-permission | N | YN\_ | YN\_ | N |
| org-admin | \_ | Y | YN\_ | Y |
| non-org-member | \_ | N | YN\_ | N |
| user | \_ | \_ | Y | Y |
| | \_ | \_ | N | N |
| unauthenticated | \_ | \_ | \_ | N |
## Scopes
@@ -137,17 +91,15 @@ The use case for specifying this type of permission in a role is limited, and do
Example of a scope for a workspace agent token, using an `allow_list` containing a single resource id.
```javascript
{
"scope": {
"name": "workspace_agent",
"display_name": "Workspace_Agent",
// The ID of the given workspace the agent token correlates to.
"allow_list": ["10d03e62-7703-4df5-a358-4f76577d4e2f"],
"site": [/* ... perms ... */],
"org": {/* ... perms ... */},
"user": [/* ... perms ... */]
}
}
"scope": {
"name": "workspace_agent",
"display_name": "Workspace_Agent",
// The ID of the given workspace the agent token correlates to.
"allow_list": ["10d03e62-7703-4df5-a358-4f76577d4e2f"],
"site": [/* ... perms ... */],
"org": {/* ... perms ... */},
"user": [/* ... perms ... */]
}
```
## OPA (Open Policy Agent)
@@ -172,31 +124,31 @@ To learn more about OPA and Rego, see https://www.openpolicyagent.org/docs.
There are two types of evaluation in OPA:
- **Full evaluation**: Produces a decision that can be enforced.
This is the default evaluation mode, where OPA evaluates the policy using `input` data that contains all known values and returns output data with the `allow` variable.
This is the default evaluation mode, where OPA evaluates the policy using `input` data that contains all known values and returns output data with the `allow` variable.
- **Partial evaluation**: Produces a new policy that can be evaluated later when the _unknowns_ become _known_.
This is an optimization in OPA where it evaluates as much of the policy as possible without resolving expressions that depend on _unknown_ values from the `input`.
To learn more about partial evaluation, see this [OPA blog post](https://blog.openpolicyagent.org/partial-evaluation-162750eaf422).
This is an optimization in OPA where it evaluates as much of the policy as possible without resolving expressions that depend on _unknown_ values from the `input`.
To learn more about partial evaluation, see this [OPA blog post](https://blog.openpolicyagent.org/partial-evaluation-162750eaf422).
Application of Full and Partial evaluation in `rbac` package:
- **Full Evaluation** is handled by the `RegoAuthorizer.Authorize()` method in [`authz.go`](authz.go).
This method determines whether a subject (user) can perform a specific action on an object.
It performs a full evaluation of the Rego policy, which returns the `allow` variable to decide whether access is granted (`true`) or denied (`false` or undefined).
This method determines whether a subject (user) can perform a specific action on an object.
It performs a full evaluation of the Rego policy, which returns the `allow` variable to decide whether access is granted (`true`) or denied (`false` or undefined).
- **Partial Evaluation** is handled by the `RegoAuthorizer.Prepare()` method in [`authz.go`](authz.go).
This method compiles OPAs partial evaluation queries into `SQL WHERE` clauses.
These clauses are then used to enforce authorization directly in database queries, rather than in application code.
This method compiles OPAs partial evaluation queries into `SQL WHERE` clauses.
These clauses are then used to enforce authorization directly in database queries, rather than in application code.
Authorization Patterns:
- Fetch-then-authorize: an object is first retrieved from the database, and a single authorization check is performed using full evaluation via `Authorize()`.
- Authorize-while-fetching: Partial evaluation via `Prepare()` is used to inject SQL filters directly into queries, allowing efficient authorization of many objects of the same type.
`dbauthz` methods that enforce authorization directly in the SQL query are prefixed with `Authorized`, for example, `GetAuthorizedWorkspaces`.
`dbauthz` methods that enforce authorization directly in the SQL query are prefixed with `Authorized`, for example, `GetAuthorizedWorkspaces`.
## Testing
- OPA Playground: https://play.openpolicyagent.org/
- OPA CLI (`opa eval`): useful for experimenting with different inputs and understanding how the policy behaves under various conditions.
`opa eval` returns the constraints that must be satisfied for a rule to evaluate to `true`.
`opa eval` returns the constraints that must be satisfied for a rule to evaluate to `true`.
- `opa eval` requires an `input.json` file containing the input data to run the policy against.
You can generate this file using the [gen_input.go](../../scripts/rbac-authz/gen_input.go) script.
Note: the script currently produces a fixed input. You may need to tweak it for your specific use case.
@@ -244,12 +196,12 @@ The script [`benchmark_authz.sh`](../../scripts/rbac-authz/benchmark_authz.sh) r
- To run benchmark on the current branch:
```bash
benchmark_authz.sh --single
```
```bash
benchmark_authz.sh --single
```
- To compare benchmarks between 2 branches:
```bash
benchmark_authz.sh --compare main prebuild_policy
```
```bash
benchmark_authz.sh --compare main prebuild_policy
```
-4
View File
@@ -165,10 +165,6 @@ func (role Role) regoValue() ast.Value {
ast.StringTerm("org"),
ast.NewTerm(regoSlice(p.Org)),
},
[2]*ast.Term{
ast.StringTerm("member"),
ast.NewTerm(regoSlice(p.Member)),
},
),
))
}
+50 -57
View File
@@ -287,7 +287,7 @@ func TestFilter(t *testing.T) {
func TestAuthorizeDomain(t *testing.T) {
t.Parallel()
defOrg := uuid.New()
unusedID := uuid.New()
unuseID := uuid.New()
allUsersGroup := "Everyone"
// orphanedUser has no organization
@@ -318,21 +318,21 @@ func TestAuthorizeDomain(t *testing.T) {
testAuthorize(t, "UserACLList", user, []authTestCase{
{
resource: ResourceWorkspace.WithOwner(unusedID.String()).InOrg(unusedID).WithACLUserList(map[string][]policy.Action{
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(unuseID).WithACLUserList(map[string][]policy.Action{
user.ID: ResourceWorkspace.AvailableActions(),
}),
actions: ResourceWorkspace.AvailableActions(),
allow: true,
},
{
resource: ResourceWorkspace.WithOwner(unusedID.String()).InOrg(unusedID).WithACLUserList(map[string][]policy.Action{
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(unuseID).WithACLUserList(map[string][]policy.Action{
user.ID: {policy.WildcardSymbol},
}),
actions: ResourceWorkspace.AvailableActions(),
allow: true,
},
{
resource: ResourceWorkspace.WithOwner(unusedID.String()).InOrg(unusedID).WithACLUserList(map[string][]policy.Action{
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(unuseID).WithACLUserList(map[string][]policy.Action{
user.ID: {policy.ActionRead, policy.ActionUpdate},
}),
actions: []policy.Action{policy.ActionCreate, policy.ActionDelete},
@@ -350,21 +350,21 @@ func TestAuthorizeDomain(t *testing.T) {
testAuthorize(t, "GroupACLList", user, []authTestCase{
{
resource: ResourceWorkspace.WithOwner(unusedID.String()).InOrg(defOrg).WithGroupACL(map[string][]policy.Action{
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(defOrg).WithGroupACL(map[string][]policy.Action{
allUsersGroup: ResourceWorkspace.AvailableActions(),
}),
actions: ResourceWorkspace.AvailableActions(),
allow: true,
},
{
resource: ResourceWorkspace.WithOwner(unusedID.String()).InOrg(defOrg).WithGroupACL(map[string][]policy.Action{
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(defOrg).WithGroupACL(map[string][]policy.Action{
allUsersGroup: {policy.WildcardSymbol},
}),
actions: ResourceWorkspace.AvailableActions(),
allow: true,
},
{
resource: ResourceWorkspace.WithOwner(unusedID.String()).InOrg(defOrg).WithGroupACL(map[string][]policy.Action{
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(defOrg).WithGroupACL(map[string][]policy.Action{
allUsersGroup: {policy.ActionRead, policy.ActionUpdate},
}),
actions: []policy.Action{policy.ActionCreate, policy.ActionDelete},
@@ -389,14 +389,13 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceWorkspace.AnyOrganization().WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: true},
{resource: ResourceTemplate.AnyOrganization(), actions: []policy.Action{policy.ActionCreate}, allow: false},
// No org + me
{resource: ResourceWorkspace.WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: true},
{resource: ResourceWorkspace.All(), actions: ResourceWorkspace.AvailableActions(), allow: false},
// Other org + me
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.InOrg(unusedID), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID), actions: ResourceWorkspace.AvailableActions(), allow: false},
// Other org + other user
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
@@ -404,8 +403,8 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceWorkspace.WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
// Other org + other us
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.InOrg(unusedID), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
})
@@ -436,8 +435,8 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceWorkspace.All(), actions: ResourceWorkspace.AvailableActions(), allow: false},
// Other org + me
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.InOrg(unusedID), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID), actions: ResourceWorkspace.AvailableActions(), allow: false},
// Other org + other user
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
@@ -445,8 +444,8 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceWorkspace.WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
// Other org + other use
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.InOrg(unusedID), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
})
@@ -456,7 +455,6 @@ func TestAuthorizeDomain(t *testing.T) {
Scope: must(ExpandScope(ScopeAll)),
Roles: Roles{
must(RoleByName(ScopedRoleOrgAdmin(defOrg))),
must(RoleByName(ScopedRoleOrgMember(defOrg))),
must(RoleByName(RoleMember())),
},
}
@@ -471,14 +469,13 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceWorkspace.InOrg(defOrg), actions: workspaceExceptConnect, allow: true},
{resource: ResourceWorkspace.InOrg(defOrg), actions: workspaceConnect, allow: false},
// No org + me
{resource: ResourceWorkspace.WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: true},
{resource: ResourceWorkspace.All(), actions: ResourceWorkspace.AvailableActions(), allow: false},
// Other org + me
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.InOrg(unusedID), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID), actions: ResourceWorkspace.AvailableActions(), allow: false},
// Other org + other user
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner("not-me"), actions: workspaceExceptConnect, allow: true},
@@ -486,9 +483,9 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceWorkspace.WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
// Other org + other user
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.InOrg(unusedID), actions: ResourceWorkspace.AvailableActions(), allow: false},
// Other org + other use
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID), actions: ResourceWorkspace.AvailableActions(), allow: false},
{resource: ResourceWorkspace.WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
})
@@ -515,8 +512,8 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceWorkspace.All(), actions: ResourceWorkspace.AvailableActions(), allow: true},
// Other org + me
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: true},
{resource: ResourceWorkspace.InOrg(unusedID), actions: ResourceWorkspace.AvailableActions(), allow: true},
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: true},
{resource: ResourceWorkspace.InOrg(unuseID), actions: ResourceWorkspace.AvailableActions(), allow: true},
// Other org + other user
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: true},
@@ -524,8 +521,8 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceWorkspace.WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: true},
// Other org + other use
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: true},
{resource: ResourceWorkspace.InOrg(unusedID), actions: ResourceWorkspace.AvailableActions(), allow: true},
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: true},
{resource: ResourceWorkspace.InOrg(unuseID), actions: ResourceWorkspace.AvailableActions(), allow: true},
{resource: ResourceWorkspace.WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: true},
})
@@ -549,14 +546,13 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner(user.ID), allow: true},
{resource: ResourceWorkspace.InOrg(defOrg), allow: false},
// No org + me
{resource: ResourceWorkspace.WithOwner(user.ID), allow: false},
{resource: ResourceWorkspace.WithOwner(user.ID), allow: true},
{resource: ResourceWorkspace.All(), allow: false},
// Other org + me
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner(user.ID), allow: false},
{resource: ResourceWorkspace.InOrg(unusedID), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner(user.ID), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID), allow: false},
// Other org + other user
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner("not-me"), allow: false},
@@ -564,8 +560,8 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceWorkspace.WithOwner("not-me"), allow: false},
// Other org + other use
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner("not-me"), allow: false},
{resource: ResourceWorkspace.InOrg(unusedID), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner("not-me"), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID), allow: false},
{resource: ResourceWorkspace.WithOwner("not-me"), allow: false},
}),
@@ -584,8 +580,8 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceWorkspace.All()},
// Other org + me
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner(user.ID)},
{resource: ResourceWorkspace.InOrg(unusedID)},
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner(user.ID)},
{resource: ResourceWorkspace.InOrg(unuseID)},
// Other org + other user
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner("not-me")},
@@ -593,8 +589,8 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceWorkspace.WithOwner("not-me")},
// Other org + other use
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner("not-me")},
{resource: ResourceWorkspace.InOrg(unusedID)},
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner("not-me")},
{resource: ResourceWorkspace.InOrg(unuseID)},
{resource: ResourceWorkspace.WithOwner("not-me")},
}),
@@ -613,8 +609,8 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceTemplate.All()},
// Other org + me
{resource: ResourceTemplate.InOrg(unusedID).WithOwner(user.ID)},
{resource: ResourceTemplate.InOrg(unusedID)},
{resource: ResourceTemplate.InOrg(unuseID).WithOwner(user.ID)},
{resource: ResourceTemplate.InOrg(unuseID)},
// Other org + other user
{resource: ResourceTemplate.InOrg(defOrg).WithOwner("not-me")},
@@ -622,8 +618,8 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceTemplate.WithOwner("not-me")},
// Other org + other use
{resource: ResourceTemplate.InOrg(unusedID).WithOwner("not-me")},
{resource: ResourceTemplate.InOrg(unusedID)},
{resource: ResourceTemplate.InOrg(unuseID).WithOwner("not-me")},
{resource: ResourceTemplate.InOrg(unuseID)},
{resource: ResourceTemplate.WithOwner("not-me")},
}),
@@ -651,7 +647,6 @@ func TestAuthorizeDomain(t *testing.T) {
ResourceType: "*",
Action: policy.ActionRead,
}},
Member: []Permission{},
},
},
},
@@ -673,8 +668,8 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceWorkspace.All(), allow: false},
// Other org + me
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner(user.ID), allow: false},
{resource: ResourceWorkspace.InOrg(unusedID), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner(user.ID), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID), allow: false},
// Other org + other user
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner("not-me"), allow: true},
@@ -682,8 +677,8 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceWorkspace.WithOwner("not-me"), allow: false},
// Other org + other use
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner("not-me"), allow: false},
{resource: ResourceWorkspace.InOrg(unusedID), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner("not-me"), allow: false},
{resource: ResourceWorkspace.InOrg(unuseID), allow: false},
{resource: ResourceWorkspace.WithOwner("not-me"), allow: false},
}),
@@ -704,8 +699,8 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceWorkspace.All()},
// Other org + me
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner(user.ID)},
{resource: ResourceWorkspace.InOrg(unusedID)},
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner(user.ID)},
{resource: ResourceWorkspace.InOrg(unuseID)},
// Other org + other user
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner("not-me")},
@@ -713,8 +708,8 @@ func TestAuthorizeDomain(t *testing.T) {
{resource: ResourceWorkspace.WithOwner("not-me")},
// Other org + other use
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner("not-me")},
{resource: ResourceWorkspace.InOrg(unusedID)},
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner("not-me")},
{resource: ResourceWorkspace.InOrg(unuseID)},
{resource: ResourceWorkspace.WithOwner("not-me")},
}))
@@ -742,7 +737,6 @@ func TestAuthorizeLevels(t *testing.T) {
Action: "*",
},
},
Member: []Permission{},
},
},
},
@@ -1156,7 +1150,6 @@ func TestAuthorizeScope(t *testing.T) {
Org: Permissions(map[string][]policy.Action{
ResourceWorkspace.Type: {policy.ActionRead},
}),
Member: []Permission{},
},
},
},
@@ -1323,9 +1316,9 @@ type authTestCase struct {
func testAuthorize(t *testing.T, name string, subject Subject, sets ...[]authTestCase) {
t.Helper()
authorizer := NewAuthorizer(prometheus.NewRegistry())
for i, cases := range sets {
for j, c := range cases {
caseName := fmt.Sprintf("%s/Set%d/Case%d", name, i, j)
for _, cases := range sets {
for i, c := range cases {
caseName := fmt.Sprintf("%s/%d", name, i)
t.Run(caseName, func(t *testing.T) {
t.Parallel()
for _, a := range c.actions {
+4 -15
View File
@@ -23,13 +23,8 @@
"action": "*"
}
],
"user": [],
"by_org_id": {
"bf7b72bd-a2b1-4ef2-962c-1d698e0483f6": {
"org": [],
"member": []
}
}
"org": {},
"user": []
}
],
"groups": ["b617a647-b5d0-4cbe-9e40-26f89710bf18"],
@@ -43,19 +38,13 @@
"action": "*"
}
],
"org": {},
"user": [],
"by_org_id": {
"bf7b72bd-a2b1-4ef2-962c-1d698e0483f6": {
"org": [],
"member": []
}
},
"allow_list": [
{
"type": "workspace",
"id": "*"
}
]
}]
}
}
}
+289 -323
View File
@@ -2,426 +2,392 @@ package authz
import rego.v1
# Check the POLICY.md file before editing this!
#
# https://play.openpolicyagent.org/
#
# A great playground: https://play.openpolicyagent.org/
# Helpful cli commands to debug.
# opa eval --format=pretty 'data.authz.allow' -d policy.rego -i input.json
# opa eval --partial --format=pretty 'data.authz.allow' -d policy.rego --unknowns input.object.owner --unknowns input.object.org_owner --unknowns input.object.acl_user_list --unknowns input.object.acl_group_list -i input.json
#==============================================================================#
# Site level rules #
#==============================================================================#
#
# This policy is specifically constructed to compress to a set of queries if the
# object's 'owner' and 'org_owner' fields are unknown. There is no specific set
# of rules that will guarantee that this policy has this property. However, there
# are some tricks. A unit test will enforce this property, so any edits that pass
# the unit test will be ok.
#
# Tricks: (It's hard to really explain this, fiddling is required)
# 1. Do not use unknown fields in any comprehension or iteration.
# 2. Use the unknown fields as minimally as possible.
# 3. Avoid making code branches based on the value of the unknown field.
# Unknown values are like a "set" of possible values.
# (This is why rule 1 usually breaks things)
# For example:
# In the org section, we calculate the 'allow' number for all orgs, rather
# than just the input.object.org_owner. This is because if the org_owner
# changes, then we don't need to recompute any 'allow' sets. We already have
# the 'allow' for the changed value. So the answer is in a lookup table.
# The final statement 'num := allow[input.object.org_owner]' does not have
# different code branches based on the org_owner. 'num's value does, but
# that is the whole point of partial evaluation.
# Site level permissions allow the subject to use that permission on any object.
# For example, a site-level workspace.read permission means that the subject can
# see every workspace in the deployment, regardless of organization or owner.
# bool_flip(b) returns the logical negation of a boolean value 'b'.
# You cannot do 'x := !false', but you can do 'x := bool_flip(false)'
bool_flip(b) := false if {
b
}
bool_flip(b) := true if {
not b
}
# number(set) maps a set of boolean values to one of the following numbers:
# -1: deny (if 'false' value is in the set) => set is {true, false} or {false}
# 0: no decision (if the set is empty) => set is {}
# 1: allow (if only 'true' values are in the set) => set is {true}
# Return -1 if the set contains any 'false' value (i.e., an explicit deny)
number(set) := -1 if {
false in set
}
# Return 0 if the set is empty (no matching permissions)
number(set) := 0 if {
count(set) == 0
}
# Return 1 if the set is non-empty and contains no 'false' values (i.e., only allows)
number(set) := 1 if {
not false in set
set[_]
}
# Permission evaluation is structured into three levels: site, org, and user.
# For each level, two variables are computed:
# - <level>: the decision based on the subject's full set of roles for that level
# - scope_<level>: the decision based on the subject's scoped roles for that level
#
# Each of these variables is assigned one of three values:
# -1 => negative (deny)
# 0 => abstain (no matching permission)
# 1 => positive (allow)
#
# These values are computed by calling the corresponding <level>_allow functions.
# The final decision is derived from combining these values (see 'allow' rule).
# -------------------
# Site Level Rules
# -------------------
default site := 0
site := check_site_permissions(input.subject.roles)
site := site_allow(input.subject.roles)
default scope_site := 0
scope_site := site_allow([input.subject.scope])
scope_site := check_site_permissions([input.subject.scope])
check_site_permissions(roles) := vote if {
# site_allow receives a list of roles and returns a single number:
# -1 if any matching permission denies access
# 1 if there's at least one allow and no denies
# 0 if there are no matching permissions
site_allow(roles) := num if {
# allow is a set of boolean values (sets don't contain duplicates)
allow := {is_allowed |
# Iterate over all site permissions in all roles, and check which ones match
# the action and object type.
# Iterate over all site permissions in all roles
perm := roles[_].site[_]
perm.action in [input.action, "*"]
perm.resource_type in [input.object.type, "*"]
# If a negative matching permission was found, then we vote to disallow it.
# If the permission is not negative, then we vote to allow it.
# is_allowed is either 'true' or 'false' if a matching permission exists.
is_allowed := bool_flip(perm.negate)
}
vote := to_vote(allow)
num := number(allow)
}
#==============================================================================#
# User level rules #
#==============================================================================#
# -------------------
# Org Level Rules
# -------------------
# User level rules apply to all objects owned by the subject which are not also
# owned by an org. Permissions for objects which are "jointly" owned by an org
# instead defer to the org member level rules.
# org_members is the list of organizations the actor is apart of.
# TODO: Should there be an org_members for the scope too? Without it,
# the membership is determined by the user's roles, not their scope permissions.
# So if an owner (who is not an org member) has an org scope, that org scope
# will fail to return '1'. Since we assume all non members return '-1' for org
# level permissions.
# Adding a second org_members set might affect the partial evaluation.
# This is being left until org scopes are used.
org_members := {orgID |
input.subject.roles[_].by_org_id[orgID]
}
default user := 0
# 'org' is the same as 'site' except we need to iterate over each organization
# that the actor is a member of.
default org := 0
org := org_allow(input.subject.roles, "org")
user := check_user_permissions(input.subject.roles)
default scope_org := 0
scope_org := org_allow([input.subject.scope], "org")
default scope_user := 0
# org_allow_set is a helper function that iterates over all orgs that the actor
# is a member of. For each organization it sets the numerical allow value
# for the given object + action if the object is in the organization.
# The resulting value is a map that looks something like:
# {"10d03e62-7703-4df5-a358-4f76577d4e2f": 1, "5750d635-82e0-4681-bd44-815b18669d65": 1}
# The caller can use this output[<object.org_owner>] to get the final allow value.
#
# The reason we calculate this for all orgs, and not just the input.object.org_owner
# is that sometimes the input.object.org_owner is unknown. In those cases
# we have a list of org_ids that can we use in a SQL 'WHERE' clause.
org_allow_set(roles, key) := allow_set if {
allow_set := {id: num |
id := org_members[_]
set := {is_allowed |
# Iterate over all org permissions in all roles
perm := roles[_].by_org_id[id][key][_]
perm.action in [input.action, "*"]
perm.resource_type in [input.object.type, "*"]
scope_user := check_user_permissions([input.subject.scope])
# is_allowed is either 'true' or 'false' if a matching permission exists.
is_allowed := bool_flip(perm.negate)
}
num := number(set)
}
}
check_user_permissions(roles) := vote if {
# The object must be owned by the subject.
input.subject.id = input.object.owner
org_allow(roles, key) := num if {
# If the object has "any_org" set to true, then use the other
# org_allow block.
not input.object.any_org
allow := org_allow_set(roles, key)
# If there is an org, use org_member permissions instead
# Return only the org value of the input's org.
# The reason why we do not do this up front, is that we need to make sure
# this policy compresses down to simple queries. One way to ensure this is
# to keep unknown values out of comprehensions.
# (https://www.openpolicyagent.org/docs/latest/policy-language/#comprehensions)
num := allow[input.object.org_owner]
}
# This block states if "object.any_org" is set to true, then disregard the
# organization id the object is associated with. Instead, we check if the user
# can do the action on any organization.
# This is useful for UI elements when we want to conclude, "Can the user create
# a new template in any organization?"
# It is easier than iterating over every organization the user is apart of.
org_allow(roles, key) := num if {
input.object.any_org # if this is false, this code block is not used
allow := org_allow_set(roles, key)
# allow is a map of {"<org_id>": <number>}. We only care about values
# that are 1, and ignore the rest.
num := number([
keep |
# for every value in the mapping
value := allow[_]
# only keep values > 0.
# 1 = allow, 0 = abstain, -1 = deny
# We only need 1 explicit allow to allow the action.
# deny's and abstains are intentionally ignored.
value > 0
# result set is a set of [true,false,...]
# which "number()" will convert to a number.
keep := true
])
}
# 'org_mem' is set to true if the user is an org member
# If 'any_org' is set to true, use the other block to determine org membership.
org_mem if {
not input.object.any_org
input.object.org_owner != ""
input.object.org_owner in org_members
}
org_mem if {
input.object.any_org
count(org_members) > 0
}
org_ok if {
org_mem
}
# If the object has no organization, then the user is also considered part of
# the non-existent org.
org_ok if {
input.object.org_owner == ""
not input.object.any_org
}
# -------------------
# User Level Rules
# -------------------
# 'user' is the same as 'site', except it only applies if the user owns the object and
# the user is apart of the org (if the object has an org).
default user := 0
user := user_allow(input.subject.roles)
default scope_user := 0
scope_user := user_allow([input.subject.scope])
user_allow(roles) := num if {
input.object.owner != ""
input.subject.id = input.object.owner
allow := {is_allowed |
# Iterate over all user permissions in all roles, and check which ones match
# the action and object type.
# Iterate over all user permissions in all roles
perm := roles[_].user[_]
perm.action in [input.action, "*"]
perm.resource_type in [input.object.type, "*"]
# If a negative matching permission was found, then we vote to disallow it.
# If the permission is not negative, then we vote to allow it.
# is_allowed is either 'true' or 'false' if a matching permission exists.
is_allowed := bool_flip(perm.negate)
}
vote := to_vote(allow)
num := number(allow)
}
#==============================================================================#
# Org level rules #
#==============================================================================#
# Org level permissions are similar to `site`, except we need to iterate over
# each organization that the subject is a member of, and check against the
# organization that the object belongs to.
# For example, an organization-level workspace.read permission means that the
# subject can see every workspace in the organization, regardless of owner.
# org_memberships is the set of organizations the subject is apart of.
org_memberships := {org_id |
input.subject.roles[_].by_org_id[org_id]
# Scope allow_list is a list of resource (Type, ID) tuples explicitly allowed by the scope.
# If the list contains `(*,*)`, then all resources are allowed.
scope_allow_list if {
input.subject.scope.allow_list[_] == {"type": "*", "id": "*"}
}
# TODO: Should there be a scope_org_memberships too? Without it, the membership
# is determined by the user's roles, not their scope permissions.
#
# If an owner (who is not an org member) has an org scope, that org scope will
# fail to return '1', since we assume all non-members return '-1' for org level
# permissions. Adding a second set of org memberships might affect the partial
# evaluation. This is being left until org scopes are used.
default org := 0
org := check_org_permissions(input.subject.roles, "org")
default scope_org := 0
scope_org := check_org_permissions([input.subject.scope], "org")
# check_all_org_permissions creates a map from org ids to votes at each org
# level, for each org that the subject is a member of. It doesn't actually check
# if the object is in the same org. Instead we look up the correct vote from
# this map based on the object's org id in `check_org_permissions`.
# For example, the `org_map` will look something like this:
#
# {"<org_id_a>": 1, "<org_id_b>": 0, "<org_id_c>": -1}
#
# The caller then uses `output[input.object.org_owner]` to get the correct vote.
#
# We have to create this map, rather than just getting the vote of the object's
# org id because the org id _might_ be unknown. In order to make sure that this
# policy compresses down to simple queries we need to keep unknown values out of
# comprehensions.
check_all_org_permissions(roles, key) := {org_id: vote |
org_id := org_memberships[_]
allow := {is_allowed |
# Iterate over all site permissions in all roles, and check which ones match
# the action and object type.
perm := roles[_].by_org_id[org_id][key][_]
perm.action in [input.action, "*"]
perm.resource_type in [input.object.type, "*"]
# If a negative matching permission was found, then we vote to disallow it.
# If the permission is not negative, then we vote to allow it.
is_allowed := bool_flip(perm.negate)
}
vote := to_vote(allow)
# This is a shortcut if the allow_list contains (type, *), then allow all IDs of that type.
scope_allow_list if {
input.subject.scope.allow_list[_] == {"type": input.object.type, "id": "*"}
}
# This check handles the case where the org id is known.
check_org_permissions(roles, key) := vote if {
# Disallow setting any_org at the same time as an org id.
not input.object.any_org
# A comprehension that iterates over the allow_list and checks if the
# (object.type, object.id) is in the allowed ids.
scope_allow_list if {
# If the wildcard is listed in the allow_list, we do not care about the
# object.id. This line is included to prevent partial compilations from
# ever needing to include the object.id.
not {"type": "*", "id": "*"} in input.subject.scope.allow_list
# This is equivalent to the above line, as `type` is known at partial query time.
not {"type": input.object.type, "id": "*"} in input.subject.scope.allow_list
allow_map := check_all_org_permissions(roles, key)
# allows_ids is the set of all ids allowed for the given object.type
allowed_ids := {allowed_id |
# Iterate over all allow list elements
ele := input.subject.scope.allow_list[_]
ele.type in [input.object.type, "*"]
allowed_id := ele.id
}
# Return only the vote of the object's org.
vote := allow_map[input.object.org_owner]
# Return if the object.id is in the allowed ids
# This rule is evaluated at the end so the partial query can use the object.id
# against this precomputed set of allowed ids.
input.object.id in allowed_ids
}
# This check handles the case where we want to know if the user has the
# appropriate permission for any organization, without needing to know which.
# This is used in several places in the UI to determine if certain parts of the
# app should be accessible.
# For example, can the user create a new template in any organization? If yes,
# then we should show the "New template" button.
check_org_permissions(roles, key) := vote if {
# Require `any_org` to be set
input.object.any_org
# -------------------
# Role-Specific Rules
# -------------------
allow_map := check_all_org_permissions(roles, key)
# Since we're checking if the subject has the permission in _any_ org, we're
# essentially trying to find the highest vote from any org.
vote := max({vote |
some vote in allow_map
})
}
# is_org_member checks if the subject belong to the same organization as the
# object.
is_org_member if {
not input.object.any_org
input.object.org_owner != ""
input.object.org_owner in org_memberships
}
# ...if 'any_org' is set to true, we check if the subject is a member of any
# org.
is_org_member if {
input.object.any_org
count(org_memberships) > 0
}
#==============================================================================#
# Org member level rules #
#==============================================================================#
# Org member level permissions apply to all objects owned by the subject _and_
# the corresponding org. Permissions for objects which are not owned by an
# organization instead defer to the user level rules.
#
# The rules for this level are very similar to the rules for the organization
# level, and so we reuse the `check_org_permissions` function from those rules.
default org_member := 0
org_member := vote if {
# Object must be jointly owned by the user
input.object.owner != ""
input.subject.id = input.object.owner
vote := check_org_permissions(input.subject.roles, "member")
}
default scope_org_member := 0
scope_org_member := vote if {
# Object must be jointly owned by the user
input.object.owner != ""
input.subject.id = input.object.owner
vote := check_org_permissions([input.subject.scope], "member")
}
#==============================================================================#
# Role rules #
#==============================================================================#
# role_allow specifies all of the conditions under which a role can grant
# permission. These rules intentionally use the "unification" operator rather
# than the equality and inequality operators, because those operators do not
# work on partial values.
# https://www.openpolicyagent.org/docs/policy-language#unification-
# Site level authorization
role_allow if {
site = 1
}
# User level authorization
role_allow if {
not site = -1
user = 1
}
# Org level authorization
role_allow if {
not site = -1
org = 1
}
# Org member authorization
role_allow if {
not site = -1
not org = -1
org_member = 1
# If we are not a member of an org, and the object has an org, then we are
# not authorized. This is an "implied -1" for not being in the org.
org_ok
user = 1
}
#==============================================================================#
# Scope rules #
#==============================================================================#
# -------------------
# Scope-Specific Rules
# -------------------
# scope_allow specifies all of the conditions under which a scope can grant
# permission. These rules intentionally use the "unification" (=) operator
# rather than the equality (==) and inequality (!=) operators, because those
# operators do not work on partial values.
# https://www.openpolicyagent.org/docs/policy-language#unification-
# Site level scope enforcement
scope_allow if {
object_is_included_in_scope_allow_list
scope_allow_list
scope_site = 1
}
# User level scope enforcement
scope_allow if {
# User scope permissions must be allowed by the scope, and not denied
# by the site. The object *must not* be owned by an organization.
object_is_included_in_scope_allow_list
scope_allow_list
not scope_site = -1
scope_user = 1
}
# Org level scope enforcement
scope_allow if {
# Org member scope permissions must be allowed by the scope, and not denied
# by the site. The object *must* be owned by an organization.
object_is_included_in_scope_allow_list
not scope_site = -1
scope_org = 1
}
# Org member level scope enforcement
scope_allow if {
# Org member scope permissions must be allowed by the scope, and not denied
# by the site or org. The object *must* be owned by an organization.
object_is_included_in_scope_allow_list
scope_allow_list
not scope_site = -1
not scope_org = -1
scope_org_member = 1
# If we are not a member of an org, and the object has an org, then we are
# not authorized. This is an "implied -1" for not being in the org.
org_ok
scope_user = 1
}
# If *.* is allowed, then all objects are in scope.
object_is_included_in_scope_allow_list if {
{"type": "*", "id": "*"} in input.subject.scope.allow_list
}
# If <type>.* is allowed, then all objects of that type are in scope.
object_is_included_in_scope_allow_list if {
{"type": input.object.type, "id": "*"} in input.subject.scope.allow_list
}
# Check if the object type and ID match one of the allow list entries.
object_is_included_in_scope_allow_list if {
# Check that the wildcard rules do not apply. This prevents partial inputs
# from needing to include `input.object.id`.
not {"type": "*", "id": "*"} in input.subject.scope.allow_list
not {"type": input.object.type, "id": "*"} in input.subject.scope.allow_list
# Check which IDs from the allow list match the object type
allowed_ids_for_object_type := {it.id |
some it in input.subject.scope.allow_list
it.type in [input.object.type, "*"]
}
# Check if the input object ID is in the set of allowed IDs for the same
# object type. We do this at the end to keep `input.object.id` out of the
# comprehension because it might be unknown.
input.object.id in allowed_ids_for_object_type
}
#==============================================================================#
# ACL rules #
#==============================================================================#
# -------------------
# ACL-Specific Rules
# Access Control List
# -------------------
# ACL for users
acl_allow if {
# TODO: Should you have to be a member of the org too?
# Should you have to be a member of the org too?
perms := input.object.acl_user_list[input.subject.id]
# Check if either the action or * is allowed
some action in [input.action, "*"]
action in perms
# Either the input action or wildcard
[input.action, "*"][_] in perms
}
# ACL for groups
acl_allow if {
# If there is no organization owner, the object cannot be owned by an
# org-scoped group.
is_org_member
some group in input.subject.groups
# org_scoped team.
org_mem
group := input.subject.groups[_]
perms := input.object.acl_group_list[group]
# Check if either the action or * is allowed
some action in [input.action, "*"]
action in perms
# Either the input action or wildcard
[input.action, "*"][_] in perms
}
# ACL for the special "Everyone" groups
# ACL for 'all_users' special group
acl_allow if {
# If there is no organization owner, the object cannot be owned by an
# org-scoped group.
is_org_member
org_mem
perms := input.object.acl_group_list[input.object.org_owner]
# Check if either the action or * is allowed
some action in [input.action, "*"]
action in perms
[input.action, "*"][_] in perms
}
#==============================================================================#
# Allow #
#==============================================================================#
# The `allow` block is quite simple. Any check that voted no will cascade down.
# Authorization looks for any `allow` statement that is true. Multiple can be
# true! Note that the absence of `allow` means "unauthorized". An explicit
# `"allow": true` is required.
# -------------------
# Final Allow
#
# We check both the subject's permissions (given by their roles or by ACL) and
# the subject's scope. (The default scope is "*:*", allowing all actions.) Both
# a permission check (either from roles or ACL) and the scope check must vote to
# allow or the action is not authorized.
# A subject can be given permission by a role
permission_allow if role_allow
# A subject can be given permission by ACL
permission_allow if acl_allow
# The 'allow' block is quite simple. Any set with `-1` cascades down in levels.
# Authorization looks for any `allow` statement that is true. Multiple can be true!
# Note that the absence of `allow` means "unauthorized".
# An explicit `"allow": true` is required.
#
# Scope is also applied. The default scope is "wildcard:wildcard" allowing
# all actions. If the scope is not "1", then the action is not authorized.
#
# Allow query:
# data.authz.role_allow = true
# data.authz.scope_allow = true
# -------------------
# The role or the ACL must allow the action. Scopes can be used to limit,
# so scope_allow must always be true.
allow if {
# Must be allowed by the subject's permissions
permission_allow
# ...and allowed by the scope
role_allow
scope_allow
}
#==============================================================================#
# Utilities #
#==============================================================================#
# bool_flip returns the logical negation of a boolean value. You can't do
# 'x := not false', but you can do 'x := bool_flip(false)'
bool_flip(b) := false if {
b
}
bool_flip(b) if {
not b
}
# to_vote gives you a voting value from a set or list of booleans.
# {false,..} => deny (-1)
# {} => abstain (0)
# {true} => allow (1)
# Any set which contains a `false` should be considered a vote to deny.
to_vote(set) := -1 if {
false in set
}
# A set which is empty should be considered abstaining.
to_vote(set) := 0 if {
count(set) == 0
}
# A set which only contains true should be considered a vote to allow.
to_vote(set) := 1 if {
not false in set
true in set
# ACL list must also have the scope_allow to pass
allow if {
acl_allow
scope_allow
}
+10 -28
View File
@@ -295,11 +295,15 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
ResourceOauth2App.Type: {policy.ActionRead},
ResourceWorkspaceProxy.Type: {policy.ActionRead},
}),
User: append(allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUser, ResourceOrganizationMember, ResourceOrganizationMember),
User: append(allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceUser, ResourceOrganizationMember),
Permissions(map[string][]policy.Action{
// Reduced permission set on dormant workspaces. No build, ssh, or exec
ResourceWorkspaceDormant.Type: {policy.ActionRead, policy.ActionDelete, policy.ActionCreate, policy.ActionUpdate, policy.ActionWorkspaceStop, policy.ActionCreateAgent, policy.ActionDeleteAgent},
// Users cannot do create/update/delete on themselves, but they
// can read their own details.
ResourceUser.Type: {policy.ActionRead, policy.ActionReadPersonal, policy.ActionUpdatePersonal},
// Can read their own organization member record
ResourceOrganizationMember.Type: {policy.ActionRead},
// Users can create provisioner daemons scoped to themselves.
ResourceProvisionerDaemon.Type: {policy.ActionRead, policy.ActionCreate, policy.ActionRead, policy.ActionUpdate},
})...,
@@ -427,7 +431,6 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
// Note: even without PrebuiltWorkspace permissions, access is still granted via Workspace permissions.
ResourcePrebuiltWorkspace.Type: {policy.ActionUpdate, policy.ActionDelete},
})...),
Member: []Permission{},
},
},
}
@@ -451,16 +454,6 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
// Can read available roles.
ResourceAssignOrgRole.Type: {policy.ActionRead},
}),
Member: append(allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceUser, ResourceOrganizationMember),
Permissions(map[string][]policy.Action{
// Reduced permission set on dormant workspaces. No build, ssh, or exec
ResourceWorkspaceDormant.Type: {policy.ActionRead, policy.ActionDelete, policy.ActionCreate, policy.ActionUpdate, policy.ActionWorkspaceStop, policy.ActionCreateAgent, policy.ActionDeleteAgent},
// Can read their own organization member record
ResourceOrganizationMember.Type: {policy.ActionRead},
// Users can create provisioner daemons scoped to themselves.
ResourceProvisionerDaemon.Type: {policy.ActionRead, policy.ActionCreate, policy.ActionRead, policy.ActionUpdate},
})...,
),
},
},
}
@@ -483,7 +476,6 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
ResourceOrganization.Type: {policy.ActionRead},
ResourceOrganizationMember.Type: {policy.ActionRead},
}),
Member: []Permission{},
},
},
}
@@ -510,7 +502,6 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
ResourceGroupMember.Type: ResourceGroupMember.AvailableActions(),
ResourceIdpsyncSettings.Type: {policy.ActionRead, policy.ActionUpdate},
}),
Member: []Permission{},
},
},
}
@@ -540,7 +531,6 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
ResourceProvisionerDaemon.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
ResourceProvisionerJobs.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreate},
}),
Member: []Permission{},
},
},
}
@@ -578,7 +568,6 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
Action: policy.ActionDeleteAgent,
},
},
Member: []Permission{},
},
},
}
@@ -691,10 +680,9 @@ func (perm Permission) Valid() error {
}
// Role is a set of permissions at multiple levels:
// - Site permissions apply EVERYWHERE
// - Org permissions apply to EVERYTHING in a given ORG
// - User permissions apply to all resources the user owns
// - OrgMember permissions apply to resources in the given org that the user owns
// - Site level permissions apply EVERYWHERE
// - Org level permissions apply to EVERYTHING in a given ORG
// - User level permissions are the lowest
// This is the type passed into the rego as a json payload.
// Users of this package should instead **only** use the role names, and
// this package will expand the role names into their json payloads.
@@ -715,8 +703,7 @@ type Role struct {
}
type OrgPermissions struct {
Org []Permission `json:"org"`
Member []Permission `json:"member"`
Org []Permission `json:"org"`
}
// Valid will check all it's permissions and ensure they are all correct
@@ -733,12 +720,7 @@ func (role Role) Valid() error {
for orgID, orgPermissions := range role.ByOrgID {
for _, perm := range orgPermissions.Org {
if err := perm.Valid(); err != nil {
errs = append(errs, xerrors.Errorf("org=%q: org %w", orgID, err))
}
}
for _, perm := range orgPermissions.Member {
if err := perm.Valid(); err != nil {
errs = append(errs, xerrors.Errorf("org=%q: member: %w", orgID, err))
errs = append(errs, xerrors.Errorf("org=%q: %w", orgID, err))
}
}
}
+7 -9
View File
@@ -33,11 +33,10 @@ func BenchmarkRBACValueAllocation(b *testing.B) {
uuid.NewString(): {policy.ActionRead, policy.ActionCreate},
uuid.NewString(): {policy.ActionRead, policy.ActionCreate},
uuid.NewString(): {policy.ActionRead, policy.ActionCreate},
}).
WithACLUserList(map[string][]policy.Action{
uuid.NewString(): {policy.ActionRead, policy.ActionCreate},
uuid.NewString(): {policy.ActionRead, policy.ActionCreate},
})
}).WithACLUserList(map[string][]policy.Action{
uuid.NewString(): {policy.ActionRead, policy.ActionCreate},
uuid.NewString(): {policy.ActionRead, policy.ActionCreate},
})
jsonSubject := authSubject{
ID: actor.ID,
@@ -108,7 +107,7 @@ func TestRegoInputValue(t *testing.T) {
t.Parallel()
// This is the input that would be passed to the rego policy.
jsonInput := map[string]any{
jsonInput := map[string]interface{}{
"subject": authSubject{
ID: actor.ID,
Roles: must(actor.Roles.Expand()),
@@ -139,7 +138,7 @@ func TestRegoInputValue(t *testing.T) {
t.Parallel()
// This is the input that would be passed to the rego policy.
jsonInput := map[string]any{
jsonInput := map[string]interface{}{
"subject": authSubject{
ID: actor.ID,
Roles: must(actor.Roles.Expand()),
@@ -147,7 +146,7 @@ func TestRegoInputValue(t *testing.T) {
Scope: must(actor.Scope.Expand()),
},
"action": action,
"object": map[string]any{
"object": map[string]interface{}{
"type": obj.Type,
},
}
@@ -283,6 +282,5 @@ func equalRoles(t *testing.T, a, b Role) {
bv, ok := b.ByOrgID[ak]
require.True(t, ok, "org permissions missing: %s", ak)
require.ElementsMatchf(t, av.Org, bv.Org, "org %s permissions", ak)
require.ElementsMatchf(t, av.Member, bv.Member, "member %s permissions", ak)
}
}
+4 -170
View File
@@ -28,6 +28,7 @@ import (
"google.golang.org/protobuf/types/known/wrapperspb"
"cdr.dev/slog"
"github.com/coder/coder/v2/buildinfo"
clitelemetry "github.com/coder/coder/v2/cli/telemetry"
"github.com/coder/coder/v2/coderd/database"
@@ -35,7 +36,6 @@ import (
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/quartz"
)
const (
@@ -48,7 +48,6 @@ type Options struct {
Disabled bool
Database database.Store
Logger slog.Logger
Clock quartz.Clock
// URL is an endpoint to direct telemetry towards!
URL *url.URL
Experiments codersdk.Experiments
@@ -66,9 +65,6 @@ type Options struct {
// Duplicate data will be sent, it's on the server-side to index by UUID.
// Data is anonymized prior to being sent!
func New(options Options) (Reporter, error) {
if options.Clock == nil {
options.Clock = quartz.NewReal()
}
if options.SnapshotFrequency == 0 {
// Report once every 30mins by default!
options.SnapshotFrequency = 30 * time.Minute
@@ -90,7 +86,7 @@ func New(options Options) (Reporter, error) {
options: options,
deploymentURL: deploymentURL,
snapshotURL: snapshotURL,
startedAt: dbtime.Time(options.Clock.Now()).UTC(),
startedAt: dbtime.Now(),
client: &http.Client{},
}
go reporter.runSnapshotter()
@@ -170,7 +166,7 @@ func (r *remoteReporter) Close() {
return
}
close(r.closed)
now := dbtime.Time(r.options.Clock.Now()).UTC()
now := dbtime.Now()
r.shutdownAt = &now
if r.Enabled() {
// Report a final collection of telemetry prior to close!
@@ -416,7 +412,7 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
ctx = r.ctx
// For resources that grow in size very quickly (like workspace builds),
// we only report events that occurred within the past hour.
createdAfter = dbtime.Time(r.options.Clock.Now().Add(-1 * time.Hour)).UTC()
createdAfter = dbtime.Now().Add(-1 * time.Hour)
eg errgroup.Group
snapshot = &Snapshot{
DeploymentID: r.options.DeploymentID,
@@ -748,14 +744,6 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
}
return nil
})
eg.Go(func() error {
summaries, err := r.generateAIBridgeInterceptionsSummaries(ctx)
if err != nil {
return xerrors.Errorf("generate AIBridge interceptions telemetry summaries: %w", err)
}
snapshot.AIBridgeInterceptionsSummaries = summaries
return nil
})
err := eg.Wait()
if err != nil {
@@ -764,76 +752,6 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
return snapshot, nil
}
func (r *remoteReporter) generateAIBridgeInterceptionsSummaries(ctx context.Context) ([]AIBridgeInterceptionsSummary, error) {
// Get the current timeframe, which is the previous hour.
now := dbtime.Time(r.options.Clock.Now()).UTC()
endedAtBefore := now.Truncate(time.Hour)
endedAtAfter := endedAtBefore.Add(-1 * time.Hour)
// Note: we don't use a transaction for this function since we do tolerate
// some errors, like duplicate lock rows, and we also calculate
// summaries in parallel.
// Claim the heartbeat lock row for this hour.
err := r.options.Database.InsertTelemetryLock(ctx, database.InsertTelemetryLockParams{
EventType: "aibridge_interceptions_summary",
PeriodEndingAt: endedAtBefore,
})
if database.IsUniqueViolation(err, database.UniqueTelemetryLocksPkey) {
// Another replica has already claimed the lock row for this hour.
r.options.Logger.Debug(ctx, "aibridge interceptions telemetry lock already claimed for this hour by another replica, skipping", slog.F("period_ending_at", endedAtBefore))
return nil, nil
}
if err != nil {
return nil, xerrors.Errorf("insert AIBridge interceptions telemetry lock (period_ending_at=%q): %w", endedAtBefore, err)
}
// List the summary categories that need to be calculated.
summaryCategories, err := r.options.Database.ListAIBridgeInterceptionsTelemetrySummaries(ctx, database.ListAIBridgeInterceptionsTelemetrySummariesParams{
EndedAtAfter: endedAtAfter, // inclusive
EndedAtBefore: endedAtBefore, // exclusive
})
if err != nil {
return nil, xerrors.Errorf("list AIBridge interceptions telemetry summaries (startedAtAfter=%q, endedAtBefore=%q): %w", endedAtAfter, endedAtBefore, err)
}
// Calculate and convert the summaries for all categories.
var (
eg, egCtx = errgroup.WithContext(ctx)
mu sync.Mutex
summaries = make([]AIBridgeInterceptionsSummary, 0, len(summaryCategories))
)
for _, category := range summaryCategories {
eg.Go(func() error {
summary, err := r.options.Database.CalculateAIBridgeInterceptionsTelemetrySummary(egCtx, database.CalculateAIBridgeInterceptionsTelemetrySummaryParams{
Provider: category.Provider,
Model: category.Model,
Client: category.Client,
EndedAtAfter: endedAtAfter,
EndedAtBefore: endedAtBefore,
})
if err != nil {
return xerrors.Errorf("calculate AIBridge interceptions telemetry summary (provider=%q, model=%q, client=%q, startedAtAfter=%q, endedAtBefore=%q): %w", category.Provider, category.Model, category.Client, endedAtAfter, endedAtBefore, err)
}
// Double check that at least one interception was found in the
// timeframe.
if summary.InterceptionCount == 0 {
return nil
}
converted := ConvertAIBridgeInterceptionsSummary(endedAtBefore, category.Provider, category.Model, category.Client, summary)
mu.Lock()
defer mu.Unlock()
summaries = append(summaries, converted)
return nil
})
}
return summaries, eg.Wait()
}
// ConvertAPIKey anonymizes an API key.
func ConvertAPIKey(apiKey database.APIKey) APIKey {
a := APIKey{
@@ -1305,7 +1223,6 @@ type Snapshot struct {
TelemetryItems []TelemetryItem `json:"telemetry_items"`
UserTailnetConnections []UserTailnetConnection `json:"user_tailnet_connections"`
PrebuiltWorkspaces []PrebuiltWorkspace `json:"prebuilt_workspaces"`
AIBridgeInterceptionsSummaries []AIBridgeInterceptionsSummary `json:"aibridge_interceptions_summaries"`
}
// Deployment contains information about the host running Coder.
@@ -1942,89 +1859,6 @@ type PrebuiltWorkspace struct {
Count int `json:"count"`
}
type AIBridgeInterceptionsSummaryDurationMillis struct {
P50 int64 `json:"p50"`
P90 int64 `json:"p90"`
P95 int64 `json:"p95"`
P99 int64 `json:"p99"`
}
type AIBridgeInterceptionsSummaryTokenCount struct {
Input int64 `json:"input"`
Output int64 `json:"output"`
CachedRead int64 `json:"cached_read"`
CachedWritten int64 `json:"cached_written"`
}
type AIBridgeInterceptionsSummaryToolCallsCount struct {
Injected int64 `json:"injected"`
NonInjected int64 `json:"non_injected"`
}
// AIBridgeInterceptionsSummary is a summary of aggregated AI Bridge
// interception data over a period of 1 hour. We send a summary each hour for
// each unique provider + model + client combination.
type AIBridgeInterceptionsSummary struct {
ID uuid.UUID `json:"id"`
// The end of the hour for which the summary is taken. This will always be a
// UTC timestamp truncated to the hour.
Timestamp time.Time `json:"timestamp"`
Provider string `json:"provider"`
Model string `json:"model"`
Client string `json:"client"`
InterceptionCount int64 `json:"interception_count"`
InterceptionDurationMillis AIBridgeInterceptionsSummaryDurationMillis `json:"interception_duration_millis"`
// Map of route to number of interceptions.
// e.g. "/v1/chat/completions:blocking", "/v1/chat/completions:streaming"
InterceptionsByRoute map[string]int64 `json:"interceptions_by_route"`
UniqueInitiatorCount int64 `json:"unique_initiator_count"`
UserPromptsCount int64 `json:"user_prompts_count"`
TokenUsagesCount int64 `json:"token_usages_count"`
TokenCount AIBridgeInterceptionsSummaryTokenCount `json:"token_count"`
ToolCallsCount AIBridgeInterceptionsSummaryToolCallsCount `json:"tool_calls_count"`
InjectedToolCallErrorCount int64 `json:"injected_tool_call_error_count"`
}
func ConvertAIBridgeInterceptionsSummary(endTime time.Time, provider, model, client string, summary database.CalculateAIBridgeInterceptionsTelemetrySummaryRow) AIBridgeInterceptionsSummary {
return AIBridgeInterceptionsSummary{
ID: uuid.New(),
Timestamp: endTime,
Provider: provider,
Model: model,
Client: client,
InterceptionCount: summary.InterceptionCount,
InterceptionDurationMillis: AIBridgeInterceptionsSummaryDurationMillis{
P50: summary.InterceptionDurationP50Millis,
P90: summary.InterceptionDurationP90Millis,
P95: summary.InterceptionDurationP95Millis,
P99: summary.InterceptionDurationP99Millis,
},
// TODO: currently we don't track by route
InterceptionsByRoute: make(map[string]int64),
UniqueInitiatorCount: summary.UniqueInitiatorCount,
UserPromptsCount: summary.UserPromptsCount,
TokenUsagesCount: summary.TokenUsagesCount,
TokenCount: AIBridgeInterceptionsSummaryTokenCount{
Input: summary.TokenCountInput,
Output: summary.TokenCountOutput,
CachedRead: summary.TokenCountCachedRead,
CachedWritten: summary.TokenCountCachedWritten,
},
ToolCallsCount: AIBridgeInterceptionsSummaryToolCallsCount{
Injected: summary.ToolCallsCountInjected,
NonInjected: summary.ToolCallsCountNonInjected,
},
InjectedToolCallErrorCount: summary.InjectedToolCallErrorCount,
}
}
type noopReporter struct{}
func (*noopReporter) Report(_ *Snapshot) {}
+2 -127
View File
@@ -28,7 +28,6 @@ import (
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestMain(m *testing.M) {
@@ -45,7 +44,6 @@ func TestTelemetry(t *testing.T) {
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitMedium)
now := dbtime.Now()
org, err := db.GetDefaultOrganization(ctx)
require.NoError(t, err)
@@ -210,88 +208,12 @@ func TestTelemetry(t *testing.T) {
AgentID: wsagent.ID,
})
previousAIBridgeInterceptionPeriod := now.Truncate(time.Hour)
user2 := dbgen.User(t, db, database.User{})
aiBridgeInterception1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
InitiatorID: user.ID,
Provider: "anthropic",
Model: "deanseek",
StartedAt: previousAIBridgeInterceptionPeriod.Add(-30 * time.Minute),
}, nil)
_ = dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{
InterceptionID: aiBridgeInterception1.ID,
InputTokens: 100,
OutputTokens: 200,
Metadata: json.RawMessage(`{"cache_read_input":300,"cache_creation_input":400}`),
})
_ = dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{
InterceptionID: aiBridgeInterception1.ID,
})
_ = dbgen.AIBridgeToolUsage(t, db, database.InsertAIBridgeToolUsageParams{
InterceptionID: aiBridgeInterception1.ID,
Injected: true,
InvocationError: sql.NullString{String: "error1", Valid: true},
})
_, err = db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
ID: aiBridgeInterception1.ID,
EndedAt: aiBridgeInterception1.StartedAt.Add(1 * time.Minute), // 1 minute duration
})
require.NoError(t, err)
aiBridgeInterception2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
InitiatorID: user2.ID,
Provider: aiBridgeInterception1.Provider,
Model: aiBridgeInterception1.Model,
StartedAt: aiBridgeInterception1.StartedAt,
}, nil)
_ = dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{
InterceptionID: aiBridgeInterception2.ID,
InputTokens: 100,
OutputTokens: 200,
Metadata: json.RawMessage(`{"cache_read_input":300,"cache_creation_input":400}`),
})
_ = dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{
InterceptionID: aiBridgeInterception2.ID,
})
_ = dbgen.AIBridgeToolUsage(t, db, database.InsertAIBridgeToolUsageParams{
InterceptionID: aiBridgeInterception2.ID,
Injected: false,
})
_, err = db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
ID: aiBridgeInterception2.ID,
EndedAt: aiBridgeInterception2.StartedAt.Add(2 * time.Minute), // 2 minute duration
})
require.NoError(t, err)
aiBridgeInterception3 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
InitiatorID: user2.ID,
Provider: "openai",
Model: "gpt-5",
StartedAt: aiBridgeInterception1.StartedAt,
}, nil)
_, err = db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
ID: aiBridgeInterception3.ID,
EndedAt: aiBridgeInterception3.StartedAt.Add(3 * time.Minute), // 3 minute duration
})
require.NoError(t, err)
_ = dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
InitiatorID: user2.ID,
Provider: "openai",
Model: "gpt-5",
StartedAt: aiBridgeInterception1.StartedAt,
}, nil)
// not ended, so it should not affect summaries
clock := quartz.NewMock(t)
clock.Set(now)
_, snapshot := collectSnapshot(ctx, t, db, func(opts telemetry.Options) telemetry.Options {
opts.Clock = clock
return opts
})
_, snapshot := collectSnapshot(ctx, t, db, nil)
require.Len(t, snapshot.ProvisionerJobs, 2)
require.Len(t, snapshot.Licenses, 1)
require.Len(t, snapshot.Templates, 2)
require.Len(t, snapshot.TemplateVersions, 3)
require.Len(t, snapshot.Users, 2)
require.Len(t, snapshot.Users, 1)
require.Len(t, snapshot.Groups, 2)
// 1 member in the everyone group + 1 member in the custom group
require.Len(t, snapshot.GroupMembers, 2)
@@ -365,53 +287,6 @@ func TestTelemetry(t *testing.T) {
for _, entity := range snapshot.Templates {
require.Equal(t, entity.OrganizationID, org.ID)
}
// 2 unique provider + model + client combinations
require.Len(t, snapshot.AIBridgeInterceptionsSummaries, 2)
snapshot1 := snapshot.AIBridgeInterceptionsSummaries[0]
snapshot2 := snapshot.AIBridgeInterceptionsSummaries[1]
if snapshot1.Provider != aiBridgeInterception1.Provider {
snapshot1, snapshot2 = snapshot2, snapshot1
}
require.Equal(t, snapshot1.Provider, aiBridgeInterception1.Provider)
require.Equal(t, snapshot1.Model, aiBridgeInterception1.Model)
require.Equal(t, snapshot1.Client, "unknown") // no client info yet
require.EqualValues(t, snapshot1.InterceptionCount, 2)
require.EqualValues(t, snapshot1.InterceptionsByRoute, map[string]int64{}) // no route info yet
require.EqualValues(t, snapshot1.InterceptionDurationMillis.P50, 90_000)
require.EqualValues(t, snapshot1.InterceptionDurationMillis.P90, 114_000)
require.EqualValues(t, snapshot1.InterceptionDurationMillis.P95, 117_000)
require.EqualValues(t, snapshot1.InterceptionDurationMillis.P99, 119_400)
require.EqualValues(t, snapshot1.UniqueInitiatorCount, 2)
require.EqualValues(t, snapshot1.UserPromptsCount, 2)
require.EqualValues(t, snapshot1.TokenUsagesCount, 2)
require.EqualValues(t, snapshot1.TokenCount.Input, 200)
require.EqualValues(t, snapshot1.TokenCount.Output, 400)
require.EqualValues(t, snapshot1.TokenCount.CachedRead, 600)
require.EqualValues(t, snapshot1.TokenCount.CachedWritten, 800)
require.EqualValues(t, snapshot1.ToolCallsCount.Injected, 1)
require.EqualValues(t, snapshot1.ToolCallsCount.NonInjected, 1)
require.EqualValues(t, snapshot1.InjectedToolCallErrorCount, 1)
require.Equal(t, snapshot2.Provider, aiBridgeInterception3.Provider)
require.Equal(t, snapshot2.Model, aiBridgeInterception3.Model)
require.Equal(t, snapshot2.Client, "unknown") // no client info yet
require.EqualValues(t, snapshot2.InterceptionCount, 1)
require.EqualValues(t, snapshot2.InterceptionsByRoute, map[string]int64{}) // no route info yet
require.EqualValues(t, snapshot2.InterceptionDurationMillis.P50, 180_000)
require.EqualValues(t, snapshot2.InterceptionDurationMillis.P90, 180_000)
require.EqualValues(t, snapshot2.InterceptionDurationMillis.P95, 180_000)
require.EqualValues(t, snapshot2.InterceptionDurationMillis.P99, 180_000)
require.EqualValues(t, snapshot2.UniqueInitiatorCount, 1)
require.EqualValues(t, snapshot2.UserPromptsCount, 0)
require.EqualValues(t, snapshot2.TokenUsagesCount, 0)
require.EqualValues(t, snapshot2.TokenCount.Input, 0)
require.EqualValues(t, snapshot2.TokenCount.Output, 0)
require.EqualValues(t, snapshot2.TokenCount.CachedRead, 0)
require.EqualValues(t, snapshot2.TokenCount.CachedWritten, 0)
require.EqualValues(t, snapshot2.ToolCallsCount.Injected, 0)
require.EqualValues(t, snapshot2.ToolCallsCount.NonInjected, 0)
})
t.Run("HashedEmail", func(t *testing.T) {
t.Parallel()
+2 -10
View File
@@ -324,19 +324,11 @@ func (p *DBTokenProvider) authorizeRequest(ctx context.Context, roles *rbac.Subj
// rbacResourceOwned is for the level "authenticated". We still need to
// make sure the API key has permissions to connect to the actor's own
// workspace. Scopes would prevent this.
// TODO: This is an odd repercussion of the org_member permission level.
// This Object used to not specify an org restriction, and `InOrg` would
// actually have a significantly different meaning (only sharing with
// other authenticated users in the same org, whereas the existing behavior
// is to share with any authenticated user). Because workspaces are always
// jointly owned by an organization, there _must_ be an org restriction on
// the object to check the proper permissions. AnyOrg is almost the same,
// but technically excludes users who are not in any organization. This is
// the closest we can get though without more significant refactoring.
rbacResourceOwned rbac.Object = rbac.ResourceWorkspace.WithOwner(roles.ID).AnyOrganization()
rbacResourceOwned rbac.Object = rbac.ResourceWorkspace.WithOwner(roles.ID)
)
if dbReq.AccessMethod == AccessMethodTerminal {
rbacAction = policy.ActionSSH
rbacResourceOwned = rbac.ResourceWorkspace.WithOwner(roles.ID)
}
// Do a standard RBAC check. This accounts for share level "owner" and any
-1
View File
@@ -2654,7 +2654,6 @@ func convertWorkspace(
Favorite: requesterFavorite,
NextStartAt: nextStartAt,
IsPrebuild: workspace.IsPrebuild(),
TaskID: workspace.TaskID,
}, nil
}
+66 -40
View File
@@ -4700,16 +4700,11 @@ func TestWorkspaceFilterHasAITask(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
// Helper function to create workspace with optional task.
createWorkspace := func(jobCompleted, createTask bool, prompt string) uuid.UUID {
// TODO(mafredri): The bellow comment is based on deprecated logic and
// kept only present to test that the old observable behavior works as
// intended.
//
// Helper function to create workspace with AI task configuration
createWorkspaceWithAIConfig := func(hasAITask sql.NullBool, jobCompleted bool, aiTaskPrompt *string) database.WorkspaceTable {
// When a provisioner job uses these tags, no provisioner will match it.
// We do this so jobs will always be stuck in "pending", allowing us to
// exercise the intermediary state when has_ai_task is nil and we
// compensate by looking at pending provisioning jobs.
// We do this so jobs will always be stuck in "pending", allowing us to exercise the intermediary state when
// has_ai_task is nil and we compensate by looking at pending provisioning jobs.
// See GetWorkspaces clauses.
unpickableTags := database.StringMap{"custom": "true"}
@@ -4728,71 +4723,102 @@ func TestWorkspaceFilterHasAITask(t *testing.T) {
jobConfig.CompletedAt = sql.NullTime{Time: time.Now(), Valid: true}
}
job := dbgen.ProvisionerJob(t, db, pubsub, jobConfig)
res := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{JobID: job.ID})
agnt := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ResourceID: res.ID})
taskApp := dbgen.WorkspaceApp(t, db, database.WorkspaceApp{AgentID: agnt.ID})
var sidebarAppID uuid.UUID
if hasAITask.Bool {
sidebarApp := dbgen.WorkspaceApp(t, db, database.WorkspaceApp{AgentID: agnt.ID})
sidebarAppID = sidebarApp.ID
}
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: ws.ID,
TemplateVersionID: version.ID,
InitiatorID: user.UserID,
JobID: job.ID,
BuildNumber: 1,
AITaskSidebarAppID: uuid.NullUUID{UUID: taskApp.ID, Valid: createTask},
HasAITask: hasAITask,
AITaskSidebarAppID: uuid.NullUUID{UUID: sidebarAppID, Valid: sidebarAppID != uuid.Nil},
})
if createTask {
task := dbgen.Task(t, db, database.TaskTable{
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
TemplateVersionID: version.ID,
Prompt: prompt,
})
dbgen.TaskWorkspaceApp(t, db, database.TaskWorkspaceApp{
TaskID: task.ID,
WorkspaceBuildNumber: build.BuildNumber,
WorkspaceAgentID: uuid.NullUUID{UUID: agnt.ID, Valid: true},
WorkspaceAppID: uuid.NullUUID{UUID: taskApp.ID, Valid: true},
if aiTaskPrompt != nil {
err := db.InsertWorkspaceBuildParameters(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceBuildParametersParams{
WorkspaceBuildID: build.ID,
Name: []string{provider.TaskPromptParameterName},
Value: []string{*aiTaskPrompt},
})
require.NoError(t, err)
}
return ws.ID
return ws
}
// Create workspaces with tasks.
wsWithTask1 := createWorkspace(true, true, "Build me a web app")
wsWithTask2 := createWorkspace(false, true, "Another task")
// Create test workspaces with different AI task configurations
wsWithAITask := createWorkspaceWithAIConfig(sql.NullBool{Bool: true, Valid: true}, true, nil)
wsWithoutAITask := createWorkspaceWithAIConfig(sql.NullBool{Bool: false, Valid: true}, false, nil)
// Create workspaces without tasks
wsWithoutTask1 := createWorkspace(true, false, "")
wsWithoutTask2 := createWorkspace(false, false, "")
aiTaskPrompt := "Build me a web app"
wsWithAITaskParam := createWorkspaceWithAIConfig(sql.NullBool{Valid: false}, false, &aiTaskPrompt)
anotherTaskPrompt := "Another task"
wsCompletedWithAITaskParam := createWorkspaceWithAIConfig(sql.NullBool{Valid: false}, true, &anotherTaskPrompt)
emptyPrompt := ""
wsWithEmptyAITaskParam := createWorkspaceWithAIConfig(sql.NullBool{Valid: false}, false, &emptyPrompt)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// Debug: Check all workspaces without filter first
allRes, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{})
require.NoError(t, err)
t.Logf("Total workspaces created: %d", len(allRes.Workspaces))
for i, ws := range allRes.Workspaces {
t.Logf("All Workspace %d: ID=%s, Name=%s, Build ID=%s, Job ID=%s", i, ws.ID, ws.Name, ws.LatestBuild.ID, ws.LatestBuild.Job.ID)
}
// Test filtering for workspaces with AI tasks
// Should include: wsWithTask1 and wsWithTask2
// Should include: wsWithAITask (has_ai_task=true) and wsWithAITaskParam (null + incomplete + param)
res, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{
FilterQuery: "has-ai-task:true",
})
require.NoError(t, err)
t.Logf("Expected 2 workspaces for has-ai-task:true, got %d", len(res.Workspaces))
t.Logf("Expected workspaces: %s, %s", wsWithAITask.ID, wsWithAITaskParam.ID)
for i, ws := range res.Workspaces {
t.Logf("AI Task True Workspace %d: ID=%s, Name=%s", i, ws.ID, ws.Name)
}
require.Len(t, res.Workspaces, 2)
workspaceIDs := []uuid.UUID{res.Workspaces[0].ID, res.Workspaces[1].ID}
require.Contains(t, workspaceIDs, wsWithTask1)
require.Contains(t, workspaceIDs, wsWithTask2)
require.Contains(t, workspaceIDs, wsWithAITask.ID)
require.Contains(t, workspaceIDs, wsWithAITaskParam.ID)
// Test filtering for workspaces without AI tasks
// Should include: wsWithoutTask1, wsWithoutTask2, wsWithoutTask3
// Should include: wsWithoutAITask, wsCompletedWithAITaskParam, wsWithEmptyAITaskParam
res, err = client.Workspaces(ctx, codersdk.WorkspaceFilter{
FilterQuery: "has-ai-task:false",
})
require.NoError(t, err)
require.Len(t, res.Workspaces, 2)
workspaceIDs = []uuid.UUID{res.Workspaces[0].ID, res.Workspaces[1].ID}
require.Contains(t, workspaceIDs, wsWithoutTask1)
require.Contains(t, workspaceIDs, wsWithoutTask2)
// Debug: print what we got
t.Logf("Expected 3 workspaces for has-ai-task:false, got %d", len(res.Workspaces))
for i, ws := range res.Workspaces {
t.Logf("Workspace %d: ID=%s, Name=%s", i, ws.ID, ws.Name)
}
t.Logf("Expected IDs: %s, %s, %s", wsWithoutAITask.ID, wsCompletedWithAITaskParam.ID, wsWithEmptyAITaskParam.ID)
require.Len(t, res.Workspaces, 3)
workspaceIDs = []uuid.UUID{res.Workspaces[0].ID, res.Workspaces[1].ID, res.Workspaces[2].ID}
require.Contains(t, workspaceIDs, wsWithoutAITask.ID)
require.Contains(t, workspaceIDs, wsCompletedWithAITaskParam.ID)
require.Contains(t, workspaceIDs, wsWithEmptyAITaskParam.ID)
// Test no filter returns all
res, err = client.Workspaces(ctx, codersdk.WorkspaceFilter{})
require.NoError(t, err)
require.Len(t, res.Workspaces, 4)
require.Len(t, res.Workspaces, 5)
}
func TestWorkspaceAppUpsertRestart(t *testing.T) {
+254
View File
@@ -0,0 +1,254 @@
package agentsdk
import (
"context"
"encoding/json"
"fmt"
"net"
"os"
"path/filepath"
"time"
"golang.org/x/xerrors"
)
// SocketClient provides a client for communicating with the agent socket
type SocketClient struct {
conn net.Conn
}
// SocketConfig holds configuration for the socket client
type SocketConfig struct {
Path string // Socket path (optional, will auto-discover if not set)
}
// NewSocketClient creates a new socket client
func NewSocketClient(config SocketConfig) (*SocketClient, error) {
path := config.Path
if path == "" {
var err error
path, err = discoverSocketPath()
if err != nil {
return nil, xerrors.Errorf("discover socket path: %w", err)
}
}
conn, err := net.Dial("unix", path)
if err != nil {
return nil, xerrors.Errorf("connect to socket: %w", err)
}
return &SocketClient{
conn: conn,
}, nil
}
// Close closes the socket connection
func (c *SocketClient) Close() error {
return c.conn.Close()
}
// Ping sends a ping request to the agent
func (c *SocketClient) Ping(ctx context.Context) (*PingResponse, error) {
req := &Request{
Version: "1.0",
Method: "ping",
ID: generateRequestID(),
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, xerrors.Errorf("ping error: %s", resp.Error.Message)
}
var pingResp PingResponse
if err := json.Unmarshal(resp.Result, &pingResp); err != nil {
return nil, xerrors.Errorf("unmarshal ping response: %w", err)
}
return &pingResp, nil
}
// Health sends a health check request to the agent
func (c *SocketClient) Health(ctx context.Context) (*HealthResponse, error) {
req := &Request{
Version: "1.0",
Method: "health",
ID: generateRequestID(),
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, xerrors.Errorf("health error: %s", resp.Error.Message)
}
var healthResp HealthResponse
if err := json.Unmarshal(resp.Result, &healthResp); err != nil {
return nil, xerrors.Errorf("unmarshal health response: %w", err)
}
return &healthResp, nil
}
// AgentInfo sends an agent info request
func (c *SocketClient) AgentInfo(ctx context.Context) (*AgentInfo, error) {
req := &Request{
Version: "1.0",
Method: "agent.info",
ID: generateRequestID(),
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, xerrors.Errorf("agent info error: %s", resp.Error.Message)
}
var agentInfo AgentInfo
if err := json.Unmarshal(resp.Result, &agentInfo); err != nil {
return nil, xerrors.Errorf("unmarshal agent info response: %w", err)
}
return &agentInfo, nil
}
// ListMethods lists available methods
func (c *SocketClient) ListMethods(ctx context.Context) ([]string, error) {
req := &Request{
Version: "1.0",
Method: "methods.list",
ID: generateRequestID(),
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, xerrors.Errorf("list methods error: %s", resp.Error.Message)
}
var methods []string
if err := json.Unmarshal(resp.Result, &methods); err != nil {
return nil, xerrors.Errorf("unmarshal methods response: %w", err)
}
return methods, nil
}
// sendRequest sends a request and returns the response
func (c *SocketClient) sendRequest(_ context.Context, req *Request) (*Response, error) {
// Set write deadline
if err := c.conn.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil {
return nil, xerrors.Errorf("set write deadline: %w", err)
}
// Send request
if err := json.NewEncoder(c.conn).Encode(req); err != nil {
return nil, xerrors.Errorf("send request: %w", err)
}
// Set read deadline
if err := c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil {
return nil, xerrors.Errorf("set read deadline: %w", err)
}
// Read response
var resp Response
if err := json.NewDecoder(c.conn).Decode(&resp); err != nil {
return nil, xerrors.Errorf("read response: %w", err)
}
return &resp, nil
}
// discoverSocketPath discovers the agent socket path
func discoverSocketPath() (string, error) {
// Check environment variable first
if path := os.Getenv("CODER_AGENT_SOCKET_PATH"); path != "" {
return path, nil
}
// Try common socket paths
paths := []string{
// XDG runtime directory
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "coder-agent.sock"),
// User-specific temp directory
filepath.Join(os.TempDir(), fmt.Sprintf("coder-agent-%d.sock", os.Getuid())),
// Fallback temp directory
filepath.Join(os.TempDir(), "coder-agent.sock"),
}
for _, path := range paths {
if path == "" {
continue
}
if _, err := os.Stat(path); err == nil {
return path, nil
}
}
return "", xerrors.New("agent socket not found")
}
// generateRequestID generates a unique request ID
func generateRequestID() string {
return fmt.Sprintf("%d", time.Now().UnixNano())
}
// Request represents a socket request
type Request struct {
Version string `json:"version"`
Method string `json:"method"`
ID string `json:"id,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
}
// Response represents a socket response
type Response struct {
Version string `json:"version"`
ID string `json:"id,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *Error `json:"error,omitempty"`
}
// Error represents a socket error
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data,omitempty"`
}
// PingResponse represents a ping response
type PingResponse struct {
Message string `json:"message"`
Timestamp time.Time `json:"timestamp"`
}
// HealthResponse represents a health check response
type HealthResponse struct {
Status string `json:"status"`
Timestamp time.Time `json:"timestamp"`
Uptime string `json:"uptime"`
}
// AgentInfo represents agent information
type AgentInfo struct {
ID string `json:"id"`
Version string `json:"version"`
Status string `json:"status"`
StartedAt time.Time `json:"started_at"`
Uptime string `json:"uptime"`
}

Some files were not shown because too many files have changed in this diff Show More