Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 34c1370090 | |||
| 851c4f907c | |||
| e3dfe45f35 |
@@ -5,13 +5,6 @@ runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Setup sqlc
|
||||
# uses: sqlc-dev/setup-sqlc@c0209b9199cd1cce6a14fc27cabcec491b651761 # v4.0.0
|
||||
# with:
|
||||
# sqlc-version: "1.30.0"
|
||||
|
||||
# Switched to coder/sqlc fork to fix ambiguous column bug, see:
|
||||
# - https://github.com/coder/sqlc/pull/1
|
||||
# - https://github.com/sqlc-dev/sqlc/pull/4159
|
||||
shell: bash
|
||||
run: |
|
||||
CGO_ENABLED=1 go install github.com/coder/sqlc/cmd/sqlc@aab4e865a51df0c43e1839f81a9d349b41d14f05
|
||||
uses: sqlc-dev/setup-sqlc@c0209b9199cd1cce6a14fc27cabcec491b651761 # v4.0.0
|
||||
with:
|
||||
sqlc-version: "1.27.0"
|
||||
|
||||
@@ -7,5 +7,5 @@ runs:
|
||||
- name: Install Terraform
|
||||
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # v3.1.2
|
||||
with:
|
||||
terraform_version: 1.13.4
|
||||
terraform_version: 1.13.0
|
||||
terraform_wrapper: false
|
||||
|
||||
+26
-20
@@ -204,17 +204,9 @@ jobs:
|
||||
|
||||
# Needed for helm chart linting
|
||||
- name: Install helm
|
||||
# uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # v4.3.1
|
||||
# with:
|
||||
# version: v3.9.2
|
||||
# The below is taken from https://helm.sh/docs/intro/install/#from-apt-debianubuntu
|
||||
run: |
|
||||
set -euo pipefail
|
||||
sudo apt-get install curl gpg apt-transport-https --yes
|
||||
curl -fsSL https://packages.buildkite.com/helm-linux/helm-debian/gpgkey | gpg --dearmor | sudo tee /usr/share/keyrings/helm.gpg > /dev/null
|
||||
echo "deb [signed-by=/usr/share/keyrings/helm.gpg] https://packages.buildkite.com/helm-linux/helm-debian/any/ any main" | sudo tee /etc/apt/sources.list.d/helm-stable-debian.list
|
||||
sudo apt-get update
|
||||
sudo apt-get install helm
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # v4.3.1
|
||||
with:
|
||||
version: v3.9.2
|
||||
|
||||
- name: make lint
|
||||
run: |
|
||||
@@ -384,6 +376,13 @@ jobs:
|
||||
id: go-paths
|
||||
uses: ./.github/actions/setup-go-paths
|
||||
|
||||
- name: Download Go Build Cache
|
||||
id: download-go-build-cache
|
||||
uses: ./.github/actions/test-cache/download
|
||||
with:
|
||||
key-prefix: test-go-build-${{ runner.os }}-${{ runner.arch }}
|
||||
cache-path: ${{ steps.go-paths.outputs.cached-dirs }}
|
||||
|
||||
- name: Setup Go
|
||||
uses: ./.github/actions/setup-go
|
||||
with:
|
||||
@@ -391,7 +390,8 @@ jobs:
|
||||
# download the toolchain configured in go.mod, so we don't
|
||||
# need to reinstall it. It's faster on Windows runners.
|
||||
use-preinstalled-go: ${{ runner.os == 'Windows' }}
|
||||
use-cache: true
|
||||
# Cache is already downloaded above
|
||||
use-cache: false
|
||||
|
||||
- name: Setup Terraform
|
||||
uses: ./.github/actions/setup-tf
|
||||
@@ -500,11 +500,17 @@ jobs:
|
||||
make test
|
||||
|
||||
- name: Upload failed test db dumps
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: failed-test-db-dump-${{matrix.os}}
|
||||
path: "**/*.test.sql"
|
||||
|
||||
- name: Upload Go Build Cache
|
||||
uses: ./.github/actions/test-cache/upload
|
||||
with:
|
||||
cache-key: ${{ steps.download-go-build-cache.outputs.cache-key }}
|
||||
cache-path: ${{ steps.go-paths.outputs.cached-dirs }}
|
||||
|
||||
- name: Upload Test Cache
|
||||
uses: ./.github/actions/test-cache/upload
|
||||
with:
|
||||
@@ -756,7 +762,7 @@ jobs:
|
||||
|
||||
- name: Upload Playwright Failed Tests
|
||||
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: failed-test-videos${{ matrix.variant.premium && '-premium' || '' }}
|
||||
path: ./site/test-results/**/*.webm
|
||||
@@ -764,7 +770,7 @@ jobs:
|
||||
|
||||
- name: Upload pprof dumps
|
||||
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: debug-pprof-dumps${{ matrix.variant.premium && '-premium' || '' }}
|
||||
path: ./site/test-results/**/debug-pprof-*.txt
|
||||
@@ -800,7 +806,7 @@ jobs:
|
||||
# the check to pass. This is desired in PRs, but not in mainline.
|
||||
- name: Publish to Chromatic (non-mainline)
|
||||
if: github.ref != 'refs/heads/main' && github.repository_owner == 'coder'
|
||||
uses: chromaui/action@bc2d84ad2b60813a67d995c5582d696104a19383 # v13.3.2
|
||||
uses: chromaui/action@4ffe736a2a8262ea28067ff05a13b635ba31ec05 # v13.3.0
|
||||
env:
|
||||
NODE_OPTIONS: "--max_old_space_size=4096"
|
||||
STORYBOOK: true
|
||||
@@ -832,7 +838,7 @@ jobs:
|
||||
# infinitely "in progress" in mainline unless we re-review each build.
|
||||
- name: Publish to Chromatic (mainline)
|
||||
if: github.ref == 'refs/heads/main' && github.repository_owner == 'coder'
|
||||
uses: chromaui/action@bc2d84ad2b60813a67d995c5582d696104a19383 # v13.3.2
|
||||
uses: chromaui/action@4ffe736a2a8262ea28067ff05a13b635ba31ec05 # v13.3.0
|
||||
env:
|
||||
NODE_OPTIONS: "--max_old_space_size=4096"
|
||||
STORYBOOK: true
|
||||
@@ -1030,7 +1036,7 @@ jobs:
|
||||
|
||||
- name: Upload build artifacts
|
||||
if: ${{ github.repository_owner == 'coder' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: dylibs
|
||||
path: |
|
||||
@@ -1195,7 +1201,7 @@ jobs:
|
||||
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
|
||||
|
||||
- name: Download dylibs
|
||||
uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0
|
||||
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
|
||||
with:
|
||||
name: dylibs
|
||||
path: ./build
|
||||
@@ -1462,7 +1468,7 @@ jobs:
|
||||
|
||||
- name: Upload build artifacts
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: coder
|
||||
path: |
|
||||
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
- name: Setup Node
|
||||
uses: ./.github/actions/setup-node
|
||||
|
||||
- uses: tj-actions/changed-files@dbf178ceecb9304128c8e0648591d71208c6e2c9 # v45.0.7
|
||||
- uses: tj-actions/changed-files@d03a93c0dbfac6d6dd6a0d8a5e7daff992b07449 # v45.0.7
|
||||
id: changed-files
|
||||
with:
|
||||
files: |
|
||||
|
||||
@@ -36,11 +36,11 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Nix
|
||||
uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34
|
||||
uses: nixbuild/nix-quick-install-action@1f095fee853b33114486cfdeae62fa099cda35a9 # v33
|
||||
with:
|
||||
# Pinning to 2.28 here, as Nix gets a "error: [json.exception.type_error.302] type must be array, but is string"
|
||||
# on version 2.29 and above.
|
||||
nix_version: "2.28.5"
|
||||
nix_version: "2.28.4"
|
||||
|
||||
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
|
||||
with:
|
||||
|
||||
@@ -131,7 +131,7 @@ jobs:
|
||||
AC_CERTIFICATE_PASSWORD_FILE: /tmp/apple_cert_password.txt
|
||||
|
||||
- name: Upload build artifacts
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: dylibs
|
||||
path: |
|
||||
@@ -327,7 +327,7 @@ jobs:
|
||||
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
|
||||
|
||||
- name: Download dylibs
|
||||
uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0
|
||||
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
|
||||
with:
|
||||
name: dylibs
|
||||
path: ./build
|
||||
@@ -761,7 +761,7 @@ jobs:
|
||||
|
||||
- name: Upload artifacts to actions (if dry-run)
|
||||
if: ${{ inputs.dry_run }}
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: release-artifacts
|
||||
path: |
|
||||
@@ -777,7 +777,7 @@ jobs:
|
||||
|
||||
- name: Upload latest sbom artifact to actions (if dry-run)
|
||||
if: inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true'
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: latest-sbom-artifact
|
||||
path: ./coder_latest_sbom.spdx.json
|
||||
|
||||
@@ -39,7 +39,7 @@ jobs:
|
||||
|
||||
# Upload the results as artifacts.
|
||||
- name: "Upload artifact"
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: SARIF file
|
||||
path: results.sarif
|
||||
@@ -47,6 +47,6 @@ jobs:
|
||||
|
||||
# Upload the results to GitHub's code scanning dashboard.
|
||||
- name: "Upload to code-scanning"
|
||||
uses: github/codeql-action/upload-sarif@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
|
||||
uses: github/codeql-action/upload-sarif@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-go
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
|
||||
uses: github/codeql-action/init@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
|
||||
with:
|
||||
languages: go, javascript
|
||||
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
rm Makefile
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
|
||||
uses: github/codeql-action/analyze@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
|
||||
|
||||
- name: Send Slack notification on failure
|
||||
if: ${{ failure() }}
|
||||
@@ -154,13 +154,13 @@ jobs:
|
||||
severity: "CRITICAL,HIGH"
|
||||
|
||||
- name: Upload Trivy scan results to GitHub Security tab
|
||||
uses: github/codeql-action/upload-sarif@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
|
||||
uses: github/codeql-action/upload-sarif@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
|
||||
with:
|
||||
sarif_file: trivy-results.sarif
|
||||
category: "Trivy"
|
||||
|
||||
- name: Upload Trivy scan results as an artifact
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: trivy
|
||||
path: trivy-results.sarif
|
||||
|
||||
@@ -125,7 +125,7 @@ jobs:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Delete PR Cleanup workflow runs
|
||||
uses: Mattraks/delete-workflow-runs@ab482449ba468316e9a8801e092d0405715c5e6d # v2.1.0
|
||||
uses: Mattraks/delete-workflow-runs@39f0bbed25d76b34de5594dceab824811479e5de # v2.0.6
|
||||
with:
|
||||
token: ${{ github.token }}
|
||||
repository: ${{ github.repository }}
|
||||
@@ -134,7 +134,7 @@ jobs:
|
||||
delete_workflow_pattern: pr-cleanup.yaml
|
||||
|
||||
- name: Delete PR Deploy workflow skipped runs
|
||||
uses: Mattraks/delete-workflow-runs@ab482449ba468316e9a8801e092d0405715c5e6d # v2.1.0
|
||||
uses: Mattraks/delete-workflow-runs@39f0bbed25d76b34de5594dceab824811479e5de # v2.0.6
|
||||
with:
|
||||
token: ${{ github.token }}
|
||||
repository: ${{ github.repository }}
|
||||
|
||||
@@ -89,5 +89,3 @@ result
|
||||
__debug_bin*
|
||||
|
||||
**/.claude/settings.local.json
|
||||
|
||||
/.env
|
||||
|
||||
+12
@@ -18,6 +18,18 @@ coderd/rbac/ @Emyrk
|
||||
scripts/apitypings/ @Emyrk
|
||||
scripts/gensite/ @aslilac
|
||||
|
||||
site/ @aslilac @Parkreiner
|
||||
site/src/hooks/ @Parkreiner
|
||||
# These rules intentionally do not specify any owners. More specific rules
|
||||
# override less specific rules, so these files are "ignored" by the site/ rule.
|
||||
site/e2e/google/protobuf/timestampGenerated.ts
|
||||
site/e2e/provisionerGenerated.ts
|
||||
site/src/api/countriesGenerated.ts
|
||||
site/src/api/rbacresourcesGenerated.ts
|
||||
site/src/api/typesGenerated.ts
|
||||
site/src/testHelpers/entities.ts
|
||||
site/CLAUDE.md
|
||||
|
||||
# The blood and guts of the autostop algorithm, which is quite complex and
|
||||
# requires elite ball knowledge of most of the scheduling code to make changes
|
||||
# without inadvertently affecting other parts of the codebase.
|
||||
|
||||
@@ -636,8 +636,8 @@ TAILNETTEST_MOCKS := \
|
||||
tailnet/tailnettest/subscriptionmock.go
|
||||
|
||||
AIBRIDGED_MOCKS := \
|
||||
enterprise/aibridged/aibridgedmock/clientmock.go \
|
||||
enterprise/aibridged/aibridgedmock/poolmock.go
|
||||
enterprise/x/aibridged/aibridgedmock/clientmock.go \
|
||||
enterprise/x/aibridged/aibridgedmock/poolmock.go
|
||||
|
||||
GEN_FILES := \
|
||||
tailnet/proto/tailnet.pb.go \
|
||||
@@ -645,7 +645,7 @@ GEN_FILES := \
|
||||
provisionersdk/proto/provisioner.pb.go \
|
||||
provisionerd/proto/provisionerd.pb.go \
|
||||
vpn/vpn.pb.go \
|
||||
enterprise/aibridged/proto/aibridged.pb.go \
|
||||
enterprise/x/aibridged/proto/aibridged.pb.go \
|
||||
$(DB_GEN_FILES) \
|
||||
$(SITE_GEN_FILES) \
|
||||
coderd/rbac/object_gen.go \
|
||||
@@ -697,7 +697,7 @@ gen/mark-fresh:
|
||||
provisionersdk/proto/provisioner.pb.go \
|
||||
provisionerd/proto/provisionerd.pb.go \
|
||||
vpn/vpn.pb.go \
|
||||
enterprise/aibridged/proto/aibridged.pb.go \
|
||||
enterprise/x/aibridged/proto/aibridged.pb.go \
|
||||
coderd/database/dump.sql \
|
||||
$(DB_GEN_FILES) \
|
||||
site/src/api/typesGenerated.ts \
|
||||
@@ -768,8 +768,8 @@ codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agen
|
||||
go generate ./codersdk/workspacesdk/agentconnmock/
|
||||
touch "$@"
|
||||
|
||||
$(AIBRIDGED_MOCKS): enterprise/aibridged/client.go enterprise/aibridged/pool.go
|
||||
go generate ./enterprise/aibridged/aibridgedmock/
|
||||
$(AIBRIDGED_MOCKS): enterprise/x/aibridged/client.go enterprise/x/aibridged/pool.go
|
||||
go generate ./enterprise/x/aibridged/aibridgedmock/
|
||||
touch "$@"
|
||||
|
||||
agent/agentcontainers/dcspec/dcspec_gen.go: \
|
||||
@@ -822,13 +822,13 @@ vpn/vpn.pb.go: vpn/vpn.proto
|
||||
--go_opt=paths=source_relative \
|
||||
./vpn/vpn.proto
|
||||
|
||||
enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged.proto
|
||||
enterprise/x/aibridged/proto/aibridged.pb.go: enterprise/x/aibridged/proto/aibridged.proto
|
||||
protoc \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./enterprise/aibridged/proto/aibridged.proto
|
||||
./enterprise/x/aibridged/proto/aibridged.proto
|
||||
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
|
||||
# -C sets the directory for the go run command
|
||||
@@ -1182,8 +1182,3 @@ endif
|
||||
|
||||
dogfood/coder/nix.hash: flake.nix flake.lock
|
||||
sha256sum flake.nix flake.lock >./dogfood/coder/nix.hash
|
||||
|
||||
# Count the number of test databases created per test package.
|
||||
count-test-databases:
|
||||
PGPASSWORD=postgres psql -h localhost -U postgres -d coder_testing -P pager=off -c 'SELECT test_package, count(*) as count from test_databases GROUP BY test_package ORDER BY count DESC'
|
||||
.PHONY: count-test-databases
|
||||
|
||||
@@ -40,6 +40,7 @@ import (
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentscripts"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/agent/proto/resourcesmonitor"
|
||||
@@ -91,6 +92,7 @@ type Options struct {
|
||||
Devcontainers bool
|
||||
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
|
||||
Clock quartz.Clock
|
||||
SocketPath string // Path for the agent socket server
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
@@ -190,6 +192,7 @@ func New(options Options) Agent {
|
||||
|
||||
devcontainers: options.Devcontainers,
|
||||
containerAPIOptions: options.DevcontainerAPIOptions,
|
||||
socketPath: options.SocketPath,
|
||||
}
|
||||
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
|
||||
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
|
||||
@@ -271,6 +274,10 @@ type agent struct {
|
||||
devcontainers bool
|
||||
containerAPIOptions []agentcontainers.Option
|
||||
containerAPI *agentcontainers.API
|
||||
|
||||
// Socket server for CLI communication
|
||||
socketPath string
|
||||
socketServer *agentsocket.Server
|
||||
}
|
||||
|
||||
func (a *agent) TailnetConn() *tailnet.Conn {
|
||||
@@ -350,9 +357,69 @@ func (a *agent) init() {
|
||||
s.ExperimentalContainers = a.devcontainers
|
||||
},
|
||||
)
|
||||
|
||||
// Initialize socket server for CLI communication
|
||||
a.initSocketServer()
|
||||
|
||||
go a.runLoop()
|
||||
}
|
||||
|
||||
// initSocketServer initializes the socket server for CLI communication
|
||||
func (a *agent) initSocketServer() {
|
||||
// Get socket path from options or environment
|
||||
socketPath := a.getSocketPath()
|
||||
if socketPath == "" {
|
||||
a.logger.Debug(a.hardCtx, "socket server disabled (no path configured)")
|
||||
return
|
||||
}
|
||||
|
||||
// Create socket server
|
||||
server := agentsocket.NewServer(agentsocket.Config{
|
||||
Path: socketPath,
|
||||
Logger: a.logger.Named("socket"),
|
||||
})
|
||||
|
||||
// Register default handlers
|
||||
handlerCtx := agentsocket.CreateHandlerContext(
|
||||
"", // Agent ID will be set when manifest is available
|
||||
buildinfo.Version(),
|
||||
"starting",
|
||||
time.Now(),
|
||||
a.logger,
|
||||
)
|
||||
agentsocket.RegisterDefaultHandlers(server, handlerCtx)
|
||||
|
||||
// Start the server
|
||||
if err := server.Start(); err != nil {
|
||||
a.logger.Warn(a.hardCtx, "failed to start socket server", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
a.socketServer = server
|
||||
a.logger.Info(a.hardCtx, "socket server started", slog.F("path", socketPath))
|
||||
}
|
||||
|
||||
// getSocketPath returns the socket path from options or environment
|
||||
func (a *agent) getSocketPath() string {
|
||||
// Check if socket path is explicitly configured
|
||||
if a.getSocketPathFromOptions() != "" {
|
||||
return a.getSocketPathFromOptions()
|
||||
}
|
||||
|
||||
// Check environment variable
|
||||
if path := os.Getenv("CODER_AGENT_SOCKET_PATH"); path != "" {
|
||||
return path
|
||||
}
|
||||
|
||||
// Return empty to disable socket server
|
||||
return ""
|
||||
}
|
||||
|
||||
// getSocketPathFromOptions returns the socket path from agent options
|
||||
func (a *agent) getSocketPathFromOptions() string {
|
||||
return a.socketPath
|
||||
}
|
||||
|
||||
// runLoop attempts to start the agent in a retry loop.
|
||||
// Coder may be offline temporarily, a connection issue
|
||||
// may be happening, but regardless after the intermittent
|
||||
@@ -1931,6 +1998,13 @@ func (a *agent) Close() error {
|
||||
a.logger.Error(a.hardCtx, "container API close", slog.Error(err))
|
||||
}
|
||||
|
||||
// Close socket server
|
||||
if a.socketServer != nil {
|
||||
if err := a.socketServer.Stop(); err != nil {
|
||||
a.logger.Error(a.hardCtx, "socket server close", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for the graceful shutdown to complete, but don't wait forever so
|
||||
// that we don't break user expectations.
|
||||
go func() {
|
||||
|
||||
@@ -682,6 +682,8 @@ func (api *API) updaterLoop() {
|
||||
} else {
|
||||
prevErr = nil
|
||||
}
|
||||
default:
|
||||
api.logger.Debug(api.ctx, "updater loop ticker skipped, update in progress")
|
||||
}
|
||||
|
||||
return nil // Always nil to keep the ticker going.
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
# Agent Socket API
|
||||
|
||||
The Agent Socket API provides a local communication channel between CLI commands running within a workspace and the Coder agent process. This enables new CLI commands to interact directly with the agent without going through the control plane.
|
||||
|
||||
## Overview
|
||||
|
||||
The socket server runs within the agent process and listens on a Unix domain socket (or named pipe on Windows). CLI commands can connect to this socket to query agent information, check health status, and perform other operations.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Socket Server
|
||||
- **Location**: `agent/agentsocket/`
|
||||
- **Protocol**: JSON-RPC 2.0 over Unix domain socket
|
||||
- **Platform Support**: Linux, macOS, Windows 10+ (build 17063+)
|
||||
- **Authentication**: Pluggable middleware (no-auth by default)
|
||||
|
||||
### Client Library
|
||||
- **Location**: `codersdk/agentsdk/socket_client.go`
|
||||
- **Auto-discovery**: Automatically finds socket path
|
||||
- **Type-safe**: Go client with proper error handling
|
||||
|
||||
## Socket Path Discovery
|
||||
|
||||
The socket path is determined in the following order:
|
||||
|
||||
1. **Environment Variable**: `CODER_AGENT_SOCKET_PATH`
|
||||
2. **XDG Runtime Directory**: `$XDG_RUNTIME_DIR/coder-agent.sock`
|
||||
3. **User Temp Directory**: `/tmp/coder-agent-{uid}.sock`
|
||||
4. **Fallback**: `/tmp/coder-agent.sock`
|
||||
|
||||
## Protocol
|
||||
|
||||
### Request Format
|
||||
```json
|
||||
{
|
||||
"version": "1.0",
|
||||
"method": "ping",
|
||||
"id": "request-123",
|
||||
"params": {}
|
||||
}
|
||||
```
|
||||
|
||||
### Response Format
|
||||
```json
|
||||
{
|
||||
"version": "1.0",
|
||||
"id": "request-123",
|
||||
"result": {
|
||||
"message": "pong",
|
||||
"timestamp": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Error Format
|
||||
```json
|
||||
{
|
||||
"version": "1.0",
|
||||
"id": "request-123",
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": "Method not found",
|
||||
"data": "nonexistent"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Available Methods
|
||||
|
||||
### Core Methods
|
||||
- `ping` - Health check with timestamp
|
||||
- `health` - Agent status and uptime
|
||||
- `agent.info` - Detailed agent information
|
||||
- `methods.list` - List available methods
|
||||
|
||||
### Example Usage
|
||||
|
||||
```go
|
||||
// Create client
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Ping the agent
|
||||
pingResp, err := client.Ping(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Agent responded: %s\n", pingResp.Message)
|
||||
|
||||
// Get agent info
|
||||
info, err := client.AgentInfo(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Agent ID: %s, Version: %s\n", info.ID, info.Version)
|
||||
```
|
||||
|
||||
## Adding New Handlers
|
||||
|
||||
### Server Side
|
||||
```go
|
||||
// Register a new handler
|
||||
server.RegisterHandler("custom.method", func(ctx Context, req *Request) (*Response, error) {
|
||||
// Handle the request
|
||||
result := map[string]string{"status": "ok"}
|
||||
return NewResponse(req.ID, result)
|
||||
})
|
||||
```
|
||||
|
||||
### Client Side
|
||||
```go
|
||||
// Add method to client
|
||||
func (c *SocketClient) CustomMethod(ctx context.Context) (*CustomResponse, error) {
|
||||
req := &Request{
|
||||
Version: "1.0",
|
||||
Method: "custom.method",
|
||||
ID: generateRequestID(),
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return nil, fmt.Errorf("custom method error: %s", resp.Error.Message)
|
||||
}
|
||||
|
||||
var result CustomResponse
|
||||
if err := json.Unmarshal(resp.Result, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
```
|
||||
|
||||
## Authentication
|
||||
|
||||
The socket server supports pluggable authentication middleware. By default, no authentication is performed (suitable for local-only communication).
|
||||
|
||||
### Custom Authentication
|
||||
```go
|
||||
type CustomAuthMiddleware struct {
|
||||
// Add auth fields
|
||||
}
|
||||
|
||||
func (m *CustomAuthMiddleware) Authenticate(ctx context.Context, conn net.Conn) (context.Context, error) {
|
||||
// Implement authentication logic
|
||||
// Return context with auth info or error
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
// Use in server config
|
||||
server := agentsocket.NewServer(agentsocket.Config{
|
||||
Path: socketPath,
|
||||
Logger: logger,
|
||||
AuthMiddleware: &CustomAuthMiddleware{},
|
||||
})
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Agent Options
|
||||
```go
|
||||
options := agent.Options{
|
||||
// ... other options
|
||||
SocketPath: "/custom/path/agent.sock", // Optional, uses auto-discovery if empty
|
||||
}
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
- `CODER_AGENT_SOCKET_PATH` - Override socket path
|
||||
- `XDG_RUNTIME_DIR` - Used for socket path discovery
|
||||
|
||||
## Error Codes
|
||||
|
||||
| Code | Description |
|
||||
|------|-------------|
|
||||
| -32700 | Parse error |
|
||||
| -32600 | Invalid request |
|
||||
| -32601 | Method not found |
|
||||
| -32602 | Invalid params |
|
||||
| -32603 | Internal error |
|
||||
|
||||
## Platform Support
|
||||
|
||||
### Unix-like Systems (Linux, macOS)
|
||||
- Uses Unix domain sockets
|
||||
- Socket file permissions: 600 (owner read/write only)
|
||||
- Auto-cleanup on shutdown
|
||||
|
||||
### Windows
|
||||
- Uses Unix domain sockets (Windows 10 build 17063+)
|
||||
- Falls back to named pipes if needed
|
||||
- Simplified permission handling
|
||||
|
||||
## Security Considerations
|
||||
|
||||
1. **Local Only**: Socket is only accessible from within the workspace
|
||||
2. **File Permissions**: Socket file is restricted to owner only
|
||||
3. **No Network Access**: Unix domain sockets don't traverse network
|
||||
4. **Authentication Ready**: Middleware pattern allows future auth implementation
|
||||
|
||||
## Future Extensibility
|
||||
|
||||
The design supports:
|
||||
- **Protocol Versioning**: Request includes version field
|
||||
- **Multiple Transports**: Interface-based design allows TCP/WebSocket later
|
||||
- **Auth Plugins**: Middleware pattern for various auth methods
|
||||
- **Custom Handlers**: Simple registration pattern for new commands
|
||||
@@ -0,0 +1,23 @@
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// AuthMiddleware defines the interface for authentication middleware
|
||||
type AuthMiddleware interface {
|
||||
// Authenticate authenticates a connection and returns a context with auth info
|
||||
Authenticate(ctx context.Context, conn net.Conn) (context.Context, error)
|
||||
}
|
||||
|
||||
// NoAuthMiddleware is a no-op authentication middleware
|
||||
type NoAuthMiddleware struct{}
|
||||
|
||||
// Authenticate implements AuthMiddleware but performs no authentication
|
||||
func (*NoAuthMiddleware) Authenticate(ctx context.Context, conn net.Conn) (context.Context, error) {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
// Ensure NoAuthMiddleware implements AuthMiddleware
|
||||
var _ AuthMiddleware = (*NoAuthMiddleware)(nil)
|
||||
@@ -0,0 +1,108 @@
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
// AgentInfo represents information about the agent
|
||||
type AgentInfo struct {
|
||||
ID string `json:"id"`
|
||||
Version string `json:"version"`
|
||||
Status string `json:"status"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
Uptime string `json:"uptime"`
|
||||
}
|
||||
|
||||
// PingResponse represents a ping response
|
||||
type PingResponse struct {
|
||||
Message string `json:"message"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// HealthResponse represents a health check response
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Uptime string `json:"uptime"`
|
||||
}
|
||||
|
||||
// HandlerContext provides context for handlers
|
||||
type HandlerContext struct {
|
||||
AgentID string
|
||||
Version string
|
||||
Status string
|
||||
StartedAt time.Time
|
||||
Logger slog.Logger
|
||||
}
|
||||
|
||||
// NewHandlers creates the default set of handlers
|
||||
func NewHandlers(handlerCtx HandlerContext) map[string]Handler {
|
||||
handlers := make(map[string]Handler)
|
||||
|
||||
// Ping handler
|
||||
handlers["ping"] = func(_ Context, req *Request) (*Response, error) {
|
||||
resp := PingResponse{
|
||||
Message: "pong",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
return NewResponse(req.ID, resp)
|
||||
}
|
||||
|
||||
// Health check handler
|
||||
handlers["health"] = func(_ Context, req *Request) (*Response, error) {
|
||||
uptime := time.Since(handlerCtx.StartedAt)
|
||||
resp := HealthResponse{
|
||||
Status: handlerCtx.Status,
|
||||
Timestamp: time.Now(),
|
||||
Uptime: uptime.String(),
|
||||
}
|
||||
return NewResponse(req.ID, resp)
|
||||
}
|
||||
|
||||
// Agent info handler
|
||||
handlers["agent.info"] = func(_ Context, req *Request) (*Response, error) {
|
||||
uptime := time.Since(handlerCtx.StartedAt)
|
||||
resp := AgentInfo{
|
||||
ID: handlerCtx.AgentID,
|
||||
Version: handlerCtx.Version,
|
||||
Status: handlerCtx.Status,
|
||||
StartedAt: handlerCtx.StartedAt,
|
||||
Uptime: uptime.String(),
|
||||
}
|
||||
return NewResponse(req.ID, resp)
|
||||
}
|
||||
|
||||
// List methods handler
|
||||
handlers["methods.list"] = func(_ Context, req *Request) (*Response, error) {
|
||||
methods := []string{
|
||||
"ping",
|
||||
"health",
|
||||
"agent.info",
|
||||
"methods.list",
|
||||
}
|
||||
return NewResponse(req.ID, methods)
|
||||
}
|
||||
|
||||
return handlers
|
||||
}
|
||||
|
||||
// RegisterDefaultHandlers registers the default set of handlers with a server
|
||||
func RegisterDefaultHandlers(server *Server, ctx HandlerContext) {
|
||||
handlers := NewHandlers(ctx)
|
||||
for method, handler := range handlers {
|
||||
server.RegisterHandler(method, handler)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateHandlerContext creates a handler context from agent information
|
||||
func CreateHandlerContext(agentID, version, status string, startedAt time.Time, logger slog.Logger) HandlerContext {
|
||||
return HandlerContext{
|
||||
AgentID: agentID,
|
||||
Version: version,
|
||||
Status: status,
|
||||
StartedAt: startedAt,
|
||||
Logger: logger,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// Protocol version for the agent socket API
|
||||
const ProtocolVersion = "1.0"
|
||||
|
||||
// Request represents an incoming request to the agent socket
|
||||
type Request struct {
|
||||
Version string `json:"version"`
|
||||
Method string `json:"method"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// Response represents a response from the agent socket
|
||||
type Response struct {
|
||||
Version string `json:"version"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// Error represents an error in the response
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Standard error codes
|
||||
const (
|
||||
ErrCodeParseError = -32700
|
||||
ErrCodeInvalidRequest = -32600
|
||||
ErrCodeMethodNotFound = -32601
|
||||
ErrCodeInvalidParams = -32602
|
||||
ErrCodeInternalError = -32603
|
||||
)
|
||||
|
||||
// NewError creates a new error response
|
||||
func NewError(code int, message string, data any) *Error {
|
||||
return &Error{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// NewResponse creates a successful response
|
||||
func NewResponse(id string, result any) (*Response, error) {
|
||||
resultBytes, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal result: %w", err)
|
||||
}
|
||||
|
||||
return &Response{
|
||||
Version: ProtocolVersion,
|
||||
ID: id,
|
||||
Result: resultBytes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewErrorResponse creates an error response
|
||||
func NewErrorResponse(id string, err *Error) *Response {
|
||||
return &Response{
|
||||
Version: ProtocolVersion,
|
||||
ID: id,
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Handler represents a function that can handle a request
|
||||
type Handler func(ctx Context, req *Request) (*Response, error)
|
||||
|
||||
// Context provides context for request handling
|
||||
type Context struct {
|
||||
// Additional context can be added here in the future
|
||||
// For now, this is a placeholder for future auth context, etc.
|
||||
}
|
||||
@@ -0,0 +1,266 @@
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
// Server represents the agent socket server
|
||||
type Server struct {
|
||||
logger slog.Logger
|
||||
path string
|
||||
listener net.Listener
|
||||
handlers map[string]Handler
|
||||
authMiddleware AuthMiddleware
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// Config holds configuration for the socket server
|
||||
type Config struct {
|
||||
Path string
|
||||
Logger slog.Logger
|
||||
AuthMiddleware AuthMiddleware
|
||||
}
|
||||
|
||||
// NewServer creates a new agent socket server
|
||||
func NewServer(config Config) *Server {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
server := &Server{
|
||||
logger: config.Logger.Named("agentsocket"),
|
||||
path: config.Path,
|
||||
handlers: make(map[string]Handler),
|
||||
authMiddleware: config.AuthMiddleware,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Set default auth middleware if none provided
|
||||
if server.authMiddleware == nil {
|
||||
server.authMiddleware = &NoAuthMiddleware{}
|
||||
}
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
// Start starts the socket server
|
||||
func (s *Server) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listener != nil {
|
||||
return xerrors.New("server already started")
|
||||
}
|
||||
|
||||
// Get socket path
|
||||
path := s.path
|
||||
if path == "" {
|
||||
var err error
|
||||
path, err = getDefaultSocketPath()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get default socket path: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if socket is available
|
||||
if !isSocketAvailable(path) {
|
||||
return xerrors.Errorf("socket path %s is not available", path)
|
||||
}
|
||||
|
||||
// Create socket listener
|
||||
listener, err := createSocket(s.ctx, path)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create socket: %w", err)
|
||||
}
|
||||
|
||||
s.listener = listener
|
||||
s.path = path
|
||||
|
||||
s.logger.Info(s.ctx, "agent socket server started", slog.F("path", path))
|
||||
|
||||
// Start accepting connections
|
||||
s.wg.Add(1)
|
||||
go s.acceptConnections()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the socket server
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listener == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.logger.Info(s.ctx, "stopping agent socket server")
|
||||
|
||||
// Cancel context to stop accepting new connections
|
||||
s.cancel()
|
||||
|
||||
// Close listener
|
||||
if err := s.listener.Close(); err != nil {
|
||||
s.logger.Warn(s.ctx, "error closing socket listener", slog.Error(err))
|
||||
}
|
||||
|
||||
// Wait for all connections to finish
|
||||
s.wg.Wait()
|
||||
|
||||
// Clean up socket file
|
||||
if err := cleanupSocket(s.path); err != nil {
|
||||
s.logger.Warn(s.ctx, "error cleaning up socket file", slog.Error(err))
|
||||
}
|
||||
|
||||
s.listener = nil
|
||||
s.logger.Info(s.ctx, "agent socket server stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for a method
|
||||
func (s *Server) RegisterHandler(method string, handler Handler) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.handlers[method] = handler
|
||||
}
|
||||
|
||||
// GetPath returns the socket path
|
||||
func (s *Server) GetPath() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.path
|
||||
}
|
||||
|
||||
// acceptConnections accepts incoming connections
|
||||
func (s *Server) acceptConnections() {
|
||||
defer s.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
s.logger.Warn(s.ctx, "error accepting connection", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Handle connection in a goroutine
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.handleConnection(conn)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection handles a single connection
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
// Authenticate connection first to get context
|
||||
ctx, err := s.authMiddleware.Authenticate(s.ctx, conn)
|
||||
if err != nil {
|
||||
s.logger.Warn(s.ctx, "authentication failed", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Set connection deadline
|
||||
if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||
s.logger.Warn(ctx, "failed to set connection deadline", slog.Error(err))
|
||||
}
|
||||
|
||||
s.logger.Debug(ctx, "new connection accepted", slog.F("remote_addr", conn.RemoteAddr()))
|
||||
|
||||
// Handle requests
|
||||
decoder := json.NewDecoder(conn)
|
||||
encoder := json.NewEncoder(conn)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Set read deadline
|
||||
if err := conn.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||
s.logger.Warn(ctx, "failed to set read deadline", slog.Error(err))
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
if err == io.EOF {
|
||||
s.logger.Debug(ctx, "connection closed by client")
|
||||
return
|
||||
}
|
||||
s.logger.Warn(ctx, "error decoding request", slog.Error(err))
|
||||
|
||||
// Send error response
|
||||
resp := NewErrorResponse("", NewError(ErrCodeParseError, "Parse error", err.Error()))
|
||||
encoder.Encode(resp)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle request
|
||||
resp := s.handleRequest(ctx, &req)
|
||||
|
||||
// Send response
|
||||
if err := encoder.Encode(resp); err != nil {
|
||||
s.logger.Warn(ctx, "error sending response", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequest handles a single request
|
||||
func (s *Server) handleRequest(ctx context.Context, req *Request) *Response {
|
||||
// Validate request
|
||||
if req.Version != ProtocolVersion {
|
||||
return NewErrorResponse(req.ID, NewError(ErrCodeInvalidRequest, "Unsupported version", req.Version))
|
||||
}
|
||||
|
||||
if req.Method == "" {
|
||||
return NewErrorResponse(req.ID, NewError(ErrCodeInvalidRequest, "Missing method", nil))
|
||||
}
|
||||
|
||||
// Get handler
|
||||
s.mu.RLock()
|
||||
handler, exists := s.handlers[req.Method]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return NewErrorResponse(req.ID, NewError(ErrCodeMethodNotFound, "Method not found", req.Method))
|
||||
}
|
||||
|
||||
// Call handler
|
||||
type requestIDKey struct{}
|
||||
ctx = context.WithValue(ctx, requestIDKey{}, req.ID)
|
||||
resp, err := handler(Context{}, req)
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "handler execution failed", slog.Error(err))
|
||||
return NewErrorResponse(req.ID, NewError(ErrCodeInternalError, "Internal error", err.Error()))
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
@@ -0,0 +1,250 @@
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
func TestServer_StartStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create temporary socket path
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||||
|
||||
// Create server
|
||||
server := NewServer(Config{
|
||||
Path: socketPath,
|
||||
Logger: slog.Make().Leveled(slog.LevelDebug),
|
||||
})
|
||||
|
||||
// Register a test handler
|
||||
server.RegisterHandler("test", func(ctx Context, req *Request) (*Response, error) {
|
||||
return NewResponse(req.ID, map[string]string{"message": "test response"})
|
||||
})
|
||||
|
||||
// Start server
|
||||
err := server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
// Verify socket file exists
|
||||
_, err = os.Stat(socketPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test connection
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send test request
|
||||
req := Request{
|
||||
Version: "1.0",
|
||||
Method: "test",
|
||||
ID: "test-1",
|
||||
}
|
||||
|
||||
err = json.NewEncoder(conn).Encode(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read response
|
||||
var resp Response
|
||||
err = json.NewDecoder(conn).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test-1", resp.ID)
|
||||
assert.Nil(t, resp.Error)
|
||||
assert.NotNil(t, resp.Result)
|
||||
|
||||
// Verify response content
|
||||
var result map[string]string
|
||||
err = json.Unmarshal(resp.Result, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test response", result["message"])
|
||||
}
|
||||
|
||||
func TestServer_ErrorHandling(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create temporary socket path
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||||
|
||||
// Create server
|
||||
server := NewServer(Config{
|
||||
Path: socketPath,
|
||||
Logger: slog.Make().Leveled(slog.LevelDebug),
|
||||
})
|
||||
|
||||
// Start server
|
||||
err := server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
// Test connection
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send request for non-existent method
|
||||
req := Request{
|
||||
Version: "1.0",
|
||||
Method: "nonexistent",
|
||||
ID: "test-1",
|
||||
}
|
||||
|
||||
err = json.NewEncoder(conn).Encode(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read response
|
||||
var resp Response
|
||||
err = json.NewDecoder(conn).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test-1", resp.ID)
|
||||
assert.NotNil(t, resp.Error)
|
||||
assert.Equal(t, ErrCodeMethodNotFound, resp.Error.Code)
|
||||
assert.Equal(t, "Method not found", resp.Error.Message)
|
||||
}
|
||||
|
||||
func TestServer_DefaultHandlers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create temporary socket path
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||||
|
||||
// Create server
|
||||
server := NewServer(Config{
|
||||
Path: socketPath,
|
||||
Logger: slog.Make().Leveled(slog.LevelDebug),
|
||||
})
|
||||
|
||||
// Register default handlers
|
||||
handlerCtx := CreateHandlerContext(
|
||||
"test-agent-id",
|
||||
"1.0.0",
|
||||
"ready",
|
||||
time.Now().Add(-time.Hour),
|
||||
slog.Make(),
|
||||
)
|
||||
RegisterDefaultHandlers(server, handlerCtx)
|
||||
|
||||
// Start server
|
||||
err := server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
// Test ping
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
req := Request{
|
||||
Version: "1.0",
|
||||
Method: "ping",
|
||||
ID: "ping-1",
|
||||
}
|
||||
|
||||
err = json.NewEncoder(conn).Encode(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var resp Response
|
||||
err = json.NewDecoder(conn).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "ping-1", resp.ID)
|
||||
assert.Nil(t, resp.Error)
|
||||
|
||||
var pingResp PingResponse
|
||||
err = json.Unmarshal(resp.Result, &pingResp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pong", pingResp.Message)
|
||||
}
|
||||
|
||||
func TestServer_ConcurrentConnections(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create temporary socket path
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||||
|
||||
// Create server
|
||||
server := NewServer(Config{
|
||||
Path: socketPath,
|
||||
Logger: slog.Make().Leveled(slog.LevelDebug),
|
||||
})
|
||||
|
||||
// Register a test handler
|
||||
server.RegisterHandler("test", func(ctx Context, req *Request) (*Response, error) {
|
||||
time.Sleep(10 * time.Millisecond) // Simulate some work
|
||||
return NewResponse(req.ID, map[string]string{"message": "test response"})
|
||||
})
|
||||
|
||||
// Start server
|
||||
err := server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
// Test multiple concurrent connections
|
||||
const numConnections = 5
|
||||
results := make(chan error, numConnections)
|
||||
|
||||
for i := 0; i < numConnections; i++ {
|
||||
go func(i int) {
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
if err != nil {
|
||||
results <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
req := Request{
|
||||
Version: "1.0",
|
||||
Method: "test",
|
||||
ID: fmt.Sprintf("test-%d", i),
|
||||
}
|
||||
|
||||
err = json.NewEncoder(conn).Encode(req)
|
||||
if err != nil {
|
||||
results <- err
|
||||
return
|
||||
}
|
||||
|
||||
var resp Response
|
||||
err = json.NewDecoder(conn).Decode(&resp)
|
||||
if err != nil {
|
||||
results <- err
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
results <- xerrors.Errorf("server error: %s", resp.Error.Message)
|
||||
return
|
||||
}
|
||||
|
||||
results <- nil
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all connections to complete
|
||||
for i := 0; i < numConnections; i++ {
|
||||
select {
|
||||
case err := <-results:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for concurrent connections")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
//go:build !windows
|
||||
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// createSocket creates a Unix domain socket listener
|
||||
func createSocket(ctx context.Context, path string) (net.Listener, error) {
|
||||
// Remove existing socket file if it exists
|
||||
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||
return nil, xerrors.Errorf("remove existing socket: %w", err)
|
||||
}
|
||||
|
||||
// Create parent directory if it doesn't exist
|
||||
parentDir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(parentDir, 0o700); err != nil {
|
||||
return nil, xerrors.Errorf("create socket directory: %w", err)
|
||||
}
|
||||
|
||||
// Create Unix domain socket listener
|
||||
listener, err := net.Listen("unix", path)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("listen on unix socket: %w", err)
|
||||
}
|
||||
|
||||
// Set socket permissions to be accessible only by the current user
|
||||
if err := os.Chmod(path, 0o600); err != nil {
|
||||
listener.Close()
|
||||
return nil, xerrors.Errorf("set socket permissions: %w", err)
|
||||
}
|
||||
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// getDefaultSocketPath returns the default socket path for Unix-like systems
|
||||
func getDefaultSocketPath() (string, error) {
|
||||
// Try XDG_RUNTIME_DIR first
|
||||
if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" {
|
||||
return filepath.Join(runtimeDir, "coder-agent.sock"), nil
|
||||
}
|
||||
|
||||
// Fall back to /tmp with user-specific path
|
||||
uid := os.Getuid()
|
||||
return filepath.Join("/tmp", fmt.Sprintf("coder-agent-%d.sock", uid)), nil
|
||||
}
|
||||
|
||||
// cleanupSocket removes the socket file
|
||||
func cleanupSocket(path string) error {
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
// isSocketAvailable checks if a socket path is available for use
|
||||
func isSocketAvailable(path string) bool {
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Try to connect to see if it's actually listening
|
||||
conn, err := net.Dial("unix", path)
|
||||
if err != nil {
|
||||
// If we can't connect, the socket is not in use
|
||||
return true
|
||||
}
|
||||
conn.Close()
|
||||
return false
|
||||
}
|
||||
|
||||
// getSocketInfo returns information about the socket file
|
||||
func getSocketInfo(path string) (*SocketInfo, error) {
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sys, ok := stat.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
return nil, xerrors.New("unable to get stat_t from file info")
|
||||
}
|
||||
return &SocketInfo{
|
||||
Path: path,
|
||||
UID: int(sys.Uid),
|
||||
GID: int(sys.Gid),
|
||||
Mode: stat.Mode(),
|
||||
ModTime: stat.ModTime(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SocketInfo contains information about a socket file
|
||||
type SocketInfo struct {
|
||||
Path string
|
||||
UID int
|
||||
GID int
|
||||
Mode os.FileMode
|
||||
ModTime time.Time
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
//go:build windows
|
||||
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// createSocket creates a Unix domain socket listener on Windows
|
||||
// Falls back to named pipe if Unix sockets are not supported
|
||||
func createSocket(ctx context.Context, path string) (net.Listener, error) {
|
||||
// Try Unix domain socket first (Windows 10 build 17063+)
|
||||
listener, err := net.Listen("unix", path)
|
||||
if err == nil {
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// Fall back to named pipe
|
||||
pipePath := `\\.\pipe\coder-agent`
|
||||
return net.Listen("tcp", pipePath)
|
||||
}
|
||||
|
||||
// getDefaultSocketPath returns the default socket path for Windows
|
||||
func getDefaultSocketPath() (string, error) {
|
||||
// Try to use a temporary directory
|
||||
tempDir := os.TempDir()
|
||||
if tempDir == "" {
|
||||
tempDir = "C:\\temp"
|
||||
}
|
||||
|
||||
// Create a user-specific subdirectory
|
||||
uid := os.Getuid()
|
||||
userDir := filepath.Join(tempDir, "coder-agent", strconv.Itoa(uid))
|
||||
|
||||
if err := os.MkdirAll(userDir, 0o700); err != nil {
|
||||
return "", fmt.Errorf("create user directory: %w", err)
|
||||
}
|
||||
|
||||
return filepath.Join(userDir, "agent.sock"), nil
|
||||
}
|
||||
|
||||
// cleanupSocket removes the socket file
|
||||
func cleanupSocket(path string) error {
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
// isSocketAvailable checks if a socket path is available for use
|
||||
func isSocketAvailable(path string) bool {
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Try to connect to see if it's actually listening
|
||||
conn, err := net.Dial("unix", path)
|
||||
if err != nil {
|
||||
// If we can't connect, the socket is not in use
|
||||
return true
|
||||
}
|
||||
conn.Close()
|
||||
return false
|
||||
}
|
||||
|
||||
// getSocketInfo returns information about the socket file
|
||||
func getSocketInfo(path string) (*SocketInfo, error) {
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// On Windows, we'll use a simplified approach for now
|
||||
// In a real implementation, you'd get the security descriptor
|
||||
return &SocketInfo{
|
||||
Path: path,
|
||||
UID: 0, // Simplified for now
|
||||
GID: 0, // Simplified for now
|
||||
Mode: stat.Mode(),
|
||||
ModTime: stat.ModTime(),
|
||||
Owner: "unknown",
|
||||
Group: "unknown",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SocketInfo contains information about a socket file
|
||||
type SocketInfo struct {
|
||||
Path string
|
||||
UID int
|
||||
GID int
|
||||
Mode os.FileMode
|
||||
ModTime time.Time
|
||||
Owner string // Windows SID string
|
||||
Group string // Windows SID string
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
package unit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// ErrConsumerNotFound is returned when a consumer ID is not registered.
|
||||
var ErrConsumerNotFound = xerrors.New("consumer not found")
|
||||
|
||||
// ErrCannotUpdateOtherConsumer is returned when attempting to update another consumer's status.
|
||||
var ErrCannotUpdateOtherConsumer = xerrors.New("cannot update other consumer's status")
|
||||
|
||||
// dependencyVertex represents a vertex in the dependency graph that is associated with a consumer.
|
||||
type dependencyVertex[ConsumerID comparable] struct {
|
||||
ID ConsumerID
|
||||
}
|
||||
|
||||
// Dependency represents a dependency relationship between consumers.
|
||||
type Dependency[StatusType, ConsumerID comparable] struct {
|
||||
Consumer ConsumerID
|
||||
DependsOn ConsumerID
|
||||
RequiredStatus StatusType
|
||||
CurrentStatus StatusType
|
||||
IsSatisfied bool
|
||||
}
|
||||
|
||||
// DependencyTracker provides reactive dependency tracking over a Graph.
|
||||
// It manages consumer registration, dependency relationships, and status updates
|
||||
// with automatic recalculation of readiness when dependencies are satisfied.
|
||||
type DependencyTracker[StatusType, ConsumerID comparable] struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// The underlying graph that stores dependency relationships
|
||||
graph *Graph[StatusType, *dependencyVertex[ConsumerID]]
|
||||
|
||||
// Track current status of each consumer
|
||||
consumerStatus map[ConsumerID]StatusType
|
||||
|
||||
// Track readiness state (cached to avoid repeated graph traversal)
|
||||
consumerReadiness map[ConsumerID]bool
|
||||
|
||||
// Track which consumers are registered
|
||||
registeredConsumers map[ConsumerID]bool
|
||||
|
||||
// Store vertex instances for each consumer to ensure consistent references
|
||||
consumerVertices map[ConsumerID]*dependencyVertex[ConsumerID]
|
||||
}
|
||||
|
||||
// NewDependencyTracker creates a new DependencyTracker instance.
|
||||
func NewDependencyTracker[StatusType, ConsumerID comparable]() *DependencyTracker[StatusType, ConsumerID] {
|
||||
return &DependencyTracker[StatusType, ConsumerID]{
|
||||
graph: &Graph[StatusType, *dependencyVertex[ConsumerID]]{},
|
||||
consumerStatus: make(map[ConsumerID]StatusType),
|
||||
consumerReadiness: make(map[ConsumerID]bool),
|
||||
registeredConsumers: make(map[ConsumerID]bool),
|
||||
consumerVertices: make(map[ConsumerID]*dependencyVertex[ConsumerID]),
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers a new consumer as a vertex in the dependency graph.
|
||||
func (dt *DependencyTracker[StatusType, ConsumerID]) Register(id ConsumerID) error {
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
if dt.registeredConsumers[id] {
|
||||
return xerrors.Errorf("consumer %v is already registered", id)
|
||||
}
|
||||
|
||||
// Create and store the vertex for this consumer
|
||||
vertex := &dependencyVertex[ConsumerID]{ID: id}
|
||||
dt.consumerVertices[id] = vertex
|
||||
dt.registeredConsumers[id] = true
|
||||
dt.consumerReadiness[id] = true // New consumers start as ready (no dependencies)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddDependency adds a dependency relationship between consumers.
|
||||
// The consumer depends on the dependsOn consumer reaching the requiredStatus.
|
||||
func (dt *DependencyTracker[StatusType, ConsumerID]) AddDependency(consumer ConsumerID, dependsOn ConsumerID, requiredStatus StatusType) error {
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return xerrors.Errorf("consumer %v is not registered", consumer)
|
||||
}
|
||||
if !dt.registeredConsumers[dependsOn] {
|
||||
return xerrors.Errorf("consumer %v is not registered", dependsOn)
|
||||
}
|
||||
|
||||
// Get the stored vertices for both consumers
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
dependsOnVertex := dt.consumerVertices[dependsOn]
|
||||
|
||||
// Add the dependency edge to the graph
|
||||
// The edge goes from consumer to dependsOn, representing the dependency
|
||||
err := dt.graph.AddEdge(consumerVertex, dependsOnVertex, requiredStatus)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to add dependency: %w", err)
|
||||
}
|
||||
|
||||
// Recalculate readiness for the consumer since it now has a dependency
|
||||
dt.recalculateReadinessUnsafe(consumer)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates a consumer's status and recalculates readiness for affected dependents.
|
||||
func (dt *DependencyTracker[StatusType, ConsumerID]) UpdateStatus(consumer ConsumerID, newStatus StatusType) error {
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return ErrConsumerNotFound
|
||||
}
|
||||
|
||||
// Update the consumer's status
|
||||
dt.consumerStatus[consumer] = newStatus
|
||||
|
||||
// Get all consumers that depend on this one (reverse adjacent vertices)
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
dependentEdges := dt.graph.GetReverseAdjacentVertices(consumerVertex)
|
||||
|
||||
// Recalculate readiness for all dependents
|
||||
for _, edge := range dependentEdges {
|
||||
dt.recalculateReadinessUnsafe(edge.From.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsReady checks if all dependencies for a consumer are satisfied.
|
||||
func (dt *DependencyTracker[StatusType, ConsumerID]) IsReady(consumer ConsumerID) (bool, error) {
|
||||
dt.mu.RLock()
|
||||
defer dt.mu.RUnlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return false, ErrConsumerNotFound
|
||||
}
|
||||
|
||||
return dt.consumerReadiness[consumer], nil
|
||||
}
|
||||
|
||||
// GetUnmetDependencies returns a list of unsatisfied dependencies for a consumer.
|
||||
func (dt *DependencyTracker[StatusType, ConsumerID]) GetUnmetDependencies(consumer ConsumerID) ([]Dependency[StatusType, ConsumerID], error) {
|
||||
dt.mu.RLock()
|
||||
defer dt.mu.RUnlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return nil, ErrConsumerNotFound
|
||||
}
|
||||
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
|
||||
|
||||
var unmetDependencies []Dependency[StatusType, ConsumerID]
|
||||
|
||||
for _, edge := range forwardEdges {
|
||||
dependsOnConsumer := edge.To.ID
|
||||
requiredStatus := edge.Edge
|
||||
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
|
||||
if !exists {
|
||||
// If the dependency consumer has no status, it's not satisfied
|
||||
var zeroStatus StatusType
|
||||
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
|
||||
Consumer: consumer,
|
||||
DependsOn: dependsOnConsumer,
|
||||
RequiredStatus: requiredStatus,
|
||||
CurrentStatus: zeroStatus, // Zero value
|
||||
IsSatisfied: false,
|
||||
})
|
||||
} else {
|
||||
isSatisfied := currentStatus == requiredStatus
|
||||
if !isSatisfied {
|
||||
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
|
||||
Consumer: consumer,
|
||||
DependsOn: dependsOnConsumer,
|
||||
RequiredStatus: requiredStatus,
|
||||
CurrentStatus: currentStatus,
|
||||
IsSatisfied: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return unmetDependencies, nil
|
||||
}
|
||||
|
||||
// recalculateReadinessUnsafe recalculates the readiness state for a consumer.
|
||||
// This method assumes the caller holds the write lock.
|
||||
func (dt *DependencyTracker[StatusType, ConsumerID]) recalculateReadinessUnsafe(consumer ConsumerID) {
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
|
||||
|
||||
// If there are no dependencies, the consumer is ready
|
||||
if len(forwardEdges) == 0 {
|
||||
dt.consumerReadiness[consumer] = true
|
||||
return
|
||||
}
|
||||
|
||||
// Check if all dependencies are satisfied
|
||||
allSatisfied := true
|
||||
for _, edge := range forwardEdges {
|
||||
dependsOnConsumer := edge.To.ID
|
||||
requiredStatus := edge.Edge
|
||||
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
|
||||
if !exists || currentStatus != requiredStatus {
|
||||
allSatisfied = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
dt.consumerReadiness[consumer] = allSatisfied
|
||||
}
|
||||
|
||||
// GetGraph returns the underlying graph for visualization and debugging.
|
||||
// This should be used carefully as it exposes the internal graph structure.
|
||||
func (dt *DependencyTracker[StatusType, ConsumerID]) GetGraph() *Graph[StatusType, *dependencyVertex[ConsumerID]] {
|
||||
return dt.graph
|
||||
}
|
||||
|
||||
// ExportDOT exports the dependency graph to DOT format for visualization.
|
||||
func (dt *DependencyTracker[StatusType, ConsumerID]) ExportDOT(name string) (string, error) {
|
||||
return dt.graph.ToDOT(name)
|
||||
}
|
||||
@@ -0,0 +1,692 @@
|
||||
package unit_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
)
|
||||
|
||||
type testStatus string
|
||||
|
||||
const (
|
||||
statusInitialized testStatus = "initialized"
|
||||
statusStarted testStatus = "started"
|
||||
statusRunning testStatus = "running"
|
||||
statusCompleted testStatus = "completed"
|
||||
)
|
||||
|
||||
type testConsumerID string
|
||||
|
||||
const (
|
||||
consumerA testConsumerID = "serviceA"
|
||||
consumerB testConsumerID = "serviceB"
|
||||
consumerC testConsumerID = "serviceC"
|
||||
consumerD testConsumerID = "serviceD"
|
||||
)
|
||||
|
||||
func TestDependencyTracker_Register(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
t.Run("RegisterNewConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Consumer should be ready initially (no dependencies)
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("RegisterDuplicateConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tracker.Register(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already registered")
|
||||
})
|
||||
|
||||
t.Run("RegisterMultipleConsumers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// All should be ready initially
|
||||
for _, consumer := range consumers {
|
||||
ready, err := tracker.IsReady(consumer)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_AddDependency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("AddDependencyBetweenRegisteredConsumers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A should no longer be ready (depends on B)
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// B should still be ready (no dependencies)
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("AddDependencyWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to add dependency to unregistered consumer
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not registered")
|
||||
})
|
||||
|
||||
t.Run("AddDependencyFromUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to add dependency from unregistered consumer
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not registered")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_UpdateStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("UpdateStatusTriggersReadinessRecalculation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially A is not ready
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "running" - A should become ready
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
err := tracker.UpdateStatus(consumerA, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
})
|
||||
|
||||
t.Run("LinearChainDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create chain: A depends on B being "started", B depends on C being "completed"
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially only C is ready
|
||||
ready, err := tracker.IsReady(consumerC)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "completed" - B should become ready
|
||||
err = tracker.UpdateStatus(consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "started" - A should become ready
|
||||
err = tracker.UpdateStatus(consumerB, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_GetUnmetDependencies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("GetUnmetDependenciesForConsumerWithNoDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, unmet)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesForConsumerWithUnsatisfiedDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, unmet, 1)
|
||||
|
||||
assert.Equal(t, consumerA, unmet[0].Consumer)
|
||||
assert.Equal(t, consumerB, unmet[0].DependsOn)
|
||||
assert.Equal(t, statusRunning, unmet[0].RequiredStatus)
|
||||
assert.False(t, unmet[0].IsSatisfied)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesForConsumerWithSatisfiedDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update B to "running"
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, unmet)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesForUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
assert.Nil(t, unmet)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_ConcurrentOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ConcurrentStatusUpdates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create dependencies: A depends on B, B depends on C, C depends on D
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const numGoroutines = 10
|
||||
|
||||
// Launch goroutines that update statuses
|
||||
errors := make([]error, numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Update D to completed (should make C ready)
|
||||
err := tracker.UpdateStatus(consumerD, statusCompleted)
|
||||
if err != nil {
|
||||
errors[goroutineID] = err
|
||||
return
|
||||
}
|
||||
|
||||
// Update C to started (should make B ready)
|
||||
err = tracker.UpdateStatus(consumerC, statusStarted)
|
||||
if err != nil {
|
||||
errors[goroutineID] = err
|
||||
return
|
||||
}
|
||||
|
||||
// Update B to running (should make A ready)
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
if err != nil {
|
||||
errors[goroutineID] = err
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check for any errors in goroutines
|
||||
for i, err := range errors {
|
||||
require.NoError(t, err, "goroutine %d had error", i)
|
||||
}
|
||||
|
||||
// All consumers should be ready after the updates
|
||||
for _, consumer := range consumers {
|
||||
ready, err := tracker.IsReady(consumer)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConcurrentReadinessChecks", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const numGoroutines = 20
|
||||
|
||||
// Launch goroutines that check readiness
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Check readiness multiple times
|
||||
for j := 0; j < 10; j++ {
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
// Initially should be false, then true after B is updated
|
||||
_ = ready
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
// B should always be ready (no dependencies)
|
||||
assert.True(t, ready)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Update B to "running" in the middle of readiness checks
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_MultipleDependencies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ConsumerWithMultipleDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// A depends on B being "running" AND C being "started"
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A should not be ready (depends on both B and C)
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "running" - A should still not be ready (needs C too)
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "started" - A should now be ready
|
||||
err = tracker.UpdateStatus(consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("ComplexDependencyChain", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create complex dependency graph:
|
||||
// A depends on B being "running" AND C being "started"
|
||||
// B depends on D being "completed"
|
||||
// C depends on D being "completed"
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially only D is ready
|
||||
ready, err := tracker.IsReady(consumerD)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerC)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update D to "completed" - B and C should become ready
|
||||
err = tracker.UpdateStatus(consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerC)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "running" - A should still not be ready (needs C)
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "started" - A should now be ready
|
||||
err = tracker.UpdateStatus(consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("DifferentStatusTypes", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerC)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running" AND C being "completed"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update B to "running" but not C - A should not be ready
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "completed" - A should now be ready
|
||||
err = tracker.UpdateStatus(consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_ErrorCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
err := tracker.UpdateStatus(consumerA, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
})
|
||||
|
||||
t.Run("IsReadyWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
assert.False(t, ready)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
assert.Nil(t, unmet)
|
||||
})
|
||||
|
||||
t.Run("AddDependencyWithUnregisteredConsumers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Try to add dependency with unregistered consumers
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not registered")
|
||||
})
|
||||
|
||||
t.Run("CyclicDependencyDetection", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to make B depend on A (creates cycle)
|
||||
err = tracker.AddDependency(consumerB, consumerA, statusStarted)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "would create a cycle")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_ToDOT(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ExportSimpleGraph", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add dependency
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
dot, err := tracker.ExportDOT("test")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, dot)
|
||||
assert.Contains(t, dot, "digraph")
|
||||
})
|
||||
|
||||
t.Run("ExportComplexGraph", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create complex dependency graph
|
||||
// A depends on B and C, B depends on D, C depends on D
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
dot, err := tracker.ExportDOT("complex")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, dot)
|
||||
assert.Contains(t, dot, "digraph")
|
||||
})
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/pflag"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
var (
|
||||
_ pflag.SliceValue = &AllowListFlag{}
|
||||
_ pflag.Value = &AllowListFlag{}
|
||||
)
|
||||
|
||||
// AllowListFlag implements pflag.SliceValue for codersdk.APIAllowListTarget entries.
|
||||
type AllowListFlag []codersdk.APIAllowListTarget
|
||||
|
||||
func AllowListFlagOf(al *[]codersdk.APIAllowListTarget) *AllowListFlag {
|
||||
return (*AllowListFlag)(al)
|
||||
}
|
||||
|
||||
func (a AllowListFlag) String() string {
|
||||
return strings.Join(a.GetSlice(), ",")
|
||||
}
|
||||
|
||||
func (a AllowListFlag) Value() []codersdk.APIAllowListTarget {
|
||||
return []codersdk.APIAllowListTarget(a)
|
||||
}
|
||||
|
||||
func (AllowListFlag) Type() string { return "allow-list" }
|
||||
|
||||
func (a *AllowListFlag) Set(set string) error {
|
||||
values, err := csv.NewReader(strings.NewReader(set)).Read()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse allow list entries as csv: %w", err)
|
||||
}
|
||||
for _, v := range values {
|
||||
if err := a.Append(v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AllowListFlag) Append(value string) error {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return xerrors.New("allow list entry cannot be empty")
|
||||
}
|
||||
var target codersdk.APIAllowListTarget
|
||||
if err := target.UnmarshalText([]byte(value)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*a = append(*a, target)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AllowListFlag) Replace(items []string) error {
|
||||
*a = []codersdk.APIAllowListTarget{}
|
||||
for _, item := range items {
|
||||
if err := a.Append(item); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AllowListFlag) GetSlice() []string {
|
||||
out := make([]string, len(*a))
|
||||
for i, entry := range *a {
|
||||
out[i] = entry.String()
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package cli_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"io"
|
||||
@@ -18,7 +19,10 @@ import (
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
@@ -39,22 +43,76 @@ func makeAITask(t *testing.T, db database.Store, orgID, adminID, ownerID uuid.UU
|
||||
},
|
||||
}).Do()
|
||||
|
||||
build := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
ws := database.WorkspaceTable{
|
||||
OrganizationID: orgID,
|
||||
OwnerID: ownerID,
|
||||
TemplateID: tv.Template.ID,
|
||||
}).
|
||||
}
|
||||
build := dbfake.WorkspaceBuild(t, db, ws).
|
||||
Seed(database.WorkspaceBuild{
|
||||
TemplateVersionID: tv.TemplateVersion.ID,
|
||||
Transition: transition,
|
||||
}).
|
||||
WithAgent().
|
||||
WithTask(database.TaskTable{
|
||||
Prompt: prompt,
|
||||
}, nil).
|
||||
Do()
|
||||
}).WithAgent().Do()
|
||||
dbgen.WorkspaceBuildParameters(t, db, []database.WorkspaceBuildParameter{
|
||||
{
|
||||
WorkspaceBuildID: build.Build.ID,
|
||||
Name: codersdk.AITaskPromptParameterName,
|
||||
Value: prompt,
|
||||
},
|
||||
})
|
||||
agents, err := db.GetWorkspaceAgentsByWorkspaceAndBuildNumber(
|
||||
dbauthz.AsSystemRestricted(context.Background()),
|
||||
database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{
|
||||
WorkspaceID: build.Workspace.ID,
|
||||
BuildNumber: build.Build.BuildNumber,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, agents)
|
||||
agentID := agents[0].ID
|
||||
|
||||
return build.Task
|
||||
// Create a workspace app and set it as the sidebar app.
|
||||
app := dbgen.WorkspaceApp(t, db, database.WorkspaceApp{
|
||||
AgentID: agentID,
|
||||
Slug: "task-sidebar",
|
||||
DisplayName: "Task Sidebar",
|
||||
External: false,
|
||||
})
|
||||
|
||||
// Update build flags to reference the sidebar app and HasAITask=true.
|
||||
err = db.UpdateWorkspaceBuildFlagsByID(
|
||||
dbauthz.AsSystemRestricted(context.Background()),
|
||||
database.UpdateWorkspaceBuildFlagsByIDParams{
|
||||
ID: build.Build.ID,
|
||||
HasAITask: sql.NullBool{Bool: true, Valid: true},
|
||||
HasExternalAgent: sql.NullBool{Bool: false, Valid: false},
|
||||
SidebarAppID: uuid.NullUUID{UUID: app.ID, Valid: true},
|
||||
UpdatedAt: build.Build.UpdatedAt,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a task record in the tasks table for the new data model.
|
||||
task := dbgen.Task(t, db, database.TaskTable{
|
||||
OrganizationID: orgID,
|
||||
OwnerID: ownerID,
|
||||
Name: build.Workspace.Name,
|
||||
WorkspaceID: uuid.NullUUID{UUID: build.Workspace.ID, Valid: true},
|
||||
TemplateVersionID: tv.TemplateVersion.ID,
|
||||
TemplateParameters: []byte("{}"),
|
||||
Prompt: prompt,
|
||||
CreatedAt: dbtime.Now(),
|
||||
})
|
||||
|
||||
// Link the task to the workspace app.
|
||||
dbgen.TaskWorkspaceApp(t, db, database.TaskWorkspaceApp{
|
||||
TaskID: task.ID,
|
||||
WorkspaceBuildNumber: build.Build.BuildNumber,
|
||||
WorkspaceAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
|
||||
WorkspaceAppID: uuid.NullUUID{UUID: app.ID, Valid: true},
|
||||
})
|
||||
|
||||
return task
|
||||
}
|
||||
|
||||
func TestExpTaskList(t *testing.T) {
|
||||
|
||||
@@ -293,6 +293,7 @@ func createAITaskTemplate(t *testing.T, client *codersdk.Client, orgID uuid.UUID
|
||||
{
|
||||
Type: &proto.Response_Plan{
|
||||
Plan: &proto.PlanComplete{
|
||||
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
|
||||
HasAiTasks: true,
|
||||
},
|
||||
},
|
||||
@@ -327,7 +328,9 @@ func createAITaskTemplate(t *testing.T, client *codersdk.Client, orgID uuid.UUID
|
||||
},
|
||||
AiTasks: []*proto.AITask{
|
||||
{
|
||||
AppId: taskAppID.String(),
|
||||
SidebarApp: &proto.AITaskSidebarApp{
|
||||
Id: taskAppID.String(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
-47
@@ -109,51 +109,6 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
}
|
||||
},
|
||||
),
|
||||
CompletionHandler: func(inv *serpent.Invocation) []string {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
res, err := client.Workspaces(inv.Context(), codersdk.WorkspaceFilter{
|
||||
Owner: codersdk.Me,
|
||||
})
|
||||
if err != nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var completions []string
|
||||
var wg sync.WaitGroup
|
||||
for _, ws := range res.Workspaces {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
resources, err := client.TemplateVersionResources(inv.Context(), ws.LatestBuild.TemplateVersionID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var agents []codersdk.WorkspaceAgent
|
||||
for _, resource := range resources {
|
||||
agents = append(agents, resource.Agents...)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(agents) == 1 {
|
||||
completions = append(completions, ws.Name)
|
||||
} else {
|
||||
for _, agent := range agents {
|
||||
completions = append(completions, fmt.Sprintf("%s.%s", ws.Name, agent.Name))
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
slices.Sort(completions)
|
||||
return completions
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) (retErr error) {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
@@ -951,8 +906,6 @@ func GetWorkspaceAndAgent(ctx context.Context, inv *serpent.Invocation, client *
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("start workspace with active template version: %w", err)
|
||||
}
|
||||
_, _ = fmt.Fprintln(inv.Stdout, "Unable to start the workspace with template version from last build. Your workspace has been updated to the current active template version.")
|
||||
default:
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("start workspace with current template version: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("start workspace with current template version: %w", err)
|
||||
|
||||
@@ -2447,99 +2447,3 @@ func tempDirUnixSocket(t *testing.T) string {
|
||||
|
||||
return t.TempDir()
|
||||
}
|
||||
|
||||
func TestSSH_Completion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("SingleAgent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, workspace, agentToken := setupWorkspaceForAgent(t)
|
||||
_ = agenttest.New(t, client.URL, agentToken)
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
|
||||
var stdout bytes.Buffer
|
||||
inv, root := clitest.New(t, "ssh", "")
|
||||
inv.Stdout = &stdout
|
||||
inv.Environ.Set("COMPLETION_MODE", "1")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// For single-agent workspaces, the only completion should be the
|
||||
// bare workspace name.
|
||||
output := stdout.String()
|
||||
t.Logf("Completion output: %q", output)
|
||||
require.Contains(t, output, workspace.Name)
|
||||
})
|
||||
|
||||
t.Run("MultiAgent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, store := coderdtest.NewWithDatabase(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
userClient, user := coderdtest.CreateAnotherUserMutators(t, client, first.OrganizationID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
|
||||
r.Username = "multiuser"
|
||||
})
|
||||
|
||||
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
|
||||
Name: "multiworkspace",
|
||||
OrganizationID: first.OrganizationID,
|
||||
OwnerID: user.ID,
|
||||
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
|
||||
return []*proto.Agent{
|
||||
{
|
||||
Name: "agent1",
|
||||
Auth: &proto.Agent_Token{},
|
||||
},
|
||||
{
|
||||
Name: "agent2",
|
||||
Auth: &proto.Agent_Token{},
|
||||
},
|
||||
}
|
||||
}).Do()
|
||||
|
||||
var stdout bytes.Buffer
|
||||
inv, root := clitest.New(t, "ssh", "")
|
||||
inv.Stdout = &stdout
|
||||
inv.Environ.Set("COMPLETION_MODE", "1")
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// For multi-agent workspaces, completions should include the
|
||||
// workspace.agent format but NOT the bare workspace name.
|
||||
output := stdout.String()
|
||||
t.Logf("Completion output: %q", output)
|
||||
lines := strings.Split(strings.TrimSpace(output), "\n")
|
||||
require.NotContains(t, lines, r.Workspace.Name)
|
||||
require.Contains(t, output, r.Workspace.Name+".agent1")
|
||||
require.Contains(t, output, r.Workspace.Name+".agent2")
|
||||
})
|
||||
|
||||
t.Run("NetworkError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var stdout bytes.Buffer
|
||||
inv, _ := clitest.New(t, "ssh", "")
|
||||
inv.Stdout = &stdout
|
||||
inv.Environ.Set("COMPLETION_MODE", "1")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
output := stdout.String()
|
||||
require.Empty(t, output)
|
||||
})
|
||||
}
|
||||
|
||||
+1
-2
@@ -90,7 +90,6 @@
|
||||
"allow_renames": false,
|
||||
"favorite": false,
|
||||
"next_start_at": "====[timestamp]=====",
|
||||
"is_prebuild": false,
|
||||
"task_id": null
|
||||
"is_prebuild": false
|
||||
}
|
||||
]
|
||||
|
||||
-35
@@ -80,41 +80,6 @@ OPTIONS:
|
||||
Periodically check for new releases of Coder and inform the owner. The
|
||||
check is performed once per day.
|
||||
|
||||
AIBRIDGE OPTIONS:
|
||||
--aibridge-anthropic-base-url string, $CODER_AIBRIDGE_ANTHROPIC_BASE_URL (default: https://api.anthropic.com/)
|
||||
The base URL of the Anthropic API.
|
||||
|
||||
--aibridge-anthropic-key string, $CODER_AIBRIDGE_ANTHROPIC_KEY
|
||||
The key to authenticate against the Anthropic API.
|
||||
|
||||
--aibridge-bedrock-access-key string, $CODER_AIBRIDGE_BEDROCK_ACCESS_KEY
|
||||
The access key to authenticate against the AWS Bedrock API.
|
||||
|
||||
--aibridge-bedrock-access-key-secret string, $CODER_AIBRIDGE_BEDROCK_ACCESS_KEY_SECRET
|
||||
The access key secret to use with the access key to authenticate
|
||||
against the AWS Bedrock API.
|
||||
|
||||
--aibridge-bedrock-model string, $CODER_AIBRIDGE_BEDROCK_MODEL (default: global.anthropic.claude-sonnet-4-5-20250929-v1:0)
|
||||
The model to use when making requests to the AWS Bedrock API.
|
||||
|
||||
--aibridge-bedrock-region string, $CODER_AIBRIDGE_BEDROCK_REGION
|
||||
The AWS Bedrock API region.
|
||||
|
||||
--aibridge-bedrock-small-fastmodel string, $CODER_AIBRIDGE_BEDROCK_SMALL_FAST_MODEL (default: global.anthropic.claude-haiku-4-5-20251001-v1:0)
|
||||
The small fast model to use when making requests to the AWS Bedrock
|
||||
API. Claude Code uses Haiku-class models to perform background tasks.
|
||||
See
|
||||
https://docs.claude.com/en/docs/claude-code/settings#environment-variables.
|
||||
|
||||
--aibridge-enabled bool, $CODER_AIBRIDGE_ENABLED (default: false)
|
||||
Whether to start an in-memory aibridged instance.
|
||||
|
||||
--aibridge-openai-base-url string, $CODER_AIBRIDGE_OPENAI_BASE_URL (default: https://api.openai.com/v1/)
|
||||
The base URL of the OpenAI API.
|
||||
|
||||
--aibridge-openai-key string, $CODER_AIBRIDGE_OPENAI_KEY
|
||||
The key to authenticate against the OpenAI API.
|
||||
|
||||
CLIENT OPTIONS:
|
||||
These options change the behavior of how clients interact with the Coder.
|
||||
Clients include the Coder CLI, Coder Desktop, IDE extensions, and the web UI.
|
||||
|
||||
-5
@@ -16,10 +16,6 @@ USAGE:
|
||||
|
||||
$ coder tokens ls
|
||||
|
||||
- Create a scoped token:
|
||||
|
||||
$ coder tokens create --scope workspace:read --allow workspace:<uuid>
|
||||
|
||||
- Remove a token by ID:
|
||||
|
||||
$ coder tokens rm WuoWs4ZsMX
|
||||
@@ -28,7 +24,6 @@ SUBCOMMANDS:
|
||||
create Create a token
|
||||
list List tokens
|
||||
remove Delete a token
|
||||
view Display detailed information about a token
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
|
||||
+1
-9
@@ -6,20 +6,12 @@ USAGE:
|
||||
Create a token
|
||||
|
||||
OPTIONS:
|
||||
--allow allow-list
|
||||
Repeatable allow-list entry (<type>:<uuid>, e.g. workspace:1234-...).
|
||||
|
||||
--lifetime string, $CODER_TOKEN_LIFETIME
|
||||
Duration for the token lifetime. Supports standard Go duration units
|
||||
(ns, us, ms, s, m, h) plus d (days) and y (years). Examples: 8h, 30d,
|
||||
1y, 1d12h30m.
|
||||
Specify a duration for the lifetime of the token.
|
||||
|
||||
-n, --name string, $CODER_TOKEN_NAME
|
||||
Specify a human-readable name.
|
||||
|
||||
--scope string-array
|
||||
Repeatable scope to attach to the token (e.g. workspace:read).
|
||||
|
||||
-u, --user string, $CODER_TOKEN_USER
|
||||
Specify the user to create the token for (Only works if logged in user
|
||||
is admin).
|
||||
|
||||
+1
-1
@@ -12,7 +12,7 @@ OPTIONS:
|
||||
Specifies whether all users' tokens will be listed or not (must have
|
||||
Owner role to see all tokens).
|
||||
|
||||
-c, --column [id|name|scopes|allow list|last used|expires at|created at|owner] (default: id,name,scopes,allow list,last used,expires at,created at)
|
||||
-c, --column [id|name|last used|expires at|created at|owner] (default: id,name,last used,expires at,created at)
|
||||
Columns to display in table output.
|
||||
|
||||
-o, --output table|json (default: table)
|
||||
|
||||
-16
@@ -1,16 +0,0 @@
|
||||
coder v0.0.0-devel
|
||||
|
||||
USAGE:
|
||||
coder tokens view [flags] <name|id>
|
||||
|
||||
Display detailed information about a token
|
||||
|
||||
OPTIONS:
|
||||
-c, --column [id|name|scopes|allow list|last used|expires at|created at|owner] (default: id,name,scopes,allow list,last used,expires at,created at,owner)
|
||||
Columns to display in table output.
|
||||
|
||||
-o, --output table|json (default: table)
|
||||
Output format.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
+4
-21
@@ -714,7 +714,8 @@ workspace_prebuilds:
|
||||
# (default: 3, type: int)
|
||||
failure_hard_limit: 3
|
||||
aibridge:
|
||||
# Whether to start an in-memory aibridged instance.
|
||||
# Whether to start an in-memory aibridged instance ("aibridge" experiment must be
|
||||
# enabled, too).
|
||||
# (default: false, type: bool)
|
||||
enabled: false
|
||||
# The base URL of the OpenAI API.
|
||||
@@ -725,25 +726,7 @@ aibridge:
|
||||
openai_key: ""
|
||||
# The base URL of the Anthropic API.
|
||||
# (default: https://api.anthropic.com/, type: string)
|
||||
anthropic_base_url: https://api.anthropic.com/
|
||||
base_url: https://api.anthropic.com/
|
||||
# The key to authenticate against the Anthropic API.
|
||||
# (default: <unset>, type: string)
|
||||
anthropic_key: ""
|
||||
# The AWS Bedrock API region.
|
||||
# (default: <unset>, type: string)
|
||||
bedrock_region: ""
|
||||
# The access key to authenticate against the AWS Bedrock API.
|
||||
# (default: <unset>, type: string)
|
||||
bedrock_access_key: ""
|
||||
# The access key secret to use with the access key to authenticate against the AWS
|
||||
# Bedrock API.
|
||||
# (default: <unset>, type: string)
|
||||
bedrock_access_key_secret: ""
|
||||
# The model to use when making requests to the AWS Bedrock API.
|
||||
# (default: global.anthropic.claude-sonnet-4-5-20250929-v1:0, type: string)
|
||||
bedrock_model: global.anthropic.claude-sonnet-4-5-20250929-v1:0
|
||||
# The small fast model to use when making requests to the AWS Bedrock API. Claude
|
||||
# Code uses Haiku-class models to perform background tasks. See
|
||||
# https://docs.claude.com/en/docs/claude-code/settings#environment-variables.
|
||||
# (default: global.anthropic.claude-haiku-4-5-20251001-v1:0, type: string)
|
||||
bedrock_small_fast_model: global.anthropic.claude-haiku-4-5-20251001-v1:0
|
||||
key: ""
|
||||
|
||||
+6
-104
@@ -4,14 +4,12 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
@@ -29,10 +27,6 @@ func (r *RootCmd) tokens() *serpent.Command {
|
||||
Description: "List your tokens",
|
||||
Command: "coder tokens ls",
|
||||
},
|
||||
Example{
|
||||
Description: "Create a scoped token",
|
||||
Command: "coder tokens create --scope workspace:read --allow workspace:<uuid>",
|
||||
},
|
||||
Example{
|
||||
Description: "Remove a token by ID",
|
||||
Command: "coder tokens rm WuoWs4ZsMX",
|
||||
@@ -45,7 +39,6 @@ func (r *RootCmd) tokens() *serpent.Command {
|
||||
Children: []*serpent.Command{
|
||||
r.createToken(),
|
||||
r.listTokens(),
|
||||
r.viewToken(),
|
||||
r.removeToken(),
|
||||
},
|
||||
}
|
||||
@@ -57,8 +50,6 @@ func (r *RootCmd) createToken() *serpent.Command {
|
||||
tokenLifetime string
|
||||
name string
|
||||
user string
|
||||
scopes []string
|
||||
allowList []codersdk.APIAllowListTarget
|
||||
)
|
||||
cmd := &serpent.Command{
|
||||
Use: "create",
|
||||
@@ -97,18 +88,10 @@ func (r *RootCmd) createToken() *serpent.Command {
|
||||
}
|
||||
}
|
||||
|
||||
req := codersdk.CreateTokenRequest{
|
||||
res, err := client.CreateToken(inv.Context(), userID, codersdk.CreateTokenRequest{
|
||||
Lifetime: parsedLifetime,
|
||||
TokenName: name,
|
||||
}
|
||||
if len(req.Scopes) == 0 {
|
||||
req.Scopes = slice.StringEnums[codersdk.APIKeyScope](scopes)
|
||||
}
|
||||
if len(allowList) > 0 {
|
||||
req.AllowList = append([]codersdk.APIAllowListTarget(nil), allowList...)
|
||||
}
|
||||
|
||||
res, err := client.CreateToken(inv.Context(), userID, req)
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create tokens: %w", err)
|
||||
}
|
||||
@@ -123,7 +106,7 @@ func (r *RootCmd) createToken() *serpent.Command {
|
||||
{
|
||||
Flag: "lifetime",
|
||||
Env: "CODER_TOKEN_LIFETIME",
|
||||
Description: "Duration for the token lifetime. Supports standard Go duration units (ns, us, ms, s, m, h) plus d (days) and y (years). Examples: 8h, 30d, 1y, 1d12h30m.",
|
||||
Description: "Specify a duration for the lifetime of the token.",
|
||||
Value: serpent.StringOf(&tokenLifetime),
|
||||
},
|
||||
{
|
||||
@@ -140,16 +123,6 @@ func (r *RootCmd) createToken() *serpent.Command {
|
||||
Description: "Specify the user to create the token for (Only works if logged in user is admin).",
|
||||
Value: serpent.StringOf(&user),
|
||||
},
|
||||
{
|
||||
Flag: "scope",
|
||||
Description: "Repeatable scope to attach to the token (e.g. workspace:read).",
|
||||
Value: serpent.StringArrayOf(&scopes),
|
||||
},
|
||||
{
|
||||
Flag: "allow",
|
||||
Description: "Repeatable allow-list entry (<type>:<uuid>, e.g. workspace:1234-...).",
|
||||
Value: AllowListFlagOf(&allowList),
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
@@ -163,8 +136,6 @@ type tokenListRow struct {
|
||||
// For table format:
|
||||
ID string `json:"-" table:"id,default_sort"`
|
||||
TokenName string `json:"token_name" table:"name"`
|
||||
Scopes string `json:"-" table:"scopes"`
|
||||
Allow string `json:"-" table:"allow list"`
|
||||
LastUsed time.Time `json:"-" table:"last used"`
|
||||
ExpiresAt time.Time `json:"-" table:"expires at"`
|
||||
CreatedAt time.Time `json:"-" table:"created at"`
|
||||
@@ -172,47 +143,20 @@ type tokenListRow struct {
|
||||
}
|
||||
|
||||
func tokenListRowFromToken(token codersdk.APIKeyWithOwner) tokenListRow {
|
||||
return tokenListRowFromKey(token.APIKey, token.Username)
|
||||
}
|
||||
|
||||
func tokenListRowFromKey(token codersdk.APIKey, owner string) tokenListRow {
|
||||
return tokenListRow{
|
||||
APIKey: token,
|
||||
APIKey: token.APIKey,
|
||||
ID: token.ID,
|
||||
TokenName: token.TokenName,
|
||||
Scopes: joinScopes(token.Scopes),
|
||||
Allow: joinAllowList(token.AllowList),
|
||||
LastUsed: token.LastUsed,
|
||||
ExpiresAt: token.ExpiresAt,
|
||||
CreatedAt: token.CreatedAt,
|
||||
Owner: owner,
|
||||
Owner: token.Username,
|
||||
}
|
||||
}
|
||||
|
||||
func joinScopes(scopes []codersdk.APIKeyScope) string {
|
||||
if len(scopes) == 0 {
|
||||
return ""
|
||||
}
|
||||
vals := slice.ToStrings(scopes)
|
||||
sort.Strings(vals)
|
||||
return strings.Join(vals, ", ")
|
||||
}
|
||||
|
||||
func joinAllowList(entries []codersdk.APIAllowListTarget) string {
|
||||
if len(entries) == 0 {
|
||||
return ""
|
||||
}
|
||||
vals := make([]string, len(entries))
|
||||
for i, entry := range entries {
|
||||
vals[i] = entry.String()
|
||||
}
|
||||
sort.Strings(vals)
|
||||
return strings.Join(vals, ", ")
|
||||
}
|
||||
|
||||
func (r *RootCmd) listTokens() *serpent.Command {
|
||||
// we only display the 'owner' column if the --all argument is passed in
|
||||
defaultCols := []string{"id", "name", "scopes", "allow list", "last used", "expires at", "created at"}
|
||||
defaultCols := []string{"id", "name", "last used", "expires at", "created at"}
|
||||
if slices.Contains(os.Args, "-a") || slices.Contains(os.Args, "--all") {
|
||||
defaultCols = append(defaultCols, "owner")
|
||||
}
|
||||
@@ -282,48 +226,6 @@ func (r *RootCmd) listTokens() *serpent.Command {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (r *RootCmd) viewToken() *serpent.Command {
|
||||
formatter := cliui.NewOutputFormatter(
|
||||
cliui.TableFormat([]tokenListRow{}, []string{"id", "name", "scopes", "allow list", "last used", "expires at", "created at", "owner"}),
|
||||
cliui.JSONFormat(),
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "view <name|id>",
|
||||
Short: "Display detailed information about a token",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(1),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tokenName := inv.Args[0]
|
||||
token, err := client.APIKeyByName(inv.Context(), codersdk.Me, tokenName)
|
||||
if err != nil {
|
||||
maybeID := strings.Split(tokenName, "-")[0]
|
||||
token, err = client.APIKeyByID(inv.Context(), codersdk.Me, maybeID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("fetch api key by name or id: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
row := tokenListRowFromKey(*token, "")
|
||||
out, err := formatter.Format(inv.Context(), []tokenListRow{row})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintln(inv.Stdout, out)
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
formatter.AttachOptions(&cmd.Options)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (r *RootCmd) removeToken() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "remove <name|id|token>",
|
||||
|
||||
+3
-56
@@ -4,13 +4,10 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -49,18 +46,6 @@ func TestTokens(t *testing.T) {
|
||||
require.NotEmpty(t, res)
|
||||
id := res[:10]
|
||||
|
||||
allowWorkspaceID := uuid.New()
|
||||
allowSpec := fmt.Sprintf("workspace:%s", allowWorkspaceID.String())
|
||||
inv, root = clitest.New(t, "tokens", "create", "--name", "scoped-token", "--scope", string(codersdk.APIKeyScopeWorkspaceRead), "--allow", allowSpec)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
res = buf.String()
|
||||
require.NotEmpty(t, res)
|
||||
scopedTokenID := res[:10]
|
||||
|
||||
// Test creating a token for second user from first user's (admin) session
|
||||
inv, root = clitest.New(t, "tokens", "create", "--name", "token-two", "--user", secondUser.ID.String())
|
||||
clitest.SetupConfig(t, client, root)
|
||||
@@ -82,7 +67,7 @@ func TestTokens(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
res = buf.String()
|
||||
require.NotEmpty(t, res)
|
||||
// Result should only contain the tokens created for the admin user
|
||||
// Result should only contain the token created for the admin user
|
||||
require.Contains(t, res, "ID")
|
||||
require.Contains(t, res, "EXPIRES AT")
|
||||
require.Contains(t, res, "CREATED AT")
|
||||
@@ -91,16 +76,6 @@ func TestTokens(t *testing.T) {
|
||||
// Result should not contain the token created for the second user
|
||||
require.NotContains(t, res, secondTokenID)
|
||||
|
||||
inv, root = clitest.New(t, "tokens", "view", "scoped-token")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
res = buf.String()
|
||||
require.Contains(t, res, string(codersdk.APIKeyScopeWorkspaceRead))
|
||||
require.Contains(t, res, allowSpec)
|
||||
|
||||
// Test listing tokens from the second user's session
|
||||
inv, root = clitest.New(t, "tokens", "ls")
|
||||
clitest.SetupConfig(t, secondUserClient, root)
|
||||
@@ -126,14 +101,6 @@ func TestTokens(t *testing.T) {
|
||||
// User (non-admin) should not be able to create a token for another user
|
||||
require.Error(t, err)
|
||||
|
||||
inv, root = clitest.New(t, "tokens", "create", "--name", "invalid-allow", "--allow", "badvalue")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid allow_list entry")
|
||||
|
||||
inv, root = clitest.New(t, "tokens", "ls", "--output=json")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
@@ -143,17 +110,8 @@ func TestTokens(t *testing.T) {
|
||||
|
||||
var tokens []codersdk.APIKey
|
||||
require.NoError(t, json.Unmarshal(buf.Bytes(), &tokens))
|
||||
require.Len(t, tokens, 2)
|
||||
tokenByName := make(map[string]codersdk.APIKey, len(tokens))
|
||||
for _, tk := range tokens {
|
||||
tokenByName[tk.TokenName] = tk
|
||||
}
|
||||
require.Contains(t, tokenByName, "token-one")
|
||||
require.Contains(t, tokenByName, "scoped-token")
|
||||
scopedToken := tokenByName["scoped-token"]
|
||||
require.Contains(t, scopedToken.Scopes, codersdk.APIKeyScopeWorkspaceRead)
|
||||
require.Len(t, scopedToken.AllowList, 1)
|
||||
require.Equal(t, allowSpec, scopedToken.AllowList[0].String())
|
||||
require.Len(t, tokens, 1)
|
||||
require.Equal(t, id, tokens[0].ID)
|
||||
|
||||
// Delete by name
|
||||
inv, root = clitest.New(t, "tokens", "rm", "token-one")
|
||||
@@ -177,17 +135,6 @@ func TestTokens(t *testing.T) {
|
||||
require.NotEmpty(t, res)
|
||||
require.Contains(t, res, "deleted")
|
||||
|
||||
// Delete scoped token by ID
|
||||
inv, root = clitest.New(t, "tokens", "rm", scopedTokenID)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
res = buf.String()
|
||||
require.NotEmpty(t, res)
|
||||
require.Contains(t, res, "deleted")
|
||||
|
||||
// Create third token
|
||||
inv, root = clitest.New(t, "tokens", "create", "--name", "token-three")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
@@ -239,10 +239,6 @@ func (a *API) Serve(ctx context.Context, l net.Listener) error {
|
||||
return xerrors.Errorf("create agent API server: %w", err)
|
||||
}
|
||||
|
||||
if err := a.ResourcesMonitoringAPI.InitMonitors(ctx); err != nil {
|
||||
return xerrors.Errorf("initialize resource monitoring: %w", err)
|
||||
}
|
||||
|
||||
return server.Serve(ctx, l)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
@@ -34,60 +33,42 @@ type ResourcesMonitoringAPI struct {
|
||||
|
||||
Debounce time.Duration
|
||||
Config resourcesmonitor.Config
|
||||
|
||||
// Cache resource monitors on first call to avoid millions of DB queries per day.
|
||||
memoryMonitor database.WorkspaceAgentMemoryResourceMonitor
|
||||
volumeMonitors []database.WorkspaceAgentVolumeResourceMonitor
|
||||
monitorsLock sync.RWMutex
|
||||
}
|
||||
|
||||
// InitMonitors fetches resource monitors from the database and caches them.
|
||||
// This must be called once after creating a ResourcesMonitoringAPI, the context should be
|
||||
// the agent per-RPC connection context. If fetching fails with a real error (not sql.ErrNoRows), the
|
||||
// connection should be torn down.
|
||||
func (a *ResourcesMonitoringAPI) InitMonitors(ctx context.Context) error {
|
||||
memMon, err := a.Database.FetchMemoryResourceMonitorsByAgentID(ctx, a.AgentID)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return xerrors.Errorf("fetch memory resource monitor: %w", err)
|
||||
}
|
||||
// If sql.ErrNoRows, memoryMonitor stays as zero value (CreatedAt.IsZero() = true).
|
||||
// Otherwise, store the fetched monitor.
|
||||
if err == nil {
|
||||
a.memoryMonitor = memMon
|
||||
func (a *ResourcesMonitoringAPI) GetResourcesMonitoringConfiguration(ctx context.Context, _ *proto.GetResourcesMonitoringConfigurationRequest) (*proto.GetResourcesMonitoringConfigurationResponse, error) {
|
||||
memoryMonitor, memoryErr := a.Database.FetchMemoryResourceMonitorsByAgentID(ctx, a.AgentID)
|
||||
if memoryErr != nil && !errors.Is(memoryErr, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("failed to fetch memory resource monitor: %w", memoryErr)
|
||||
}
|
||||
|
||||
volMons, err := a.Database.FetchVolumesResourceMonitorsByAgentID(ctx, a.AgentID)
|
||||
volumeMonitors, err := a.Database.FetchVolumesResourceMonitorsByAgentID(ctx, a.AgentID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("fetch volume resource monitors: %w", err)
|
||||
return nil, xerrors.Errorf("failed to fetch volume resource monitors: %w", err)
|
||||
}
|
||||
// 0 length is valid, indicating none configured, since the volume monitors in the DB can be many.
|
||||
a.volumeMonitors = volMons
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *ResourcesMonitoringAPI) GetResourcesMonitoringConfiguration(_ context.Context, _ *proto.GetResourcesMonitoringConfigurationRequest) (*proto.GetResourcesMonitoringConfigurationResponse, error) {
|
||||
return &proto.GetResourcesMonitoringConfigurationResponse{
|
||||
Config: &proto.GetResourcesMonitoringConfigurationResponse_Config{
|
||||
CollectionIntervalSeconds: int32(a.Config.CollectionInterval.Seconds()),
|
||||
NumDatapoints: a.Config.NumDatapoints,
|
||||
},
|
||||
Memory: func() *proto.GetResourcesMonitoringConfigurationResponse_Memory {
|
||||
if a.memoryMonitor.CreatedAt.IsZero() {
|
||||
if memoryErr != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &proto.GetResourcesMonitoringConfigurationResponse_Memory{
|
||||
Enabled: a.memoryMonitor.Enabled,
|
||||
Enabled: memoryMonitor.Enabled,
|
||||
}
|
||||
}(),
|
||||
Volumes: func() []*proto.GetResourcesMonitoringConfigurationResponse_Volume {
|
||||
volumes := make([]*proto.GetResourcesMonitoringConfigurationResponse_Volume, 0, len(a.volumeMonitors))
|
||||
for _, monitor := range a.volumeMonitors {
|
||||
volumes := make([]*proto.GetResourcesMonitoringConfigurationResponse_Volume, 0, len(volumeMonitors))
|
||||
for _, monitor := range volumeMonitors {
|
||||
volumes = append(volumes, &proto.GetResourcesMonitoringConfigurationResponse_Volume{
|
||||
Enabled: monitor.Enabled,
|
||||
Path: monitor.Path,
|
||||
})
|
||||
}
|
||||
|
||||
return volumes
|
||||
}(),
|
||||
}, nil
|
||||
@@ -96,10 +77,6 @@ func (a *ResourcesMonitoringAPI) GetResourcesMonitoringConfiguration(_ context.C
|
||||
func (a *ResourcesMonitoringAPI) PushResourcesMonitoringUsage(ctx context.Context, req *proto.PushResourcesMonitoringUsageRequest) (*proto.PushResourcesMonitoringUsageResponse, error) {
|
||||
var err error
|
||||
|
||||
// Lock for the entire push operation since calls are sequential from the agent
|
||||
a.monitorsLock.Lock()
|
||||
defer a.monitorsLock.Unlock()
|
||||
|
||||
if memoryErr := a.monitorMemory(ctx, req.Datapoints); memoryErr != nil {
|
||||
err = errors.Join(err, xerrors.Errorf("monitor memory: %w", memoryErr))
|
||||
}
|
||||
@@ -112,7 +89,18 @@ func (a *ResourcesMonitoringAPI) PushResourcesMonitoringUsage(ctx context.Contex
|
||||
}
|
||||
|
||||
func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints []*proto.PushResourcesMonitoringUsageRequest_Datapoint) error {
|
||||
if !a.memoryMonitor.Enabled {
|
||||
monitor, err := a.Database.FetchMemoryResourceMonitorsByAgentID(ctx, a.AgentID)
|
||||
if err != nil {
|
||||
// It is valid for an agent to not have a memory monitor, so we
|
||||
// do not want to treat it as an error.
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return xerrors.Errorf("fetch memory resource monitor: %w", err)
|
||||
}
|
||||
|
||||
if !monitor.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -121,15 +109,15 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
|
||||
usageDatapoints = append(usageDatapoints, datapoint.Memory)
|
||||
}
|
||||
|
||||
usageStates := resourcesmonitor.CalculateMemoryUsageStates(a.memoryMonitor, usageDatapoints)
|
||||
usageStates := resourcesmonitor.CalculateMemoryUsageStates(monitor, usageDatapoints)
|
||||
|
||||
oldState := a.memoryMonitor.State
|
||||
oldState := monitor.State
|
||||
newState := resourcesmonitor.NextState(a.Config, oldState, usageStates)
|
||||
|
||||
debouncedUntil, shouldNotify := a.memoryMonitor.Debounce(a.Debounce, a.Clock.Now(), oldState, newState)
|
||||
debouncedUntil, shouldNotify := monitor.Debounce(a.Debounce, a.Clock.Now(), oldState, newState)
|
||||
|
||||
//nolint:gocritic // We need to be able to update the resource monitor here.
|
||||
err := a.Database.UpdateMemoryResourceMonitor(dbauthz.AsResourceMonitor(ctx), database.UpdateMemoryResourceMonitorParams{
|
||||
err = a.Database.UpdateMemoryResourceMonitor(dbauthz.AsResourceMonitor(ctx), database.UpdateMemoryResourceMonitorParams{
|
||||
AgentID: a.AgentID,
|
||||
State: newState,
|
||||
UpdatedAt: dbtime.Time(a.Clock.Now()),
|
||||
@@ -139,11 +127,6 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
|
||||
return xerrors.Errorf("update workspace monitor: %w", err)
|
||||
}
|
||||
|
||||
// Update cached state
|
||||
a.memoryMonitor.State = newState
|
||||
a.memoryMonitor.DebouncedUntil = dbtime.Time(debouncedUntil)
|
||||
a.memoryMonitor.UpdatedAt = dbtime.Time(a.Clock.Now())
|
||||
|
||||
if !shouldNotify {
|
||||
return nil
|
||||
}
|
||||
@@ -160,7 +143,7 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
|
||||
notifications.TemplateWorkspaceOutOfMemory,
|
||||
map[string]string{
|
||||
"workspace": workspace.Name,
|
||||
"threshold": fmt.Sprintf("%d%%", a.memoryMonitor.Threshold),
|
||||
"threshold": fmt.Sprintf("%d%%", monitor.Threshold),
|
||||
},
|
||||
map[string]any{
|
||||
// NOTE(DanielleMaywood):
|
||||
@@ -186,9 +169,14 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
|
||||
}
|
||||
|
||||
func (a *ResourcesMonitoringAPI) monitorVolumes(ctx context.Context, datapoints []*proto.PushResourcesMonitoringUsageRequest_Datapoint) error {
|
||||
volumeMonitors, err := a.Database.FetchVolumesResourceMonitorsByAgentID(ctx, a.AgentID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get or insert volume monitor: %w", err)
|
||||
}
|
||||
|
||||
outOfDiskVolumes := make([]map[string]any, 0)
|
||||
|
||||
for i, monitor := range a.volumeMonitors {
|
||||
for _, monitor := range volumeMonitors {
|
||||
if !monitor.Enabled {
|
||||
continue
|
||||
}
|
||||
@@ -231,11 +219,6 @@ func (a *ResourcesMonitoringAPI) monitorVolumes(ctx context.Context, datapoints
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update workspace monitor: %w", err)
|
||||
}
|
||||
|
||||
// Update cached state
|
||||
a.volumeMonitors[i].State = newState
|
||||
a.volumeMonitors[i].DebouncedUntil = dbtime.Time(debouncedUntil)
|
||||
a.volumeMonitors[i].UpdatedAt = dbtime.Time(a.Clock.Now())
|
||||
}
|
||||
|
||||
if len(outOfDiskVolumes) == 0 {
|
||||
|
||||
@@ -101,9 +101,6 @@ func TestMemoryResourceMonitorDebounce(t *testing.T) {
|
||||
Threshold: 80,
|
||||
})
|
||||
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
// When: The monitor is given a state that will trigger NOK
|
||||
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
|
||||
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
|
||||
@@ -307,9 +304,6 @@ func TestMemoryResourceMonitor(t *testing.T) {
|
||||
Threshold: 80,
|
||||
})
|
||||
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
clock.Set(collectedAt)
|
||||
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
|
||||
Datapoints: datapoints,
|
||||
@@ -343,8 +337,6 @@ func TestMemoryResourceMonitorMissingData(t *testing.T) {
|
||||
State: database.WorkspaceAgentMonitorStateOK,
|
||||
Threshold: 80,
|
||||
})
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
// When: A datapoint is missing, surrounded by two NOK datapoints.
|
||||
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
|
||||
@@ -395,9 +387,6 @@ func TestMemoryResourceMonitorMissingData(t *testing.T) {
|
||||
Threshold: 80,
|
||||
})
|
||||
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
// When: A datapoint is missing, surrounded by two OK datapoints.
|
||||
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
|
||||
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
|
||||
@@ -477,9 +466,6 @@ func TestVolumeResourceMonitorDebounce(t *testing.T) {
|
||||
Threshold: 80,
|
||||
})
|
||||
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
// When:
|
||||
// - First monitor is in a NOK state
|
||||
// - Second monitor is in an OK state
|
||||
@@ -756,9 +742,6 @@ func TestVolumeResourceMonitor(t *testing.T) {
|
||||
Threshold: tt.thresholdPercent,
|
||||
})
|
||||
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
clock.Set(collectedAt)
|
||||
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
|
||||
Datapoints: datapoints,
|
||||
@@ -797,9 +780,6 @@ func TestVolumeResourceMonitorMultiple(t *testing.T) {
|
||||
Threshold: 80,
|
||||
})
|
||||
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
// When: both of them move to a NOK state
|
||||
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
|
||||
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
|
||||
@@ -852,9 +832,6 @@ func TestVolumeResourceMonitorMissingData(t *testing.T) {
|
||||
Threshold: 80,
|
||||
})
|
||||
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
// When: A datapoint is missing, surrounded by two NOK datapoints.
|
||||
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
|
||||
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
|
||||
@@ -914,9 +891,6 @@ func TestVolumeResourceMonitorMissingData(t *testing.T) {
|
||||
Threshold: 80,
|
||||
})
|
||||
|
||||
// Initialize API to fetch and cache the monitors
|
||||
require.NoError(t, api.InitMonitors(context.Background()))
|
||||
|
||||
// When: A datapoint is missing, surrounded by two OK datapoints.
|
||||
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
|
||||
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
|
||||
|
||||
+57
-25
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -23,12 +24,62 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/searchquery"
|
||||
"github.com/coder/coder/v2/coderd/taskname"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
|
||||
aiagentapi "github.com/coder/agentapi-sdk-go"
|
||||
)
|
||||
|
||||
// This endpoint is experimental and not guaranteed to be stable, so we're not
|
||||
// generating public-facing documentation for it.
|
||||
func (api *API) aiTasksPrompts(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
buildIDsParam := r.URL.Query().Get("build_ids")
|
||||
if buildIDsParam == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "build_ids query parameter is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse build IDs
|
||||
buildIDStrings := strings.Split(buildIDsParam, ",")
|
||||
buildIDs := make([]uuid.UUID, 0, len(buildIDStrings))
|
||||
for _, idStr := range buildIDStrings {
|
||||
id, err := uuid.Parse(strings.TrimSpace(idStr))
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Invalid build ID format: %s", idStr),
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
buildIDs = append(buildIDs, id)
|
||||
}
|
||||
|
||||
parameters, err := api.Database.GetWorkspaceBuildParametersByBuildIDs(ctx, buildIDs)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspace build parameters.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
promptsByBuildID := make(map[string]string, len(parameters))
|
||||
for _, param := range parameters {
|
||||
if param.Name != codersdk.AITaskPromptParameterName {
|
||||
continue
|
||||
}
|
||||
buildID := param.WorkspaceBuildID.String()
|
||||
promptsByBuildID[buildID] = param.Value
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AITasksPromptsResponse{
|
||||
Prompts: promptsByBuildID,
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Create a new AI task
|
||||
// @Description: EXPERIMENTAL: this endpoint is experimental and not guaranteed to be stable.
|
||||
// @ID create-task
|
||||
@@ -123,31 +174,13 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the template defines the AI Prompt parameter.
|
||||
templateParams, err := api.Database.GetTemplateVersionParameters(ctx, req.TemplateVersionID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching template parameters.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var richParams []codersdk.WorkspaceBuildParameter
|
||||
if _, hasAIPromptParam := slice.Find(templateParams, func(param database.TemplateVersionParameter) bool {
|
||||
return param.Name == codersdk.AITaskPromptParameterName
|
||||
}); hasAIPromptParam {
|
||||
// Only add the AI Prompt parameter if the template defines it.
|
||||
richParams = []codersdk.WorkspaceBuildParameter{
|
||||
{Name: codersdk.AITaskPromptParameterName, Value: req.Input},
|
||||
}
|
||||
}
|
||||
|
||||
createReq := codersdk.CreateWorkspaceRequest{
|
||||
Name: taskName,
|
||||
TemplateVersionID: req.TemplateVersionID,
|
||||
TemplateVersionPresetID: req.TemplateVersionPresetID,
|
||||
RichParameterValues: richParams,
|
||||
RichParameterValues: []codersdk.WorkspaceBuildParameter{
|
||||
{Name: codersdk.AITaskPromptParameterName, Value: req.Input},
|
||||
},
|
||||
}
|
||||
|
||||
var owner workspaceOwner
|
||||
@@ -208,7 +241,6 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
|
||||
// Create task record in the database before creating the workspace so that
|
||||
// we can request that the workspace be linked to it after creation.
|
||||
dbTaskTable, err = tx.InsertTask(ctx, database.InsertTaskParams{
|
||||
ID: uuid.New(),
|
||||
OrganizationID: templateVersion.OrganizationID,
|
||||
OwnerID: owner.ID,
|
||||
Name: taskName,
|
||||
@@ -306,8 +338,8 @@ func taskFromDBTaskAndWorkspace(dbTask database.Task, ws codersdk.Workspace) cod
|
||||
ID: dbTask.ID,
|
||||
OrganizationID: dbTask.OrganizationID,
|
||||
OwnerID: dbTask.OwnerID,
|
||||
OwnerName: dbTask.OwnerUsername,
|
||||
OwnerAvatarURL: dbTask.OwnerAvatarUrl,
|
||||
OwnerName: ws.OwnerName,
|
||||
OwnerAvatarURL: ws.OwnerAvatarURL,
|
||||
Name: dbTask.Name,
|
||||
TemplateID: ws.TemplateID,
|
||||
TemplateVersionID: dbTask.TemplateVersionID,
|
||||
|
||||
+139
-90
@@ -1,7 +1,6 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"io"
|
||||
@@ -35,6 +34,128 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestAITasksPrompts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("EmptyBuildIDs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
experimentalClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
// Test with empty build IDs
|
||||
prompts, err := experimentalClient.AITaskPrompts(ctx, []uuid.UUID{})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, prompts.Prompts)
|
||||
})
|
||||
|
||||
t.Run("MultipleBuilds", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
adminClient := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
first := coderdtest.CreateFirstUser(t, adminClient)
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, first.OrganizationID)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Create a template with parameters
|
||||
version := coderdtest.CreateTemplateVersion(t, adminClient, first.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionPlan: []*proto.Response{{
|
||||
Type: &proto.Response_Plan{
|
||||
Plan: &proto.PlanComplete{
|
||||
Parameters: []*proto.RichParameter{
|
||||
{
|
||||
Name: "param1",
|
||||
Type: "string",
|
||||
DefaultValue: "default1",
|
||||
},
|
||||
{
|
||||
Name: codersdk.AITaskPromptParameterName,
|
||||
Type: "string",
|
||||
DefaultValue: "default2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}},
|
||||
ProvisionApply: echo.ApplyComplete,
|
||||
})
|
||||
template := coderdtest.CreateTemplate(t, adminClient, first.OrganizationID, version.ID)
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, adminClient, version.ID)
|
||||
|
||||
// Create two workspaces with different parameters
|
||||
workspace1 := coderdtest.CreateWorkspace(t, memberClient, template.ID, func(request *codersdk.CreateWorkspaceRequest) {
|
||||
request.RichParameterValues = []codersdk.WorkspaceBuildParameter{
|
||||
{Name: "param1", Value: "value1a"},
|
||||
{Name: codersdk.AITaskPromptParameterName, Value: "value2a"},
|
||||
}
|
||||
})
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, memberClient, workspace1.LatestBuild.ID)
|
||||
|
||||
workspace2 := coderdtest.CreateWorkspace(t, memberClient, template.ID, func(request *codersdk.CreateWorkspaceRequest) {
|
||||
request.RichParameterValues = []codersdk.WorkspaceBuildParameter{
|
||||
{Name: "param1", Value: "value1b"},
|
||||
{Name: codersdk.AITaskPromptParameterName, Value: "value2b"},
|
||||
}
|
||||
})
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, memberClient, workspace2.LatestBuild.ID)
|
||||
|
||||
workspace3 := coderdtest.CreateWorkspace(t, adminClient, template.ID, func(request *codersdk.CreateWorkspaceRequest) {
|
||||
request.RichParameterValues = []codersdk.WorkspaceBuildParameter{
|
||||
{Name: "param1", Value: "value1c"},
|
||||
{Name: codersdk.AITaskPromptParameterName, Value: "value2c"},
|
||||
}
|
||||
})
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, adminClient, workspace3.LatestBuild.ID)
|
||||
allBuildIDs := []uuid.UUID{workspace1.LatestBuild.ID, workspace2.LatestBuild.ID, workspace3.LatestBuild.ID}
|
||||
|
||||
experimentalMemberClient := codersdk.NewExperimentalClient(memberClient)
|
||||
// Test parameters endpoint as member
|
||||
prompts, err := experimentalMemberClient.AITaskPrompts(ctx, allBuildIDs)
|
||||
require.NoError(t, err)
|
||||
// we expect 2 prompts because the member client does not have access to workspace3
|
||||
// since it was created by the admin client
|
||||
require.Len(t, prompts.Prompts, 2)
|
||||
|
||||
// Check workspace1 parameters
|
||||
build1Prompt := prompts.Prompts[workspace1.LatestBuild.ID.String()]
|
||||
require.Equal(t, "value2a", build1Prompt)
|
||||
|
||||
// Check workspace2 parameters
|
||||
build2Prompt := prompts.Prompts[workspace2.LatestBuild.ID.String()]
|
||||
require.Equal(t, "value2b", build2Prompt)
|
||||
|
||||
experimentalAdminClient := codersdk.NewExperimentalClient(adminClient)
|
||||
// Test parameters endpoint as admin
|
||||
// we expect 3 prompts because the admin client has access to all workspaces
|
||||
prompts, err = experimentalAdminClient.AITaskPrompts(ctx, allBuildIDs)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, prompts.Prompts, 3)
|
||||
|
||||
// Check workspace3 parameters
|
||||
build3Prompt := prompts.Prompts[workspace3.LatestBuild.ID.String()]
|
||||
require.Equal(t, "value2c", build3Prompt)
|
||||
})
|
||||
|
||||
t.Run("NonExistentBuildIDs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
// Test with non-existent build IDs
|
||||
nonExistentID := uuid.New()
|
||||
experimentalClient := codersdk.NewExperimentalClient(client)
|
||||
prompts, err := experimentalClient.AITaskPrompts(ctx, []uuid.UUID{nonExistentID})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, prompts.Prompts)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTasks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -66,6 +187,7 @@ func TestTasks(t *testing.T) {
|
||||
{
|
||||
Type: &proto.Response_Plan{
|
||||
Plan: &proto.PlanComplete{
|
||||
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
|
||||
HasAiTasks: true,
|
||||
},
|
||||
},
|
||||
@@ -136,9 +258,6 @@ func TestTasks(t *testing.T) {
|
||||
// Wait for the workspace to be built.
|
||||
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
if assert.True(t, workspace.TaskID.Valid, "task id should be set on workspace") {
|
||||
assert.Equal(t, task.ID, workspace.TaskID.UUID, "workspace task id should match")
|
||||
}
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
// List tasks via experimental API and verify the prompt and status mapping.
|
||||
@@ -177,9 +296,6 @@ func TestTasks(t *testing.T) {
|
||||
// Get the workspace and wait for it to be ready.
|
||||
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
if assert.True(t, ws.TaskID.Valid, "task id should be set on workspace") {
|
||||
assert.Equal(t, task.ID, ws.TaskID.UUID, "workspace task id should match")
|
||||
}
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
ws = coderdtest.MustWorkspace(t, client, task.WorkspaceID.UUID)
|
||||
// Assert invariant: the workspace has exactly one resource with one agent with one app.
|
||||
@@ -254,23 +370,24 @@ func TestTasks(t *testing.T) {
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
if assert.True(t, ws.TaskID.Valid, "task id should be set on workspace") {
|
||||
assert.Equal(t, task.ID, ws.TaskID.UUID, "workspace task id should match")
|
||||
}
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
|
||||
err = exp.DeleteTask(ctx, "me", task.ID)
|
||||
require.NoError(t, err, "delete task request should be accepted")
|
||||
|
||||
// Poll until the workspace is deleted.
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) (done bool) {
|
||||
for {
|
||||
dws, derr := client.DeletedWorkspace(ctx, task.WorkspaceID.UUID)
|
||||
if !assert.NoError(t, derr, "expected to fetch deleted workspace before deadline") {
|
||||
return false
|
||||
if derr == nil && dws.LatestBuild.Status == codersdk.WorkspaceStatusDeleted {
|
||||
break
|
||||
}
|
||||
t.Logf("workspace latest_build status: %q", dws.LatestBuild.Status)
|
||||
return dws.LatestBuild.Status == codersdk.WorkspaceStatusDeleted
|
||||
}, testutil.IntervalMedium, "workspace should be deleted before deadline")
|
||||
if ctx.Err() != nil {
|
||||
require.NoError(t, derr, "expected to fetch deleted workspace before deadline")
|
||||
require.Equal(t, codersdk.WorkspaceStatusDeleted, dws.LatestBuild.Status, "workspace should be deleted before deadline")
|
||||
break
|
||||
}
|
||||
time.Sleep(testutil.IntervalMedium)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
@@ -303,9 +420,6 @@ func TestTasks(t *testing.T) {
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
ws := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
if assert.False(t, ws.TaskID.Valid, "task id should not be set on non-task workspace") {
|
||||
assert.Zero(t, ws.TaskID, "non-task workspace task id should be empty")
|
||||
}
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
@@ -354,32 +468,6 @@ func TestTasks(t *testing.T) {
|
||||
t.Fatalf("unexpected status code: %d (expected 403 or 404)", authErr.StatusCode())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoWorkspace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
template := createAITemplate(t, client, user)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
task, err := exp.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
|
||||
TemplateVersionID: template.ActiveVersionID,
|
||||
Input: "delete me",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
// Delete the task workspace
|
||||
coderdtest.MustTransitionWorkspace(t, client, ws.ID, codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionDelete)
|
||||
// We should still be able to fetch the task after deleting its workspace
|
||||
task, err = exp.TaskByID(ctx, task.ID)
|
||||
require.NoError(t, err, "fetching a task should still work after deleting its related workspace")
|
||||
err = exp.DeleteTask(ctx, task.OwnerID.String(), task.ID)
|
||||
require.NoError(t, err, "should be possible to delete a task with no workspace")
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Send", func(t *testing.T) {
|
||||
@@ -694,51 +782,6 @@ func TestTasksCreate(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionApply: echo.ApplyComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
|
||||
HasAiTasks: true,
|
||||
}}},
|
||||
},
|
||||
})
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
task, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
|
||||
TemplateVersionID: template.ActiveVersionID,
|
||||
Input: taskPrompt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, task.WorkspaceID.Valid)
|
||||
|
||||
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
|
||||
assert.NotEmpty(t, task.Name)
|
||||
assert.Equal(t, template.ID, task.TemplateID)
|
||||
|
||||
parameters, err := client.WorkspaceBuildParameters(ctx, ws.LatestBuild.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, parameters, 0)
|
||||
})
|
||||
|
||||
t.Run("OK AIPromptBackCompat", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
taskPrompt = "Some task prompt"
|
||||
)
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Given: A template with an "AI Prompt" parameter
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
@@ -818,6 +861,7 @@ func TestTasksCreate(t *testing.T) {
|
||||
ProvisionApply: echo.ApplyComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
|
||||
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
|
||||
HasAiTasks: true,
|
||||
}}},
|
||||
},
|
||||
@@ -933,6 +977,7 @@ func TestTasksCreate(t *testing.T) {
|
||||
ProvisionApply: echo.ApplyComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
|
||||
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
|
||||
HasAiTasks: true,
|
||||
}}},
|
||||
},
|
||||
@@ -992,6 +1037,7 @@ func TestTasksCreate(t *testing.T) {
|
||||
ProvisionApply: echo.ApplyComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
|
||||
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
|
||||
HasAiTasks: true,
|
||||
}}},
|
||||
},
|
||||
@@ -1028,6 +1074,7 @@ func TestTasksCreate(t *testing.T) {
|
||||
ProvisionApply: echo.ApplyComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
|
||||
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
|
||||
HasAiTasks: true,
|
||||
}}},
|
||||
},
|
||||
@@ -1080,6 +1127,7 @@ func TestTasksCreate(t *testing.T) {
|
||||
ProvisionApply: echo.ApplyComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
|
||||
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
|
||||
HasAiTasks: true,
|
||||
}}},
|
||||
},
|
||||
@@ -1092,6 +1140,7 @@ func TestTasksCreate(t *testing.T) {
|
||||
ProvisionApply: echo.ApplyComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
|
||||
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
|
||||
HasAiTasks: true,
|
||||
}}},
|
||||
},
|
||||
|
||||
Generated
+6
-59
@@ -85,7 +85,7 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/aibridge/interceptions": {
|
||||
"/api/experimental/aibridge/interceptions": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
@@ -11668,35 +11668,12 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeBedrockConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"access_key": {
|
||||
"type": "string"
|
||||
},
|
||||
"access_key_secret": {
|
||||
"type": "string"
|
||||
},
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"region": {
|
||||
"type": "string"
|
||||
},
|
||||
"small_fast_model": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"anthropic": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAnthropicConfig"
|
||||
},
|
||||
"bedrock": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeBedrockConfig"
|
||||
},
|
||||
"enabled": {
|
||||
"type": "boolean"
|
||||
},
|
||||
@@ -11708,10 +11685,6 @@ const docTemplate = `{
|
||||
"codersdk.AIBridgeInterception": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -12523,13 +12496,6 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"organization_member_permissions": {
|
||||
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.Permission"
|
||||
}
|
||||
},
|
||||
"organization_permissions": {
|
||||
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
|
||||
"type": "array",
|
||||
@@ -13753,13 +13719,6 @@ const docTemplate = `{
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"organization_member_permissions": {
|
||||
"description": "OrganizationMemberPermissions are specific to the organization the role belongs to.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.Permission"
|
||||
}
|
||||
},
|
||||
"organization_permissions": {
|
||||
"description": "OrganizationPermissions are specific to the organization the role belongs to.",
|
||||
"type": "array",
|
||||
@@ -14316,9 +14275,11 @@ const docTemplate = `{
|
||||
"web-push",
|
||||
"oauth2",
|
||||
"mcp-server-http",
|
||||
"workspace-sharing"
|
||||
"workspace-sharing",
|
||||
"aibridge"
|
||||
],
|
||||
"x-enum-comments": {
|
||||
"ExperimentAIBridge": "Enables AI Bridge functionality.",
|
||||
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
|
||||
"ExperimentExample": "This isn't used for anything.",
|
||||
"ExperimentMCPServerHTTP": "Enables the MCP HTTP server functionality.",
|
||||
@@ -14336,7 +14297,8 @@ const docTemplate = `{
|
||||
"ExperimentWebPush",
|
||||
"ExperimentOAuth2",
|
||||
"ExperimentMCPServerHTTP",
|
||||
"ExperimentWorkspaceSharing"
|
||||
"ExperimentWorkspaceSharing",
|
||||
"ExperimentAIBridge"
|
||||
]
|
||||
},
|
||||
"codersdk.ExternalAPIKeyScopes": {
|
||||
@@ -17524,13 +17486,6 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"organization_member_permissions": {
|
||||
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.Permission"
|
||||
}
|
||||
},
|
||||
"organization_permissions": {
|
||||
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
|
||||
"type": "array",
|
||||
@@ -19712,14 +19667,6 @@ const docTemplate = `{
|
||||
"description": "OwnerName is the username of the owner of the workspace.",
|
||||
"type": "string"
|
||||
},
|
||||
"task_id": {
|
||||
"description": "TaskID, if set, indicates that the workspace is relevant to the given codersdk.Task.",
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/uuid.NullUUID"
|
||||
}
|
||||
]
|
||||
},
|
||||
"template_active_version_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
|
||||
Generated
+6
-59
@@ -65,7 +65,7 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/aibridge/interceptions": {
|
||||
"/api/experimental/aibridge/interceptions": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
@@ -10364,35 +10364,12 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeBedrockConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"access_key": {
|
||||
"type": "string"
|
||||
},
|
||||
"access_key_secret": {
|
||||
"type": "string"
|
||||
},
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"region": {
|
||||
"type": "string"
|
||||
},
|
||||
"small_fast_model": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"anthropic": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAnthropicConfig"
|
||||
},
|
||||
"bedrock": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeBedrockConfig"
|
||||
},
|
||||
"enabled": {
|
||||
"type": "boolean"
|
||||
},
|
||||
@@ -10404,10 +10381,6 @@
|
||||
"codersdk.AIBridgeInterception": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -11205,13 +11178,6 @@
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"organization_member_permissions": {
|
||||
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.Permission"
|
||||
}
|
||||
},
|
||||
"organization_permissions": {
|
||||
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
|
||||
"type": "array",
|
||||
@@ -12367,13 +12333,6 @@
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"organization_member_permissions": {
|
||||
"description": "OrganizationMemberPermissions are specific to the organization the role belongs to.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.Permission"
|
||||
}
|
||||
},
|
||||
"organization_permissions": {
|
||||
"description": "OrganizationPermissions are specific to the organization the role belongs to.",
|
||||
"type": "array",
|
||||
@@ -12923,9 +12882,11 @@
|
||||
"web-push",
|
||||
"oauth2",
|
||||
"mcp-server-http",
|
||||
"workspace-sharing"
|
||||
"workspace-sharing",
|
||||
"aibridge"
|
||||
],
|
||||
"x-enum-comments": {
|
||||
"ExperimentAIBridge": "Enables AI Bridge functionality.",
|
||||
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
|
||||
"ExperimentExample": "This isn't used for anything.",
|
||||
"ExperimentMCPServerHTTP": "Enables the MCP HTTP server functionality.",
|
||||
@@ -12943,7 +12904,8 @@
|
||||
"ExperimentWebPush",
|
||||
"ExperimentOAuth2",
|
||||
"ExperimentMCPServerHTTP",
|
||||
"ExperimentWorkspaceSharing"
|
||||
"ExperimentWorkspaceSharing",
|
||||
"ExperimentAIBridge"
|
||||
]
|
||||
},
|
||||
"codersdk.ExternalAPIKeyScopes": {
|
||||
@@ -16016,13 +15978,6 @@
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"organization_member_permissions": {
|
||||
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.Permission"
|
||||
}
|
||||
},
|
||||
"organization_permissions": {
|
||||
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
|
||||
"type": "array",
|
||||
@@ -18098,14 +18053,6 @@
|
||||
"description": "OwnerName is the username of the owner of the workspace.",
|
||||
"type": "string"
|
||||
},
|
||||
"task_id": {
|
||||
"description": "TaskID, if set, indicates that the workspace is relevant to the given codersdk.Task.",
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/uuid.NullUUID"
|
||||
}
|
||||
]
|
||||
},
|
||||
"template_active_version_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
|
||||
+2
-2
@@ -509,11 +509,11 @@ func (api *API) auditLogResourceLink(ctx context.Context, alog database.GetAudit
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
user, err := api.Database.GetUserByID(ctx, task.OwnerID)
|
||||
workspace, err := api.Database.GetWorkspaceByID(ctx, task.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("/tasks/%s/%s", user.Username, task.ID)
|
||||
return fmt.Sprintf("/tasks/%s/%s", workspace.OwnerName, task.Name)
|
||||
|
||||
default:
|
||||
return ""
|
||||
|
||||
+10
-11
@@ -50,13 +50,6 @@ func TestCheckPermissions(t *testing.T) {
|
||||
},
|
||||
Action: "read",
|
||||
},
|
||||
readOrgWorkspaces: {
|
||||
Object: codersdk.AuthorizationObject{
|
||||
ResourceType: codersdk.ResourceWorkspace,
|
||||
OrganizationID: adminUser.OrganizationID.String(),
|
||||
},
|
||||
Action: "read",
|
||||
},
|
||||
readMyself: {
|
||||
Object: codersdk.AuthorizationObject{
|
||||
ResourceType: codersdk.ResourceUser,
|
||||
@@ -65,10 +58,16 @@ func TestCheckPermissions(t *testing.T) {
|
||||
Action: "read",
|
||||
},
|
||||
readOwnWorkspaces: {
|
||||
Object: codersdk.AuthorizationObject{
|
||||
ResourceType: codersdk.ResourceWorkspace,
|
||||
OwnerID: "me",
|
||||
},
|
||||
Action: "read",
|
||||
},
|
||||
readOrgWorkspaces: {
|
||||
Object: codersdk.AuthorizationObject{
|
||||
ResourceType: codersdk.ResourceWorkspace,
|
||||
OrganizationID: adminUser.OrganizationID.String(),
|
||||
OwnerID: "me",
|
||||
},
|
||||
Action: "read",
|
||||
},
|
||||
@@ -93,9 +92,9 @@ func TestCheckPermissions(t *testing.T) {
|
||||
UserID: adminUser.UserID,
|
||||
Check: map[string]bool{
|
||||
readAllUsers: true,
|
||||
readOrgWorkspaces: true,
|
||||
readMyself: true,
|
||||
readOwnWorkspaces: true,
|
||||
readOrgWorkspaces: true,
|
||||
updateSpecificTemplate: true,
|
||||
},
|
||||
},
|
||||
@@ -105,9 +104,9 @@ func TestCheckPermissions(t *testing.T) {
|
||||
UserID: orgAdminUser.ID,
|
||||
Check: map[string]bool{
|
||||
readAllUsers: true,
|
||||
readOrgWorkspaces: true,
|
||||
readMyself: true,
|
||||
readOwnWorkspaces: true,
|
||||
readOrgWorkspaces: true,
|
||||
updateSpecificTemplate: true,
|
||||
},
|
||||
},
|
||||
@@ -117,9 +116,9 @@ func TestCheckPermissions(t *testing.T) {
|
||||
UserID: memberUser.ID,
|
||||
Check: map[string]bool{
|
||||
readAllUsers: false,
|
||||
readOrgWorkspaces: false,
|
||||
readMyself: true,
|
||||
readOwnWorkspaces: true,
|
||||
readOrgWorkspaces: false,
|
||||
updateSpecificTemplate: false,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -1764,175 +1764,3 @@ func TestExecutorAutostartSkipsWhenNoProvisionersAvailable(t *testing.T) {
|
||||
|
||||
assert.Len(t, stats.Transitions, 1, "should create builds when provisioners are available")
|
||||
}
|
||||
|
||||
func TestExecutorTaskWorkspace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
createTaskTemplate := func(t *testing.T, client *codersdk.Client, orgID uuid.UUID, ctx context.Context, defaultTTL time.Duration) codersdk.Template {
|
||||
t.Helper()
|
||||
|
||||
taskAppID := uuid.New()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, orgID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{
|
||||
Type: &proto.Response_Plan{
|
||||
Plan: &proto.PlanComplete{HasAiTasks: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
ProvisionApply: []*proto.Response{
|
||||
{
|
||||
Type: &proto.Response_Apply{
|
||||
Apply: &proto.ApplyComplete{
|
||||
Resources: []*proto.Resource{
|
||||
{
|
||||
Agents: []*proto.Agent{
|
||||
{
|
||||
Id: uuid.NewString(),
|
||||
Name: "dev",
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: uuid.NewString(),
|
||||
},
|
||||
Apps: []*proto.App{
|
||||
{
|
||||
Id: taskAppID.String(),
|
||||
Slug: "task-app",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
AiTasks: []*proto.AITask{
|
||||
{
|
||||
AppId: taskAppID.String(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, orgID, version.ID)
|
||||
|
||||
if defaultTTL > 0 {
|
||||
_, err := client.UpdateTemplateMeta(ctx, template.ID, codersdk.UpdateTemplateMeta{
|
||||
DefaultTTLMillis: defaultTTL.Milliseconds(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
return template
|
||||
}
|
||||
|
||||
createTaskWorkspace := func(t *testing.T, client *codersdk.Client, template codersdk.Template, ctx context.Context, input string) codersdk.Workspace {
|
||||
t.Helper()
|
||||
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
task, err := exp.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
|
||||
TemplateVersionID: template.ActiveVersionID,
|
||||
Input: input,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace")
|
||||
|
||||
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
return workspace
|
||||
}
|
||||
|
||||
t.Run("Autostart", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
sched = mustSchedule(t, "CRON_TZ=UTC 0 * * * *")
|
||||
tickCh = make(chan time.Time)
|
||||
statsCh = make(chan autobuild.Stats)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
AutobuildTicker: tickCh,
|
||||
IncludeProvisionerDaemon: true,
|
||||
AutobuildStats: statsCh,
|
||||
})
|
||||
admin = coderdtest.CreateFirstUser(t, client)
|
||||
)
|
||||
|
||||
// Given: A task workspace
|
||||
template := createTaskTemplate(t, client, admin.OrganizationID, ctx, 0)
|
||||
workspace := createTaskWorkspace(t, client, template, ctx, "test task for autostart")
|
||||
|
||||
// Given: The task workspace has an autostart schedule
|
||||
err := client.UpdateWorkspaceAutostart(ctx, workspace.ID, codersdk.UpdateWorkspaceAutostartRequest{
|
||||
Schedule: ptr.Ref(sched.String()),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Given: That the workspace is in a stopped state.
|
||||
workspace = coderdtest.MustTransitionWorkspace(t, client, workspace.ID, codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionStop)
|
||||
|
||||
p, err := coderdtest.GetProvisionerForTags(db, time.Now(), workspace.OrganizationID, map[string]string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// When: the autobuild executor ticks after the scheduled time
|
||||
go func() {
|
||||
tickTime := sched.Next(workspace.LatestBuild.CreatedAt)
|
||||
coderdtest.UpdateProvisionerLastSeenAt(t, db, p.ID, tickTime)
|
||||
tickCh <- tickTime
|
||||
close(tickCh)
|
||||
}()
|
||||
|
||||
// Then: We expect to see a start transition
|
||||
stats := <-statsCh
|
||||
require.Len(t, stats.Transitions, 1, "lifecycle executor should transition the task workspace")
|
||||
assert.Contains(t, stats.Transitions, workspace.ID, "task workspace should be in transitions")
|
||||
assert.Equal(t, database.WorkspaceTransitionStart, stats.Transitions[workspace.ID], "should autostart the workspace")
|
||||
require.Empty(t, stats.Errors, "should have no errors when managing task workspaces")
|
||||
})
|
||||
|
||||
t.Run("Autostop", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
tickCh = make(chan time.Time)
|
||||
statsCh = make(chan autobuild.Stats)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
AutobuildTicker: tickCh,
|
||||
IncludeProvisionerDaemon: true,
|
||||
AutobuildStats: statsCh,
|
||||
})
|
||||
admin = coderdtest.CreateFirstUser(t, client)
|
||||
)
|
||||
|
||||
// Given: A task workspace with an 8 hour deadline
|
||||
template := createTaskTemplate(t, client, admin.OrganizationID, ctx, 8*time.Hour)
|
||||
workspace := createTaskWorkspace(t, client, template, ctx, "test task for autostop")
|
||||
|
||||
// Given: The workspace is currently running
|
||||
workspace = coderdtest.MustWorkspace(t, client, workspace.ID)
|
||||
require.Equal(t, codersdk.WorkspaceTransitionStart, workspace.LatestBuild.Transition)
|
||||
require.NotZero(t, workspace.LatestBuild.Deadline, "workspace should have a deadline for autostop")
|
||||
|
||||
p, err := coderdtest.GetProvisionerForTags(db, time.Now(), workspace.OrganizationID, map[string]string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// When: the autobuild executor ticks after the deadline
|
||||
go func() {
|
||||
tickTime := workspace.LatestBuild.Deadline.Time.Add(time.Minute)
|
||||
coderdtest.UpdateProvisionerLastSeenAt(t, db, p.ID, tickTime)
|
||||
tickCh <- tickTime
|
||||
close(tickCh)
|
||||
}()
|
||||
|
||||
// Then: We expect to see a stop transition
|
||||
stats := <-statsCh
|
||||
require.Len(t, stats.Transitions, 1, "lifecycle executor should transition the task workspace")
|
||||
assert.Contains(t, stats.Transitions, workspace.ID, "task workspace should be in transitions")
|
||||
assert.Equal(t, database.WorkspaceTransitionStop, stats.Transitions[workspace.ID], "should autostop the workspace")
|
||||
require.Empty(t, stats.Errors, "should have no errors when managing task workspaces")
|
||||
})
|
||||
}
|
||||
|
||||
+4
-1
@@ -1021,7 +1021,10 @@ func New(options *Options) *API {
|
||||
apiRateLimiter,
|
||||
httpmw.ReportCLITelemetry(api.Logger, options.Telemetry),
|
||||
)
|
||||
|
||||
r.Route("/aitasks", func(r chi.Router) {
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Get("/prompts", api.aiTasksPrompts)
|
||||
})
|
||||
r.Route("/tasks", func(r chi.Router) {
|
||||
r.Use(apiKeyMiddleware)
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ const (
|
||||
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
|
||||
CheckWorkspaceBuildsAiTaskSidebarAppIDRequired CheckConstraint = "workspace_builds_ai_task_sidebar_app_id_required" // workspace_builds
|
||||
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
)
|
||||
|
||||
@@ -714,13 +714,12 @@ func RBACRole(role rbac.Role) codersdk.Role {
|
||||
|
||||
orgPerms := role.ByOrgID[slim.OrganizationID]
|
||||
return codersdk.Role{
|
||||
Name: slim.Name,
|
||||
OrganizationID: slim.OrganizationID,
|
||||
DisplayName: slim.DisplayName,
|
||||
SitePermissions: List(role.Site, RBACPermission),
|
||||
UserPermissions: List(role.User, RBACPermission),
|
||||
OrganizationPermissions: List(orgPerms.Org, RBACPermission),
|
||||
OrganizationMemberPermissions: List(orgPerms.Member, RBACPermission),
|
||||
Name: slim.Name,
|
||||
OrganizationID: slim.OrganizationID,
|
||||
DisplayName: slim.DisplayName,
|
||||
SitePermissions: List(role.Site, RBACPermission),
|
||||
OrganizationPermissions: List(orgPerms.Org, RBACPermission),
|
||||
UserPermissions: List(role.User, RBACPermission),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -735,8 +734,8 @@ func Role(role database.CustomRole) codersdk.Role {
|
||||
OrganizationID: orgID,
|
||||
DisplayName: role.DisplayName,
|
||||
SitePermissions: List(role.SitePermissions, Permission),
|
||||
UserPermissions: List(role.UserPermissions, Permission),
|
||||
OrganizationPermissions: List(role.OrgPermissions, Permission),
|
||||
UserPermissions: List(role.UserPermissions, Permission),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -963,7 +962,7 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
|
||||
// created_at ASC
|
||||
return sdkToolUsages[i].CreatedAt.Before(sdkToolUsages[j].CreatedAt)
|
||||
})
|
||||
intc := codersdk.AIBridgeInterception{
|
||||
return codersdk.AIBridgeInterception{
|
||||
ID: interception.ID,
|
||||
Initiator: MinimalUserFromVisibleUser(initiator),
|
||||
Provider: interception.Provider,
|
||||
@@ -974,10 +973,6 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
|
||||
UserPrompts: sdkUserPrompts,
|
||||
ToolUsages: sdkToolUsages,
|
||||
}
|
||||
if interception.EndedAt.Valid {
|
||||
intc.EndedAt = &interception.EndedAt.Time
|
||||
}
|
||||
return intc
|
||||
}
|
||||
|
||||
func AIBridgeTokenUsage(usage database.AIBridgeTokenUsage) codersdk.AIBridgeTokenUsage {
|
||||
|
||||
@@ -254,7 +254,6 @@ var (
|
||||
rbac.ResourceFile.Type: {policy.ActionRead}, // Required to read terraform files
|
||||
rbac.ResourceNotificationMessage.Type: {policy.ActionCreate, policy.ActionRead},
|
||||
rbac.ResourceSystem.Type: {policy.WildcardSymbol},
|
||||
rbac.ResourceTask.Type: {policy.ActionRead, policy.ActionUpdate},
|
||||
rbac.ResourceTemplate.Type: {policy.ActionRead, policy.ActionUpdate},
|
||||
rbac.ResourceUser.Type: {policy.ActionRead},
|
||||
rbac.ResourceWorkspace.Type: {policy.ActionDelete, policy.ActionRead, policy.ActionUpdate, policy.ActionWorkspaceStart, policy.ActionWorkspaceStop},
|
||||
@@ -396,13 +395,11 @@ var (
|
||||
Identifier: rbac.RoleIdentifier{Name: "subagentapi"},
|
||||
DisplayName: "Sub Agent API",
|
||||
Site: []rbac.Permission{},
|
||||
User: []rbac.Permission{},
|
||||
User: rbac.Permissions(map[string][]policy.Action{
|
||||
rbac.ResourceWorkspace.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreateAgent, policy.ActionDeleteAgent},
|
||||
}),
|
||||
ByOrgID: map[string]rbac.OrgPermissions{
|
||||
orgID.String(): {
|
||||
Member: rbac.Permissions(map[string][]policy.Action{
|
||||
rbac.ResourceWorkspace.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreateAgent, policy.ActionDeleteAgent},
|
||||
}),
|
||||
},
|
||||
orgID.String(): {},
|
||||
},
|
||||
},
|
||||
}),
|
||||
@@ -1293,17 +1290,14 @@ func (q *querier) customRoleCheck(ctx context.Context, role database.CustomRole)
|
||||
return xerrors.Errorf("invalid role: %w", err)
|
||||
}
|
||||
|
||||
if len(rbacRole.ByOrgID) > 0 && (len(rbacRole.Site) > 0 || len(rbacRole.User) > 0) {
|
||||
// This is a choice to keep roles simple. If we allow mixing site and org
|
||||
// scoped perms, then knowing who can do what gets more complicated. Roles
|
||||
// should either be entirely org-scoped or entirely unrelated to
|
||||
// organizations.
|
||||
return xerrors.Errorf("invalid custom role, cannot assign both org-scoped and site/user permissions at the same time")
|
||||
if len(rbacRole.ByOrgID) > 0 && len(rbacRole.Site) > 0 {
|
||||
// This is a choice to keep roles simple. If we allow mixing site and org scoped perms, then knowing who can
|
||||
// do what gets more complicated.
|
||||
return xerrors.Errorf("invalid custom role, cannot assign both org and site permissions at the same time")
|
||||
}
|
||||
|
||||
if len(rbacRole.ByOrgID) > 1 {
|
||||
// Again to avoid more complexity in our roles. Roles are limited to one
|
||||
// organization.
|
||||
// Again to avoid more complexity in our roles
|
||||
return xerrors.Errorf("invalid custom role, cannot assign permissions to more than 1 org at a time")
|
||||
}
|
||||
|
||||
@@ -1319,18 +1313,7 @@ func (q *querier) customRoleCheck(ctx context.Context, role database.CustomRole)
|
||||
for _, orgPerm := range perms.Org {
|
||||
err := q.customRoleEscalationCheck(ctx, act, orgPerm, rbac.Object{OrgID: orgID, Type: orgPerm.ResourceType})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("org=%q: org: %w", orgID, err)
|
||||
}
|
||||
}
|
||||
for _, memberPerm := range perms.Member {
|
||||
// The person giving the permission should still be required to have
|
||||
// the permissions throughout the org in order to give individuals the
|
||||
// same permission among their own resources, since the role can be given
|
||||
// to anyone. The `Owner` is intentionally omitted from the `Object` to
|
||||
// enforce this.
|
||||
err := q.customRoleEscalationCheck(ctx, act, memberPerm, rbac.Object{OrgID: orgID, Type: memberPerm.ResourceType})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("org=%q: member: %w", orgID, err)
|
||||
return xerrors.Errorf("org=%q: %w", orgID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1348,8 +1331,8 @@ func (q *querier) customRoleCheck(ctx context.Context, role database.CustomRole)
|
||||
func (q *querier) authorizeProvisionerJob(ctx context.Context, job database.ProvisionerJob) error {
|
||||
switch job.Type {
|
||||
case database.ProvisionerJobTypeWorkspaceBuild:
|
||||
// Authorized call to get workspace build. If we can read the build, we can
|
||||
// read the job.
|
||||
// Authorized call to get workspace build. If we can read the build, we
|
||||
// can read the job.
|
||||
_, err := q.GetWorkspaceBuildByJobID(ctx, job.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("fetch related workspace build: %w", err)
|
||||
@@ -1392,8 +1375,8 @@ func (q *querier) ActivityBumpWorkspace(ctx context.Context, arg database.Activi
|
||||
}
|
||||
|
||||
func (q *querier) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error) {
|
||||
// Although this technically only reads users, only system-related functions
|
||||
// should be allowed to call this.
|
||||
// Although this technically only reads users, only system-related functions should be
|
||||
// allowed to call this.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1412,8 +1395,8 @@ func (q *querier) ArchiveUnusedTemplateVersions(ctx context.Context, arg databas
|
||||
}
|
||||
|
||||
func (q *querier) BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg database.BatchUpdateWorkspaceLastUsedAtParams) error {
|
||||
// Could be any workspace and checking auth to each workspace is overkill for
|
||||
// the purpose of this function.
|
||||
// Could be any workspace and checking auth to each workspace is overkill for the purpose
|
||||
// of this function.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceWorkspace.All()); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1441,13 +1424,6 @@ func (q *querier) BulkMarkNotificationMessagesSent(ctx context.Context, arg data
|
||||
return q.db.BulkMarkNotificationMessagesSent(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg database.CalculateAIBridgeInterceptionsTelemetrySummaryParams) (database.CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
|
||||
return database.CalculateAIBridgeInterceptionsTelemetrySummaryRow{}, err
|
||||
}
|
||||
return q.db.CalculateAIBridgeInterceptionsTelemetrySummary(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) {
|
||||
empty := database.ClaimPrebuiltWorkspaceRow{}
|
||||
|
||||
@@ -1747,13 +1723,6 @@ func (q *querier) DeleteOldProvisionerDaemons(ctx context.Context) error {
|
||||
return q.db.DeleteOldProvisionerDaemons(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOldTelemetryLocks(ctx context.Context, beforeTime time.Time) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteOldTelemetryLocks(ctx, beforeTime)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOldWorkspaceAgentLogs(ctx context.Context, threshold time.Time) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
@@ -2649,13 +2618,6 @@ func (q *querier) GetOrganizationsByUserID(ctx context.Context, userID database.
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetOrganizationsByUserID)(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg database.GetOrganizationsWithPrebuildStatusParams) ([]database.GetOrganizationsWithPrebuildStatusRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceOrganization.All()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetOrganizationsWithPrebuildStatus(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
|
||||
version, err := q.db.GetTemplateVersionByJobID(ctx, jobID)
|
||||
if err != nil {
|
||||
@@ -4250,13 +4212,6 @@ func (q *querier) InsertTelemetryItemIfNotExists(ctx context.Context, arg databa
|
||||
return q.db.InsertTelemetryItemIfNotExists(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertTelemetryLock(ctx context.Context, arg database.InsertTelemetryLockParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.InsertTelemetryLock(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) error {
|
||||
obj := rbac.ResourceTemplate.InOrg(arg.OrganizationID)
|
||||
if err := q.authorizeContext(ctx, policy.ActionCreate, obj); err != nil {
|
||||
@@ -4568,13 +4523,6 @@ func (q *querier) ListAIBridgeInterceptions(ctx context.Context, arg database.Li
|
||||
return q.db.ListAuthorizedAIBridgeInterceptions(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg database.ListAIBridgeInterceptionsTelemetrySummariesParams) ([]database.ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
// This function is a system function until we implement a join for aibridge interceptions.
|
||||
// Matches the behavior of the workspaces listing endpoint.
|
||||
@@ -4763,13 +4711,6 @@ func (q *querier) UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error {
|
||||
return update(q.log, q.auth, fetch, q.db.UnfavoriteWorkspace)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateAIBridgeInterceptionEnded(ctx context.Context, params database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
|
||||
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, params.ID); err != nil {
|
||||
return database.AIBridgeInterception{}, err
|
||||
}
|
||||
return q.db.UpdateAIBridgeInterceptionEnded(ctx, params)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
|
||||
fetch := func(ctx context.Context, arg database.UpdateAPIKeyByIDParams) (database.APIKey, error) {
|
||||
return q.db.GetAPIKeyByID(ctx, arg.ID)
|
||||
@@ -4941,10 +4882,10 @@ func (q *querier) UpdateOrganizationDeletedByID(ctx context.Context, arg databas
|
||||
return deleteQ(q.log, q.auth, q.db.GetOrganizationByID, deleteF)(ctx, arg.ID)
|
||||
}
|
||||
|
||||
func (q *querier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
|
||||
func (q *querier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
|
||||
// Prebuild operation for canceling pending prebuild jobs from non-active template versions
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourcePrebuiltWorkspace); err != nil {
|
||||
return []database.UpdatePrebuildProvisionerJobWithCancelRow{}, err
|
||||
return []uuid.UUID{}, err
|
||||
}
|
||||
return q.db.UpdatePrebuildProvisionerJobWithCancel(ctx, arg)
|
||||
}
|
||||
|
||||
@@ -646,13 +646,10 @@ func (s *MethodTestSuite) TestProvisionerJob() {
|
||||
PresetID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
Now: dbtime.Now(),
|
||||
}
|
||||
canceledJobs := []database.UpdatePrebuildProvisionerJobWithCancelRow{
|
||||
{ID: uuid.New(), WorkspaceID: uuid.New(), TemplateID: uuid.New(), TemplateVersionPresetID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
|
||||
{ID: uuid.New(), WorkspaceID: uuid.New(), TemplateID: uuid.New(), TemplateVersionPresetID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
|
||||
}
|
||||
jobIDs := []uuid.UUID{uuid.New(), uuid.New()}
|
||||
|
||||
dbm.EXPECT().UpdatePrebuildProvisionerJobWithCancel(gomock.Any(), arg).Return(canceledJobs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourcePrebuiltWorkspace, policy.ActionUpdate).Returns(canceledJobs)
|
||||
dbm.EXPECT().UpdatePrebuildProvisionerJobWithCancel(gomock.Any(), arg).Return(jobIDs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourcePrebuiltWorkspace, policy.ActionUpdate).Returns(jobIDs)
|
||||
}))
|
||||
s.Run("GetProvisionerJobsByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
org := testutil.Fake(s.T(), faker, database.Organization{})
|
||||
@@ -3759,14 +3756,6 @@ func (s *MethodTestSuite) TestPrebuilds() {
|
||||
dbm.EXPECT().GetPrebuildMetrics(gomock.Any()).Return([]database.GetPrebuildMetricsRow{}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetOrganizationsWithPrebuildStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetOrganizationsWithPrebuildStatusParams{
|
||||
UserID: uuid.New(),
|
||||
GroupName: "test",
|
||||
}
|
||||
dbm.EXPECT().GetOrganizationsWithPrebuildStatus(gomock.Any(), arg).Return([]database.GetOrganizationsWithPrebuildStatusRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceOrganization.All(), policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetPrebuildsSettings", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetPrebuildsSettings(gomock.Any()).Return("{}", nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
@@ -4628,35 +4617,4 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
db.EXPECT().ListAIBridgeToolUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeToolUsage{}, nil).AnyTimes()
|
||||
check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.AIBridgeToolUsage{})
|
||||
}))
|
||||
|
||||
s.Run("UpdateAIBridgeInterceptionEnded", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
intcID := uuid.UUID{1}
|
||||
params := database.UpdateAIBridgeInterceptionEndedParams{ID: intcID}
|
||||
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intcID})
|
||||
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intcID).Return(intc, nil).AnyTimes() // Validation.
|
||||
db.EXPECT().UpdateAIBridgeInterceptionEnded(gomock.Any(), params).Return(intc, nil).AnyTimes()
|
||||
check.Args(params).Asserts(intc, policy.ActionUpdate).Returns(intc)
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestTelemetry() {
|
||||
s.Run("InsertTelemetryLock", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
db.EXPECT().InsertTelemetryLock(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
check.Args(database.InsertTelemetryLockParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate)
|
||||
}))
|
||||
|
||||
s.Run("DeleteOldTelemetryLocks", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
db.EXPECT().DeleteOldTelemetryLocks(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
check.Args(time.Time{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeInterceptionsTelemetrySummaries", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
db.EXPECT().ListAIBridgeInterceptionsTelemetrySummaries(gomock.Any(), gomock.Any()).Return([]database.ListAIBridgeInterceptionsTelemetrySummariesRow{}, nil).AnyTimes()
|
||||
check.Args(database.ListAIBridgeInterceptionsTelemetrySummariesParams{}).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead)
|
||||
}))
|
||||
|
||||
s.Run("CalculateAIBridgeInterceptionsTelemetrySummary", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
db.EXPECT().CalculateAIBridgeInterceptionsTelemetrySummary(gomock.Any(), gomock.Any()).Return(database.CalculateAIBridgeInterceptionsTelemetrySummaryRow{}, nil).AnyTimes()
|
||||
check.Args(database.CalculateAIBridgeInterceptionsTelemetrySummaryParams{}).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead)
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -41,7 +41,6 @@ type WorkspaceResponse struct {
|
||||
Build database.WorkspaceBuild
|
||||
AgentToken string
|
||||
TemplateVersionResponse
|
||||
Task database.Task
|
||||
}
|
||||
|
||||
// WorkspaceBuildBuilder generates workspace builds and associated
|
||||
@@ -58,7 +57,6 @@ type WorkspaceBuildBuilder struct {
|
||||
agentToken string
|
||||
jobStatus database.ProvisionerJobStatus
|
||||
taskAppID uuid.UUID
|
||||
taskSeed database.TaskTable
|
||||
}
|
||||
|
||||
// WorkspaceBuild generates a workspace build for the provided workspace.
|
||||
@@ -117,28 +115,25 @@ func (b WorkspaceBuildBuilder) WithAgent(mutations ...func([]*sdkproto.Agent) []
|
||||
return b
|
||||
}
|
||||
|
||||
func (b WorkspaceBuildBuilder) WithTask(taskSeed database.TaskTable, appSeed *sdkproto.App) WorkspaceBuildBuilder {
|
||||
//nolint:revive // returns modified struct
|
||||
b.taskSeed = taskSeed
|
||||
|
||||
if appSeed == nil {
|
||||
appSeed = &sdkproto.App{}
|
||||
func (b WorkspaceBuildBuilder) WithTask(seed *sdkproto.App) WorkspaceBuildBuilder {
|
||||
if seed == nil {
|
||||
seed = &sdkproto.App{}
|
||||
}
|
||||
|
||||
var err error
|
||||
//nolint: revive // returns modified struct
|
||||
b.taskAppID, err = uuid.Parse(takeFirst(appSeed.Id, uuid.NewString()))
|
||||
b.taskAppID, err = uuid.Parse(takeFirst(seed.Id, uuid.NewString()))
|
||||
require.NoError(b.t, err)
|
||||
|
||||
return b.Params(database.WorkspaceBuildParameter{
|
||||
Name: codersdk.AITaskPromptParameterName,
|
||||
Value: b.taskSeed.Prompt,
|
||||
Value: "list me",
|
||||
}).WithAgent(func(a []*sdkproto.Agent) []*sdkproto.Agent {
|
||||
a[0].Apps = []*sdkproto.App{
|
||||
{
|
||||
Id: b.taskAppID.String(),
|
||||
Slug: takeFirst(appSeed.Slug, "task-app"),
|
||||
Url: takeFirst(appSeed.Url, ""),
|
||||
Slug: takeFirst(seed.Slug, "task-app"),
|
||||
Url: takeFirst(seed.Url, ""),
|
||||
},
|
||||
}
|
||||
return a
|
||||
@@ -166,19 +161,6 @@ func (b WorkspaceBuildBuilder) Canceled() WorkspaceBuildBuilder {
|
||||
// Workspace will be optionally populated if no ID is set on the provided
|
||||
// workspace.
|
||||
func (b WorkspaceBuildBuilder) Do() WorkspaceResponse {
|
||||
var resp WorkspaceResponse
|
||||
// Use transaction, like real wsbuilder.
|
||||
err := b.db.InTx(func(tx database.Store) error {
|
||||
//nolint:revive // calls do on modified struct
|
||||
b.db = tx
|
||||
resp = b.doInTX()
|
||||
return nil
|
||||
}, nil)
|
||||
require.NoError(b.t, err)
|
||||
return resp
|
||||
}
|
||||
|
||||
func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
|
||||
b.t.Helper()
|
||||
jobID := uuid.New()
|
||||
b.seed.ID = uuid.New()
|
||||
@@ -230,37 +212,6 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
|
||||
b.seed.WorkspaceID = b.ws.ID
|
||||
b.seed.InitiatorID = takeFirst(b.seed.InitiatorID, b.ws.OwnerID)
|
||||
|
||||
// If a task was requested, ensure it exists and is associated with this
|
||||
// workspace.
|
||||
if b.taskAppID != uuid.Nil {
|
||||
b.logger.Debug(context.Background(), "creating or updating task", "task_id", b.taskSeed.ID)
|
||||
b.taskSeed.OrganizationID = takeFirst(b.taskSeed.OrganizationID, b.ws.OrganizationID)
|
||||
b.taskSeed.OwnerID = takeFirst(b.taskSeed.OwnerID, b.ws.OwnerID)
|
||||
b.taskSeed.Name = takeFirst(b.taskSeed.Name, b.ws.Name)
|
||||
b.taskSeed.WorkspaceID = uuid.NullUUID{UUID: takeFirst(b.taskSeed.WorkspaceID.UUID, b.ws.ID), Valid: true}
|
||||
b.taskSeed.TemplateVersionID = takeFirst(b.taskSeed.TemplateVersionID, b.seed.TemplateVersionID)
|
||||
|
||||
// Try to fetch existing task and update its workspace ID.
|
||||
if task, err := b.db.GetTaskByID(ownerCtx, b.taskSeed.ID); err == nil {
|
||||
if !task.WorkspaceID.Valid {
|
||||
b.logger.Info(context.Background(), "updating task workspace id", "task_id", b.taskSeed.ID, "workspace_id", b.ws.ID)
|
||||
_, err = b.db.UpdateTaskWorkspaceID(ownerCtx, database.UpdateTaskWorkspaceIDParams{
|
||||
ID: b.taskSeed.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: b.ws.ID, Valid: true},
|
||||
})
|
||||
require.NoError(b.t, err, "update task workspace id")
|
||||
} else if task.WorkspaceID.UUID != b.ws.ID {
|
||||
require.Fail(b.t, "task already has a workspace id, mismatch", task.WorkspaceID.UUID, b.ws.ID)
|
||||
}
|
||||
} else if errors.Is(err, sql.ErrNoRows) {
|
||||
task := dbgen.Task(b.t, b.db, b.taskSeed)
|
||||
b.taskSeed.ID = task.ID
|
||||
b.logger.Info(context.Background(), "created new task", "task_id", b.taskSeed.ID)
|
||||
} else {
|
||||
require.NoError(b.t, err, "get task by id")
|
||||
}
|
||||
}
|
||||
|
||||
// Create a provisioner job for the build!
|
||||
payload, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{
|
||||
WorkspaceBuildID: b.seed.ID,
|
||||
@@ -373,11 +324,6 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
|
||||
b.logger.Debug(context.Background(), "linked task to workspace build",
|
||||
slog.F("task_id", task.ID),
|
||||
slog.F("build_number", resp.Build.BuildNumber))
|
||||
|
||||
// Update task after linking.
|
||||
task, err = b.db.GetTaskByID(ownerCtx, task.ID)
|
||||
require.NoError(b.t, err, "get task by id")
|
||||
resp.Task = task
|
||||
}
|
||||
|
||||
for i := range b.params {
|
||||
|
||||
@@ -1495,7 +1495,7 @@ func ClaimPrebuild(
|
||||
return claimedWorkspace
|
||||
}
|
||||
|
||||
func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertAIBridgeInterceptionParams, endedAt *time.Time) database.AIBridgeInterception {
|
||||
func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertAIBridgeInterceptionParams) database.AIBridgeInterception {
|
||||
interception, err := db.InsertAIBridgeInterception(genCtx, database.InsertAIBridgeInterceptionParams{
|
||||
ID: takeFirst(seed.ID, uuid.New()),
|
||||
InitiatorID: takeFirst(seed.InitiatorID, uuid.New()),
|
||||
@@ -1504,13 +1504,6 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
|
||||
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
|
||||
StartedAt: takeFirst(seed.StartedAt, dbtime.Now()),
|
||||
})
|
||||
if endedAt != nil {
|
||||
interception, err = db.UpdateAIBridgeInterceptionEnded(genCtx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: interception.ID,
|
||||
EndedAt: *endedAt,
|
||||
})
|
||||
require.NoError(t, err, "insert aibridge interception")
|
||||
}
|
||||
require.NoError(t, err, "insert aibridge interception")
|
||||
return interception
|
||||
}
|
||||
@@ -1576,7 +1569,6 @@ func Task(t testing.TB, db database.Store, orig database.TaskTable) database.Tas
|
||||
}
|
||||
|
||||
task, err := db.InsertTask(genCtx, database.InsertTaskParams{
|
||||
ID: takeFirst(orig.ID, uuid.New()),
|
||||
OrganizationID: orig.OrganizationID,
|
||||
OwnerID: orig.OwnerID,
|
||||
Name: takeFirst(orig.Name, taskname.GenerateFallback()),
|
||||
|
||||
@@ -158,13 +158,6 @@ func (m queryMetricsStore) BulkMarkNotificationMessagesSent(ctx context.Context,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg database.CalculateAIBridgeInterceptionsTelemetrySummaryParams) (database.CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CalculateAIBridgeInterceptionsTelemetrySummary(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("CalculateAIBridgeInterceptionsTelemetrySummary").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ClaimPrebuiltWorkspace(ctx, arg)
|
||||
@@ -410,13 +403,6 @@ func (m queryMetricsStore) DeleteOldProvisionerDaemons(ctx context.Context) erro
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteOldTelemetryLocks(ctx, periodEndingAtBefore)
|
||||
m.queryLatencies.WithLabelValues("DeleteOldTelemetryLocks").Observe(time.Since(start).Seconds())
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOldWorkspaceAgentLogs(ctx context.Context, arg time.Time) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteOldWorkspaceAgentLogs(ctx, arg)
|
||||
@@ -1243,13 +1229,6 @@ func (m queryMetricsStore) GetOrganizationsByUserID(ctx context.Context, userID
|
||||
return organizations, err
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg database.GetOrganizationsWithPrebuildStatusParams) ([]database.GetOrganizationsWithPrebuildStatusRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetOrganizationsWithPrebuildStatus(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetOrganizationsWithPrebuildStatus").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
|
||||
start := time.Now()
|
||||
schemas, err := m.s.GetParameterSchemasByJobID(ctx, jobID)
|
||||
@@ -2538,13 +2517,6 @@ func (m queryMetricsStore) InsertTelemetryItemIfNotExists(ctx context.Context, a
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertTelemetryLock(ctx context.Context, arg database.InsertTelemetryLockParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.InsertTelemetryLock(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertTelemetryLock").Observe(time.Since(start).Seconds())
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) error {
|
||||
start := time.Now()
|
||||
err := m.s.InsertTemplate(ctx, arg)
|
||||
@@ -2762,13 +2734,6 @@ func (m queryMetricsStore) ListAIBridgeInterceptions(ctx context.Context, arg da
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg database.ListAIBridgeInterceptionsTelemetrySummariesParams) ([]database.ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("ListAIBridgeInterceptionsTelemetrySummaries").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIds)
|
||||
@@ -2923,13 +2888,6 @@ func (m queryMetricsStore) UnfavoriteWorkspace(ctx context.Context, arg uuid.UUI
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateAIBridgeInterceptionEnded(ctx context.Context, id database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateAIBridgeInterceptionEnded(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("UpdateAIBridgeInterceptionEnded").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
|
||||
start := time.Now()
|
||||
err := m.s.UpdateAPIKeyByID(ctx, arg)
|
||||
@@ -3049,7 +3007,7 @@ func (m queryMetricsStore) UpdateOrganizationDeletedByID(ctx context.Context, ar
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
|
||||
func (m queryMetricsStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdatePrebuildProvisionerJobWithCancel(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdatePrebuildProvisionerJobWithCancel").Observe(time.Since(start).Seconds())
|
||||
|
||||
@@ -190,21 +190,6 @@ func (mr *MockStoreMockRecorder) BulkMarkNotificationMessagesSent(ctx, arg any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BulkMarkNotificationMessagesSent", reflect.TypeOf((*MockStore)(nil).BulkMarkNotificationMessagesSent), ctx, arg)
|
||||
}
|
||||
|
||||
// CalculateAIBridgeInterceptionsTelemetrySummary mocks base method.
|
||||
func (m *MockStore) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg database.CalculateAIBridgeInterceptionsTelemetrySummaryParams) (database.CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CalculateAIBridgeInterceptionsTelemetrySummary", ctx, arg)
|
||||
ret0, _ := ret[0].(database.CalculateAIBridgeInterceptionsTelemetrySummaryRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CalculateAIBridgeInterceptionsTelemetrySummary indicates an expected call of CalculateAIBridgeInterceptionsTelemetrySummary.
|
||||
func (mr *MockStoreMockRecorder) CalculateAIBridgeInterceptionsTelemetrySummary(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CalculateAIBridgeInterceptionsTelemetrySummary", reflect.TypeOf((*MockStore)(nil).CalculateAIBridgeInterceptionsTelemetrySummary), ctx, arg)
|
||||
}
|
||||
|
||||
// ClaimPrebuiltWorkspace mocks base method.
|
||||
func (m *MockStore) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -751,20 +736,6 @@ func (mr *MockStoreMockRecorder) DeleteOldProvisionerDaemons(ctx any) *gomock.Ca
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldProvisionerDaemons", reflect.TypeOf((*MockStore)(nil).DeleteOldProvisionerDaemons), ctx)
|
||||
}
|
||||
|
||||
// DeleteOldTelemetryLocks mocks base method.
|
||||
func (m *MockStore) DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteOldTelemetryLocks", ctx, periodEndingAtBefore)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteOldTelemetryLocks indicates an expected call of DeleteOldTelemetryLocks.
|
||||
func (mr *MockStoreMockRecorder) DeleteOldTelemetryLocks(ctx, periodEndingAtBefore any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldTelemetryLocks", reflect.TypeOf((*MockStore)(nil).DeleteOldTelemetryLocks), ctx, periodEndingAtBefore)
|
||||
}
|
||||
|
||||
// DeleteOldWorkspaceAgentLogs mocks base method.
|
||||
func (m *MockStore) DeleteOldWorkspaceAgentLogs(ctx context.Context, threshold time.Time) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2622,21 +2593,6 @@ func (mr *MockStoreMockRecorder) GetOrganizationsByUserID(ctx, arg any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationsByUserID", reflect.TypeOf((*MockStore)(nil).GetOrganizationsByUserID), ctx, arg)
|
||||
}
|
||||
|
||||
// GetOrganizationsWithPrebuildStatus mocks base method.
|
||||
func (m *MockStore) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg database.GetOrganizationsWithPrebuildStatusParams) ([]database.GetOrganizationsWithPrebuildStatusRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetOrganizationsWithPrebuildStatus", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetOrganizationsWithPrebuildStatusRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetOrganizationsWithPrebuildStatus indicates an expected call of GetOrganizationsWithPrebuildStatus.
|
||||
func (mr *MockStoreMockRecorder) GetOrganizationsWithPrebuildStatus(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationsWithPrebuildStatus", reflect.TypeOf((*MockStore)(nil).GetOrganizationsWithPrebuildStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// GetParameterSchemasByJobID mocks base method.
|
||||
func (m *MockStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -5436,20 +5392,6 @@ func (mr *MockStoreMockRecorder) InsertTelemetryItemIfNotExists(ctx, arg any) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertTelemetryItemIfNotExists", reflect.TypeOf((*MockStore)(nil).InsertTelemetryItemIfNotExists), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertTelemetryLock mocks base method.
|
||||
func (m *MockStore) InsertTelemetryLock(ctx context.Context, arg database.InsertTelemetryLockParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertTelemetryLock", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// InsertTelemetryLock indicates an expected call of InsertTelemetryLock.
|
||||
func (mr *MockStoreMockRecorder) InsertTelemetryLock(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertTelemetryLock", reflect.TypeOf((*MockStore)(nil).InsertTelemetryLock), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertTemplate mocks base method.
|
||||
func (m *MockStore) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -5905,21 +5847,6 @@ func (mr *MockStoreMockRecorder) ListAIBridgeInterceptions(ctx, arg any) *gomock
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).ListAIBridgeInterceptions), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeInterceptionsTelemetrySummaries mocks base method.
|
||||
func (m *MockStore) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg database.ListAIBridgeInterceptionsTelemetrySummariesParams) ([]database.ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAIBridgeInterceptionsTelemetrySummaries", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.ListAIBridgeInterceptionsTelemetrySummariesRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAIBridgeInterceptionsTelemetrySummaries indicates an expected call of ListAIBridgeInterceptionsTelemetrySummaries.
|
||||
func (mr *MockStoreMockRecorder) ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeInterceptionsTelemetrySummaries", reflect.TypeOf((*MockStore)(nil).ListAIBridgeInterceptionsTelemetrySummaries), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeTokenUsagesByInterceptionIDs mocks base method.
|
||||
func (m *MockStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6289,21 +6216,6 @@ func (mr *MockStoreMockRecorder) UnfavoriteWorkspace(ctx, id any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnfavoriteWorkspace", reflect.TypeOf((*MockStore)(nil).UnfavoriteWorkspace), ctx, id)
|
||||
}
|
||||
|
||||
// UpdateAIBridgeInterceptionEnded mocks base method.
|
||||
func (m *MockStore) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateAIBridgeInterceptionEnded", ctx, arg)
|
||||
ret0, _ := ret[0].(database.AIBridgeInterception)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateAIBridgeInterceptionEnded indicates an expected call of UpdateAIBridgeInterceptionEnded.
|
||||
func (mr *MockStoreMockRecorder) UpdateAIBridgeInterceptionEnded(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAIBridgeInterceptionEnded", reflect.TypeOf((*MockStore)(nil).UpdateAIBridgeInterceptionEnded), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateAPIKeyByID mocks base method.
|
||||
func (m *MockStore) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6555,10 +6467,10 @@ func (mr *MockStoreMockRecorder) UpdateOrganizationDeletedByID(ctx, arg any) *go
|
||||
}
|
||||
|
||||
// UpdatePrebuildProvisionerJobWithCancel mocks base method.
|
||||
func (m *MockStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
|
||||
func (m *MockStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdatePrebuildProvisionerJobWithCancel", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.UpdatePrebuildProvisionerJobWithCancelRow)
|
||||
ret0, _ := ret[0].([]uuid.UUID)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
@@ -24,12 +24,6 @@ const (
|
||||
// but we won't touch the `connection_logs` table.
|
||||
maxAuditLogConnectionEventAge = 90 * 24 * time.Hour // 90 days
|
||||
auditLogConnectionEventBatchSize = 1000
|
||||
// Telemetry heartbeats are used to deduplicate events across replicas. We
|
||||
// don't need to persist heartbeat rows for longer than 24 hours, as they
|
||||
// are only used for deduplication across replicas. The time needs to be
|
||||
// long enough to cover the maximum interval of a heartbeat event (currently
|
||||
// 1 hour) plus some buffer.
|
||||
maxTelemetryHeartbeatAge = 24 * time.Hour
|
||||
)
|
||||
|
||||
// New creates a new periodically purging database instance.
|
||||
@@ -77,10 +71,6 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, clk quartz.
|
||||
if err := tx.ExpirePrebuildsAPIKeys(ctx, dbtime.Time(start)); err != nil {
|
||||
return xerrors.Errorf("failed to expire prebuilds user api keys: %w", err)
|
||||
}
|
||||
deleteOldTelemetryLocksBefore := start.Add(-maxTelemetryHeartbeatAge)
|
||||
if err := tx.DeleteOldTelemetryLocks(ctx, deleteOldTelemetryLocksBefore); err != nil {
|
||||
return xerrors.Errorf("failed to delete old telemetry locks: %w", err)
|
||||
}
|
||||
|
||||
deleteOldAuditLogConnectionEventsBefore := start.Add(-maxAuditLogConnectionEventAge)
|
||||
if err := tx.DeleteOldAuditLogConnectionEvents(ctx, database.DeleteOldAuditLogConnectionEventsParams{
|
||||
|
||||
@@ -704,56 +704,3 @@ func TestExpireOldAPIKeys(t *testing.T) {
|
||||
// Out of an abundance of caution, we do not expire explicitly named prebuilds API keys.
|
||||
assertKeyActive(namedPrebuildsAPIKey.ID)
|
||||
}
|
||||
|
||||
func TestDeleteOldTelemetryHeartbeats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
clk := quartz.NewMock(t)
|
||||
now := clk.Now().UTC()
|
||||
|
||||
// Insert telemetry heartbeats.
|
||||
err := db.InsertTelemetryLock(ctx, database.InsertTelemetryLockParams{
|
||||
EventType: "aibridge_interceptions_summary",
|
||||
PeriodEndingAt: now.Add(-25 * time.Hour), // should be purged
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = db.InsertTelemetryLock(ctx, database.InsertTelemetryLockParams{
|
||||
EventType: "aibridge_interceptions_summary",
|
||||
PeriodEndingAt: now.Add(-23 * time.Hour), // should be kept
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = db.InsertTelemetryLock(ctx, database.InsertTelemetryLockParams{
|
||||
EventType: "aibridge_interceptions_summary",
|
||||
PeriodEndingAt: now, // should be kept
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, clk)
|
||||
defer closer.Close()
|
||||
<-done // doTick() has now run.
|
||||
|
||||
require.Eventuallyf(t, func() bool {
|
||||
// We use an SQL queries directly here because we don't expose queries
|
||||
// for deleting heartbeats in the application code.
|
||||
var totalCount int
|
||||
err := sqlDB.QueryRowContext(ctx, `
|
||||
SELECT COUNT(*) FROM telemetry_locks;
|
||||
`).Scan(&totalCount)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var oldCount int
|
||||
err = sqlDB.QueryRowContext(ctx, `
|
||||
SELECT COUNT(*) FROM telemetry_locks WHERE period_ending_at < $1;
|
||||
`, now.Add(-24*time.Hour)).Scan(&oldCount)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Expect 2 heartbeats remaining and none older than 24 hours.
|
||||
t.Logf("eventually: total count: %d, old count: %d", totalCount, oldCount)
|
||||
return totalCount == 2 && oldCount == 0
|
||||
}, testutil.WaitShort, testutil.IntervalFast, "it should delete old telemetry heartbeats")
|
||||
}
|
||||
|
||||
@@ -6,8 +6,6 @@ import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -47,8 +45,6 @@ func (b *Broker) Create(t TBSubset, opts ...OpenOption) (ConnectionParams, error
|
||||
host = defaultConnectionParams.Host
|
||||
port = defaultConnectionParams.Port
|
||||
)
|
||||
packageName := getTestPackageName(t)
|
||||
testName := t.Name()
|
||||
|
||||
// Use a time-based prefix to make it easier to find the database
|
||||
// when debugging.
|
||||
@@ -59,9 +55,9 @@ func (b *Broker) Create(t TBSubset, opts ...OpenOption) (ConnectionParams, error
|
||||
}
|
||||
dbName := now + "_" + dbSuffix
|
||||
|
||||
// TODO: add package and test name
|
||||
_, err = b.coderTestingDB.Exec(
|
||||
"INSERT INTO test_databases (name, process_uuid, test_package, test_name) VALUES ($1, $2, $3, $4)",
|
||||
dbName, b.uuid, packageName, testName)
|
||||
"INSERT INTO test_databases (name, process_uuid) VALUES ($1, $2)", dbName, b.uuid)
|
||||
if err != nil {
|
||||
return ConnectionParams{}, xerrors.Errorf("insert test_database row: %w", err)
|
||||
}
|
||||
@@ -108,10 +104,10 @@ func (b *Broker) clean(t TBSubset, dbName string) func() {
|
||||
func (b *Broker) init(t TBSubset) error {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
b.refCount++
|
||||
t.Cleanup(b.decRef)
|
||||
if b.coderTestingDB != nil {
|
||||
// already initialized
|
||||
b.refCount++
|
||||
t.Cleanup(b.decRef)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -128,8 +124,8 @@ func (b *Broker) init(t TBSubset) error {
|
||||
return xerrors.Errorf("open postgres connection: %w", err)
|
||||
}
|
||||
|
||||
// coderTestingSQLInit is idempotent, so we can run it every time.
|
||||
_, err = coderTestingDB.Exec(coderTestingSQLInit)
|
||||
// creating the db can succeed even if the database doesn't exist. Ping it to find out.
|
||||
err = coderTestingDB.Ping()
|
||||
var pqErr *pq.Error
|
||||
if xerrors.As(err, &pqErr) && pqErr.Code == "3D000" {
|
||||
// database does not exist.
|
||||
@@ -149,8 +145,6 @@ func (b *Broker) init(t TBSubset) error {
|
||||
return xerrors.Errorf("ping '%s' database: %w", CoderTestingDBName, err)
|
||||
}
|
||||
b.coderTestingDB = coderTestingDB
|
||||
b.refCount++
|
||||
t.Cleanup(b.decRef)
|
||||
|
||||
if b.uuid == uuid.Nil {
|
||||
b.uuid = uuid.New()
|
||||
@@ -192,42 +186,3 @@ func (b *Broker) decRef() {
|
||||
b.coderTestingDB = nil
|
||||
}
|
||||
}
|
||||
|
||||
// getTestPackageName returns the package name of the test that called it.
|
||||
func getTestPackageName(t TBSubset) string {
|
||||
packageName := "unknown"
|
||||
// Ask runtime.Callers for up to 100 program counters, including runtime.Callers itself.
|
||||
pc := make([]uintptr, 100)
|
||||
n := runtime.Callers(0, pc)
|
||||
if n == 0 {
|
||||
// No PCs available. This can happen if the first argument to
|
||||
// runtime.Callers is large.
|
||||
//
|
||||
// Return now to avoid processing the zero Frame that would
|
||||
// otherwise be returned by frames.Next below.
|
||||
t.Logf("could not determine test package name: no PCs available")
|
||||
return packageName
|
||||
}
|
||||
|
||||
pc = pc[:n] // pass only valid pcs to runtime.CallersFrames
|
||||
frames := runtime.CallersFrames(pc)
|
||||
|
||||
// Loop to get frames.
|
||||
// A fixed number of PCs can expand to an indefinite number of Frames.
|
||||
for {
|
||||
frame, more := frames.Next()
|
||||
|
||||
if strings.HasPrefix(frame.Function, "github.com/coder/coder/v2/") {
|
||||
packageName = strings.SplitN(strings.TrimPrefix(frame.Function, "github.com/coder/coder/v2/"), ".", 2)[0]
|
||||
}
|
||||
if strings.HasPrefix(frame.Function, "testing") {
|
||||
break
|
||||
}
|
||||
|
||||
// Check whether there are more frames to process after this one.
|
||||
if !more {
|
||||
break
|
||||
}
|
||||
}
|
||||
return packageName
|
||||
}
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
package dbtestutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetTestPackageName(t *testing.T) {
|
||||
t.Parallel()
|
||||
packageName := getTestPackageName(t)
|
||||
require.Equal(t, "coderd/database/dbtestutil", packageName)
|
||||
}
|
||||
@@ -1,6 +1,3 @@
|
||||
BEGIN TRANSACTION;
|
||||
SELECT pg_advisory_xact_lock(7283699);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS test_databases (
|
||||
name text PRIMARY KEY,
|
||||
created_at timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
@@ -9,10 +6,3 @@ CREATE TABLE IF NOT EXISTS test_databases (
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS test_databases_process_uuid ON test_databases (process_uuid, dropped_at);
|
||||
|
||||
ALTER TABLE test_databases ADD COLUMN IF NOT EXISTS test_name text;
|
||||
COMMENT ON COLUMN test_databases.test_name IS 'Name of the test that created the database';
|
||||
ALTER TABLE test_databases ADD COLUMN IF NOT EXISTS test_package text;
|
||||
COMMENT ON COLUMN test_databases.test_package IS 'Package of the test that created the database';
|
||||
|
||||
COMMIT;
|
||||
|
||||
Generated
+14
-41
@@ -1828,15 +1828,6 @@ CREATE TABLE tasks (
|
||||
deleted_at timestamp with time zone
|
||||
);
|
||||
|
||||
CREATE VIEW visible_users AS
|
||||
SELECT users.id,
|
||||
users.username,
|
||||
users.name,
|
||||
users.avatar_url
|
||||
FROM users;
|
||||
|
||||
COMMENT ON VIEW visible_users IS 'Visible fields of users are allowed to be joined with other tables for including context of other resources.';
|
||||
|
||||
CREATE TABLE workspace_agents (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
@@ -1987,16 +1978,8 @@ CREATE VIEW tasks_with_status AS
|
||||
END AS status,
|
||||
task_app.workspace_build_number,
|
||||
task_app.workspace_agent_id,
|
||||
task_app.workspace_app_id,
|
||||
task_owner.owner_username,
|
||||
task_owner.owner_name,
|
||||
task_owner.owner_avatar_url
|
||||
FROM (((((tasks
|
||||
CROSS JOIN LATERAL ( SELECT vu.username AS owner_username,
|
||||
vu.name AS owner_name,
|
||||
vu.avatar_url AS owner_avatar_url
|
||||
FROM visible_users vu
|
||||
WHERE (vu.id = tasks.owner_id)) task_owner)
|
||||
task_app.workspace_app_id
|
||||
FROM ((((tasks
|
||||
LEFT JOIN LATERAL ( SELECT task_app_1.workspace_build_number,
|
||||
task_app_1.workspace_agent_id,
|
||||
task_app_1.workspace_app_id
|
||||
@@ -2029,18 +2012,6 @@ CREATE TABLE telemetry_items (
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE telemetry_locks (
|
||||
event_type text NOT NULL,
|
||||
period_ending_at timestamp with time zone NOT NULL,
|
||||
CONSTRAINT telemetry_lock_event_type_constraint CHECK ((event_type = 'aibridge_interceptions_summary'::text))
|
||||
);
|
||||
|
||||
COMMENT ON TABLE telemetry_locks IS 'Telemetry lock tracking table for deduplication of heartbeat events across replicas.';
|
||||
|
||||
COMMENT ON COLUMN telemetry_locks.event_type IS 'The type of event that was sent.';
|
||||
|
||||
COMMENT ON COLUMN telemetry_locks.period_ending_at IS 'The heartbeat period end timestamp.';
|
||||
|
||||
CREATE TABLE template_usage_stats (
|
||||
start_time timestamp with time zone NOT NULL,
|
||||
end_time timestamp with time zone NOT NULL,
|
||||
@@ -2227,6 +2198,15 @@ COMMENT ON COLUMN template_versions.external_auth_providers IS 'IDs of External
|
||||
|
||||
COMMENT ON COLUMN template_versions.message IS 'Message describing the changes in this version of the template, similar to a Git commit message. Like a commit message, this should be a short, high-level description of the changes in this version of the template. This message is immutable and should not be updated after the fact.';
|
||||
|
||||
CREATE VIEW visible_users AS
|
||||
SELECT users.id,
|
||||
users.username,
|
||||
users.name,
|
||||
users.avatar_url
|
||||
FROM users;
|
||||
|
||||
COMMENT ON VIEW visible_users IS 'Visible fields of users are allowed to be joined with other tables for including context of other resources.';
|
||||
|
||||
CREATE VIEW template_version_with_user AS
|
||||
SELECT template_versions.id,
|
||||
template_versions.template_id,
|
||||
@@ -2922,13 +2902,11 @@ CREATE VIEW workspaces_expanded AS
|
||||
templates.name AS template_name,
|
||||
templates.display_name AS template_display_name,
|
||||
templates.icon AS template_icon,
|
||||
templates.description AS template_description,
|
||||
tasks.id AS task_id
|
||||
FROM ((((workspaces
|
||||
templates.description AS template_description
|
||||
FROM (((workspaces
|
||||
JOIN visible_users ON ((workspaces.owner_id = visible_users.id)))
|
||||
JOIN organizations ON ((workspaces.organization_id = organizations.id)))
|
||||
JOIN templates ON ((workspaces.template_id = templates.id)))
|
||||
LEFT JOIN tasks ON ((workspaces.id = tasks.workspace_id)));
|
||||
JOIN templates ON ((workspaces.template_id = templates.id)));
|
||||
|
||||
COMMENT ON VIEW workspaces_expanded IS 'Joins in the display name information such as username, avatar, and organization name.';
|
||||
|
||||
@@ -3112,9 +3090,6 @@ ALTER TABLE ONLY tasks
|
||||
ALTER TABLE ONLY telemetry_items
|
||||
ADD CONSTRAINT telemetry_items_pkey PRIMARY KEY (key);
|
||||
|
||||
ALTER TABLE ONLY telemetry_locks
|
||||
ADD CONSTRAINT telemetry_locks_pkey PRIMARY KEY (event_type, period_ending_at);
|
||||
|
||||
ALTER TABLE ONLY template_usage_stats
|
||||
ADD CONSTRAINT template_usage_stats_pkey PRIMARY KEY (start_time, template_id, user_id);
|
||||
|
||||
@@ -3340,8 +3315,6 @@ CREATE INDEX idx_tailnet_tunnels_dst_id ON tailnet_tunnels USING hash (dst_id);
|
||||
|
||||
CREATE INDEX idx_tailnet_tunnels_src_id ON tailnet_tunnels USING hash (src_id);
|
||||
|
||||
CREATE INDEX idx_telemetry_locks_period_ending_at ON telemetry_locks USING btree (period_ending_at);
|
||||
|
||||
CREATE UNIQUE INDEX idx_template_version_presets_default ON template_version_presets USING btree (template_version_id) WHERE (is_default = true);
|
||||
|
||||
CREATE INDEX idx_template_versions_has_ai_task ON template_versions USING btree (has_ai_task);
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
DROP TABLE telemetry_locks;
|
||||
@@ -1,12 +0,0 @@
|
||||
CREATE TABLE telemetry_locks (
|
||||
event_type TEXT NOT NULL CONSTRAINT telemetry_lock_event_type_constraint CHECK (event_type IN ('aibridge_interceptions_summary')),
|
||||
period_ending_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
|
||||
PRIMARY KEY (event_type, period_ending_at)
|
||||
);
|
||||
|
||||
COMMENT ON TABLE telemetry_locks IS 'Telemetry lock tracking table for deduplication of heartbeat events across replicas.';
|
||||
COMMENT ON COLUMN telemetry_locks.event_type IS 'The type of event that was sent.';
|
||||
COMMENT ON COLUMN telemetry_locks.period_ending_at IS 'The heartbeat period end timestamp.';
|
||||
|
||||
CREATE INDEX idx_telemetry_locks_period_ending_at ON telemetry_locks (period_ending_at);
|
||||
@@ -1,74 +0,0 @@
|
||||
-- Drop view from 000390_tasks_with_status_user_fields.up.sql.
|
||||
DROP VIEW IF EXISTS tasks_with_status;
|
||||
|
||||
-- Restore from 000382_add_columns_to_tasks_with_status.up.sql.
|
||||
CREATE VIEW
|
||||
tasks_with_status
|
||||
AS
|
||||
SELECT
|
||||
tasks.*,
|
||||
CASE
|
||||
WHEN tasks.workspace_id IS NULL OR latest_build.job_status IS NULL THEN 'pending'::task_status
|
||||
|
||||
WHEN latest_build.job_status = 'failed' THEN 'error'::task_status
|
||||
|
||||
WHEN latest_build.transition IN ('stop', 'delete')
|
||||
AND latest_build.job_status = 'succeeded' THEN 'paused'::task_status
|
||||
|
||||
WHEN latest_build.transition = 'start'
|
||||
AND latest_build.job_status = 'pending' THEN 'initializing'::task_status
|
||||
|
||||
WHEN latest_build.transition = 'start' AND latest_build.job_status IN ('running', 'succeeded') THEN
|
||||
CASE
|
||||
WHEN agent_status.none THEN 'initializing'::task_status
|
||||
WHEN agent_status.connecting THEN 'initializing'::task_status
|
||||
WHEN agent_status.connected THEN
|
||||
CASE
|
||||
WHEN app_status.any_unhealthy THEN 'error'::task_status
|
||||
WHEN app_status.any_initializing THEN 'initializing'::task_status
|
||||
WHEN app_status.all_healthy_or_disabled THEN 'active'::task_status
|
||||
ELSE 'unknown'::task_status
|
||||
END
|
||||
ELSE 'unknown'::task_status
|
||||
END
|
||||
|
||||
ELSE 'unknown'::task_status
|
||||
END AS status,
|
||||
task_app.*
|
||||
FROM
|
||||
tasks
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT workspace_build_number, workspace_agent_id, workspace_app_id
|
||||
FROM task_workspace_apps task_app
|
||||
WHERE task_id = tasks.id
|
||||
ORDER BY workspace_build_number DESC
|
||||
LIMIT 1
|
||||
) task_app ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT
|
||||
workspace_build.transition,
|
||||
provisioner_job.job_status,
|
||||
workspace_build.job_id
|
||||
FROM workspace_builds workspace_build
|
||||
JOIN provisioner_jobs provisioner_job ON provisioner_job.id = workspace_build.job_id
|
||||
WHERE workspace_build.workspace_id = tasks.workspace_id
|
||||
AND workspace_build.build_number = task_app.workspace_build_number
|
||||
) latest_build ON TRUE
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
COUNT(*) = 0 AS none,
|
||||
bool_or(workspace_agent.lifecycle_state IN ('created', 'starting')) AS connecting,
|
||||
bool_and(workspace_agent.lifecycle_state = 'ready') AS connected
|
||||
FROM workspace_agents workspace_agent
|
||||
WHERE workspace_agent.id = task_app.workspace_agent_id
|
||||
) agent_status
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
bool_or(workspace_app.health = 'unhealthy') AS any_unhealthy,
|
||||
bool_or(workspace_app.health = 'initializing') AS any_initializing,
|
||||
bool_and(workspace_app.health IN ('healthy', 'disabled')) AS all_healthy_or_disabled
|
||||
FROM workspace_apps workspace_app
|
||||
WHERE workspace_app.id = task_app.workspace_app_id
|
||||
) app_status
|
||||
WHERE
|
||||
tasks.deleted_at IS NULL;
|
||||
@@ -1,84 +0,0 @@
|
||||
-- Drop view from 00037_add_columns_to_tasks_with_status.up.sql.
|
||||
DROP VIEW IF EXISTS tasks_with_status;
|
||||
|
||||
-- Add owner_name, owner_avatar_url columns.
|
||||
CREATE VIEW
|
||||
tasks_with_status
|
||||
AS
|
||||
SELECT
|
||||
tasks.*,
|
||||
CASE
|
||||
WHEN tasks.workspace_id IS NULL OR latest_build.job_status IS NULL THEN 'pending'::task_status
|
||||
|
||||
WHEN latest_build.job_status = 'failed' THEN 'error'::task_status
|
||||
|
||||
WHEN latest_build.transition IN ('stop', 'delete')
|
||||
AND latest_build.job_status = 'succeeded' THEN 'paused'::task_status
|
||||
|
||||
WHEN latest_build.transition = 'start'
|
||||
AND latest_build.job_status = 'pending' THEN 'initializing'::task_status
|
||||
|
||||
WHEN latest_build.transition = 'start' AND latest_build.job_status IN ('running', 'succeeded') THEN
|
||||
CASE
|
||||
WHEN agent_status.none THEN 'initializing'::task_status
|
||||
WHEN agent_status.connecting THEN 'initializing'::task_status
|
||||
WHEN agent_status.connected THEN
|
||||
CASE
|
||||
WHEN app_status.any_unhealthy THEN 'error'::task_status
|
||||
WHEN app_status.any_initializing THEN 'initializing'::task_status
|
||||
WHEN app_status.all_healthy_or_disabled THEN 'active'::task_status
|
||||
ELSE 'unknown'::task_status
|
||||
END
|
||||
ELSE 'unknown'::task_status
|
||||
END
|
||||
|
||||
ELSE 'unknown'::task_status
|
||||
END AS status,
|
||||
task_app.*,
|
||||
task_owner.*
|
||||
FROM
|
||||
tasks
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
vu.username AS owner_username,
|
||||
vu.name AS owner_name,
|
||||
vu.avatar_url AS owner_avatar_url
|
||||
FROM visible_users vu
|
||||
WHERE vu.id = tasks.owner_id
|
||||
) task_owner
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT workspace_build_number, workspace_agent_id, workspace_app_id
|
||||
FROM task_workspace_apps task_app
|
||||
WHERE task_id = tasks.id
|
||||
ORDER BY workspace_build_number DESC
|
||||
LIMIT 1
|
||||
) task_app ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT
|
||||
workspace_build.transition,
|
||||
provisioner_job.job_status,
|
||||
workspace_build.job_id
|
||||
FROM workspace_builds workspace_build
|
||||
JOIN provisioner_jobs provisioner_job ON provisioner_job.id = workspace_build.job_id
|
||||
WHERE workspace_build.workspace_id = tasks.workspace_id
|
||||
AND workspace_build.build_number = task_app.workspace_build_number
|
||||
) latest_build ON TRUE
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
COUNT(*) = 0 AS none,
|
||||
bool_or(workspace_agent.lifecycle_state IN ('created', 'starting')) AS connecting,
|
||||
bool_and(workspace_agent.lifecycle_state = 'ready') AS connected
|
||||
FROM workspace_agents workspace_agent
|
||||
WHERE workspace_agent.id = task_app.workspace_agent_id
|
||||
) agent_status
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
bool_or(workspace_app.health = 'unhealthy') AS any_unhealthy,
|
||||
bool_or(workspace_app.health = 'initializing') AS any_initializing,
|
||||
bool_and(workspace_app.health IN ('healthy', 'disabled')) AS all_healthy_or_disabled
|
||||
FROM workspace_apps workspace_app
|
||||
WHERE workspace_app.id = task_app.workspace_app_id
|
||||
) app_status
|
||||
WHERE
|
||||
tasks.deleted_at IS NULL;
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
UPDATE notification_templates
|
||||
SET enabled_by_default = true
|
||||
WHERE id IN (
|
||||
'8c5a4d12-9f7e-4b3a-a1c8-6e4f2d9b5a7c',
|
||||
'3b7e8f1a-4c2d-49a6-b5e9-7f3a1c8d6b4e',
|
||||
'bd4b7168-d05e-4e19-ad0f-3593b77aa90f',
|
||||
'd4a6271c-cced-4ed0-84ad-afd02a9c7799'
|
||||
);
|
||||
@@ -1,8 +0,0 @@
|
||||
UPDATE notification_templates
|
||||
SET enabled_by_default = false
|
||||
WHERE id IN (
|
||||
'8c5a4d12-9f7e-4b3a-a1c8-6e4f2d9b5a7c',
|
||||
'3b7e8f1a-4c2d-49a6-b5e9-7f3a1c8d6b4e',
|
||||
'bd4b7168-d05e-4e19-ad0f-3593b77aa90f',
|
||||
'd4a6271c-cced-4ed0-84ad-afd02a9c7799'
|
||||
);
|
||||
@@ -1,39 +0,0 @@
|
||||
DROP VIEW workspaces_expanded;
|
||||
|
||||
-- Recreate the view from 000354_workspace_acl.up.sql
|
||||
CREATE VIEW workspaces_expanded AS
|
||||
SELECT workspaces.id,
|
||||
workspaces.created_at,
|
||||
workspaces.updated_at,
|
||||
workspaces.owner_id,
|
||||
workspaces.organization_id,
|
||||
workspaces.template_id,
|
||||
workspaces.deleted,
|
||||
workspaces.name,
|
||||
workspaces.autostart_schedule,
|
||||
workspaces.ttl,
|
||||
workspaces.last_used_at,
|
||||
workspaces.dormant_at,
|
||||
workspaces.deleting_at,
|
||||
workspaces.automatic_updates,
|
||||
workspaces.favorite,
|
||||
workspaces.next_start_at,
|
||||
workspaces.group_acl,
|
||||
workspaces.user_acl,
|
||||
visible_users.avatar_url AS owner_avatar_url,
|
||||
visible_users.username AS owner_username,
|
||||
visible_users.name AS owner_name,
|
||||
organizations.name AS organization_name,
|
||||
organizations.display_name AS organization_display_name,
|
||||
organizations.icon AS organization_icon,
|
||||
organizations.description AS organization_description,
|
||||
templates.name AS template_name,
|
||||
templates.display_name AS template_display_name,
|
||||
templates.icon AS template_icon,
|
||||
templates.description AS template_description
|
||||
FROM (((workspaces
|
||||
JOIN visible_users ON ((workspaces.owner_id = visible_users.id)))
|
||||
JOIN organizations ON ((workspaces.organization_id = organizations.id)))
|
||||
JOIN templates ON ((workspaces.template_id = templates.id)));
|
||||
|
||||
COMMENT ON VIEW workspaces_expanded IS 'Joins in the display name information such as username, avatar, and organization name.';
|
||||
@@ -1,42 +0,0 @@
|
||||
DROP VIEW workspaces_expanded;
|
||||
|
||||
-- Add nullable task_id to workspaces_expanded view
|
||||
CREATE VIEW workspaces_expanded AS
|
||||
SELECT workspaces.id,
|
||||
workspaces.created_at,
|
||||
workspaces.updated_at,
|
||||
workspaces.owner_id,
|
||||
workspaces.organization_id,
|
||||
workspaces.template_id,
|
||||
workspaces.deleted,
|
||||
workspaces.name,
|
||||
workspaces.autostart_schedule,
|
||||
workspaces.ttl,
|
||||
workspaces.last_used_at,
|
||||
workspaces.dormant_at,
|
||||
workspaces.deleting_at,
|
||||
workspaces.automatic_updates,
|
||||
workspaces.favorite,
|
||||
workspaces.next_start_at,
|
||||
workspaces.group_acl,
|
||||
workspaces.user_acl,
|
||||
visible_users.avatar_url AS owner_avatar_url,
|
||||
visible_users.username AS owner_username,
|
||||
visible_users.name AS owner_name,
|
||||
organizations.name AS organization_name,
|
||||
organizations.display_name AS organization_display_name,
|
||||
organizations.icon AS organization_icon,
|
||||
organizations.description AS organization_description,
|
||||
templates.name AS template_name,
|
||||
templates.display_name AS template_display_name,
|
||||
templates.icon AS template_icon,
|
||||
templates.description AS template_description,
|
||||
tasks.id AS task_id
|
||||
FROM ((((workspaces
|
||||
JOIN visible_users ON ((workspaces.owner_id = visible_users.id)))
|
||||
JOIN organizations ON ((workspaces.organization_id = organizations.id)))
|
||||
JOIN templates ON ((workspaces.template_id = templates.id)))
|
||||
LEFT JOIN tasks ON ((workspaces.id = tasks.workspace_id)));
|
||||
|
||||
COMMENT ON VIEW workspaces_expanded IS 'Joins in the display name information such as username, avatar, and organization name.';
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
INSERT INTO telemetry_locks (
|
||||
event_type,
|
||||
period_ending_at
|
||||
)
|
||||
VALUES (
|
||||
'aibridge_interceptions_summary',
|
||||
'2025-01-01 00:00:00+00'::timestamptz
|
||||
);
|
||||
@@ -208,7 +208,6 @@ func (s APIKeyScopes) expandRBACScope() (rbac.Scope, error) {
|
||||
for orgID, perms := range expanded.ByOrgID {
|
||||
orgPerms := merged.ByOrgID[orgID]
|
||||
orgPerms.Org = append(orgPerms.Org, perms.Org...)
|
||||
orgPerms.Member = append(orgPerms.Member, perms.Member...)
|
||||
merged.ByOrgID[orgID] = orgPerms
|
||||
}
|
||||
merged.User = append(merged.User, expanded.User...)
|
||||
@@ -221,7 +220,6 @@ func (s APIKeyScopes) expandRBACScope() (rbac.Scope, error) {
|
||||
merged.User = rbac.DeduplicatePermissions(merged.User)
|
||||
for orgID, perms := range merged.ByOrgID {
|
||||
perms.Org = rbac.DeduplicatePermissions(perms.Org)
|
||||
perms.Member = rbac.DeduplicatePermissions(perms.Member)
|
||||
merged.ByOrgID[orgID] = perms
|
||||
}
|
||||
|
||||
|
||||
@@ -321,7 +321,6 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
|
||||
&i.TemplateDisplayName,
|
||||
&i.TemplateIcon,
|
||||
&i.TemplateDescription,
|
||||
&i.TaskID,
|
||||
&i.TemplateVersionID,
|
||||
&i.TemplateVersionName,
|
||||
&i.LatestBuildCompletedAt,
|
||||
@@ -329,6 +328,7 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
|
||||
&i.LatestBuildError,
|
||||
&i.LatestBuildTransition,
|
||||
&i.LatestBuildStatus,
|
||||
&i.LatestBuildHasAITask,
|
||||
&i.LatestBuildHasExternalAgent,
|
||||
&i.Count,
|
||||
); err != nil {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// sqlc v1.27.0
|
||||
|
||||
package database
|
||||
|
||||
@@ -4221,9 +4221,6 @@ type Task struct {
|
||||
WorkspaceBuildNumber sql.NullInt32 `db:"workspace_build_number" json:"workspace_build_number"`
|
||||
WorkspaceAgentID uuid.NullUUID `db:"workspace_agent_id" json:"workspace_agent_id"`
|
||||
WorkspaceAppID uuid.NullUUID `db:"workspace_app_id" json:"workspace_app_id"`
|
||||
OwnerUsername string `db:"owner_username" json:"owner_username"`
|
||||
OwnerName string `db:"owner_name" json:"owner_name"`
|
||||
OwnerAvatarUrl string `db:"owner_avatar_url" json:"owner_avatar_url"`
|
||||
}
|
||||
|
||||
type TaskTable struct {
|
||||
@@ -4253,14 +4250,6 @@ type TelemetryItem struct {
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
// Telemetry lock tracking table for deduplication of heartbeat events across replicas.
|
||||
type TelemetryLock struct {
|
||||
// The type of event that was sent.
|
||||
EventType string `db:"event_type" json:"event_type"`
|
||||
// The heartbeat period end timestamp.
|
||||
PeriodEndingAt time.Time `db:"period_ending_at" json:"period_ending_at"`
|
||||
}
|
||||
|
||||
// Joins in the display name information such as username, avatar, and organization name.
|
||||
type Template struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
@@ -4663,7 +4652,6 @@ type Workspace struct {
|
||||
TemplateDisplayName string `db:"template_display_name" json:"template_display_name"`
|
||||
TemplateIcon string `db:"template_icon" json:"template_icon"`
|
||||
TemplateDescription string `db:"template_description" json:"template_description"`
|
||||
TaskID uuid.NullUUID `db:"task_id" json:"task_id"`
|
||||
}
|
||||
|
||||
type WorkspaceAgent struct {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// sqlc v1.27.0
|
||||
|
||||
package database
|
||||
|
||||
@@ -60,9 +60,6 @@ type sqlcQuerier interface {
|
||||
BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error
|
||||
BulkMarkNotificationMessagesFailed(ctx context.Context, arg BulkMarkNotificationMessagesFailedParams) (int64, error)
|
||||
BulkMarkNotificationMessagesSent(ctx context.Context, arg BulkMarkNotificationMessagesSentParams) (int64, error)
|
||||
// Calculates the telemetry summary for a given provider, model, and client
|
||||
// combination for telemetry reporting.
|
||||
CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg CalculateAIBridgeInterceptionsTelemetrySummaryParams) (CalculateAIBridgeInterceptionsTelemetrySummaryRow, error)
|
||||
ClaimPrebuiltWorkspace(ctx context.Context, arg ClaimPrebuiltWorkspaceParams) (ClaimPrebuiltWorkspaceRow, error)
|
||||
CleanTailnetCoordinators(ctx context.Context) error
|
||||
CleanTailnetLostPeers(ctx context.Context) error
|
||||
@@ -110,8 +107,6 @@ type sqlcQuerier interface {
|
||||
// A provisioner daemon with "zeroed" last_seen_at column indicates possible
|
||||
// connectivity issues (no provisioner daemon activity since registration).
|
||||
DeleteOldProvisionerDaemons(ctx context.Context) error
|
||||
// Deletes old telemetry locks from the telemetry_locks table.
|
||||
DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error
|
||||
// If an agent hasn't connected in the last 7 days, we purge it's logs.
|
||||
// Exception: if the logs are related to the latest build, we keep those around.
|
||||
// Logs can take up a lot of space, so it's important we clean up frequently.
|
||||
@@ -269,9 +264,6 @@ type sqlcQuerier interface {
|
||||
GetOrganizationResourceCountByID(ctx context.Context, organizationID uuid.UUID) (GetOrganizationResourceCountByIDRow, error)
|
||||
GetOrganizations(ctx context.Context, arg GetOrganizationsParams) ([]Organization, error)
|
||||
GetOrganizationsByUserID(ctx context.Context, arg GetOrganizationsByUserIDParams) ([]Organization, error)
|
||||
// GetOrganizationsWithPrebuildStatus returns organizations with prebuilds configured and their
|
||||
// membership status for the prebuilds system user (org membership, group existence, group membership).
|
||||
GetOrganizationsWithPrebuildStatus(ctx context.Context, arg GetOrganizationsWithPrebuildStatusParams) ([]GetOrganizationsWithPrebuildStatusRow, error)
|
||||
GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error)
|
||||
GetPrebuildMetrics(ctx context.Context) ([]GetPrebuildMetricsRow, error)
|
||||
GetPrebuildsSettings(ctx context.Context) (string, error)
|
||||
@@ -567,12 +559,6 @@ type sqlcQuerier interface {
|
||||
InsertReplica(ctx context.Context, arg InsertReplicaParams) (Replica, error)
|
||||
InsertTask(ctx context.Context, arg InsertTaskParams) (TaskTable, error)
|
||||
InsertTelemetryItemIfNotExists(ctx context.Context, arg InsertTelemetryItemIfNotExistsParams) error
|
||||
// Inserts a new lock row into the telemetry_locks table. Replicas should call
|
||||
// this function prior to attempting to generate or publish a heartbeat event to
|
||||
// the telemetry service.
|
||||
// If the query returns a duplicate primary key error, the replica should not
|
||||
// attempt to generate or publish the event to the telemetry service.
|
||||
InsertTelemetryLock(ctx context.Context, arg InsertTelemetryLockParams) error
|
||||
InsertTemplate(ctx context.Context, arg InsertTemplateParams) error
|
||||
InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) error
|
||||
InsertTemplateVersionParameter(ctx context.Context, arg InsertTemplateVersionParameterParams) (TemplateVersionParameter, error)
|
||||
@@ -609,9 +595,6 @@ type sqlcQuerier interface {
|
||||
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
|
||||
InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error)
|
||||
ListAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams) ([]ListAIBridgeInterceptionsRow, error)
|
||||
// Finds all unique AIBridge interception telemetry summaries combinations
|
||||
// (provider, model, client) in the given timeframe for telemetry reporting.
|
||||
ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error)
|
||||
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
|
||||
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
|
||||
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
|
||||
@@ -649,7 +632,6 @@ type sqlcQuerier interface {
|
||||
// This will always work regardless of the current state of the template version.
|
||||
UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error
|
||||
UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error
|
||||
UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error)
|
||||
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
|
||||
UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error)
|
||||
UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error)
|
||||
@@ -670,7 +652,7 @@ type sqlcQuerier interface {
|
||||
// Cancels all pending provisioner jobs for prebuilt workspaces on a specific preset from an
|
||||
// inactive template version.
|
||||
// This is an optimization to clean up stale pending jobs.
|
||||
UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg UpdatePrebuildProvisionerJobWithCancelParams) ([]UpdatePrebuildProvisionerJobWithCancelRow, error)
|
||||
UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error)
|
||||
UpdatePresetPrebuildStatus(ctx context.Context, arg UpdatePresetPrebuildStatusParams) error
|
||||
UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg UpdateProvisionerDaemonLastSeenAtParams) error
|
||||
UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error
|
||||
|
||||
@@ -7248,9 +7248,7 @@ func TestTaskNameUniqueness(t *testing.T) {
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
taskID := uuid.New()
|
||||
task, err := db.InsertTask(ctx, database.InsertTaskParams{
|
||||
ID: taskID,
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: tt.ownerID,
|
||||
Name: tt.taskName,
|
||||
@@ -7265,7 +7263,6 @@ func TestTaskNameUniqueness(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, uuid.Nil, task.ID)
|
||||
require.NotEqual(t, task1.ID, task.ID)
|
||||
require.Equal(t, taskID, task.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -7727,68 +7724,3 @@ func TestUpdateTaskWorkspaceID(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAIBridgeInterceptionEnded(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
t.Run("NonExistingInterception", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
got, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: uuid.New(),
|
||||
EndedAt: time.Now(),
|
||||
})
|
||||
require.ErrorContains(t, err, "no rows in result set")
|
||||
require.EqualValues(t, database.AIBridgeInterception{}, got)
|
||||
})
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
interceptions := []database.AIBridgeInterception{}
|
||||
|
||||
for _, uid := range []uuid.UUID{{1}, {2}, {3}} {
|
||||
insertParams := database.InsertAIBridgeInterceptionParams{
|
||||
ID: uid,
|
||||
InitiatorID: user.ID,
|
||||
Metadata: json.RawMessage("{}"),
|
||||
}
|
||||
|
||||
intc, err := db.InsertAIBridgeInterception(ctx, insertParams)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uid, intc.ID)
|
||||
require.False(t, intc.EndedAt.Valid)
|
||||
interceptions = append(interceptions, intc)
|
||||
}
|
||||
|
||||
intc0 := interceptions[0]
|
||||
endedAt := time.Now()
|
||||
// Mark first interception as done
|
||||
updated, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: intc0.ID,
|
||||
EndedAt: endedAt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, updated.ID, intc0.ID)
|
||||
require.True(t, updated.EndedAt.Valid)
|
||||
require.WithinDuration(t, endedAt, updated.EndedAt.Time, 5*time.Second)
|
||||
|
||||
// Updating first interception again should fail
|
||||
updated, err = db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: intc0.ID,
|
||||
EndedAt: endedAt.Add(time.Hour),
|
||||
})
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
|
||||
// Other interceptions should not have ended_at set
|
||||
for _, intc := range interceptions[1:] {
|
||||
got, err := db.GetAIBridgeInterceptionByID(ctx, intc.ID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, got.EndedAt.Valid)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
+48
-431
@@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// sqlc v1.27.0
|
||||
|
||||
package database
|
||||
|
||||
@@ -111,164 +111,6 @@ func (q *sqlQuerier) ActivityBumpWorkspace(ctx context.Context, arg ActivityBump
|
||||
return err
|
||||
}
|
||||
|
||||
const calculateAIBridgeInterceptionsTelemetrySummary = `-- name: CalculateAIBridgeInterceptionsTelemetrySummary :one
|
||||
WITH interceptions_in_range AS (
|
||||
-- Get all matching interceptions in the given timeframe.
|
||||
SELECT
|
||||
id,
|
||||
initiator_id,
|
||||
(ended_at - started_at) AS duration
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
provider = $1::text
|
||||
AND model = $2::text
|
||||
-- TODO: use the client value once we have it (see https://github.com/coder/aibridge/issues/31)
|
||||
AND 'unknown' = $3::text
|
||||
AND ended_at IS NOT NULL -- incomplete interceptions are not included in summaries
|
||||
AND ended_at >= $4::timestamptz
|
||||
AND ended_at < $5::timestamptz
|
||||
),
|
||||
interception_counts AS (
|
||||
SELECT
|
||||
COUNT(id) AS interception_count,
|
||||
COUNT(DISTINCT initiator_id) AS unique_initiator_count
|
||||
FROM
|
||||
interceptions_in_range
|
||||
),
|
||||
duration_percentiles AS (
|
||||
SELECT
|
||||
(COALESCE(PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p50_millis,
|
||||
(COALESCE(PERCENTILE_CONT(0.90) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p90_millis,
|
||||
(COALESCE(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p95_millis,
|
||||
(COALESCE(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p99_millis
|
||||
FROM
|
||||
interceptions_in_range
|
||||
),
|
||||
token_aggregates AS (
|
||||
SELECT
|
||||
COALESCE(SUM(tu.input_tokens), 0) AS token_count_input,
|
||||
COALESCE(SUM(tu.output_tokens), 0) AS token_count_output,
|
||||
-- Cached tokens are stored in metadata JSON, extract if available.
|
||||
-- Read tokens may be stored in:
|
||||
-- - cache_read_input (Anthropic)
|
||||
-- - prompt_cached (OpenAI)
|
||||
COALESCE(SUM(
|
||||
COALESCE((tu.metadata->>'cache_read_input')::bigint, 0) +
|
||||
COALESCE((tu.metadata->>'prompt_cached')::bigint, 0)
|
||||
), 0) AS token_count_cached_read,
|
||||
-- Written tokens may be stored in:
|
||||
-- - cache_creation_input (Anthropic)
|
||||
-- Note that cache_ephemeral_5m_input and cache_ephemeral_1h_input on
|
||||
-- Anthropic are included in the cache_creation_input field.
|
||||
COALESCE(SUM(
|
||||
COALESCE((tu.metadata->>'cache_creation_input')::bigint, 0)
|
||||
), 0) AS token_count_cached_written,
|
||||
COUNT(tu.id) AS token_usages_count
|
||||
FROM
|
||||
interceptions_in_range i
|
||||
LEFT JOIN
|
||||
aibridge_token_usages tu ON i.id = tu.interception_id
|
||||
),
|
||||
prompt_aggregates AS (
|
||||
SELECT
|
||||
COUNT(up.id) AS user_prompts_count
|
||||
FROM
|
||||
interceptions_in_range i
|
||||
LEFT JOIN
|
||||
aibridge_user_prompts up ON i.id = up.interception_id
|
||||
),
|
||||
tool_aggregates AS (
|
||||
SELECT
|
||||
COUNT(tu.id) FILTER (WHERE tu.injected = true) AS tool_calls_count_injected,
|
||||
COUNT(tu.id) FILTER (WHERE tu.injected = false) AS tool_calls_count_non_injected,
|
||||
COUNT(tu.id) FILTER (WHERE tu.injected = true AND tu.invocation_error IS NOT NULL) AS injected_tool_call_error_count
|
||||
FROM
|
||||
interceptions_in_range i
|
||||
LEFT JOIN
|
||||
aibridge_tool_usages tu ON i.id = tu.interception_id
|
||||
)
|
||||
SELECT
|
||||
ic.interception_count::bigint AS interception_count,
|
||||
dp.interception_duration_p50_millis::bigint AS interception_duration_p50_millis,
|
||||
dp.interception_duration_p90_millis::bigint AS interception_duration_p90_millis,
|
||||
dp.interception_duration_p95_millis::bigint AS interception_duration_p95_millis,
|
||||
dp.interception_duration_p99_millis::bigint AS interception_duration_p99_millis,
|
||||
ic.unique_initiator_count::bigint AS unique_initiator_count,
|
||||
pa.user_prompts_count::bigint AS user_prompts_count,
|
||||
tok_agg.token_usages_count::bigint AS token_usages_count,
|
||||
tok_agg.token_count_input::bigint AS token_count_input,
|
||||
tok_agg.token_count_output::bigint AS token_count_output,
|
||||
tok_agg.token_count_cached_read::bigint AS token_count_cached_read,
|
||||
tok_agg.token_count_cached_written::bigint AS token_count_cached_written,
|
||||
tool_agg.tool_calls_count_injected::bigint AS tool_calls_count_injected,
|
||||
tool_agg.tool_calls_count_non_injected::bigint AS tool_calls_count_non_injected,
|
||||
tool_agg.injected_tool_call_error_count::bigint AS injected_tool_call_error_count
|
||||
FROM
|
||||
interception_counts ic,
|
||||
duration_percentiles dp,
|
||||
token_aggregates tok_agg,
|
||||
prompt_aggregates pa,
|
||||
tool_aggregates tool_agg
|
||||
`
|
||||
|
||||
type CalculateAIBridgeInterceptionsTelemetrySummaryParams struct {
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
Model string `db:"model" json:"model"`
|
||||
Client string `db:"client" json:"client"`
|
||||
EndedAtAfter time.Time `db:"ended_at_after" json:"ended_at_after"`
|
||||
EndedAtBefore time.Time `db:"ended_at_before" json:"ended_at_before"`
|
||||
}
|
||||
|
||||
type CalculateAIBridgeInterceptionsTelemetrySummaryRow struct {
|
||||
InterceptionCount int64 `db:"interception_count" json:"interception_count"`
|
||||
InterceptionDurationP50Millis int64 `db:"interception_duration_p50_millis" json:"interception_duration_p50_millis"`
|
||||
InterceptionDurationP90Millis int64 `db:"interception_duration_p90_millis" json:"interception_duration_p90_millis"`
|
||||
InterceptionDurationP95Millis int64 `db:"interception_duration_p95_millis" json:"interception_duration_p95_millis"`
|
||||
InterceptionDurationP99Millis int64 `db:"interception_duration_p99_millis" json:"interception_duration_p99_millis"`
|
||||
UniqueInitiatorCount int64 `db:"unique_initiator_count" json:"unique_initiator_count"`
|
||||
UserPromptsCount int64 `db:"user_prompts_count" json:"user_prompts_count"`
|
||||
TokenUsagesCount int64 `db:"token_usages_count" json:"token_usages_count"`
|
||||
TokenCountInput int64 `db:"token_count_input" json:"token_count_input"`
|
||||
TokenCountOutput int64 `db:"token_count_output" json:"token_count_output"`
|
||||
TokenCountCachedRead int64 `db:"token_count_cached_read" json:"token_count_cached_read"`
|
||||
TokenCountCachedWritten int64 `db:"token_count_cached_written" json:"token_count_cached_written"`
|
||||
ToolCallsCountInjected int64 `db:"tool_calls_count_injected" json:"tool_calls_count_injected"`
|
||||
ToolCallsCountNonInjected int64 `db:"tool_calls_count_non_injected" json:"tool_calls_count_non_injected"`
|
||||
InjectedToolCallErrorCount int64 `db:"injected_tool_call_error_count" json:"injected_tool_call_error_count"`
|
||||
}
|
||||
|
||||
// Calculates the telemetry summary for a given provider, model, and client
|
||||
// combination for telemetry reporting.
|
||||
func (q *sqlQuerier) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg CalculateAIBridgeInterceptionsTelemetrySummaryParams) (CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
|
||||
row := q.db.QueryRowContext(ctx, calculateAIBridgeInterceptionsTelemetrySummary,
|
||||
arg.Provider,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.EndedAtAfter,
|
||||
arg.EndedAtBefore,
|
||||
)
|
||||
var i CalculateAIBridgeInterceptionsTelemetrySummaryRow
|
||||
err := row.Scan(
|
||||
&i.InterceptionCount,
|
||||
&i.InterceptionDurationP50Millis,
|
||||
&i.InterceptionDurationP90Millis,
|
||||
&i.InterceptionDurationP95Millis,
|
||||
&i.InterceptionDurationP99Millis,
|
||||
&i.UniqueInitiatorCount,
|
||||
&i.UserPromptsCount,
|
||||
&i.TokenUsagesCount,
|
||||
&i.TokenCountInput,
|
||||
&i.TokenCountOutput,
|
||||
&i.TokenCountCachedRead,
|
||||
&i.TokenCountCachedWritten,
|
||||
&i.ToolCallsCountInjected,
|
||||
&i.ToolCallsCountNonInjected,
|
||||
&i.InjectedToolCallErrorCount,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const countAIBridgeInterceptions = `-- name: CountAIBridgeInterceptions :one
|
||||
SELECT
|
||||
COUNT(*)
|
||||
@@ -805,57 +647,6 @@ func (q *sqlQuerier) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBr
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listAIBridgeInterceptionsTelemetrySummaries = `-- name: ListAIBridgeInterceptionsTelemetrySummaries :many
|
||||
SELECT
|
||||
DISTINCT ON (provider, model, client)
|
||||
provider,
|
||||
model,
|
||||
-- TODO: use the client value once we have it (see https://github.com/coder/aibridge/issues/31)
|
||||
'unknown' AS client
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
ended_at IS NOT NULL -- incomplete interceptions are not included in summaries
|
||||
AND ended_at >= $1::timestamptz
|
||||
AND ended_at < $2::timestamptz
|
||||
`
|
||||
|
||||
type ListAIBridgeInterceptionsTelemetrySummariesParams struct {
|
||||
EndedAtAfter time.Time `db:"ended_at_after" json:"ended_at_after"`
|
||||
EndedAtBefore time.Time `db:"ended_at_before" json:"ended_at_before"`
|
||||
}
|
||||
|
||||
type ListAIBridgeInterceptionsTelemetrySummariesRow struct {
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
Model string `db:"model" json:"model"`
|
||||
Client string `db:"client" json:"client"`
|
||||
}
|
||||
|
||||
// Finds all unique AIBridge interception telemetry summaries combinations
|
||||
// (provider, model, client) in the given timeframe for telemetry reporting.
|
||||
func (q *sqlQuerier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, listAIBridgeInterceptionsTelemetrySummaries, arg.EndedAtAfter, arg.EndedAtBefore)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ListAIBridgeInterceptionsTelemetrySummariesRow
|
||||
for rows.Next() {
|
||||
var i ListAIBridgeInterceptionsTelemetrySummariesRow
|
||||
if err := rows.Scan(&i.Provider, &i.Model, &i.Client); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listAIBridgeTokenUsagesByInterceptionIDs = `-- name: ListAIBridgeTokenUsagesByInterceptionIDs :many
|
||||
SELECT
|
||||
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
|
||||
@@ -987,35 +778,6 @@ func (q *sqlQuerier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Contex
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateAIBridgeInterceptionEnded = `-- name: UpdateAIBridgeInterceptionEnded :one
|
||||
UPDATE aibridge_interceptions
|
||||
SET ended_at = $1::timestamptz
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
AND ended_at IS NULL
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at
|
||||
`
|
||||
|
||||
type UpdateAIBridgeInterceptionEndedParams struct {
|
||||
EndedAt time.Time `db:"ended_at" json:"ended_at"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateAIBridgeInterceptionEnded, arg.EndedAt, arg.ID)
|
||||
var i AIBridgeInterception
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.InitiatorID,
|
||||
&i.Provider,
|
||||
&i.Model,
|
||||
&i.StartedAt,
|
||||
&i.Metadata,
|
||||
&i.EndedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteAPIKeyByID = `-- name: DeleteAPIKeyByID :exec
|
||||
DELETE FROM
|
||||
api_keys
|
||||
@@ -8285,93 +8047,6 @@ func (q *sqlQuerier) FindMatchingPresetID(ctx context.Context, arg FindMatchingP
|
||||
return template_version_preset_id, err
|
||||
}
|
||||
|
||||
const getOrganizationsWithPrebuildStatus = `-- name: GetOrganizationsWithPrebuildStatus :many
|
||||
WITH orgs_with_prebuilds AS (
|
||||
-- Get unique organizations that have presets with prebuilds configured
|
||||
SELECT DISTINCT o.id, o.name
|
||||
FROM organizations o
|
||||
INNER JOIN templates t ON t.organization_id = o.id
|
||||
INNER JOIN template_versions tv ON tv.template_id = t.id
|
||||
INNER JOIN template_version_presets tvp ON tvp.template_version_id = tv.id
|
||||
WHERE tvp.desired_instances IS NOT NULL
|
||||
),
|
||||
prebuild_user_membership AS (
|
||||
-- Check if the user is a member of the organizations
|
||||
SELECT om.organization_id
|
||||
FROM organization_members om
|
||||
INNER JOIN orgs_with_prebuilds owp ON owp.id = om.organization_id
|
||||
WHERE om.user_id = $1::uuid
|
||||
),
|
||||
prebuild_groups AS (
|
||||
-- Check if the organizations have the prebuilds group
|
||||
SELECT g.organization_id, g.id as group_id
|
||||
FROM groups g
|
||||
INNER JOIN orgs_with_prebuilds owp ON owp.id = g.organization_id
|
||||
WHERE g.name = $2::text
|
||||
),
|
||||
prebuild_group_membership AS (
|
||||
-- Check if the user is in the prebuilds group
|
||||
SELECT pg.organization_id
|
||||
FROM prebuild_groups pg
|
||||
INNER JOIN group_members gm ON gm.group_id = pg.group_id
|
||||
WHERE gm.user_id = $1::uuid
|
||||
)
|
||||
SELECT
|
||||
owp.id AS organization_id,
|
||||
owp.name AS organization_name,
|
||||
(pum.organization_id IS NOT NULL)::boolean AS has_prebuild_user,
|
||||
pg.group_id AS prebuilds_group_id,
|
||||
(pgm.organization_id IS NOT NULL)::boolean AS has_prebuild_user_in_group
|
||||
FROM orgs_with_prebuilds owp
|
||||
LEFT JOIN prebuild_groups pg ON pg.organization_id = owp.id
|
||||
LEFT JOIN prebuild_user_membership pum ON pum.organization_id = owp.id
|
||||
LEFT JOIN prebuild_group_membership pgm ON pgm.organization_id = owp.id
|
||||
`
|
||||
|
||||
type GetOrganizationsWithPrebuildStatusParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
GroupName string `db:"group_name" json:"group_name"`
|
||||
}
|
||||
|
||||
type GetOrganizationsWithPrebuildStatusRow struct {
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
OrganizationName string `db:"organization_name" json:"organization_name"`
|
||||
HasPrebuildUser bool `db:"has_prebuild_user" json:"has_prebuild_user"`
|
||||
PrebuildsGroupID uuid.NullUUID `db:"prebuilds_group_id" json:"prebuilds_group_id"`
|
||||
HasPrebuildUserInGroup bool `db:"has_prebuild_user_in_group" json:"has_prebuild_user_in_group"`
|
||||
}
|
||||
|
||||
// GetOrganizationsWithPrebuildStatus returns organizations with prebuilds configured and their
|
||||
// membership status for the prebuilds system user (org membership, group existence, group membership).
|
||||
func (q *sqlQuerier) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg GetOrganizationsWithPrebuildStatusParams) ([]GetOrganizationsWithPrebuildStatusRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getOrganizationsWithPrebuildStatus, arg.UserID, arg.GroupName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetOrganizationsWithPrebuildStatusRow
|
||||
for rows.Next() {
|
||||
var i GetOrganizationsWithPrebuildStatusRow
|
||||
if err := rows.Scan(
|
||||
&i.OrganizationID,
|
||||
&i.OrganizationName,
|
||||
&i.HasPrebuildUser,
|
||||
&i.PrebuildsGroupID,
|
||||
&i.HasPrebuildUserInGroup,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getPrebuildMetrics = `-- name: GetPrebuildMetrics :many
|
||||
SELECT
|
||||
t.name as template_name,
|
||||
@@ -8774,8 +8449,12 @@ func (q *sqlQuerier) GetTemplatePresetsWithPrebuilds(ctx context.Context, templa
|
||||
}
|
||||
|
||||
const updatePrebuildProvisionerJobWithCancel = `-- name: UpdatePrebuildProvisionerJobWithCancel :many
|
||||
WITH jobs_to_cancel AS (
|
||||
SELECT pj.id, w.id AS workspace_id, w.template_id, wpb.template_version_preset_id
|
||||
UPDATE provisioner_jobs
|
||||
SET
|
||||
canceled_at = $1::timestamptz,
|
||||
completed_at = $1::timestamptz
|
||||
WHERE id IN (
|
||||
SELECT pj.id
|
||||
FROM provisioner_jobs pj
|
||||
INNER JOIN workspace_prebuild_builds wpb ON wpb.job_id = pj.id
|
||||
INNER JOIN workspaces w ON w.id = wpb.workspace_id
|
||||
@@ -8794,13 +8473,7 @@ WITH jobs_to_cancel AS (
|
||||
AND pj.canceled_at IS NULL
|
||||
AND pj.completed_at IS NULL
|
||||
)
|
||||
UPDATE provisioner_jobs
|
||||
SET
|
||||
canceled_at = $1::timestamptz,
|
||||
completed_at = $1::timestamptz
|
||||
FROM jobs_to_cancel
|
||||
WHERE provisioner_jobs.id = jobs_to_cancel.id
|
||||
RETURNING jobs_to_cancel.id, jobs_to_cancel.workspace_id, jobs_to_cancel.template_id, jobs_to_cancel.template_version_preset_id
|
||||
RETURNING id
|
||||
`
|
||||
|
||||
type UpdatePrebuildProvisionerJobWithCancelParams struct {
|
||||
@@ -8808,34 +8481,22 @@ type UpdatePrebuildProvisionerJobWithCancelParams struct {
|
||||
PresetID uuid.NullUUID `db:"preset_id" json:"preset_id"`
|
||||
}
|
||||
|
||||
type UpdatePrebuildProvisionerJobWithCancelRow struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
|
||||
TemplateID uuid.UUID `db:"template_id" json:"template_id"`
|
||||
TemplateVersionPresetID uuid.NullUUID `db:"template_version_preset_id" json:"template_version_preset_id"`
|
||||
}
|
||||
|
||||
// Cancels all pending provisioner jobs for prebuilt workspaces on a specific preset from an
|
||||
// inactive template version.
|
||||
// This is an optimization to clean up stale pending jobs.
|
||||
func (q *sqlQuerier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg UpdatePrebuildProvisionerJobWithCancelParams) ([]UpdatePrebuildProvisionerJobWithCancelRow, error) {
|
||||
func (q *sqlQuerier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
|
||||
rows, err := q.db.QueryContext(ctx, updatePrebuildProvisionerJobWithCancel, arg.Now, arg.PresetID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []UpdatePrebuildProvisionerJobWithCancelRow
|
||||
var items []uuid.UUID
|
||||
for rows.Next() {
|
||||
var i UpdatePrebuildProvisionerJobWithCancelRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.WorkspaceID,
|
||||
&i.TemplateID,
|
||||
&i.TemplateVersionPresetID,
|
||||
); err != nil {
|
||||
var id uuid.UUID
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
items = append(items, id)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
@@ -13065,7 +12726,7 @@ func (q *sqlQuerier) DeleteTask(ctx context.Context, arg DeleteTaskParams) (Task
|
||||
}
|
||||
|
||||
const getTaskByID = `-- name: GetTaskByID :one
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE id = $1::uuid
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id FROM tasks_with_status WHERE id = $1::uuid
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error) {
|
||||
@@ -13086,15 +12747,12 @@ func (q *sqlQuerier) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error
|
||||
&i.WorkspaceBuildNumber,
|
||||
&i.WorkspaceAgentID,
|
||||
&i.WorkspaceAppID,
|
||||
&i.OwnerUsername,
|
||||
&i.OwnerName,
|
||||
&i.OwnerAvatarUrl,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTaskByWorkspaceID = `-- name: GetTaskByWorkspaceID :one
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE workspace_id = $1::uuid
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id FROM tasks_with_status WHERE workspace_id = $1::uuid
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (Task, error) {
|
||||
@@ -13115,9 +12773,6 @@ func (q *sqlQuerier) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.
|
||||
&i.WorkspaceBuildNumber,
|
||||
&i.WorkspaceAgentID,
|
||||
&i.WorkspaceAppID,
|
||||
&i.OwnerUsername,
|
||||
&i.OwnerName,
|
||||
&i.OwnerAvatarUrl,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -13126,12 +12781,11 @@ const insertTask = `-- name: InsertTask :one
|
||||
INSERT INTO tasks
|
||||
(id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
(gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at
|
||||
`
|
||||
|
||||
type InsertTaskParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
@@ -13144,7 +12798,6 @@ type InsertTaskParams struct {
|
||||
|
||||
func (q *sqlQuerier) InsertTask(ctx context.Context, arg InsertTaskParams) (TaskTable, error) {
|
||||
row := q.db.QueryRowContext(ctx, insertTask,
|
||||
arg.ID,
|
||||
arg.OrganizationID,
|
||||
arg.OwnerID,
|
||||
arg.Name,
|
||||
@@ -13171,7 +12824,7 @@ func (q *sqlQuerier) InsertTask(ctx context.Context, arg InsertTaskParams) (Task
|
||||
}
|
||||
|
||||
const listTasks = `-- name: ListTasks :many
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id, owner_username, owner_name, owner_avatar_url FROM tasks_with_status tws
|
||||
SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id FROM tasks_with_status tws
|
||||
WHERE tws.deleted_at IS NULL
|
||||
AND CASE WHEN $1::UUID != '00000000-0000-0000-0000-000000000000' THEN tws.owner_id = $1::UUID ELSE TRUE END
|
||||
AND CASE WHEN $2::UUID != '00000000-0000-0000-0000-000000000000' THEN tws.organization_id = $2::UUID ELSE TRUE END
|
||||
@@ -13209,9 +12862,6 @@ func (q *sqlQuerier) ListTasks(ctx context.Context, arg ListTasksParams) ([]Task
|
||||
&i.WorkspaceBuildNumber,
|
||||
&i.WorkspaceAgentID,
|
||||
&i.WorkspaceAppID,
|
||||
&i.OwnerUsername,
|
||||
&i.OwnerName,
|
||||
&i.OwnerAvatarUrl,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -13385,41 +13035,6 @@ func (q *sqlQuerier) UpsertTelemetryItem(ctx context.Context, arg UpsertTelemetr
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteOldTelemetryLocks = `-- name: DeleteOldTelemetryLocks :exec
|
||||
DELETE FROM
|
||||
telemetry_locks
|
||||
WHERE
|
||||
period_ending_at < $1::timestamptz
|
||||
`
|
||||
|
||||
// Deletes old telemetry locks from the telemetry_locks table.
|
||||
func (q *sqlQuerier) DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteOldTelemetryLocks, periodEndingAtBefore)
|
||||
return err
|
||||
}
|
||||
|
||||
const insertTelemetryLock = `-- name: InsertTelemetryLock :exec
|
||||
INSERT INTO
|
||||
telemetry_locks (event_type, period_ending_at)
|
||||
VALUES
|
||||
($1, $2)
|
||||
`
|
||||
|
||||
type InsertTelemetryLockParams struct {
|
||||
EventType string `db:"event_type" json:"event_type"`
|
||||
PeriodEndingAt time.Time `db:"period_ending_at" json:"period_ending_at"`
|
||||
}
|
||||
|
||||
// Inserts a new lock row into the telemetry_locks table. Replicas should call
|
||||
// this function prior to attempting to generate or publish a heartbeat event to
|
||||
// the telemetry service.
|
||||
// If the query returns a duplicate primary key error, the replica should not
|
||||
// attempt to generate or publish the event to the telemetry service.
|
||||
func (q *sqlQuerier) InsertTelemetryLock(ctx context.Context, arg InsertTelemetryLockParams) error {
|
||||
_, err := q.db.ExecContext(ctx, insertTelemetryLock, arg.EventType, arg.PeriodEndingAt)
|
||||
return err
|
||||
}
|
||||
|
||||
const getTemplateAverageBuildTime = `-- name: GetTemplateAverageBuildTime :one
|
||||
WITH build_times AS (
|
||||
SELECT
|
||||
@@ -21927,7 +21542,7 @@ func (q *sqlQuerier) GetWorkspaceACLByID(ctx context.Context, id uuid.UUID) (Get
|
||||
|
||||
const getWorkspaceByAgentID = `-- name: GetWorkspaceByAgentID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
|
||||
FROM
|
||||
workspaces_expanded as workspaces
|
||||
WHERE
|
||||
@@ -21988,14 +21603,13 @@ func (q *sqlQuerier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUI
|
||||
&i.TemplateDisplayName,
|
||||
&i.TemplateIcon,
|
||||
&i.TemplateDescription,
|
||||
&i.TaskID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getWorkspaceByID = `-- name: GetWorkspaceByID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
|
||||
FROM
|
||||
workspaces_expanded
|
||||
WHERE
|
||||
@@ -22037,14 +21651,13 @@ func (q *sqlQuerier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (Worksp
|
||||
&i.TemplateDisplayName,
|
||||
&i.TemplateIcon,
|
||||
&i.TemplateDescription,
|
||||
&i.TaskID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getWorkspaceByOwnerIDAndName = `-- name: GetWorkspaceByOwnerIDAndName :one
|
||||
SELECT
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
|
||||
FROM
|
||||
workspaces_expanded as workspaces
|
||||
WHERE
|
||||
@@ -22093,14 +21706,13 @@ func (q *sqlQuerier) GetWorkspaceByOwnerIDAndName(ctx context.Context, arg GetWo
|
||||
&i.TemplateDisplayName,
|
||||
&i.TemplateIcon,
|
||||
&i.TemplateDescription,
|
||||
&i.TaskID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getWorkspaceByResourceID = `-- name: GetWorkspaceByResourceID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
|
||||
FROM
|
||||
workspaces_expanded as workspaces
|
||||
WHERE
|
||||
@@ -22156,14 +21768,13 @@ func (q *sqlQuerier) GetWorkspaceByResourceID(ctx context.Context, resourceID uu
|
||||
&i.TemplateDisplayName,
|
||||
&i.TemplateIcon,
|
||||
&i.TemplateDescription,
|
||||
&i.TaskID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getWorkspaceByWorkspaceAppID = `-- name: GetWorkspaceByWorkspaceAppID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description, task_id
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at, dormant_at, deleting_at, automatic_updates, favorite, next_start_at, group_acl, user_acl, owner_avatar_url, owner_username, owner_name, organization_name, organization_display_name, organization_icon, organization_description, template_name, template_display_name, template_icon, template_description
|
||||
FROM
|
||||
workspaces_expanded as workspaces
|
||||
WHERE
|
||||
@@ -22231,7 +21842,6 @@ func (q *sqlQuerier) GetWorkspaceByWorkspaceAppID(ctx context.Context, workspace
|
||||
&i.TemplateDisplayName,
|
||||
&i.TemplateIcon,
|
||||
&i.TemplateDescription,
|
||||
&i.TaskID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -22281,7 +21891,7 @@ SELECT
|
||||
),
|
||||
filtered_workspaces AS (
|
||||
SELECT
|
||||
workspaces.id, workspaces.created_at, workspaces.updated_at, workspaces.owner_id, workspaces.organization_id, workspaces.template_id, workspaces.deleted, workspaces.name, workspaces.autostart_schedule, workspaces.ttl, workspaces.last_used_at, workspaces.dormant_at, workspaces.deleting_at, workspaces.automatic_updates, workspaces.favorite, workspaces.next_start_at, workspaces.group_acl, workspaces.user_acl, workspaces.owner_avatar_url, workspaces.owner_username, workspaces.owner_name, workspaces.organization_name, workspaces.organization_display_name, workspaces.organization_icon, workspaces.organization_description, workspaces.template_name, workspaces.template_display_name, workspaces.template_icon, workspaces.template_description, workspaces.task_id,
|
||||
workspaces.id, workspaces.created_at, workspaces.updated_at, workspaces.owner_id, workspaces.organization_id, workspaces.template_id, workspaces.deleted, workspaces.name, workspaces.autostart_schedule, workspaces.ttl, workspaces.last_used_at, workspaces.dormant_at, workspaces.deleting_at, workspaces.automatic_updates, workspaces.favorite, workspaces.next_start_at, workspaces.group_acl, workspaces.user_acl, workspaces.owner_avatar_url, workspaces.owner_username, workspaces.owner_name, workspaces.organization_name, workspaces.organization_display_name, workspaces.organization_icon, workspaces.organization_description, workspaces.template_name, workspaces.template_display_name, workspaces.template_icon, workspaces.template_description,
|
||||
latest_build.template_version_id,
|
||||
latest_build.template_version_name,
|
||||
latest_build.completed_at as latest_build_completed_at,
|
||||
@@ -22289,6 +21899,7 @@ SELECT
|
||||
latest_build.error as latest_build_error,
|
||||
latest_build.transition as latest_build_transition,
|
||||
latest_build.job_status as latest_build_status,
|
||||
latest_build.has_ai_task as latest_build_has_ai_task,
|
||||
latest_build.has_external_agent as latest_build_has_external_agent
|
||||
FROM
|
||||
workspaces_expanded as workspaces
|
||||
@@ -22522,19 +22133,25 @@ WHERE
|
||||
(latest_build.template_version_id = template.active_version_id) = $18 :: boolean
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by has_ai_task, checks if this is a task workspace.
|
||||
-- Filter by has_ai_task in latest build
|
||||
AND CASE
|
||||
WHEN $19::boolean IS NOT NULL
|
||||
THEN $19::boolean = EXISTS (
|
||||
SELECT
|
||||
1
|
||||
FROM
|
||||
tasks
|
||||
WHERE
|
||||
-- Consider all tasks, deleting a task does not turn the
|
||||
-- workspace into a non-task workspace.
|
||||
tasks.workspace_id = workspaces.id
|
||||
)
|
||||
WHEN $19 :: boolean IS NOT NULL THEN
|
||||
(COALESCE(latest_build.has_ai_task, false) OR (
|
||||
-- If the build has no AI task, it means that the provisioner job is in progress
|
||||
-- and we don't know if it has an AI task yet. In this case, we optimistically
|
||||
-- assume that it has an AI task if the AI Prompt parameter is not empty. This
|
||||
-- lets the AI Task frontend spawn a task and see it immediately after instead of
|
||||
-- having to wait for the build to complete.
|
||||
latest_build.has_ai_task IS NULL AND
|
||||
latest_build.completed_at IS NULL AND
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM workspace_build_parameters
|
||||
WHERE workspace_build_parameters.workspace_build_id = latest_build.id
|
||||
AND workspace_build_parameters.name = 'AI Prompt'
|
||||
AND workspace_build_parameters.value != ''
|
||||
)
|
||||
)) = ($19 :: boolean)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by has_external_agent in latest build
|
||||
@@ -22565,7 +22182,7 @@ WHERE
|
||||
-- @authorize_filter
|
||||
), filtered_workspaces_order AS (
|
||||
SELECT
|
||||
fw.id, fw.created_at, fw.updated_at, fw.owner_id, fw.organization_id, fw.template_id, fw.deleted, fw.name, fw.autostart_schedule, fw.ttl, fw.last_used_at, fw.dormant_at, fw.deleting_at, fw.automatic_updates, fw.favorite, fw.next_start_at, fw.group_acl, fw.user_acl, fw.owner_avatar_url, fw.owner_username, fw.owner_name, fw.organization_name, fw.organization_display_name, fw.organization_icon, fw.organization_description, fw.template_name, fw.template_display_name, fw.template_icon, fw.template_description, fw.task_id, fw.template_version_id, fw.template_version_name, fw.latest_build_completed_at, fw.latest_build_canceled_at, fw.latest_build_error, fw.latest_build_transition, fw.latest_build_status, fw.latest_build_has_external_agent
|
||||
fw.id, fw.created_at, fw.updated_at, fw.owner_id, fw.organization_id, fw.template_id, fw.deleted, fw.name, fw.autostart_schedule, fw.ttl, fw.last_used_at, fw.dormant_at, fw.deleting_at, fw.automatic_updates, fw.favorite, fw.next_start_at, fw.group_acl, fw.user_acl, fw.owner_avatar_url, fw.owner_username, fw.owner_name, fw.organization_name, fw.organization_display_name, fw.organization_icon, fw.organization_description, fw.template_name, fw.template_display_name, fw.template_icon, fw.template_description, fw.template_version_id, fw.template_version_name, fw.latest_build_completed_at, fw.latest_build_canceled_at, fw.latest_build_error, fw.latest_build_transition, fw.latest_build_status, fw.latest_build_has_ai_task, fw.latest_build_has_external_agent
|
||||
FROM
|
||||
filtered_workspaces fw
|
||||
ORDER BY
|
||||
@@ -22586,7 +22203,7 @@ WHERE
|
||||
$25
|
||||
), filtered_workspaces_order_with_summary AS (
|
||||
SELECT
|
||||
fwo.id, fwo.created_at, fwo.updated_at, fwo.owner_id, fwo.organization_id, fwo.template_id, fwo.deleted, fwo.name, fwo.autostart_schedule, fwo.ttl, fwo.last_used_at, fwo.dormant_at, fwo.deleting_at, fwo.automatic_updates, fwo.favorite, fwo.next_start_at, fwo.group_acl, fwo.user_acl, fwo.owner_avatar_url, fwo.owner_username, fwo.owner_name, fwo.organization_name, fwo.organization_display_name, fwo.organization_icon, fwo.organization_description, fwo.template_name, fwo.template_display_name, fwo.template_icon, fwo.template_description, fwo.task_id, fwo.template_version_id, fwo.template_version_name, fwo.latest_build_completed_at, fwo.latest_build_canceled_at, fwo.latest_build_error, fwo.latest_build_transition, fwo.latest_build_status, fwo.latest_build_has_external_agent
|
||||
fwo.id, fwo.created_at, fwo.updated_at, fwo.owner_id, fwo.organization_id, fwo.template_id, fwo.deleted, fwo.name, fwo.autostart_schedule, fwo.ttl, fwo.last_used_at, fwo.dormant_at, fwo.deleting_at, fwo.automatic_updates, fwo.favorite, fwo.next_start_at, fwo.group_acl, fwo.user_acl, fwo.owner_avatar_url, fwo.owner_username, fwo.owner_name, fwo.organization_name, fwo.organization_display_name, fwo.organization_icon, fwo.organization_description, fwo.template_name, fwo.template_display_name, fwo.template_icon, fwo.template_description, fwo.template_version_id, fwo.template_version_name, fwo.latest_build_completed_at, fwo.latest_build_canceled_at, fwo.latest_build_error, fwo.latest_build_transition, fwo.latest_build_status, fwo.latest_build_has_ai_task, fwo.latest_build_has_external_agent
|
||||
FROM
|
||||
filtered_workspaces_order fwo
|
||||
-- Return a technical summary row with total count of workspaces.
|
||||
@@ -22622,7 +22239,6 @@ WHERE
|
||||
'', -- template_display_name
|
||||
'', -- template_icon
|
||||
'', -- template_description
|
||||
'00000000-0000-0000-0000-000000000000'::uuid, -- task_id
|
||||
-- Extra columns added to ` + "`" + `filtered_workspaces` + "`" + `
|
||||
'00000000-0000-0000-0000-000000000000'::uuid, -- template_version_id
|
||||
'', -- template_version_name
|
||||
@@ -22631,6 +22247,7 @@ WHERE
|
||||
'', -- latest_build_error
|
||||
'start'::workspace_transition, -- latest_build_transition
|
||||
'unknown'::provisioner_job_status, -- latest_build_status
|
||||
false, -- latest_build_has_ai_task
|
||||
false -- latest_build_has_external_agent
|
||||
WHERE
|
||||
$27 :: boolean = true
|
||||
@@ -22641,7 +22258,7 @@ WHERE
|
||||
filtered_workspaces
|
||||
)
|
||||
SELECT
|
||||
fwos.id, fwos.created_at, fwos.updated_at, fwos.owner_id, fwos.organization_id, fwos.template_id, fwos.deleted, fwos.name, fwos.autostart_schedule, fwos.ttl, fwos.last_used_at, fwos.dormant_at, fwos.deleting_at, fwos.automatic_updates, fwos.favorite, fwos.next_start_at, fwos.group_acl, fwos.user_acl, fwos.owner_avatar_url, fwos.owner_username, fwos.owner_name, fwos.organization_name, fwos.organization_display_name, fwos.organization_icon, fwos.organization_description, fwos.template_name, fwos.template_display_name, fwos.template_icon, fwos.template_description, fwos.task_id, fwos.template_version_id, fwos.template_version_name, fwos.latest_build_completed_at, fwos.latest_build_canceled_at, fwos.latest_build_error, fwos.latest_build_transition, fwos.latest_build_status, fwos.latest_build_has_external_agent,
|
||||
fwos.id, fwos.created_at, fwos.updated_at, fwos.owner_id, fwos.organization_id, fwos.template_id, fwos.deleted, fwos.name, fwos.autostart_schedule, fwos.ttl, fwos.last_used_at, fwos.dormant_at, fwos.deleting_at, fwos.automatic_updates, fwos.favorite, fwos.next_start_at, fwos.group_acl, fwos.user_acl, fwos.owner_avatar_url, fwos.owner_username, fwos.owner_name, fwos.organization_name, fwos.organization_display_name, fwos.organization_icon, fwos.organization_description, fwos.template_name, fwos.template_display_name, fwos.template_icon, fwos.template_description, fwos.template_version_id, fwos.template_version_name, fwos.latest_build_completed_at, fwos.latest_build_canceled_at, fwos.latest_build_error, fwos.latest_build_transition, fwos.latest_build_status, fwos.latest_build_has_ai_task, fwos.latest_build_has_external_agent,
|
||||
tc.count
|
||||
FROM
|
||||
filtered_workspaces_order_with_summary fwos
|
||||
@@ -22709,7 +22326,6 @@ type GetWorkspacesRow struct {
|
||||
TemplateDisplayName string `db:"template_display_name" json:"template_display_name"`
|
||||
TemplateIcon string `db:"template_icon" json:"template_icon"`
|
||||
TemplateDescription string `db:"template_description" json:"template_description"`
|
||||
TaskID uuid.NullUUID `db:"task_id" json:"task_id"`
|
||||
TemplateVersionID uuid.UUID `db:"template_version_id" json:"template_version_id"`
|
||||
TemplateVersionName sql.NullString `db:"template_version_name" json:"template_version_name"`
|
||||
LatestBuildCompletedAt sql.NullTime `db:"latest_build_completed_at" json:"latest_build_completed_at"`
|
||||
@@ -22717,6 +22333,7 @@ type GetWorkspacesRow struct {
|
||||
LatestBuildError sql.NullString `db:"latest_build_error" json:"latest_build_error"`
|
||||
LatestBuildTransition WorkspaceTransition `db:"latest_build_transition" json:"latest_build_transition"`
|
||||
LatestBuildStatus ProvisionerJobStatus `db:"latest_build_status" json:"latest_build_status"`
|
||||
LatestBuildHasAITask sql.NullBool `db:"latest_build_has_ai_task" json:"latest_build_has_ai_task"`
|
||||
LatestBuildHasExternalAgent sql.NullBool `db:"latest_build_has_external_agent" json:"latest_build_has_external_agent"`
|
||||
Count int64 `db:"count" json:"count"`
|
||||
}
|
||||
@@ -22791,7 +22408,6 @@ func (q *sqlQuerier) GetWorkspaces(ctx context.Context, arg GetWorkspacesParams)
|
||||
&i.TemplateDisplayName,
|
||||
&i.TemplateIcon,
|
||||
&i.TemplateDescription,
|
||||
&i.TaskID,
|
||||
&i.TemplateVersionID,
|
||||
&i.TemplateVersionName,
|
||||
&i.LatestBuildCompletedAt,
|
||||
@@ -22799,6 +22415,7 @@ func (q *sqlQuerier) GetWorkspaces(ctx context.Context, arg GetWorkspacesParams)
|
||||
&i.LatestBuildError,
|
||||
&i.LatestBuildTransition,
|
||||
&i.LatestBuildStatus,
|
||||
&i.LatestBuildHasAITask,
|
||||
&i.LatestBuildHasExternalAgent,
|
||||
&i.Count,
|
||||
); err != nil {
|
||||
|
||||
@@ -6,14 +6,6 @@ INSERT INTO aibridge_interceptions (
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateAIBridgeInterceptionEnded :one
|
||||
UPDATE aibridge_interceptions
|
||||
SET ended_at = @ended_at::timestamptz
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
AND ended_at IS NULL
|
||||
RETURNING *;
|
||||
|
||||
-- name: InsertAIBridgeTokenUsage :one
|
||||
INSERT INTO aibridge_token_usages (
|
||||
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
|
||||
@@ -207,122 +199,3 @@ WHERE
|
||||
ORDER BY
|
||||
created_at ASC,
|
||||
id ASC;
|
||||
|
||||
-- name: ListAIBridgeInterceptionsTelemetrySummaries :many
|
||||
-- Finds all unique AIBridge interception telemetry summaries combinations
|
||||
-- (provider, model, client) in the given timeframe for telemetry reporting.
|
||||
SELECT
|
||||
DISTINCT ON (provider, model, client)
|
||||
provider,
|
||||
model,
|
||||
-- TODO: use the client value once we have it (see https://github.com/coder/aibridge/issues/31)
|
||||
'unknown' AS client
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
ended_at IS NOT NULL -- incomplete interceptions are not included in summaries
|
||||
AND ended_at >= @ended_at_after::timestamptz
|
||||
AND ended_at < @ended_at_before::timestamptz;
|
||||
|
||||
-- name: CalculateAIBridgeInterceptionsTelemetrySummary :one
|
||||
-- Calculates the telemetry summary for a given provider, model, and client
|
||||
-- combination for telemetry reporting.
|
||||
WITH interceptions_in_range AS (
|
||||
-- Get all matching interceptions in the given timeframe.
|
||||
SELECT
|
||||
id,
|
||||
initiator_id,
|
||||
(ended_at - started_at) AS duration
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
provider = @provider::text
|
||||
AND model = @model::text
|
||||
-- TODO: use the client value once we have it (see https://github.com/coder/aibridge/issues/31)
|
||||
AND 'unknown' = @client::text
|
||||
AND ended_at IS NOT NULL -- incomplete interceptions are not included in summaries
|
||||
AND ended_at >= @ended_at_after::timestamptz
|
||||
AND ended_at < @ended_at_before::timestamptz
|
||||
),
|
||||
interception_counts AS (
|
||||
SELECT
|
||||
COUNT(id) AS interception_count,
|
||||
COUNT(DISTINCT initiator_id) AS unique_initiator_count
|
||||
FROM
|
||||
interceptions_in_range
|
||||
),
|
||||
duration_percentiles AS (
|
||||
SELECT
|
||||
(COALESCE(PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p50_millis,
|
||||
(COALESCE(PERCENTILE_CONT(0.90) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p90_millis,
|
||||
(COALESCE(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p95_millis,
|
||||
(COALESCE(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM duration)), 0) * 1000)::bigint AS interception_duration_p99_millis
|
||||
FROM
|
||||
interceptions_in_range
|
||||
),
|
||||
token_aggregates AS (
|
||||
SELECT
|
||||
COALESCE(SUM(tu.input_tokens), 0) AS token_count_input,
|
||||
COALESCE(SUM(tu.output_tokens), 0) AS token_count_output,
|
||||
-- Cached tokens are stored in metadata JSON, extract if available.
|
||||
-- Read tokens may be stored in:
|
||||
-- - cache_read_input (Anthropic)
|
||||
-- - prompt_cached (OpenAI)
|
||||
COALESCE(SUM(
|
||||
COALESCE((tu.metadata->>'cache_read_input')::bigint, 0) +
|
||||
COALESCE((tu.metadata->>'prompt_cached')::bigint, 0)
|
||||
), 0) AS token_count_cached_read,
|
||||
-- Written tokens may be stored in:
|
||||
-- - cache_creation_input (Anthropic)
|
||||
-- Note that cache_ephemeral_5m_input and cache_ephemeral_1h_input on
|
||||
-- Anthropic are included in the cache_creation_input field.
|
||||
COALESCE(SUM(
|
||||
COALESCE((tu.metadata->>'cache_creation_input')::bigint, 0)
|
||||
), 0) AS token_count_cached_written,
|
||||
COUNT(tu.id) AS token_usages_count
|
||||
FROM
|
||||
interceptions_in_range i
|
||||
LEFT JOIN
|
||||
aibridge_token_usages tu ON i.id = tu.interception_id
|
||||
),
|
||||
prompt_aggregates AS (
|
||||
SELECT
|
||||
COUNT(up.id) AS user_prompts_count
|
||||
FROM
|
||||
interceptions_in_range i
|
||||
LEFT JOIN
|
||||
aibridge_user_prompts up ON i.id = up.interception_id
|
||||
),
|
||||
tool_aggregates AS (
|
||||
SELECT
|
||||
COUNT(tu.id) FILTER (WHERE tu.injected = true) AS tool_calls_count_injected,
|
||||
COUNT(tu.id) FILTER (WHERE tu.injected = false) AS tool_calls_count_non_injected,
|
||||
COUNT(tu.id) FILTER (WHERE tu.injected = true AND tu.invocation_error IS NOT NULL) AS injected_tool_call_error_count
|
||||
FROM
|
||||
interceptions_in_range i
|
||||
LEFT JOIN
|
||||
aibridge_tool_usages tu ON i.id = tu.interception_id
|
||||
)
|
||||
SELECT
|
||||
ic.interception_count::bigint AS interception_count,
|
||||
dp.interception_duration_p50_millis::bigint AS interception_duration_p50_millis,
|
||||
dp.interception_duration_p90_millis::bigint AS interception_duration_p90_millis,
|
||||
dp.interception_duration_p95_millis::bigint AS interception_duration_p95_millis,
|
||||
dp.interception_duration_p99_millis::bigint AS interception_duration_p99_millis,
|
||||
ic.unique_initiator_count::bigint AS unique_initiator_count,
|
||||
pa.user_prompts_count::bigint AS user_prompts_count,
|
||||
tok_agg.token_usages_count::bigint AS token_usages_count,
|
||||
tok_agg.token_count_input::bigint AS token_count_input,
|
||||
tok_agg.token_count_output::bigint AS token_count_output,
|
||||
tok_agg.token_count_cached_read::bigint AS token_count_cached_read,
|
||||
tok_agg.token_count_cached_written::bigint AS token_count_cached_written,
|
||||
tool_agg.tool_calls_count_injected::bigint AS tool_calls_count_injected,
|
||||
tool_agg.tool_calls_count_non_injected::bigint AS tool_calls_count_non_injected,
|
||||
tool_agg.injected_tool_call_error_count::bigint AS injected_tool_call_error_count
|
||||
FROM
|
||||
interception_counts ic,
|
||||
duration_percentiles dp,
|
||||
token_aggregates tok_agg,
|
||||
prompt_aggregates pa,
|
||||
tool_aggregates tool_agg
|
||||
;
|
||||
|
||||
@@ -300,8 +300,12 @@ GROUP BY wpb.template_version_preset_id;
|
||||
-- Cancels all pending provisioner jobs for prebuilt workspaces on a specific preset from an
|
||||
-- inactive template version.
|
||||
-- This is an optimization to clean up stale pending jobs.
|
||||
WITH jobs_to_cancel AS (
|
||||
SELECT pj.id, w.id AS workspace_id, w.template_id, wpb.template_version_preset_id
|
||||
UPDATE provisioner_jobs
|
||||
SET
|
||||
canceled_at = @now::timestamptz,
|
||||
completed_at = @now::timestamptz
|
||||
WHERE id IN (
|
||||
SELECT pj.id
|
||||
FROM provisioner_jobs pj
|
||||
INNER JOIN workspace_prebuild_builds wpb ON wpb.job_id = pj.id
|
||||
INNER JOIN workspaces w ON w.id = wpb.workspace_id
|
||||
@@ -320,54 +324,4 @@ WITH jobs_to_cancel AS (
|
||||
AND pj.canceled_at IS NULL
|
||||
AND pj.completed_at IS NULL
|
||||
)
|
||||
UPDATE provisioner_jobs
|
||||
SET
|
||||
canceled_at = @now::timestamptz,
|
||||
completed_at = @now::timestamptz
|
||||
FROM jobs_to_cancel
|
||||
WHERE provisioner_jobs.id = jobs_to_cancel.id
|
||||
RETURNING jobs_to_cancel.id, jobs_to_cancel.workspace_id, jobs_to_cancel.template_id, jobs_to_cancel.template_version_preset_id;
|
||||
|
||||
-- name: GetOrganizationsWithPrebuildStatus :many
|
||||
-- GetOrganizationsWithPrebuildStatus returns organizations with prebuilds configured and their
|
||||
-- membership status for the prebuilds system user (org membership, group existence, group membership).
|
||||
WITH orgs_with_prebuilds AS (
|
||||
-- Get unique organizations that have presets with prebuilds configured
|
||||
SELECT DISTINCT o.id, o.name
|
||||
FROM organizations o
|
||||
INNER JOIN templates t ON t.organization_id = o.id
|
||||
INNER JOIN template_versions tv ON tv.template_id = t.id
|
||||
INNER JOIN template_version_presets tvp ON tvp.template_version_id = tv.id
|
||||
WHERE tvp.desired_instances IS NOT NULL
|
||||
),
|
||||
prebuild_user_membership AS (
|
||||
-- Check if the user is a member of the organizations
|
||||
SELECT om.organization_id
|
||||
FROM organization_members om
|
||||
INNER JOIN orgs_with_prebuilds owp ON owp.id = om.organization_id
|
||||
WHERE om.user_id = @user_id::uuid
|
||||
),
|
||||
prebuild_groups AS (
|
||||
-- Check if the organizations have the prebuilds group
|
||||
SELECT g.organization_id, g.id as group_id
|
||||
FROM groups g
|
||||
INNER JOIN orgs_with_prebuilds owp ON owp.id = g.organization_id
|
||||
WHERE g.name = @group_name::text
|
||||
),
|
||||
prebuild_group_membership AS (
|
||||
-- Check if the user is in the prebuilds group
|
||||
SELECT pg.organization_id
|
||||
FROM prebuild_groups pg
|
||||
INNER JOIN group_members gm ON gm.group_id = pg.group_id
|
||||
WHERE gm.user_id = @user_id::uuid
|
||||
)
|
||||
SELECT
|
||||
owp.id AS organization_id,
|
||||
owp.name AS organization_name,
|
||||
(pum.organization_id IS NOT NULL)::boolean AS has_prebuild_user,
|
||||
pg.group_id AS prebuilds_group_id,
|
||||
(pgm.organization_id IS NOT NULL)::boolean AS has_prebuild_user_in_group
|
||||
FROM orgs_with_prebuilds owp
|
||||
LEFT JOIN prebuild_groups pg ON pg.organization_id = owp.id
|
||||
LEFT JOIN prebuild_user_membership pum ON pum.organization_id = owp.id
|
||||
LEFT JOIN prebuild_group_membership pgm ON pgm.organization_id = owp.id;
|
||||
RETURNING id;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
INSERT INTO tasks
|
||||
(id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
(gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateTaskWorkspaceID :one
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
-- name: InsertTelemetryLock :exec
|
||||
-- Inserts a new lock row into the telemetry_locks table. Replicas should call
|
||||
-- this function prior to attempting to generate or publish a heartbeat event to
|
||||
-- the telemetry service.
|
||||
-- If the query returns a duplicate primary key error, the replica should not
|
||||
-- attempt to generate or publish the event to the telemetry service.
|
||||
INSERT INTO
|
||||
telemetry_locks (event_type, period_ending_at)
|
||||
VALUES
|
||||
($1, $2);
|
||||
|
||||
-- name: DeleteOldTelemetryLocks :exec
|
||||
-- Deletes old telemetry locks from the telemetry_locks table.
|
||||
DELETE FROM
|
||||
telemetry_locks
|
||||
WHERE
|
||||
period_ending_at < @period_ending_at_before::timestamptz;
|
||||
@@ -117,6 +117,7 @@ SELECT
|
||||
latest_build.error as latest_build_error,
|
||||
latest_build.transition as latest_build_transition,
|
||||
latest_build.job_status as latest_build_status,
|
||||
latest_build.has_ai_task as latest_build_has_ai_task,
|
||||
latest_build.has_external_agent as latest_build_has_external_agent
|
||||
FROM
|
||||
workspaces_expanded as workspaces
|
||||
@@ -350,19 +351,25 @@ WHERE
|
||||
(latest_build.template_version_id = template.active_version_id) = sqlc.narg('using_active') :: boolean
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by has_ai_task, checks if this is a task workspace.
|
||||
-- Filter by has_ai_task in latest build
|
||||
AND CASE
|
||||
WHEN sqlc.narg('has_ai_task')::boolean IS NOT NULL
|
||||
THEN sqlc.narg('has_ai_task')::boolean = EXISTS (
|
||||
SELECT
|
||||
1
|
||||
FROM
|
||||
tasks
|
||||
WHERE
|
||||
-- Consider all tasks, deleting a task does not turn the
|
||||
-- workspace into a non-task workspace.
|
||||
tasks.workspace_id = workspaces.id
|
||||
)
|
||||
WHEN sqlc.narg('has_ai_task') :: boolean IS NOT NULL THEN
|
||||
(COALESCE(latest_build.has_ai_task, false) OR (
|
||||
-- If the build has no AI task, it means that the provisioner job is in progress
|
||||
-- and we don't know if it has an AI task yet. In this case, we optimistically
|
||||
-- assume that it has an AI task if the AI Prompt parameter is not empty. This
|
||||
-- lets the AI Task frontend spawn a task and see it immediately after instead of
|
||||
-- having to wait for the build to complete.
|
||||
latest_build.has_ai_task IS NULL AND
|
||||
latest_build.completed_at IS NULL AND
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM workspace_build_parameters
|
||||
WHERE workspace_build_parameters.workspace_build_id = latest_build.id
|
||||
AND workspace_build_parameters.name = 'AI Prompt'
|
||||
AND workspace_build_parameters.value != ''
|
||||
)
|
||||
)) = (sqlc.narg('has_ai_task') :: boolean)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by has_external_agent in latest build
|
||||
@@ -450,7 +457,6 @@ WHERE
|
||||
'', -- template_display_name
|
||||
'', -- template_icon
|
||||
'', -- template_description
|
||||
'00000000-0000-0000-0000-000000000000'::uuid, -- task_id
|
||||
-- Extra columns added to `filtered_workspaces`
|
||||
'00000000-0000-0000-0000-000000000000'::uuid, -- template_version_id
|
||||
'', -- template_version_name
|
||||
@@ -459,6 +465,7 @@ WHERE
|
||||
'', -- latest_build_error
|
||||
'start'::workspace_transition, -- latest_build_transition
|
||||
'unknown'::provisioner_job_status, -- latest_build_status
|
||||
false, -- latest_build_has_ai_task
|
||||
false -- latest_build_has_external_agent
|
||||
WHERE
|
||||
@with_summary :: boolean = true
|
||||
|
||||
@@ -62,7 +62,6 @@ const (
|
||||
UniqueTaskWorkspaceAppsPkey UniqueConstraint = "task_workspace_apps_pkey" // ALTER TABLE ONLY task_workspace_apps ADD CONSTRAINT task_workspace_apps_pkey PRIMARY KEY (task_id, workspace_build_number);
|
||||
UniqueTasksPkey UniqueConstraint = "tasks_pkey" // ALTER TABLE ONLY tasks ADD CONSTRAINT tasks_pkey PRIMARY KEY (id);
|
||||
UniqueTelemetryItemsPkey UniqueConstraint = "telemetry_items_pkey" // ALTER TABLE ONLY telemetry_items ADD CONSTRAINT telemetry_items_pkey PRIMARY KEY (key);
|
||||
UniqueTelemetryLocksPkey UniqueConstraint = "telemetry_locks_pkey" // ALTER TABLE ONLY telemetry_locks ADD CONSTRAINT telemetry_locks_pkey PRIMARY KEY (event_type, period_ending_at);
|
||||
UniqueTemplateUsageStatsPkey UniqueConstraint = "template_usage_stats_pkey" // ALTER TABLE ONLY template_usage_stats ADD CONSTRAINT template_usage_stats_pkey PRIMARY KEY (start_time, template_id, user_id);
|
||||
UniqueTemplateVersionParametersTemplateVersionIDNameKey UniqueConstraint = "template_version_parameters_template_version_id_name_key" // ALTER TABLE ONLY template_version_parameters ADD CONSTRAINT template_version_parameters_template_version_id_name_key UNIQUE (template_version_id, name);
|
||||
UniqueTemplateVersionPresetParametersPkey UniqueConstraint = "template_version_preset_parameters_pkey" // ALTER TABLE ONLY template_version_preset_parameters ADD CONSTRAINT template_version_preset_parameters_pkey PRIMARY KEY (id);
|
||||
|
||||
@@ -37,18 +37,13 @@ type ReconciliationOrchestrator interface {
|
||||
TrackResourceReplacement(ctx context.Context, workspaceID, buildID uuid.UUID, replacements []*sdkproto.ResourceReplacement)
|
||||
}
|
||||
|
||||
// ReconcileStats contains statistics about a reconciliation cycle.
|
||||
type ReconcileStats struct {
|
||||
Elapsed time.Duration
|
||||
}
|
||||
|
||||
type Reconciler interface {
|
||||
StateSnapshotter
|
||||
|
||||
// ReconcileAll orchestrates the reconciliation of all prebuilds across all templates.
|
||||
// It takes a global snapshot of the system state and then reconciles each preset
|
||||
// in parallel, creating or deleting prebuilds as needed to reach their desired states.
|
||||
ReconcileAll(ctx context.Context) (ReconcileStats, error)
|
||||
ReconcileAll(ctx context.Context) error
|
||||
}
|
||||
|
||||
// StateSnapshotter defines the operations necessary to capture workspace prebuilds state.
|
||||
|
||||
@@ -17,11 +17,7 @@ func (NoopReconciler) Run(context.Context) {}
|
||||
func (NoopReconciler) Stop(context.Context, error) {}
|
||||
func (NoopReconciler) TrackResourceReplacement(context.Context, uuid.UUID, uuid.UUID, []*sdkproto.ResourceReplacement) {
|
||||
}
|
||||
|
||||
func (NoopReconciler) ReconcileAll(context.Context) (ReconcileStats, error) {
|
||||
return ReconcileStats{}, nil
|
||||
}
|
||||
|
||||
func (NoopReconciler) ReconcileAll(context.Context) error { return nil }
|
||||
func (NoopReconciler) SnapshotState(context.Context, database.Store) (*GlobalSnapshot, error) {
|
||||
return &GlobalSnapshot{}, nil
|
||||
}
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
# Rego authorization policy
|
||||
|
||||
## Code style
|
||||
|
||||
It's a good idea to consult the [Rego style guide](https://docs.styra.com/opa/rego-style-guide). The "Variables and Data Types" section in particular has some helpful and non-obvious advice in it.
|
||||
|
||||
## Debugging
|
||||
|
||||
Open Policy Agent provides a CLI and a playground that can be used for evaluating, formatting, testing, and linting policies.
|
||||
|
||||
### CLI
|
||||
|
||||
Below are some helpful commands you can use for debugging.
|
||||
|
||||
For full evaluation, run:
|
||||
|
||||
```sh
|
||||
opa eval --format=pretty 'data.authz.allow' -d policy.rego -i input.json
|
||||
```
|
||||
|
||||
For partial evaluation, run:
|
||||
|
||||
```sh
|
||||
opa eval --partial --format=pretty 'data.authz.allow' -d policy.rego \
|
||||
--unknowns input.object.owner --unknowns input.object.org_owner \
|
||||
--unknowns input.object.acl_user_list --unknowns input.object.acl_group_list \
|
||||
-i input.json
|
||||
```
|
||||
|
||||
### Playground
|
||||
|
||||
Use the [Open Policy Agent Playground](https://play.openpolicyagent.org/) while editing to getting linting, code formatting, and help debugging!
|
||||
|
||||
You can use the contents of input.json as a starting point for your own testing input. Paste the contents of policy.rego into the left-hand side of the playground, and the contents of input.json into the "Input" section. Click "Evaluate" and you should see something like the following in the output.
|
||||
|
||||
```json
|
||||
{
|
||||
"allow": true,
|
||||
"check_scope_allow_list": true,
|
||||
"org": 0,
|
||||
"org_member": 0,
|
||||
"org_memberships": [],
|
||||
"permission_allow": true,
|
||||
"role_allow": true,
|
||||
"scope_allow": true,
|
||||
"scope_org": 0,
|
||||
"scope_org_member": 0,
|
||||
"scope_site": 1,
|
||||
"scope_user": 0,
|
||||
"site": 1,
|
||||
"user": 0
|
||||
}
|
||||
```
|
||||
|
||||
## Levels
|
||||
|
||||
Permissions are evaluated at four levels: site, user, org, org_member.
|
||||
|
||||
For each level, two checks are performed:
|
||||
- Do the subject's permissions allow them to perform this action?
|
||||
- Does the subject's scope allow them to perform this action?
|
||||
|
||||
Each of these checks gets a "vote", which must one of three values:
|
||||
- -1 to deny (usually because of a negative permission)
|
||||
- 0 to abstain (no matching permission)
|
||||
- 1 to allow
|
||||
|
||||
If a level abstains, then the decision gets deferred to the next level. When
|
||||
there is no "next" level to defer to it is equivalent to being denied.
|
||||
|
||||
### Scope
|
||||
Additionally, each input has a "scope" that can be thought of as a second set of permissions, where each permission belongs to one of the four levels–exactly the same as role permissions. An action is only allowed if it is allowed by both the subject's permissions _and_ their current scope. This is to allow issuing tokens for a subject that have a subset of the full subjects permissions.
|
||||
|
||||
For example, you may have a scope like...
|
||||
|
||||
```json
|
||||
{
|
||||
"by_org_id": {
|
||||
"<org_id>": {
|
||||
"member": [{ "resource_type": "workspace", "action": "*" }]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
...to limit the token to only accessing workspaces owned by the user within a specific org. This provides some assurances for an admin user, that the token can only access intended resources, rather than having full access to everything.
|
||||
|
||||
The final policy decision is determined by evaluating each of these checks in their proper precedence order from the `allow` rule.
|
||||
|
||||
## Unknown values
|
||||
|
||||
This policy is specifically constructed to compress to a set of queries if 'input.object.owner' and 'input.object.org_owner' are unknown. There is no specific set of rules that will guarantee that this policy has this property, however, there are some tricks. We have tests that enforce this property, so any changes that pass the tests will be okay.
|
||||
|
||||
Some general rules to follow:
|
||||
|
||||
1. Do not use unknown values in any [comprehensions](https://www.openpolicyagent.org/docs/latest/policy-language/#comprehensions) or iterations.
|
||||
|
||||
2. Use the unknown values as minimally as possible.
|
||||
|
||||
3. Avoid making code branches based on the value of the unknown field.
|
||||
|
||||
Unknown values are like a "set" of possible values (which is why rule 1 usually breaks things).
|
||||
|
||||
For example, in the org level rules, we calculate the "vote" for all orgs, rather than just the `input.object.org_owner`. This way, if the `org_owner` changes, then we don't need to recompute any votes; we already have it for the changed value. This means we don't need branching, because the end result is just a lookup table.
|
||||
+35
-83
@@ -58,68 +58,22 @@ This can be represented by the following truth table, where Y represents _positi
|
||||
- `+site.app.*.read`: allowed to perform the `read` action against all objects of type `app` in a given Coder deployment.
|
||||
- `-user.workspace.*.create`: user is not allowed to create workspaces.
|
||||
|
||||
## Levels
|
||||
|
||||
A user can be given (or deprived) a permission at several levels. Currently,
|
||||
those levels are:
|
||||
|
||||
- Site-wide level
|
||||
- Organization level
|
||||
- User level
|
||||
- Organization member level
|
||||
|
||||
The site-wide level is the most authoritative. Any permission granted or denied at the side-wide level is absolute. After checking the site-wide level, depending of if the resource is owned by an organization or not, it will check the other levels.
|
||||
|
||||
- If the resource is owned by an organization, the next most authoritative level is the organization level. It acts like the site-wide level, but only for resources within the corresponding organization. The user can use that permission on any resource within that organization.
|
||||
- After the organization level is the member level. This level only applies to resources that are owned by both the organization _and_ the user.
|
||||
|
||||
- If the resource is not owned by an organization, the next level to check is the user level. This level only applies to resources owned by the user and that are not owned by any organization.
|
||||
|
||||
```
|
||||
┌──────────┐
|
||||
│ Site │
|
||||
└─────┬────┘
|
||||
┌──────────┴───────────┐
|
||||
┌──┤ Owned by an org? ├──┐
|
||||
│ └──────────────────────┘ │
|
||||
┌──┴──┐ ┌──┴─┐
|
||||
│ Yes │ │ No │
|
||||
└──┬──┘ └──┬─┘
|
||||
┌────────┴─────────┐ ┌─────┴────┐
|
||||
│ Organization │ │ User │
|
||||
└────────┬─────────┘ └──────────┘
|
||||
┌─────┴──────┐
|
||||
│ Member │
|
||||
└────────────┘
|
||||
```
|
||||
|
||||
## Roles
|
||||
|
||||
A _role_ is a set of permissions. When evaluating a role's permission to form an action, all the relevant permissions for the role are combined at each level. Permissions at a higher level override permissions at a lower level.
|
||||
|
||||
The following tables show the per-level role evaluation. Y indicates that the role provides positive permissions, N indicates the role provides negative permissions, and _indicates the role does not provide positive or negative permissions. YN_ indicates that the value in the cell does not matter for the access result. The table varies depending on if the resource belongs to an organization or not.
|
||||
The following table shows the per-level role evaluation.
|
||||
Y indicates that the role provides positive permissions, N indicates the role provides negative permissions, and _indicates the role does not provide positive or negative permissions. YN_ indicates that the value in the cell does not matter for the access result.
|
||||
|
||||
If the resource is owned by an organization, such as a template or a workspace:
|
||||
|
||||
| Role (example) | Site | Org | OrgMember | Result |
|
||||
|--------------------------|------|------|-----------|--------|
|
||||
| site-admin | Y | YN\_ | YN\_ | Y |
|
||||
| negative-site-permission | N | YN\_ | YN\_ | N |
|
||||
| org-admin | \_ | Y | YN\_ | Y |
|
||||
| non-org-member | \_ | N | YN\_ | N |
|
||||
| member-owned | \_ | \_ | Y | Y |
|
||||
| not-member-owned | \_ | \_ | N | N |
|
||||
| unauthenticated | \_ | \_ | \_ | N |
|
||||
|
||||
If the resource is not owned by an organization:
|
||||
|
||||
| Role (example) | Site | User | Result |
|
||||
|--------------------------|------|------|--------|
|
||||
| site-admin | Y | YN\_ | Y |
|
||||
| negative-site-permission | N | YN\_ | N |
|
||||
| user-owned | \_ | Y | Y |
|
||||
| not-user-owned | \_ | N | N |
|
||||
| unauthenticated | \_ | \_ | N |
|
||||
| Role (example) | Site | Org | User | Result |
|
||||
|-----------------|------|------|------|--------|
|
||||
| site-admin | Y | YN\_ | YN\_ | Y |
|
||||
| no-permission | N | YN\_ | YN\_ | N |
|
||||
| org-admin | \_ | Y | YN\_ | Y |
|
||||
| non-org-member | \_ | N | YN\_ | N |
|
||||
| user | \_ | \_ | Y | Y |
|
||||
| | \_ | \_ | N | N |
|
||||
| unauthenticated | \_ | \_ | \_ | N |
|
||||
|
||||
## Scopes
|
||||
|
||||
@@ -137,17 +91,15 @@ The use case for specifying this type of permission in a role is limited, and do
|
||||
Example of a scope for a workspace agent token, using an `allow_list` containing a single resource id.
|
||||
|
||||
```javascript
|
||||
{
|
||||
"scope": {
|
||||
"name": "workspace_agent",
|
||||
"display_name": "Workspace_Agent",
|
||||
// The ID of the given workspace the agent token correlates to.
|
||||
"allow_list": ["10d03e62-7703-4df5-a358-4f76577d4e2f"],
|
||||
"site": [/* ... perms ... */],
|
||||
"org": {/* ... perms ... */},
|
||||
"user": [/* ... perms ... */]
|
||||
}
|
||||
}
|
||||
"scope": {
|
||||
"name": "workspace_agent",
|
||||
"display_name": "Workspace_Agent",
|
||||
// The ID of the given workspace the agent token correlates to.
|
||||
"allow_list": ["10d03e62-7703-4df5-a358-4f76577d4e2f"],
|
||||
"site": [/* ... perms ... */],
|
||||
"org": {/* ... perms ... */},
|
||||
"user": [/* ... perms ... */]
|
||||
}
|
||||
```
|
||||
|
||||
## OPA (Open Policy Agent)
|
||||
@@ -172,31 +124,31 @@ To learn more about OPA and Rego, see https://www.openpolicyagent.org/docs.
|
||||
There are two types of evaluation in OPA:
|
||||
|
||||
- **Full evaluation**: Produces a decision that can be enforced.
|
||||
This is the default evaluation mode, where OPA evaluates the policy using `input` data that contains all known values and returns output data with the `allow` variable.
|
||||
This is the default evaluation mode, where OPA evaluates the policy using `input` data that contains all known values and returns output data with the `allow` variable.
|
||||
- **Partial evaluation**: Produces a new policy that can be evaluated later when the _unknowns_ become _known_.
|
||||
This is an optimization in OPA where it evaluates as much of the policy as possible without resolving expressions that depend on _unknown_ values from the `input`.
|
||||
To learn more about partial evaluation, see this [OPA blog post](https://blog.openpolicyagent.org/partial-evaluation-162750eaf422).
|
||||
This is an optimization in OPA where it evaluates as much of the policy as possible without resolving expressions that depend on _unknown_ values from the `input`.
|
||||
To learn more about partial evaluation, see this [OPA blog post](https://blog.openpolicyagent.org/partial-evaluation-162750eaf422).
|
||||
|
||||
Application of Full and Partial evaluation in `rbac` package:
|
||||
|
||||
- **Full Evaluation** is handled by the `RegoAuthorizer.Authorize()` method in [`authz.go`](authz.go).
|
||||
This method determines whether a subject (user) can perform a specific action on an object.
|
||||
It performs a full evaluation of the Rego policy, which returns the `allow` variable to decide whether access is granted (`true`) or denied (`false` or undefined).
|
||||
This method determines whether a subject (user) can perform a specific action on an object.
|
||||
It performs a full evaluation of the Rego policy, which returns the `allow` variable to decide whether access is granted (`true`) or denied (`false` or undefined).
|
||||
- **Partial Evaluation** is handled by the `RegoAuthorizer.Prepare()` method in [`authz.go`](authz.go).
|
||||
This method compiles OPA’s partial evaluation queries into `SQL WHERE` clauses.
|
||||
These clauses are then used to enforce authorization directly in database queries, rather than in application code.
|
||||
This method compiles OPA’s partial evaluation queries into `SQL WHERE` clauses.
|
||||
These clauses are then used to enforce authorization directly in database queries, rather than in application code.
|
||||
|
||||
Authorization Patterns:
|
||||
|
||||
- Fetch-then-authorize: an object is first retrieved from the database, and a single authorization check is performed using full evaluation via `Authorize()`.
|
||||
- Authorize-while-fetching: Partial evaluation via `Prepare()` is used to inject SQL filters directly into queries, allowing efficient authorization of many objects of the same type.
|
||||
`dbauthz` methods that enforce authorization directly in the SQL query are prefixed with `Authorized`, for example, `GetAuthorizedWorkspaces`.
|
||||
`dbauthz` methods that enforce authorization directly in the SQL query are prefixed with `Authorized`, for example, `GetAuthorizedWorkspaces`.
|
||||
|
||||
## Testing
|
||||
|
||||
- OPA Playground: https://play.openpolicyagent.org/
|
||||
- OPA CLI (`opa eval`): useful for experimenting with different inputs and understanding how the policy behaves under various conditions.
|
||||
`opa eval` returns the constraints that must be satisfied for a rule to evaluate to `true`.
|
||||
`opa eval` returns the constraints that must be satisfied for a rule to evaluate to `true`.
|
||||
- `opa eval` requires an `input.json` file containing the input data to run the policy against.
|
||||
You can generate this file using the [gen_input.go](../../scripts/rbac-authz/gen_input.go) script.
|
||||
Note: the script currently produces a fixed input. You may need to tweak it for your specific use case.
|
||||
@@ -244,12 +196,12 @@ The script [`benchmark_authz.sh`](../../scripts/rbac-authz/benchmark_authz.sh) r
|
||||
|
||||
- To run benchmark on the current branch:
|
||||
|
||||
```bash
|
||||
benchmark_authz.sh --single
|
||||
```
|
||||
```bash
|
||||
benchmark_authz.sh --single
|
||||
```
|
||||
|
||||
- To compare benchmarks between 2 branches:
|
||||
|
||||
```bash
|
||||
benchmark_authz.sh --compare main prebuild_policy
|
||||
```
|
||||
```bash
|
||||
benchmark_authz.sh --compare main prebuild_policy
|
||||
```
|
||||
|
||||
@@ -165,10 +165,6 @@ func (role Role) regoValue() ast.Value {
|
||||
ast.StringTerm("org"),
|
||||
ast.NewTerm(regoSlice(p.Org)),
|
||||
},
|
||||
[2]*ast.Term{
|
||||
ast.StringTerm("member"),
|
||||
ast.NewTerm(regoSlice(p.Member)),
|
||||
},
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -287,7 +287,7 @@ func TestFilter(t *testing.T) {
|
||||
func TestAuthorizeDomain(t *testing.T) {
|
||||
t.Parallel()
|
||||
defOrg := uuid.New()
|
||||
unusedID := uuid.New()
|
||||
unuseID := uuid.New()
|
||||
allUsersGroup := "Everyone"
|
||||
|
||||
// orphanedUser has no organization
|
||||
@@ -318,21 +318,21 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
|
||||
testAuthorize(t, "UserACLList", user, []authTestCase{
|
||||
{
|
||||
resource: ResourceWorkspace.WithOwner(unusedID.String()).InOrg(unusedID).WithACLUserList(map[string][]policy.Action{
|
||||
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(unuseID).WithACLUserList(map[string][]policy.Action{
|
||||
user.ID: ResourceWorkspace.AvailableActions(),
|
||||
}),
|
||||
actions: ResourceWorkspace.AvailableActions(),
|
||||
allow: true,
|
||||
},
|
||||
{
|
||||
resource: ResourceWorkspace.WithOwner(unusedID.String()).InOrg(unusedID).WithACLUserList(map[string][]policy.Action{
|
||||
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(unuseID).WithACLUserList(map[string][]policy.Action{
|
||||
user.ID: {policy.WildcardSymbol},
|
||||
}),
|
||||
actions: ResourceWorkspace.AvailableActions(),
|
||||
allow: true,
|
||||
},
|
||||
{
|
||||
resource: ResourceWorkspace.WithOwner(unusedID.String()).InOrg(unusedID).WithACLUserList(map[string][]policy.Action{
|
||||
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(unuseID).WithACLUserList(map[string][]policy.Action{
|
||||
user.ID: {policy.ActionRead, policy.ActionUpdate},
|
||||
}),
|
||||
actions: []policy.Action{policy.ActionCreate, policy.ActionDelete},
|
||||
@@ -350,21 +350,21 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
|
||||
testAuthorize(t, "GroupACLList", user, []authTestCase{
|
||||
{
|
||||
resource: ResourceWorkspace.WithOwner(unusedID.String()).InOrg(defOrg).WithGroupACL(map[string][]policy.Action{
|
||||
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(defOrg).WithGroupACL(map[string][]policy.Action{
|
||||
allUsersGroup: ResourceWorkspace.AvailableActions(),
|
||||
}),
|
||||
actions: ResourceWorkspace.AvailableActions(),
|
||||
allow: true,
|
||||
},
|
||||
{
|
||||
resource: ResourceWorkspace.WithOwner(unusedID.String()).InOrg(defOrg).WithGroupACL(map[string][]policy.Action{
|
||||
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(defOrg).WithGroupACL(map[string][]policy.Action{
|
||||
allUsersGroup: {policy.WildcardSymbol},
|
||||
}),
|
||||
actions: ResourceWorkspace.AvailableActions(),
|
||||
allow: true,
|
||||
},
|
||||
{
|
||||
resource: ResourceWorkspace.WithOwner(unusedID.String()).InOrg(defOrg).WithGroupACL(map[string][]policy.Action{
|
||||
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(defOrg).WithGroupACL(map[string][]policy.Action{
|
||||
allUsersGroup: {policy.ActionRead, policy.ActionUpdate},
|
||||
}),
|
||||
actions: []policy.Action{policy.ActionCreate, policy.ActionDelete},
|
||||
@@ -389,14 +389,13 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
{resource: ResourceWorkspace.AnyOrganization().WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: true},
|
||||
{resource: ResourceTemplate.AnyOrganization(), actions: []policy.Action{policy.ActionCreate}, allow: false},
|
||||
|
||||
// No org + me
|
||||
{resource: ResourceWorkspace.WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: true},
|
||||
|
||||
{resource: ResourceWorkspace.All(), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
|
||||
// Other org + me
|
||||
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unusedID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
|
||||
// Other org + other user
|
||||
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
@@ -404,8 +403,8 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
{resource: ResourceWorkspace.WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
|
||||
// Other org + other us
|
||||
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unusedID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
|
||||
{resource: ResourceWorkspace.WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
})
|
||||
@@ -436,8 +435,8 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
{resource: ResourceWorkspace.All(), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
|
||||
// Other org + me
|
||||
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unusedID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
|
||||
// Other org + other user
|
||||
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
@@ -445,8 +444,8 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
{resource: ResourceWorkspace.WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
|
||||
// Other org + other use
|
||||
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unusedID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
|
||||
{resource: ResourceWorkspace.WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
})
|
||||
@@ -456,7 +455,6 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
Scope: must(ExpandScope(ScopeAll)),
|
||||
Roles: Roles{
|
||||
must(RoleByName(ScopedRoleOrgAdmin(defOrg))),
|
||||
must(RoleByName(ScopedRoleOrgMember(defOrg))),
|
||||
must(RoleByName(RoleMember())),
|
||||
},
|
||||
}
|
||||
@@ -471,14 +469,13 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
{resource: ResourceWorkspace.InOrg(defOrg), actions: workspaceExceptConnect, allow: true},
|
||||
{resource: ResourceWorkspace.InOrg(defOrg), actions: workspaceConnect, allow: false},
|
||||
|
||||
// No org + me
|
||||
{resource: ResourceWorkspace.WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: true},
|
||||
|
||||
{resource: ResourceWorkspace.All(), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
|
||||
// Other org + me
|
||||
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unusedID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
|
||||
// Other org + other user
|
||||
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner("not-me"), actions: workspaceExceptConnect, allow: true},
|
||||
@@ -486,9 +483,9 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
|
||||
{resource: ResourceWorkspace.WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
|
||||
// Other org + other user
|
||||
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unusedID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
// Other org + other use
|
||||
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
|
||||
{resource: ResourceWorkspace.WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: false},
|
||||
})
|
||||
@@ -515,8 +512,8 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
{resource: ResourceWorkspace.All(), actions: ResourceWorkspace.AvailableActions(), allow: true},
|
||||
|
||||
// Other org + me
|
||||
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: true},
|
||||
{resource: ResourceWorkspace.InOrg(unusedID), actions: ResourceWorkspace.AvailableActions(), allow: true},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner(user.ID), actions: ResourceWorkspace.AvailableActions(), allow: true},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID), actions: ResourceWorkspace.AvailableActions(), allow: true},
|
||||
|
||||
// Other org + other user
|
||||
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: true},
|
||||
@@ -524,8 +521,8 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
{resource: ResourceWorkspace.WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: true},
|
||||
|
||||
// Other org + other use
|
||||
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: true},
|
||||
{resource: ResourceWorkspace.InOrg(unusedID), actions: ResourceWorkspace.AvailableActions(), allow: true},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: true},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID), actions: ResourceWorkspace.AvailableActions(), allow: true},
|
||||
|
||||
{resource: ResourceWorkspace.WithOwner("not-me"), actions: ResourceWorkspace.AvailableActions(), allow: true},
|
||||
})
|
||||
@@ -549,14 +546,13 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner(user.ID), allow: true},
|
||||
{resource: ResourceWorkspace.InOrg(defOrg), allow: false},
|
||||
|
||||
// No org + me
|
||||
{resource: ResourceWorkspace.WithOwner(user.ID), allow: false},
|
||||
{resource: ResourceWorkspace.WithOwner(user.ID), allow: true},
|
||||
|
||||
{resource: ResourceWorkspace.All(), allow: false},
|
||||
|
||||
// Other org + me
|
||||
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner(user.ID), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unusedID), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner(user.ID), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID), allow: false},
|
||||
|
||||
// Other org + other user
|
||||
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner("not-me"), allow: false},
|
||||
@@ -564,8 +560,8 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
{resource: ResourceWorkspace.WithOwner("not-me"), allow: false},
|
||||
|
||||
// Other org + other use
|
||||
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner("not-me"), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unusedID), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner("not-me"), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID), allow: false},
|
||||
|
||||
{resource: ResourceWorkspace.WithOwner("not-me"), allow: false},
|
||||
}),
|
||||
@@ -584,8 +580,8 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
{resource: ResourceWorkspace.All()},
|
||||
|
||||
// Other org + me
|
||||
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner(user.ID)},
|
||||
{resource: ResourceWorkspace.InOrg(unusedID)},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner(user.ID)},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID)},
|
||||
|
||||
// Other org + other user
|
||||
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner("not-me")},
|
||||
@@ -593,8 +589,8 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
{resource: ResourceWorkspace.WithOwner("not-me")},
|
||||
|
||||
// Other org + other use
|
||||
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner("not-me")},
|
||||
{resource: ResourceWorkspace.InOrg(unusedID)},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner("not-me")},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID)},
|
||||
|
||||
{resource: ResourceWorkspace.WithOwner("not-me")},
|
||||
}),
|
||||
@@ -613,8 +609,8 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
{resource: ResourceTemplate.All()},
|
||||
|
||||
// Other org + me
|
||||
{resource: ResourceTemplate.InOrg(unusedID).WithOwner(user.ID)},
|
||||
{resource: ResourceTemplate.InOrg(unusedID)},
|
||||
{resource: ResourceTemplate.InOrg(unuseID).WithOwner(user.ID)},
|
||||
{resource: ResourceTemplate.InOrg(unuseID)},
|
||||
|
||||
// Other org + other user
|
||||
{resource: ResourceTemplate.InOrg(defOrg).WithOwner("not-me")},
|
||||
@@ -622,8 +618,8 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
{resource: ResourceTemplate.WithOwner("not-me")},
|
||||
|
||||
// Other org + other use
|
||||
{resource: ResourceTemplate.InOrg(unusedID).WithOwner("not-me")},
|
||||
{resource: ResourceTemplate.InOrg(unusedID)},
|
||||
{resource: ResourceTemplate.InOrg(unuseID).WithOwner("not-me")},
|
||||
{resource: ResourceTemplate.InOrg(unuseID)},
|
||||
|
||||
{resource: ResourceTemplate.WithOwner("not-me")},
|
||||
}),
|
||||
@@ -651,7 +647,6 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
ResourceType: "*",
|
||||
Action: policy.ActionRead,
|
||||
}},
|
||||
Member: []Permission{},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -673,8 +668,8 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
{resource: ResourceWorkspace.All(), allow: false},
|
||||
|
||||
// Other org + me
|
||||
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner(user.ID), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unusedID), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner(user.ID), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID), allow: false},
|
||||
|
||||
// Other org + other user
|
||||
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner("not-me"), allow: true},
|
||||
@@ -682,8 +677,8 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
{resource: ResourceWorkspace.WithOwner("not-me"), allow: false},
|
||||
|
||||
// Other org + other use
|
||||
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner("not-me"), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unusedID), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner("not-me"), allow: false},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID), allow: false},
|
||||
|
||||
{resource: ResourceWorkspace.WithOwner("not-me"), allow: false},
|
||||
}),
|
||||
@@ -704,8 +699,8 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
{resource: ResourceWorkspace.All()},
|
||||
|
||||
// Other org + me
|
||||
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner(user.ID)},
|
||||
{resource: ResourceWorkspace.InOrg(unusedID)},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner(user.ID)},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID)},
|
||||
|
||||
// Other org + other user
|
||||
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner("not-me")},
|
||||
@@ -713,8 +708,8 @@ func TestAuthorizeDomain(t *testing.T) {
|
||||
{resource: ResourceWorkspace.WithOwner("not-me")},
|
||||
|
||||
// Other org + other use
|
||||
{resource: ResourceWorkspace.InOrg(unusedID).WithOwner("not-me")},
|
||||
{resource: ResourceWorkspace.InOrg(unusedID)},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID).WithOwner("not-me")},
|
||||
{resource: ResourceWorkspace.InOrg(unuseID)},
|
||||
|
||||
{resource: ResourceWorkspace.WithOwner("not-me")},
|
||||
}))
|
||||
@@ -742,7 +737,6 @@ func TestAuthorizeLevels(t *testing.T) {
|
||||
Action: "*",
|
||||
},
|
||||
},
|
||||
Member: []Permission{},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1156,7 +1150,6 @@ func TestAuthorizeScope(t *testing.T) {
|
||||
Org: Permissions(map[string][]policy.Action{
|
||||
ResourceWorkspace.Type: {policy.ActionRead},
|
||||
}),
|
||||
Member: []Permission{},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1323,9 +1316,9 @@ type authTestCase struct {
|
||||
func testAuthorize(t *testing.T, name string, subject Subject, sets ...[]authTestCase) {
|
||||
t.Helper()
|
||||
authorizer := NewAuthorizer(prometheus.NewRegistry())
|
||||
for i, cases := range sets {
|
||||
for j, c := range cases {
|
||||
caseName := fmt.Sprintf("%s/Set%d/Case%d", name, i, j)
|
||||
for _, cases := range sets {
|
||||
for i, c := range cases {
|
||||
caseName := fmt.Sprintf("%s/%d", name, i)
|
||||
t.Run(caseName, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, a := range c.actions {
|
||||
|
||||
+4
-15
@@ -23,13 +23,8 @@
|
||||
"action": "*"
|
||||
}
|
||||
],
|
||||
"user": [],
|
||||
"by_org_id": {
|
||||
"bf7b72bd-a2b1-4ef2-962c-1d698e0483f6": {
|
||||
"org": [],
|
||||
"member": []
|
||||
}
|
||||
}
|
||||
"org": {},
|
||||
"user": []
|
||||
}
|
||||
],
|
||||
"groups": ["b617a647-b5d0-4cbe-9e40-26f89710bf18"],
|
||||
@@ -43,19 +38,13 @@
|
||||
"action": "*"
|
||||
}
|
||||
],
|
||||
"org": {},
|
||||
"user": [],
|
||||
"by_org_id": {
|
||||
"bf7b72bd-a2b1-4ef2-962c-1d698e0483f6": {
|
||||
"org": [],
|
||||
"member": []
|
||||
}
|
||||
},
|
||||
"allow_list": [
|
||||
{
|
||||
"type": "workspace",
|
||||
"id": "*"
|
||||
}
|
||||
]
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+289
-323
@@ -2,426 +2,392 @@ package authz
|
||||
|
||||
import rego.v1
|
||||
|
||||
# Check the POLICY.md file before editing this!
|
||||
#
|
||||
# https://play.openpolicyagent.org/
|
||||
#
|
||||
# A great playground: https://play.openpolicyagent.org/
|
||||
# Helpful cli commands to debug.
|
||||
# opa eval --format=pretty 'data.authz.allow' -d policy.rego -i input.json
|
||||
# opa eval --partial --format=pretty 'data.authz.allow' -d policy.rego --unknowns input.object.owner --unknowns input.object.org_owner --unknowns input.object.acl_user_list --unknowns input.object.acl_group_list -i input.json
|
||||
|
||||
#==============================================================================#
|
||||
# Site level rules #
|
||||
#==============================================================================#
|
||||
#
|
||||
# This policy is specifically constructed to compress to a set of queries if the
|
||||
# object's 'owner' and 'org_owner' fields are unknown. There is no specific set
|
||||
# of rules that will guarantee that this policy has this property. However, there
|
||||
# are some tricks. A unit test will enforce this property, so any edits that pass
|
||||
# the unit test will be ok.
|
||||
#
|
||||
# Tricks: (It's hard to really explain this, fiddling is required)
|
||||
# 1. Do not use unknown fields in any comprehension or iteration.
|
||||
# 2. Use the unknown fields as minimally as possible.
|
||||
# 3. Avoid making code branches based on the value of the unknown field.
|
||||
# Unknown values are like a "set" of possible values.
|
||||
# (This is why rule 1 usually breaks things)
|
||||
# For example:
|
||||
# In the org section, we calculate the 'allow' number for all orgs, rather
|
||||
# than just the input.object.org_owner. This is because if the org_owner
|
||||
# changes, then we don't need to recompute any 'allow' sets. We already have
|
||||
# the 'allow' for the changed value. So the answer is in a lookup table.
|
||||
# The final statement 'num := allow[input.object.org_owner]' does not have
|
||||
# different code branches based on the org_owner. 'num's value does, but
|
||||
# that is the whole point of partial evaluation.
|
||||
|
||||
# Site level permissions allow the subject to use that permission on any object.
|
||||
# For example, a site-level workspace.read permission means that the subject can
|
||||
# see every workspace in the deployment, regardless of organization or owner.
|
||||
# bool_flip(b) returns the logical negation of a boolean value 'b'.
|
||||
# You cannot do 'x := !false', but you can do 'x := bool_flip(false)'
|
||||
bool_flip(b) := false if {
|
||||
b
|
||||
}
|
||||
|
||||
bool_flip(b) := true if {
|
||||
not b
|
||||
}
|
||||
|
||||
# number(set) maps a set of boolean values to one of the following numbers:
|
||||
# -1: deny (if 'false' value is in the set) => set is {true, false} or {false}
|
||||
# 0: no decision (if the set is empty) => set is {}
|
||||
# 1: allow (if only 'true' values are in the set) => set is {true}
|
||||
|
||||
# Return -1 if the set contains any 'false' value (i.e., an explicit deny)
|
||||
number(set) := -1 if {
|
||||
false in set
|
||||
}
|
||||
|
||||
# Return 0 if the set is empty (no matching permissions)
|
||||
number(set) := 0 if {
|
||||
count(set) == 0
|
||||
}
|
||||
|
||||
# Return 1 if the set is non-empty and contains no 'false' values (i.e., only allows)
|
||||
number(set) := 1 if {
|
||||
not false in set
|
||||
set[_]
|
||||
}
|
||||
|
||||
# Permission evaluation is structured into three levels: site, org, and user.
|
||||
# For each level, two variables are computed:
|
||||
# - <level>: the decision based on the subject's full set of roles for that level
|
||||
# - scope_<level>: the decision based on the subject's scoped roles for that level
|
||||
#
|
||||
# Each of these variables is assigned one of three values:
|
||||
# -1 => negative (deny)
|
||||
# 0 => abstain (no matching permission)
|
||||
# 1 => positive (allow)
|
||||
#
|
||||
# These values are computed by calling the corresponding <level>_allow functions.
|
||||
# The final decision is derived from combining these values (see 'allow' rule).
|
||||
|
||||
# -------------------
|
||||
# Site Level Rules
|
||||
# -------------------
|
||||
|
||||
default site := 0
|
||||
|
||||
site := check_site_permissions(input.subject.roles)
|
||||
site := site_allow(input.subject.roles)
|
||||
|
||||
default scope_site := 0
|
||||
scope_site := site_allow([input.subject.scope])
|
||||
|
||||
scope_site := check_site_permissions([input.subject.scope])
|
||||
|
||||
check_site_permissions(roles) := vote if {
|
||||
# site_allow receives a list of roles and returns a single number:
|
||||
# -1 if any matching permission denies access
|
||||
# 1 if there's at least one allow and no denies
|
||||
# 0 if there are no matching permissions
|
||||
site_allow(roles) := num if {
|
||||
# allow is a set of boolean values (sets don't contain duplicates)
|
||||
allow := {is_allowed |
|
||||
# Iterate over all site permissions in all roles, and check which ones match
|
||||
# the action and object type.
|
||||
# Iterate over all site permissions in all roles
|
||||
perm := roles[_].site[_]
|
||||
perm.action in [input.action, "*"]
|
||||
perm.resource_type in [input.object.type, "*"]
|
||||
|
||||
# If a negative matching permission was found, then we vote to disallow it.
|
||||
# If the permission is not negative, then we vote to allow it.
|
||||
# is_allowed is either 'true' or 'false' if a matching permission exists.
|
||||
is_allowed := bool_flip(perm.negate)
|
||||
}
|
||||
vote := to_vote(allow)
|
||||
num := number(allow)
|
||||
}
|
||||
|
||||
#==============================================================================#
|
||||
# User level rules #
|
||||
#==============================================================================#
|
||||
# -------------------
|
||||
# Org Level Rules
|
||||
# -------------------
|
||||
|
||||
# User level rules apply to all objects owned by the subject which are not also
|
||||
# owned by an org. Permissions for objects which are "jointly" owned by an org
|
||||
# instead defer to the org member level rules.
|
||||
# org_members is the list of organizations the actor is apart of.
|
||||
# TODO: Should there be an org_members for the scope too? Without it,
|
||||
# the membership is determined by the user's roles, not their scope permissions.
|
||||
# So if an owner (who is not an org member) has an org scope, that org scope
|
||||
# will fail to return '1'. Since we assume all non members return '-1' for org
|
||||
# level permissions.
|
||||
# Adding a second org_members set might affect the partial evaluation.
|
||||
# This is being left until org scopes are used.
|
||||
org_members := {orgID |
|
||||
input.subject.roles[_].by_org_id[orgID]
|
||||
}
|
||||
|
||||
default user := 0
|
||||
# 'org' is the same as 'site' except we need to iterate over each organization
|
||||
# that the actor is a member of.
|
||||
default org := 0
|
||||
org := org_allow(input.subject.roles, "org")
|
||||
|
||||
user := check_user_permissions(input.subject.roles)
|
||||
default scope_org := 0
|
||||
scope_org := org_allow([input.subject.scope], "org")
|
||||
|
||||
default scope_user := 0
|
||||
# org_allow_set is a helper function that iterates over all orgs that the actor
|
||||
# is a member of. For each organization it sets the numerical allow value
|
||||
# for the given object + action if the object is in the organization.
|
||||
# The resulting value is a map that looks something like:
|
||||
# {"10d03e62-7703-4df5-a358-4f76577d4e2f": 1, "5750d635-82e0-4681-bd44-815b18669d65": 1}
|
||||
# The caller can use this output[<object.org_owner>] to get the final allow value.
|
||||
#
|
||||
# The reason we calculate this for all orgs, and not just the input.object.org_owner
|
||||
# is that sometimes the input.object.org_owner is unknown. In those cases
|
||||
# we have a list of org_ids that can we use in a SQL 'WHERE' clause.
|
||||
org_allow_set(roles, key) := allow_set if {
|
||||
allow_set := {id: num |
|
||||
id := org_members[_]
|
||||
set := {is_allowed |
|
||||
# Iterate over all org permissions in all roles
|
||||
perm := roles[_].by_org_id[id][key][_]
|
||||
perm.action in [input.action, "*"]
|
||||
perm.resource_type in [input.object.type, "*"]
|
||||
|
||||
scope_user := check_user_permissions([input.subject.scope])
|
||||
# is_allowed is either 'true' or 'false' if a matching permission exists.
|
||||
is_allowed := bool_flip(perm.negate)
|
||||
}
|
||||
num := number(set)
|
||||
}
|
||||
}
|
||||
|
||||
check_user_permissions(roles) := vote if {
|
||||
# The object must be owned by the subject.
|
||||
input.subject.id = input.object.owner
|
||||
org_allow(roles, key) := num if {
|
||||
# If the object has "any_org" set to true, then use the other
|
||||
# org_allow block.
|
||||
not input.object.any_org
|
||||
allow := org_allow_set(roles, key)
|
||||
|
||||
# If there is an org, use org_member permissions instead
|
||||
# Return only the org value of the input's org.
|
||||
# The reason why we do not do this up front, is that we need to make sure
|
||||
# this policy compresses down to simple queries. One way to ensure this is
|
||||
# to keep unknown values out of comprehensions.
|
||||
# (https://www.openpolicyagent.org/docs/latest/policy-language/#comprehensions)
|
||||
num := allow[input.object.org_owner]
|
||||
}
|
||||
|
||||
# This block states if "object.any_org" is set to true, then disregard the
|
||||
# organization id the object is associated with. Instead, we check if the user
|
||||
# can do the action on any organization.
|
||||
# This is useful for UI elements when we want to conclude, "Can the user create
|
||||
# a new template in any organization?"
|
||||
# It is easier than iterating over every organization the user is apart of.
|
||||
org_allow(roles, key) := num if {
|
||||
input.object.any_org # if this is false, this code block is not used
|
||||
allow := org_allow_set(roles, key)
|
||||
|
||||
# allow is a map of {"<org_id>": <number>}. We only care about values
|
||||
# that are 1, and ignore the rest.
|
||||
num := number([
|
||||
keep |
|
||||
# for every value in the mapping
|
||||
value := allow[_]
|
||||
|
||||
# only keep values > 0.
|
||||
# 1 = allow, 0 = abstain, -1 = deny
|
||||
# We only need 1 explicit allow to allow the action.
|
||||
# deny's and abstains are intentionally ignored.
|
||||
value > 0
|
||||
|
||||
# result set is a set of [true,false,...]
|
||||
# which "number()" will convert to a number.
|
||||
keep := true
|
||||
])
|
||||
}
|
||||
|
||||
# 'org_mem' is set to true if the user is an org member
|
||||
# If 'any_org' is set to true, use the other block to determine org membership.
|
||||
org_mem if {
|
||||
not input.object.any_org
|
||||
input.object.org_owner != ""
|
||||
input.object.org_owner in org_members
|
||||
}
|
||||
|
||||
org_mem if {
|
||||
input.object.any_org
|
||||
count(org_members) > 0
|
||||
}
|
||||
|
||||
org_ok if {
|
||||
org_mem
|
||||
}
|
||||
|
||||
# If the object has no organization, then the user is also considered part of
|
||||
# the non-existent org.
|
||||
org_ok if {
|
||||
input.object.org_owner == ""
|
||||
not input.object.any_org
|
||||
}
|
||||
|
||||
# -------------------
|
||||
# User Level Rules
|
||||
# -------------------
|
||||
|
||||
# 'user' is the same as 'site', except it only applies if the user owns the object and
|
||||
# the user is apart of the org (if the object has an org).
|
||||
default user := 0
|
||||
user := user_allow(input.subject.roles)
|
||||
|
||||
default scope_user := 0
|
||||
scope_user := user_allow([input.subject.scope])
|
||||
|
||||
user_allow(roles) := num if {
|
||||
input.object.owner != ""
|
||||
input.subject.id = input.object.owner
|
||||
|
||||
allow := {is_allowed |
|
||||
# Iterate over all user permissions in all roles, and check which ones match
|
||||
# the action and object type.
|
||||
# Iterate over all user permissions in all roles
|
||||
perm := roles[_].user[_]
|
||||
perm.action in [input.action, "*"]
|
||||
perm.resource_type in [input.object.type, "*"]
|
||||
|
||||
# If a negative matching permission was found, then we vote to disallow it.
|
||||
# If the permission is not negative, then we vote to allow it.
|
||||
# is_allowed is either 'true' or 'false' if a matching permission exists.
|
||||
is_allowed := bool_flip(perm.negate)
|
||||
}
|
||||
vote := to_vote(allow)
|
||||
num := number(allow)
|
||||
}
|
||||
|
||||
#==============================================================================#
|
||||
# Org level rules #
|
||||
#==============================================================================#
|
||||
|
||||
# Org level permissions are similar to `site`, except we need to iterate over
|
||||
# each organization that the subject is a member of, and check against the
|
||||
# organization that the object belongs to.
|
||||
# For example, an organization-level workspace.read permission means that the
|
||||
# subject can see every workspace in the organization, regardless of owner.
|
||||
|
||||
# org_memberships is the set of organizations the subject is apart of.
|
||||
org_memberships := {org_id |
|
||||
input.subject.roles[_].by_org_id[org_id]
|
||||
# Scope allow_list is a list of resource (Type, ID) tuples explicitly allowed by the scope.
|
||||
# If the list contains `(*,*)`, then all resources are allowed.
|
||||
scope_allow_list if {
|
||||
input.subject.scope.allow_list[_] == {"type": "*", "id": "*"}
|
||||
}
|
||||
|
||||
# TODO: Should there be a scope_org_memberships too? Without it, the membership
|
||||
# is determined by the user's roles, not their scope permissions.
|
||||
#
|
||||
# If an owner (who is not an org member) has an org scope, that org scope will
|
||||
# fail to return '1', since we assume all non-members return '-1' for org level
|
||||
# permissions. Adding a second set of org memberships might affect the partial
|
||||
# evaluation. This is being left until org scopes are used.
|
||||
|
||||
default org := 0
|
||||
|
||||
org := check_org_permissions(input.subject.roles, "org")
|
||||
|
||||
default scope_org := 0
|
||||
|
||||
scope_org := check_org_permissions([input.subject.scope], "org")
|
||||
|
||||
# check_all_org_permissions creates a map from org ids to votes at each org
|
||||
# level, for each org that the subject is a member of. It doesn't actually check
|
||||
# if the object is in the same org. Instead we look up the correct vote from
|
||||
# this map based on the object's org id in `check_org_permissions`.
|
||||
# For example, the `org_map` will look something like this:
|
||||
#
|
||||
# {"<org_id_a>": 1, "<org_id_b>": 0, "<org_id_c>": -1}
|
||||
#
|
||||
# The caller then uses `output[input.object.org_owner]` to get the correct vote.
|
||||
#
|
||||
# We have to create this map, rather than just getting the vote of the object's
|
||||
# org id because the org id _might_ be unknown. In order to make sure that this
|
||||
# policy compresses down to simple queries we need to keep unknown values out of
|
||||
# comprehensions.
|
||||
check_all_org_permissions(roles, key) := {org_id: vote |
|
||||
org_id := org_memberships[_]
|
||||
allow := {is_allowed |
|
||||
# Iterate over all site permissions in all roles, and check which ones match
|
||||
# the action and object type.
|
||||
perm := roles[_].by_org_id[org_id][key][_]
|
||||
perm.action in [input.action, "*"]
|
||||
perm.resource_type in [input.object.type, "*"]
|
||||
|
||||
# If a negative matching permission was found, then we vote to disallow it.
|
||||
# If the permission is not negative, then we vote to allow it.
|
||||
is_allowed := bool_flip(perm.negate)
|
||||
}
|
||||
vote := to_vote(allow)
|
||||
# This is a shortcut if the allow_list contains (type, *), then allow all IDs of that type.
|
||||
scope_allow_list if {
|
||||
input.subject.scope.allow_list[_] == {"type": input.object.type, "id": "*"}
|
||||
}
|
||||
|
||||
# This check handles the case where the org id is known.
|
||||
check_org_permissions(roles, key) := vote if {
|
||||
# Disallow setting any_org at the same time as an org id.
|
||||
not input.object.any_org
|
||||
# A comprehension that iterates over the allow_list and checks if the
|
||||
# (object.type, object.id) is in the allowed ids.
|
||||
scope_allow_list if {
|
||||
# If the wildcard is listed in the allow_list, we do not care about the
|
||||
# object.id. This line is included to prevent partial compilations from
|
||||
# ever needing to include the object.id.
|
||||
not {"type": "*", "id": "*"} in input.subject.scope.allow_list
|
||||
# This is equivalent to the above line, as `type` is known at partial query time.
|
||||
not {"type": input.object.type, "id": "*"} in input.subject.scope.allow_list
|
||||
|
||||
allow_map := check_all_org_permissions(roles, key)
|
||||
# allows_ids is the set of all ids allowed for the given object.type
|
||||
allowed_ids := {allowed_id |
|
||||
# Iterate over all allow list elements
|
||||
ele := input.subject.scope.allow_list[_]
|
||||
ele.type in [input.object.type, "*"]
|
||||
allowed_id := ele.id
|
||||
}
|
||||
|
||||
# Return only the vote of the object's org.
|
||||
vote := allow_map[input.object.org_owner]
|
||||
# Return if the object.id is in the allowed ids
|
||||
# This rule is evaluated at the end so the partial query can use the object.id
|
||||
# against this precomputed set of allowed ids.
|
||||
input.object.id in allowed_ids
|
||||
}
|
||||
|
||||
# This check handles the case where we want to know if the user has the
|
||||
# appropriate permission for any organization, without needing to know which.
|
||||
# This is used in several places in the UI to determine if certain parts of the
|
||||
# app should be accessible.
|
||||
# For example, can the user create a new template in any organization? If yes,
|
||||
# then we should show the "New template" button.
|
||||
check_org_permissions(roles, key) := vote if {
|
||||
# Require `any_org` to be set
|
||||
input.object.any_org
|
||||
# -------------------
|
||||
# Role-Specific Rules
|
||||
# -------------------
|
||||
|
||||
allow_map := check_all_org_permissions(roles, key)
|
||||
|
||||
# Since we're checking if the subject has the permission in _any_ org, we're
|
||||
# essentially trying to find the highest vote from any org.
|
||||
vote := max({vote |
|
||||
some vote in allow_map
|
||||
})
|
||||
}
|
||||
|
||||
# is_org_member checks if the subject belong to the same organization as the
|
||||
# object.
|
||||
is_org_member if {
|
||||
not input.object.any_org
|
||||
input.object.org_owner != ""
|
||||
input.object.org_owner in org_memberships
|
||||
}
|
||||
|
||||
# ...if 'any_org' is set to true, we check if the subject is a member of any
|
||||
# org.
|
||||
is_org_member if {
|
||||
input.object.any_org
|
||||
count(org_memberships) > 0
|
||||
}
|
||||
|
||||
#==============================================================================#
|
||||
# Org member level rules #
|
||||
#==============================================================================#
|
||||
|
||||
# Org member level permissions apply to all objects owned by the subject _and_
|
||||
# the corresponding org. Permissions for objects which are not owned by an
|
||||
# organization instead defer to the user level rules.
|
||||
#
|
||||
# The rules for this level are very similar to the rules for the organization
|
||||
# level, and so we reuse the `check_org_permissions` function from those rules.
|
||||
|
||||
default org_member := 0
|
||||
|
||||
org_member := vote if {
|
||||
# Object must be jointly owned by the user
|
||||
input.object.owner != ""
|
||||
input.subject.id = input.object.owner
|
||||
vote := check_org_permissions(input.subject.roles, "member")
|
||||
}
|
||||
|
||||
default scope_org_member := 0
|
||||
|
||||
scope_org_member := vote if {
|
||||
# Object must be jointly owned by the user
|
||||
input.object.owner != ""
|
||||
input.subject.id = input.object.owner
|
||||
vote := check_org_permissions([input.subject.scope], "member")
|
||||
}
|
||||
|
||||
#==============================================================================#
|
||||
# Role rules #
|
||||
#==============================================================================#
|
||||
|
||||
# role_allow specifies all of the conditions under which a role can grant
|
||||
# permission. These rules intentionally use the "unification" operator rather
|
||||
# than the equality and inequality operators, because those operators do not
|
||||
# work on partial values.
|
||||
# https://www.openpolicyagent.org/docs/policy-language#unification-
|
||||
|
||||
# Site level authorization
|
||||
role_allow if {
|
||||
site = 1
|
||||
}
|
||||
|
||||
# User level authorization
|
||||
role_allow if {
|
||||
not site = -1
|
||||
|
||||
user = 1
|
||||
}
|
||||
|
||||
# Org level authorization
|
||||
role_allow if {
|
||||
not site = -1
|
||||
|
||||
org = 1
|
||||
}
|
||||
|
||||
# Org member authorization
|
||||
role_allow if {
|
||||
not site = -1
|
||||
not org = -1
|
||||
|
||||
org_member = 1
|
||||
# If we are not a member of an org, and the object has an org, then we are
|
||||
# not authorized. This is an "implied -1" for not being in the org.
|
||||
org_ok
|
||||
user = 1
|
||||
}
|
||||
|
||||
#==============================================================================#
|
||||
# Scope rules #
|
||||
#==============================================================================#
|
||||
# -------------------
|
||||
# Scope-Specific Rules
|
||||
# -------------------
|
||||
|
||||
# scope_allow specifies all of the conditions under which a scope can grant
|
||||
# permission. These rules intentionally use the "unification" (=) operator
|
||||
# rather than the equality (==) and inequality (!=) operators, because those
|
||||
# operators do not work on partial values.
|
||||
# https://www.openpolicyagent.org/docs/policy-language#unification-
|
||||
|
||||
# Site level scope enforcement
|
||||
scope_allow if {
|
||||
object_is_included_in_scope_allow_list
|
||||
scope_allow_list
|
||||
scope_site = 1
|
||||
}
|
||||
|
||||
# User level scope enforcement
|
||||
scope_allow if {
|
||||
# User scope permissions must be allowed by the scope, and not denied
|
||||
# by the site. The object *must not* be owned by an organization.
|
||||
object_is_included_in_scope_allow_list
|
||||
scope_allow_list
|
||||
not scope_site = -1
|
||||
|
||||
scope_user = 1
|
||||
}
|
||||
|
||||
# Org level scope enforcement
|
||||
scope_allow if {
|
||||
# Org member scope permissions must be allowed by the scope, and not denied
|
||||
# by the site. The object *must* be owned by an organization.
|
||||
object_is_included_in_scope_allow_list
|
||||
not scope_site = -1
|
||||
|
||||
scope_org = 1
|
||||
}
|
||||
|
||||
# Org member level scope enforcement
|
||||
scope_allow if {
|
||||
# Org member scope permissions must be allowed by the scope, and not denied
|
||||
# by the site or org. The object *must* be owned by an organization.
|
||||
object_is_included_in_scope_allow_list
|
||||
scope_allow_list
|
||||
not scope_site = -1
|
||||
not scope_org = -1
|
||||
|
||||
scope_org_member = 1
|
||||
# If we are not a member of an org, and the object has an org, then we are
|
||||
# not authorized. This is an "implied -1" for not being in the org.
|
||||
org_ok
|
||||
scope_user = 1
|
||||
}
|
||||
|
||||
# If *.* is allowed, then all objects are in scope.
|
||||
object_is_included_in_scope_allow_list if {
|
||||
{"type": "*", "id": "*"} in input.subject.scope.allow_list
|
||||
}
|
||||
|
||||
# If <type>.* is allowed, then all objects of that type are in scope.
|
||||
object_is_included_in_scope_allow_list if {
|
||||
{"type": input.object.type, "id": "*"} in input.subject.scope.allow_list
|
||||
}
|
||||
|
||||
# Check if the object type and ID match one of the allow list entries.
|
||||
object_is_included_in_scope_allow_list if {
|
||||
# Check that the wildcard rules do not apply. This prevents partial inputs
|
||||
# from needing to include `input.object.id`.
|
||||
not {"type": "*", "id": "*"} in input.subject.scope.allow_list
|
||||
not {"type": input.object.type, "id": "*"} in input.subject.scope.allow_list
|
||||
|
||||
# Check which IDs from the allow list match the object type
|
||||
allowed_ids_for_object_type := {it.id |
|
||||
some it in input.subject.scope.allow_list
|
||||
it.type in [input.object.type, "*"]
|
||||
}
|
||||
|
||||
# Check if the input object ID is in the set of allowed IDs for the same
|
||||
# object type. We do this at the end to keep `input.object.id` out of the
|
||||
# comprehension because it might be unknown.
|
||||
input.object.id in allowed_ids_for_object_type
|
||||
}
|
||||
|
||||
#==============================================================================#
|
||||
# ACL rules #
|
||||
#==============================================================================#
|
||||
# -------------------
|
||||
# ACL-Specific Rules
|
||||
# Access Control List
|
||||
# -------------------
|
||||
|
||||
# ACL for users
|
||||
acl_allow if {
|
||||
# TODO: Should you have to be a member of the org too?
|
||||
# Should you have to be a member of the org too?
|
||||
perms := input.object.acl_user_list[input.subject.id]
|
||||
|
||||
# Check if either the action or * is allowed
|
||||
some action in [input.action, "*"]
|
||||
action in perms
|
||||
# Either the input action or wildcard
|
||||
[input.action, "*"][_] in perms
|
||||
}
|
||||
|
||||
# ACL for groups
|
||||
acl_allow if {
|
||||
# If there is no organization owner, the object cannot be owned by an
|
||||
# org-scoped group.
|
||||
is_org_member
|
||||
some group in input.subject.groups
|
||||
# org_scoped team.
|
||||
org_mem
|
||||
group := input.subject.groups[_]
|
||||
perms := input.object.acl_group_list[group]
|
||||
|
||||
# Check if either the action or * is allowed
|
||||
some action in [input.action, "*"]
|
||||
action in perms
|
||||
# Either the input action or wildcard
|
||||
[input.action, "*"][_] in perms
|
||||
}
|
||||
|
||||
# ACL for the special "Everyone" groups
|
||||
# ACL for 'all_users' special group
|
||||
acl_allow if {
|
||||
# If there is no organization owner, the object cannot be owned by an
|
||||
# org-scoped group.
|
||||
is_org_member
|
||||
org_mem
|
||||
perms := input.object.acl_group_list[input.object.org_owner]
|
||||
|
||||
# Check if either the action or * is allowed
|
||||
some action in [input.action, "*"]
|
||||
action in perms
|
||||
[input.action, "*"][_] in perms
|
||||
}
|
||||
|
||||
#==============================================================================#
|
||||
# Allow #
|
||||
#==============================================================================#
|
||||
|
||||
# The `allow` block is quite simple. Any check that voted no will cascade down.
|
||||
# Authorization looks for any `allow` statement that is true. Multiple can be
|
||||
# true! Note that the absence of `allow` means "unauthorized". An explicit
|
||||
# `"allow": true` is required.
|
||||
# -------------------
|
||||
# Final Allow
|
||||
#
|
||||
# We check both the subject's permissions (given by their roles or by ACL) and
|
||||
# the subject's scope. (The default scope is "*:*", allowing all actions.) Both
|
||||
# a permission check (either from roles or ACL) and the scope check must vote to
|
||||
# allow or the action is not authorized.
|
||||
|
||||
# A subject can be given permission by a role
|
||||
permission_allow if role_allow
|
||||
|
||||
# A subject can be given permission by ACL
|
||||
permission_allow if acl_allow
|
||||
# The 'allow' block is quite simple. Any set with `-1` cascades down in levels.
|
||||
# Authorization looks for any `allow` statement that is true. Multiple can be true!
|
||||
# Note that the absence of `allow` means "unauthorized".
|
||||
# An explicit `"allow": true` is required.
|
||||
#
|
||||
# Scope is also applied. The default scope is "wildcard:wildcard" allowing
|
||||
# all actions. If the scope is not "1", then the action is not authorized.
|
||||
#
|
||||
# Allow query:
|
||||
# data.authz.role_allow = true
|
||||
# data.authz.scope_allow = true
|
||||
# -------------------
|
||||
|
||||
# The role or the ACL must allow the action. Scopes can be used to limit,
|
||||
# so scope_allow must always be true.
|
||||
allow if {
|
||||
# Must be allowed by the subject's permissions
|
||||
permission_allow
|
||||
|
||||
# ...and allowed by the scope
|
||||
role_allow
|
||||
scope_allow
|
||||
}
|
||||
|
||||
#==============================================================================#
|
||||
# Utilities #
|
||||
#==============================================================================#
|
||||
|
||||
# bool_flip returns the logical negation of a boolean value. You can't do
|
||||
# 'x := not false', but you can do 'x := bool_flip(false)'
|
||||
bool_flip(b) := false if {
|
||||
b
|
||||
}
|
||||
|
||||
bool_flip(b) if {
|
||||
not b
|
||||
}
|
||||
|
||||
# to_vote gives you a voting value from a set or list of booleans.
|
||||
# {false,..} => deny (-1)
|
||||
# {} => abstain (0)
|
||||
# {true} => allow (1)
|
||||
|
||||
# Any set which contains a `false` should be considered a vote to deny.
|
||||
to_vote(set) := -1 if {
|
||||
false in set
|
||||
}
|
||||
|
||||
# A set which is empty should be considered abstaining.
|
||||
to_vote(set) := 0 if {
|
||||
count(set) == 0
|
||||
}
|
||||
|
||||
# A set which only contains true should be considered a vote to allow.
|
||||
to_vote(set) := 1 if {
|
||||
not false in set
|
||||
true in set
|
||||
# ACL list must also have the scope_allow to pass
|
||||
allow if {
|
||||
acl_allow
|
||||
scope_allow
|
||||
}
|
||||
|
||||
+10
-28
@@ -295,11 +295,15 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
ResourceOauth2App.Type: {policy.ActionRead},
|
||||
ResourceWorkspaceProxy.Type: {policy.ActionRead},
|
||||
}),
|
||||
User: append(allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUser, ResourceOrganizationMember, ResourceOrganizationMember),
|
||||
User: append(allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceUser, ResourceOrganizationMember),
|
||||
Permissions(map[string][]policy.Action{
|
||||
// Reduced permission set on dormant workspaces. No build, ssh, or exec
|
||||
ResourceWorkspaceDormant.Type: {policy.ActionRead, policy.ActionDelete, policy.ActionCreate, policy.ActionUpdate, policy.ActionWorkspaceStop, policy.ActionCreateAgent, policy.ActionDeleteAgent},
|
||||
// Users cannot do create/update/delete on themselves, but they
|
||||
// can read their own details.
|
||||
ResourceUser.Type: {policy.ActionRead, policy.ActionReadPersonal, policy.ActionUpdatePersonal},
|
||||
// Can read their own organization member record
|
||||
ResourceOrganizationMember.Type: {policy.ActionRead},
|
||||
// Users can create provisioner daemons scoped to themselves.
|
||||
ResourceProvisionerDaemon.Type: {policy.ActionRead, policy.ActionCreate, policy.ActionRead, policy.ActionUpdate},
|
||||
})...,
|
||||
@@ -427,7 +431,6 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
// Note: even without PrebuiltWorkspace permissions, access is still granted via Workspace permissions.
|
||||
ResourcePrebuiltWorkspace.Type: {policy.ActionUpdate, policy.ActionDelete},
|
||||
})...),
|
||||
Member: []Permission{},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -451,16 +454,6 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
// Can read available roles.
|
||||
ResourceAssignOrgRole.Type: {policy.ActionRead},
|
||||
}),
|
||||
Member: append(allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceUser, ResourceOrganizationMember),
|
||||
Permissions(map[string][]policy.Action{
|
||||
// Reduced permission set on dormant workspaces. No build, ssh, or exec
|
||||
ResourceWorkspaceDormant.Type: {policy.ActionRead, policy.ActionDelete, policy.ActionCreate, policy.ActionUpdate, policy.ActionWorkspaceStop, policy.ActionCreateAgent, policy.ActionDeleteAgent},
|
||||
// Can read their own organization member record
|
||||
ResourceOrganizationMember.Type: {policy.ActionRead},
|
||||
// Users can create provisioner daemons scoped to themselves.
|
||||
ResourceProvisionerDaemon.Type: {policy.ActionRead, policy.ActionCreate, policy.ActionRead, policy.ActionUpdate},
|
||||
})...,
|
||||
),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -483,7 +476,6 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
ResourceOrganization.Type: {policy.ActionRead},
|
||||
ResourceOrganizationMember.Type: {policy.ActionRead},
|
||||
}),
|
||||
Member: []Permission{},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -510,7 +502,6 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
ResourceGroupMember.Type: ResourceGroupMember.AvailableActions(),
|
||||
ResourceIdpsyncSettings.Type: {policy.ActionRead, policy.ActionUpdate},
|
||||
}),
|
||||
Member: []Permission{},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -540,7 +531,6 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
ResourceProvisionerDaemon.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
ResourceProvisionerJobs.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreate},
|
||||
}),
|
||||
Member: []Permission{},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -578,7 +568,6 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
Action: policy.ActionDeleteAgent,
|
||||
},
|
||||
},
|
||||
Member: []Permission{},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -691,10 +680,9 @@ func (perm Permission) Valid() error {
|
||||
}
|
||||
|
||||
// Role is a set of permissions at multiple levels:
|
||||
// - Site permissions apply EVERYWHERE
|
||||
// - Org permissions apply to EVERYTHING in a given ORG
|
||||
// - User permissions apply to all resources the user owns
|
||||
// - OrgMember permissions apply to resources in the given org that the user owns
|
||||
// - Site level permissions apply EVERYWHERE
|
||||
// - Org level permissions apply to EVERYTHING in a given ORG
|
||||
// - User level permissions are the lowest
|
||||
// This is the type passed into the rego as a json payload.
|
||||
// Users of this package should instead **only** use the role names, and
|
||||
// this package will expand the role names into their json payloads.
|
||||
@@ -715,8 +703,7 @@ type Role struct {
|
||||
}
|
||||
|
||||
type OrgPermissions struct {
|
||||
Org []Permission `json:"org"`
|
||||
Member []Permission `json:"member"`
|
||||
Org []Permission `json:"org"`
|
||||
}
|
||||
|
||||
// Valid will check all it's permissions and ensure they are all correct
|
||||
@@ -733,12 +720,7 @@ func (role Role) Valid() error {
|
||||
for orgID, orgPermissions := range role.ByOrgID {
|
||||
for _, perm := range orgPermissions.Org {
|
||||
if err := perm.Valid(); err != nil {
|
||||
errs = append(errs, xerrors.Errorf("org=%q: org %w", orgID, err))
|
||||
}
|
||||
}
|
||||
for _, perm := range orgPermissions.Member {
|
||||
if err := perm.Valid(); err != nil {
|
||||
errs = append(errs, xerrors.Errorf("org=%q: member: %w", orgID, err))
|
||||
errs = append(errs, xerrors.Errorf("org=%q: %w", orgID, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,11 +33,10 @@ func BenchmarkRBACValueAllocation(b *testing.B) {
|
||||
uuid.NewString(): {policy.ActionRead, policy.ActionCreate},
|
||||
uuid.NewString(): {policy.ActionRead, policy.ActionCreate},
|
||||
uuid.NewString(): {policy.ActionRead, policy.ActionCreate},
|
||||
}).
|
||||
WithACLUserList(map[string][]policy.Action{
|
||||
uuid.NewString(): {policy.ActionRead, policy.ActionCreate},
|
||||
uuid.NewString(): {policy.ActionRead, policy.ActionCreate},
|
||||
})
|
||||
}).WithACLUserList(map[string][]policy.Action{
|
||||
uuid.NewString(): {policy.ActionRead, policy.ActionCreate},
|
||||
uuid.NewString(): {policy.ActionRead, policy.ActionCreate},
|
||||
})
|
||||
|
||||
jsonSubject := authSubject{
|
||||
ID: actor.ID,
|
||||
@@ -108,7 +107,7 @@ func TestRegoInputValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// This is the input that would be passed to the rego policy.
|
||||
jsonInput := map[string]any{
|
||||
jsonInput := map[string]interface{}{
|
||||
"subject": authSubject{
|
||||
ID: actor.ID,
|
||||
Roles: must(actor.Roles.Expand()),
|
||||
@@ -139,7 +138,7 @@ func TestRegoInputValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// This is the input that would be passed to the rego policy.
|
||||
jsonInput := map[string]any{
|
||||
jsonInput := map[string]interface{}{
|
||||
"subject": authSubject{
|
||||
ID: actor.ID,
|
||||
Roles: must(actor.Roles.Expand()),
|
||||
@@ -147,7 +146,7 @@ func TestRegoInputValue(t *testing.T) {
|
||||
Scope: must(actor.Scope.Expand()),
|
||||
},
|
||||
"action": action,
|
||||
"object": map[string]any{
|
||||
"object": map[string]interface{}{
|
||||
"type": obj.Type,
|
||||
},
|
||||
}
|
||||
@@ -283,6 +282,5 @@ func equalRoles(t *testing.T, a, b Role) {
|
||||
bv, ok := b.ByOrgID[ak]
|
||||
require.True(t, ok, "org permissions missing: %s", ak)
|
||||
require.ElementsMatchf(t, av.Org, bv.Org, "org %s permissions", ak)
|
||||
require.ElementsMatchf(t, av.Member, bv.Member, "member %s permissions", ak)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
clitelemetry "github.com/coder/coder/v2/cli/telemetry"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -35,7 +36,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -48,7 +48,6 @@ type Options struct {
|
||||
Disabled bool
|
||||
Database database.Store
|
||||
Logger slog.Logger
|
||||
Clock quartz.Clock
|
||||
// URL is an endpoint to direct telemetry towards!
|
||||
URL *url.URL
|
||||
Experiments codersdk.Experiments
|
||||
@@ -66,9 +65,6 @@ type Options struct {
|
||||
// Duplicate data will be sent, it's on the server-side to index by UUID.
|
||||
// Data is anonymized prior to being sent!
|
||||
func New(options Options) (Reporter, error) {
|
||||
if options.Clock == nil {
|
||||
options.Clock = quartz.NewReal()
|
||||
}
|
||||
if options.SnapshotFrequency == 0 {
|
||||
// Report once every 30mins by default!
|
||||
options.SnapshotFrequency = 30 * time.Minute
|
||||
@@ -90,7 +86,7 @@ func New(options Options) (Reporter, error) {
|
||||
options: options,
|
||||
deploymentURL: deploymentURL,
|
||||
snapshotURL: snapshotURL,
|
||||
startedAt: dbtime.Time(options.Clock.Now()).UTC(),
|
||||
startedAt: dbtime.Now(),
|
||||
client: &http.Client{},
|
||||
}
|
||||
go reporter.runSnapshotter()
|
||||
@@ -170,7 +166,7 @@ func (r *remoteReporter) Close() {
|
||||
return
|
||||
}
|
||||
close(r.closed)
|
||||
now := dbtime.Time(r.options.Clock.Now()).UTC()
|
||||
now := dbtime.Now()
|
||||
r.shutdownAt = &now
|
||||
if r.Enabled() {
|
||||
// Report a final collection of telemetry prior to close!
|
||||
@@ -416,7 +412,7 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
|
||||
ctx = r.ctx
|
||||
// For resources that grow in size very quickly (like workspace builds),
|
||||
// we only report events that occurred within the past hour.
|
||||
createdAfter = dbtime.Time(r.options.Clock.Now().Add(-1 * time.Hour)).UTC()
|
||||
createdAfter = dbtime.Now().Add(-1 * time.Hour)
|
||||
eg errgroup.Group
|
||||
snapshot = &Snapshot{
|
||||
DeploymentID: r.options.DeploymentID,
|
||||
@@ -748,14 +744,6 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
summaries, err := r.generateAIBridgeInterceptionsSummaries(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("generate AIBridge interceptions telemetry summaries: %w", err)
|
||||
}
|
||||
snapshot.AIBridgeInterceptionsSummaries = summaries
|
||||
return nil
|
||||
})
|
||||
|
||||
err := eg.Wait()
|
||||
if err != nil {
|
||||
@@ -764,76 +752,6 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
func (r *remoteReporter) generateAIBridgeInterceptionsSummaries(ctx context.Context) ([]AIBridgeInterceptionsSummary, error) {
|
||||
// Get the current timeframe, which is the previous hour.
|
||||
now := dbtime.Time(r.options.Clock.Now()).UTC()
|
||||
endedAtBefore := now.Truncate(time.Hour)
|
||||
endedAtAfter := endedAtBefore.Add(-1 * time.Hour)
|
||||
|
||||
// Note: we don't use a transaction for this function since we do tolerate
|
||||
// some errors, like duplicate lock rows, and we also calculate
|
||||
// summaries in parallel.
|
||||
|
||||
// Claim the heartbeat lock row for this hour.
|
||||
err := r.options.Database.InsertTelemetryLock(ctx, database.InsertTelemetryLockParams{
|
||||
EventType: "aibridge_interceptions_summary",
|
||||
PeriodEndingAt: endedAtBefore,
|
||||
})
|
||||
if database.IsUniqueViolation(err, database.UniqueTelemetryLocksPkey) {
|
||||
// Another replica has already claimed the lock row for this hour.
|
||||
r.options.Logger.Debug(ctx, "aibridge interceptions telemetry lock already claimed for this hour by another replica, skipping", slog.F("period_ending_at", endedAtBefore))
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert AIBridge interceptions telemetry lock (period_ending_at=%q): %w", endedAtBefore, err)
|
||||
}
|
||||
|
||||
// List the summary categories that need to be calculated.
|
||||
summaryCategories, err := r.options.Database.ListAIBridgeInterceptionsTelemetrySummaries(ctx, database.ListAIBridgeInterceptionsTelemetrySummariesParams{
|
||||
EndedAtAfter: endedAtAfter, // inclusive
|
||||
EndedAtBefore: endedAtBefore, // exclusive
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("list AIBridge interceptions telemetry summaries (startedAtAfter=%q, endedAtBefore=%q): %w", endedAtAfter, endedAtBefore, err)
|
||||
}
|
||||
|
||||
// Calculate and convert the summaries for all categories.
|
||||
var (
|
||||
eg, egCtx = errgroup.WithContext(ctx)
|
||||
mu sync.Mutex
|
||||
summaries = make([]AIBridgeInterceptionsSummary, 0, len(summaryCategories))
|
||||
)
|
||||
for _, category := range summaryCategories {
|
||||
eg.Go(func() error {
|
||||
summary, err := r.options.Database.CalculateAIBridgeInterceptionsTelemetrySummary(egCtx, database.CalculateAIBridgeInterceptionsTelemetrySummaryParams{
|
||||
Provider: category.Provider,
|
||||
Model: category.Model,
|
||||
Client: category.Client,
|
||||
EndedAtAfter: endedAtAfter,
|
||||
EndedAtBefore: endedAtBefore,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("calculate AIBridge interceptions telemetry summary (provider=%q, model=%q, client=%q, startedAtAfter=%q, endedAtBefore=%q): %w", category.Provider, category.Model, category.Client, endedAtAfter, endedAtBefore, err)
|
||||
}
|
||||
|
||||
// Double check that at least one interception was found in the
|
||||
// timeframe.
|
||||
if summary.InterceptionCount == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
converted := ConvertAIBridgeInterceptionsSummary(endedAtBefore, category.Provider, category.Model, category.Client, summary)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
summaries = append(summaries, converted)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return summaries, eg.Wait()
|
||||
}
|
||||
|
||||
// ConvertAPIKey anonymizes an API key.
|
||||
func ConvertAPIKey(apiKey database.APIKey) APIKey {
|
||||
a := APIKey{
|
||||
@@ -1305,7 +1223,6 @@ type Snapshot struct {
|
||||
TelemetryItems []TelemetryItem `json:"telemetry_items"`
|
||||
UserTailnetConnections []UserTailnetConnection `json:"user_tailnet_connections"`
|
||||
PrebuiltWorkspaces []PrebuiltWorkspace `json:"prebuilt_workspaces"`
|
||||
AIBridgeInterceptionsSummaries []AIBridgeInterceptionsSummary `json:"aibridge_interceptions_summaries"`
|
||||
}
|
||||
|
||||
// Deployment contains information about the host running Coder.
|
||||
@@ -1942,89 +1859,6 @@ type PrebuiltWorkspace struct {
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
type AIBridgeInterceptionsSummaryDurationMillis struct {
|
||||
P50 int64 `json:"p50"`
|
||||
P90 int64 `json:"p90"`
|
||||
P95 int64 `json:"p95"`
|
||||
P99 int64 `json:"p99"`
|
||||
}
|
||||
|
||||
type AIBridgeInterceptionsSummaryTokenCount struct {
|
||||
Input int64 `json:"input"`
|
||||
Output int64 `json:"output"`
|
||||
CachedRead int64 `json:"cached_read"`
|
||||
CachedWritten int64 `json:"cached_written"`
|
||||
}
|
||||
|
||||
type AIBridgeInterceptionsSummaryToolCallsCount struct {
|
||||
Injected int64 `json:"injected"`
|
||||
NonInjected int64 `json:"non_injected"`
|
||||
}
|
||||
|
||||
// AIBridgeInterceptionsSummary is a summary of aggregated AI Bridge
|
||||
// interception data over a period of 1 hour. We send a summary each hour for
|
||||
// each unique provider + model + client combination.
|
||||
type AIBridgeInterceptionsSummary struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
|
||||
// The end of the hour for which the summary is taken. This will always be a
|
||||
// UTC timestamp truncated to the hour.
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Client string `json:"client"`
|
||||
|
||||
InterceptionCount int64 `json:"interception_count"`
|
||||
InterceptionDurationMillis AIBridgeInterceptionsSummaryDurationMillis `json:"interception_duration_millis"`
|
||||
|
||||
// Map of route to number of interceptions.
|
||||
// e.g. "/v1/chat/completions:blocking", "/v1/chat/completions:streaming"
|
||||
InterceptionsByRoute map[string]int64 `json:"interceptions_by_route"`
|
||||
|
||||
UniqueInitiatorCount int64 `json:"unique_initiator_count"`
|
||||
|
||||
UserPromptsCount int64 `json:"user_prompts_count"`
|
||||
|
||||
TokenUsagesCount int64 `json:"token_usages_count"`
|
||||
TokenCount AIBridgeInterceptionsSummaryTokenCount `json:"token_count"`
|
||||
|
||||
ToolCallsCount AIBridgeInterceptionsSummaryToolCallsCount `json:"tool_calls_count"`
|
||||
InjectedToolCallErrorCount int64 `json:"injected_tool_call_error_count"`
|
||||
}
|
||||
|
||||
func ConvertAIBridgeInterceptionsSummary(endTime time.Time, provider, model, client string, summary database.CalculateAIBridgeInterceptionsTelemetrySummaryRow) AIBridgeInterceptionsSummary {
|
||||
return AIBridgeInterceptionsSummary{
|
||||
ID: uuid.New(),
|
||||
Timestamp: endTime,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Client: client,
|
||||
InterceptionCount: summary.InterceptionCount,
|
||||
InterceptionDurationMillis: AIBridgeInterceptionsSummaryDurationMillis{
|
||||
P50: summary.InterceptionDurationP50Millis,
|
||||
P90: summary.InterceptionDurationP90Millis,
|
||||
P95: summary.InterceptionDurationP95Millis,
|
||||
P99: summary.InterceptionDurationP99Millis,
|
||||
},
|
||||
// TODO: currently we don't track by route
|
||||
InterceptionsByRoute: make(map[string]int64),
|
||||
UniqueInitiatorCount: summary.UniqueInitiatorCount,
|
||||
UserPromptsCount: summary.UserPromptsCount,
|
||||
TokenUsagesCount: summary.TokenUsagesCount,
|
||||
TokenCount: AIBridgeInterceptionsSummaryTokenCount{
|
||||
Input: summary.TokenCountInput,
|
||||
Output: summary.TokenCountOutput,
|
||||
CachedRead: summary.TokenCountCachedRead,
|
||||
CachedWritten: summary.TokenCountCachedWritten,
|
||||
},
|
||||
ToolCallsCount: AIBridgeInterceptionsSummaryToolCallsCount{
|
||||
Injected: summary.ToolCallsCountInjected,
|
||||
NonInjected: summary.ToolCallsCountNonInjected,
|
||||
},
|
||||
InjectedToolCallErrorCount: summary.InjectedToolCallErrorCount,
|
||||
}
|
||||
}
|
||||
|
||||
type noopReporter struct{}
|
||||
|
||||
func (*noopReporter) Report(_ *Snapshot) {}
|
||||
|
||||
@@ -28,7 +28,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -45,7 +44,6 @@ func TestTelemetry(t *testing.T) {
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
now := dbtime.Now()
|
||||
|
||||
org, err := db.GetDefaultOrganization(ctx)
|
||||
require.NoError(t, err)
|
||||
@@ -210,88 +208,12 @@ func TestTelemetry(t *testing.T) {
|
||||
AgentID: wsagent.ID,
|
||||
})
|
||||
|
||||
previousAIBridgeInterceptionPeriod := now.Truncate(time.Hour)
|
||||
user2 := dbgen.User(t, db, database.User{})
|
||||
aiBridgeInterception1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: user.ID,
|
||||
Provider: "anthropic",
|
||||
Model: "deanseek",
|
||||
StartedAt: previousAIBridgeInterceptionPeriod.Add(-30 * time.Minute),
|
||||
}, nil)
|
||||
_ = dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{
|
||||
InterceptionID: aiBridgeInterception1.ID,
|
||||
InputTokens: 100,
|
||||
OutputTokens: 200,
|
||||
Metadata: json.RawMessage(`{"cache_read_input":300,"cache_creation_input":400}`),
|
||||
})
|
||||
_ = dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{
|
||||
InterceptionID: aiBridgeInterception1.ID,
|
||||
})
|
||||
_ = dbgen.AIBridgeToolUsage(t, db, database.InsertAIBridgeToolUsageParams{
|
||||
InterceptionID: aiBridgeInterception1.ID,
|
||||
Injected: true,
|
||||
InvocationError: sql.NullString{String: "error1", Valid: true},
|
||||
})
|
||||
_, err = db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: aiBridgeInterception1.ID,
|
||||
EndedAt: aiBridgeInterception1.StartedAt.Add(1 * time.Minute), // 1 minute duration
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiBridgeInterception2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: user2.ID,
|
||||
Provider: aiBridgeInterception1.Provider,
|
||||
Model: aiBridgeInterception1.Model,
|
||||
StartedAt: aiBridgeInterception1.StartedAt,
|
||||
}, nil)
|
||||
_ = dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{
|
||||
InterceptionID: aiBridgeInterception2.ID,
|
||||
InputTokens: 100,
|
||||
OutputTokens: 200,
|
||||
Metadata: json.RawMessage(`{"cache_read_input":300,"cache_creation_input":400}`),
|
||||
})
|
||||
_ = dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{
|
||||
InterceptionID: aiBridgeInterception2.ID,
|
||||
})
|
||||
_ = dbgen.AIBridgeToolUsage(t, db, database.InsertAIBridgeToolUsageParams{
|
||||
InterceptionID: aiBridgeInterception2.ID,
|
||||
Injected: false,
|
||||
})
|
||||
_, err = db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: aiBridgeInterception2.ID,
|
||||
EndedAt: aiBridgeInterception2.StartedAt.Add(2 * time.Minute), // 2 minute duration
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiBridgeInterception3 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: user2.ID,
|
||||
Provider: "openai",
|
||||
Model: "gpt-5",
|
||||
StartedAt: aiBridgeInterception1.StartedAt,
|
||||
}, nil)
|
||||
_, err = db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: aiBridgeInterception3.ID,
|
||||
EndedAt: aiBridgeInterception3.StartedAt.Add(3 * time.Minute), // 3 minute duration
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_ = dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: user2.ID,
|
||||
Provider: "openai",
|
||||
Model: "gpt-5",
|
||||
StartedAt: aiBridgeInterception1.StartedAt,
|
||||
}, nil)
|
||||
// not ended, so it should not affect summaries
|
||||
|
||||
clock := quartz.NewMock(t)
|
||||
clock.Set(now)
|
||||
|
||||
_, snapshot := collectSnapshot(ctx, t, db, func(opts telemetry.Options) telemetry.Options {
|
||||
opts.Clock = clock
|
||||
return opts
|
||||
})
|
||||
_, snapshot := collectSnapshot(ctx, t, db, nil)
|
||||
require.Len(t, snapshot.ProvisionerJobs, 2)
|
||||
require.Len(t, snapshot.Licenses, 1)
|
||||
require.Len(t, snapshot.Templates, 2)
|
||||
require.Len(t, snapshot.TemplateVersions, 3)
|
||||
require.Len(t, snapshot.Users, 2)
|
||||
require.Len(t, snapshot.Users, 1)
|
||||
require.Len(t, snapshot.Groups, 2)
|
||||
// 1 member in the everyone group + 1 member in the custom group
|
||||
require.Len(t, snapshot.GroupMembers, 2)
|
||||
@@ -365,53 +287,6 @@ func TestTelemetry(t *testing.T) {
|
||||
for _, entity := range snapshot.Templates {
|
||||
require.Equal(t, entity.OrganizationID, org.ID)
|
||||
}
|
||||
|
||||
// 2 unique provider + model + client combinations
|
||||
require.Len(t, snapshot.AIBridgeInterceptionsSummaries, 2)
|
||||
snapshot1 := snapshot.AIBridgeInterceptionsSummaries[0]
|
||||
snapshot2 := snapshot.AIBridgeInterceptionsSummaries[1]
|
||||
if snapshot1.Provider != aiBridgeInterception1.Provider {
|
||||
snapshot1, snapshot2 = snapshot2, snapshot1
|
||||
}
|
||||
|
||||
require.Equal(t, snapshot1.Provider, aiBridgeInterception1.Provider)
|
||||
require.Equal(t, snapshot1.Model, aiBridgeInterception1.Model)
|
||||
require.Equal(t, snapshot1.Client, "unknown") // no client info yet
|
||||
require.EqualValues(t, snapshot1.InterceptionCount, 2)
|
||||
require.EqualValues(t, snapshot1.InterceptionsByRoute, map[string]int64{}) // no route info yet
|
||||
require.EqualValues(t, snapshot1.InterceptionDurationMillis.P50, 90_000)
|
||||
require.EqualValues(t, snapshot1.InterceptionDurationMillis.P90, 114_000)
|
||||
require.EqualValues(t, snapshot1.InterceptionDurationMillis.P95, 117_000)
|
||||
require.EqualValues(t, snapshot1.InterceptionDurationMillis.P99, 119_400)
|
||||
require.EqualValues(t, snapshot1.UniqueInitiatorCount, 2)
|
||||
require.EqualValues(t, snapshot1.UserPromptsCount, 2)
|
||||
require.EqualValues(t, snapshot1.TokenUsagesCount, 2)
|
||||
require.EqualValues(t, snapshot1.TokenCount.Input, 200)
|
||||
require.EqualValues(t, snapshot1.TokenCount.Output, 400)
|
||||
require.EqualValues(t, snapshot1.TokenCount.CachedRead, 600)
|
||||
require.EqualValues(t, snapshot1.TokenCount.CachedWritten, 800)
|
||||
require.EqualValues(t, snapshot1.ToolCallsCount.Injected, 1)
|
||||
require.EqualValues(t, snapshot1.ToolCallsCount.NonInjected, 1)
|
||||
require.EqualValues(t, snapshot1.InjectedToolCallErrorCount, 1)
|
||||
|
||||
require.Equal(t, snapshot2.Provider, aiBridgeInterception3.Provider)
|
||||
require.Equal(t, snapshot2.Model, aiBridgeInterception3.Model)
|
||||
require.Equal(t, snapshot2.Client, "unknown") // no client info yet
|
||||
require.EqualValues(t, snapshot2.InterceptionCount, 1)
|
||||
require.EqualValues(t, snapshot2.InterceptionsByRoute, map[string]int64{}) // no route info yet
|
||||
require.EqualValues(t, snapshot2.InterceptionDurationMillis.P50, 180_000)
|
||||
require.EqualValues(t, snapshot2.InterceptionDurationMillis.P90, 180_000)
|
||||
require.EqualValues(t, snapshot2.InterceptionDurationMillis.P95, 180_000)
|
||||
require.EqualValues(t, snapshot2.InterceptionDurationMillis.P99, 180_000)
|
||||
require.EqualValues(t, snapshot2.UniqueInitiatorCount, 1)
|
||||
require.EqualValues(t, snapshot2.UserPromptsCount, 0)
|
||||
require.EqualValues(t, snapshot2.TokenUsagesCount, 0)
|
||||
require.EqualValues(t, snapshot2.TokenCount.Input, 0)
|
||||
require.EqualValues(t, snapshot2.TokenCount.Output, 0)
|
||||
require.EqualValues(t, snapshot2.TokenCount.CachedRead, 0)
|
||||
require.EqualValues(t, snapshot2.TokenCount.CachedWritten, 0)
|
||||
require.EqualValues(t, snapshot2.ToolCallsCount.Injected, 0)
|
||||
require.EqualValues(t, snapshot2.ToolCallsCount.NonInjected, 0)
|
||||
})
|
||||
t.Run("HashedEmail", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -324,19 +324,11 @@ func (p *DBTokenProvider) authorizeRequest(ctx context.Context, roles *rbac.Subj
|
||||
// rbacResourceOwned is for the level "authenticated". We still need to
|
||||
// make sure the API key has permissions to connect to the actor's own
|
||||
// workspace. Scopes would prevent this.
|
||||
// TODO: This is an odd repercussion of the org_member permission level.
|
||||
// This Object used to not specify an org restriction, and `InOrg` would
|
||||
// actually have a significantly different meaning (only sharing with
|
||||
// other authenticated users in the same org, whereas the existing behavior
|
||||
// is to share with any authenticated user). Because workspaces are always
|
||||
// jointly owned by an organization, there _must_ be an org restriction on
|
||||
// the object to check the proper permissions. AnyOrg is almost the same,
|
||||
// but technically excludes users who are not in any organization. This is
|
||||
// the closest we can get though without more significant refactoring.
|
||||
rbacResourceOwned rbac.Object = rbac.ResourceWorkspace.WithOwner(roles.ID).AnyOrganization()
|
||||
rbacResourceOwned rbac.Object = rbac.ResourceWorkspace.WithOwner(roles.ID)
|
||||
)
|
||||
if dbReq.AccessMethod == AccessMethodTerminal {
|
||||
rbacAction = policy.ActionSSH
|
||||
rbacResourceOwned = rbac.ResourceWorkspace.WithOwner(roles.ID)
|
||||
}
|
||||
|
||||
// Do a standard RBAC check. This accounts for share level "owner" and any
|
||||
|
||||
@@ -2654,7 +2654,6 @@ func convertWorkspace(
|
||||
Favorite: requesterFavorite,
|
||||
NextStartAt: nextStartAt,
|
||||
IsPrebuild: workspace.IsPrebuild(),
|
||||
TaskID: workspace.TaskID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
+66
-40
@@ -4700,16 +4700,11 @@ func TestWorkspaceFilterHasAITask(t *testing.T) {
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Helper function to create workspace with optional task.
|
||||
createWorkspace := func(jobCompleted, createTask bool, prompt string) uuid.UUID {
|
||||
// TODO(mafredri): The bellow comment is based on deprecated logic and
|
||||
// kept only present to test that the old observable behavior works as
|
||||
// intended.
|
||||
//
|
||||
// Helper function to create workspace with AI task configuration
|
||||
createWorkspaceWithAIConfig := func(hasAITask sql.NullBool, jobCompleted bool, aiTaskPrompt *string) database.WorkspaceTable {
|
||||
// When a provisioner job uses these tags, no provisioner will match it.
|
||||
// We do this so jobs will always be stuck in "pending", allowing us to
|
||||
// exercise the intermediary state when has_ai_task is nil and we
|
||||
// compensate by looking at pending provisioning jobs.
|
||||
// We do this so jobs will always be stuck in "pending", allowing us to exercise the intermediary state when
|
||||
// has_ai_task is nil and we compensate by looking at pending provisioning jobs.
|
||||
// See GetWorkspaces clauses.
|
||||
unpickableTags := database.StringMap{"custom": "true"}
|
||||
|
||||
@@ -4728,71 +4723,102 @@ func TestWorkspaceFilterHasAITask(t *testing.T) {
|
||||
jobConfig.CompletedAt = sql.NullTime{Time: time.Now(), Valid: true}
|
||||
}
|
||||
job := dbgen.ProvisionerJob(t, db, pubsub, jobConfig)
|
||||
|
||||
res := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{JobID: job.ID})
|
||||
agnt := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ResourceID: res.ID})
|
||||
taskApp := dbgen.WorkspaceApp(t, db, database.WorkspaceApp{AgentID: agnt.ID})
|
||||
|
||||
var sidebarAppID uuid.UUID
|
||||
if hasAITask.Bool {
|
||||
sidebarApp := dbgen.WorkspaceApp(t, db, database.WorkspaceApp{AgentID: agnt.ID})
|
||||
sidebarAppID = sidebarApp.ID
|
||||
}
|
||||
|
||||
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
WorkspaceID: ws.ID,
|
||||
TemplateVersionID: version.ID,
|
||||
InitiatorID: user.UserID,
|
||||
JobID: job.ID,
|
||||
BuildNumber: 1,
|
||||
AITaskSidebarAppID: uuid.NullUUID{UUID: taskApp.ID, Valid: createTask},
|
||||
HasAITask: hasAITask,
|
||||
AITaskSidebarAppID: uuid.NullUUID{UUID: sidebarAppID, Valid: sidebarAppID != uuid.Nil},
|
||||
})
|
||||
|
||||
if createTask {
|
||||
task := dbgen.Task(t, db, database.TaskTable{
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
TemplateVersionID: version.ID,
|
||||
Prompt: prompt,
|
||||
})
|
||||
dbgen.TaskWorkspaceApp(t, db, database.TaskWorkspaceApp{
|
||||
TaskID: task.ID,
|
||||
WorkspaceBuildNumber: build.BuildNumber,
|
||||
WorkspaceAgentID: uuid.NullUUID{UUID: agnt.ID, Valid: true},
|
||||
WorkspaceAppID: uuid.NullUUID{UUID: taskApp.ID, Valid: true},
|
||||
if aiTaskPrompt != nil {
|
||||
err := db.InsertWorkspaceBuildParameters(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceBuildParametersParams{
|
||||
WorkspaceBuildID: build.ID,
|
||||
Name: []string{provider.TaskPromptParameterName},
|
||||
Value: []string{*aiTaskPrompt},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
return ws.ID
|
||||
return ws
|
||||
}
|
||||
|
||||
// Create workspaces with tasks.
|
||||
wsWithTask1 := createWorkspace(true, true, "Build me a web app")
|
||||
wsWithTask2 := createWorkspace(false, true, "Another task")
|
||||
// Create test workspaces with different AI task configurations
|
||||
wsWithAITask := createWorkspaceWithAIConfig(sql.NullBool{Bool: true, Valid: true}, true, nil)
|
||||
wsWithoutAITask := createWorkspaceWithAIConfig(sql.NullBool{Bool: false, Valid: true}, false, nil)
|
||||
|
||||
// Create workspaces without tasks
|
||||
wsWithoutTask1 := createWorkspace(true, false, "")
|
||||
wsWithoutTask2 := createWorkspace(false, false, "")
|
||||
aiTaskPrompt := "Build me a web app"
|
||||
wsWithAITaskParam := createWorkspaceWithAIConfig(sql.NullBool{Valid: false}, false, &aiTaskPrompt)
|
||||
|
||||
anotherTaskPrompt := "Another task"
|
||||
wsCompletedWithAITaskParam := createWorkspaceWithAIConfig(sql.NullBool{Valid: false}, true, &anotherTaskPrompt)
|
||||
|
||||
emptyPrompt := ""
|
||||
wsWithEmptyAITaskParam := createWorkspaceWithAIConfig(sql.NullBool{Valid: false}, false, &emptyPrompt)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
// Debug: Check all workspaces without filter first
|
||||
allRes, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{})
|
||||
require.NoError(t, err)
|
||||
t.Logf("Total workspaces created: %d", len(allRes.Workspaces))
|
||||
for i, ws := range allRes.Workspaces {
|
||||
t.Logf("All Workspace %d: ID=%s, Name=%s, Build ID=%s, Job ID=%s", i, ws.ID, ws.Name, ws.LatestBuild.ID, ws.LatestBuild.Job.ID)
|
||||
}
|
||||
|
||||
// Test filtering for workspaces with AI tasks
|
||||
// Should include: wsWithTask1 and wsWithTask2
|
||||
// Should include: wsWithAITask (has_ai_task=true) and wsWithAITaskParam (null + incomplete + param)
|
||||
res, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{
|
||||
FilterQuery: "has-ai-task:true",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Logf("Expected 2 workspaces for has-ai-task:true, got %d", len(res.Workspaces))
|
||||
t.Logf("Expected workspaces: %s, %s", wsWithAITask.ID, wsWithAITaskParam.ID)
|
||||
for i, ws := range res.Workspaces {
|
||||
t.Logf("AI Task True Workspace %d: ID=%s, Name=%s", i, ws.ID, ws.Name)
|
||||
}
|
||||
require.Len(t, res.Workspaces, 2)
|
||||
workspaceIDs := []uuid.UUID{res.Workspaces[0].ID, res.Workspaces[1].ID}
|
||||
require.Contains(t, workspaceIDs, wsWithTask1)
|
||||
require.Contains(t, workspaceIDs, wsWithTask2)
|
||||
require.Contains(t, workspaceIDs, wsWithAITask.ID)
|
||||
require.Contains(t, workspaceIDs, wsWithAITaskParam.ID)
|
||||
|
||||
// Test filtering for workspaces without AI tasks
|
||||
// Should include: wsWithoutTask1, wsWithoutTask2, wsWithoutTask3
|
||||
// Should include: wsWithoutAITask, wsCompletedWithAITaskParam, wsWithEmptyAITaskParam
|
||||
res, err = client.Workspaces(ctx, codersdk.WorkspaceFilter{
|
||||
FilterQuery: "has-ai-task:false",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Workspaces, 2)
|
||||
workspaceIDs = []uuid.UUID{res.Workspaces[0].ID, res.Workspaces[1].ID}
|
||||
require.Contains(t, workspaceIDs, wsWithoutTask1)
|
||||
require.Contains(t, workspaceIDs, wsWithoutTask2)
|
||||
|
||||
// Debug: print what we got
|
||||
t.Logf("Expected 3 workspaces for has-ai-task:false, got %d", len(res.Workspaces))
|
||||
for i, ws := range res.Workspaces {
|
||||
t.Logf("Workspace %d: ID=%s, Name=%s", i, ws.ID, ws.Name)
|
||||
}
|
||||
t.Logf("Expected IDs: %s, %s, %s", wsWithoutAITask.ID, wsCompletedWithAITaskParam.ID, wsWithEmptyAITaskParam.ID)
|
||||
|
||||
require.Len(t, res.Workspaces, 3)
|
||||
workspaceIDs = []uuid.UUID{res.Workspaces[0].ID, res.Workspaces[1].ID, res.Workspaces[2].ID}
|
||||
require.Contains(t, workspaceIDs, wsWithoutAITask.ID)
|
||||
require.Contains(t, workspaceIDs, wsCompletedWithAITaskParam.ID)
|
||||
require.Contains(t, workspaceIDs, wsWithEmptyAITaskParam.ID)
|
||||
|
||||
// Test no filter returns all
|
||||
res, err = client.Workspaces(ctx, codersdk.WorkspaceFilter{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Workspaces, 4)
|
||||
require.Len(t, res.Workspaces, 5)
|
||||
}
|
||||
|
||||
func TestWorkspaceAppUpsertRestart(t *testing.T) {
|
||||
|
||||
@@ -0,0 +1,254 @@
|
||||
package agentsdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// SocketClient provides a client for communicating with the agent socket
|
||||
type SocketClient struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
// SocketConfig holds configuration for the socket client
|
||||
type SocketConfig struct {
|
||||
Path string // Socket path (optional, will auto-discover if not set)
|
||||
}
|
||||
|
||||
// NewSocketClient creates a new socket client
|
||||
func NewSocketClient(config SocketConfig) (*SocketClient, error) {
|
||||
path := config.Path
|
||||
if path == "" {
|
||||
var err error
|
||||
path, err = discoverSocketPath()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("discover socket path: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := net.Dial("unix", path)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("connect to socket: %w", err)
|
||||
}
|
||||
|
||||
return &SocketClient{
|
||||
conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the socket connection
|
||||
func (c *SocketClient) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// Ping sends a ping request to the agent
|
||||
func (c *SocketClient) Ping(ctx context.Context) (*PingResponse, error) {
|
||||
req := &Request{
|
||||
Version: "1.0",
|
||||
Method: "ping",
|
||||
ID: generateRequestID(),
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return nil, xerrors.Errorf("ping error: %s", resp.Error.Message)
|
||||
}
|
||||
|
||||
var pingResp PingResponse
|
||||
if err := json.Unmarshal(resp.Result, &pingResp); err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal ping response: %w", err)
|
||||
}
|
||||
|
||||
return &pingResp, nil
|
||||
}
|
||||
|
||||
// Health sends a health check request to the agent
|
||||
func (c *SocketClient) Health(ctx context.Context) (*HealthResponse, error) {
|
||||
req := &Request{
|
||||
Version: "1.0",
|
||||
Method: "health",
|
||||
ID: generateRequestID(),
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return nil, xerrors.Errorf("health error: %s", resp.Error.Message)
|
||||
}
|
||||
|
||||
var healthResp HealthResponse
|
||||
if err := json.Unmarshal(resp.Result, &healthResp); err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal health response: %w", err)
|
||||
}
|
||||
|
||||
return &healthResp, nil
|
||||
}
|
||||
|
||||
// AgentInfo sends an agent info request
|
||||
func (c *SocketClient) AgentInfo(ctx context.Context) (*AgentInfo, error) {
|
||||
req := &Request{
|
||||
Version: "1.0",
|
||||
Method: "agent.info",
|
||||
ID: generateRequestID(),
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return nil, xerrors.Errorf("agent info error: %s", resp.Error.Message)
|
||||
}
|
||||
|
||||
var agentInfo AgentInfo
|
||||
if err := json.Unmarshal(resp.Result, &agentInfo); err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal agent info response: %w", err)
|
||||
}
|
||||
|
||||
return &agentInfo, nil
|
||||
}
|
||||
|
||||
// ListMethods lists available methods
|
||||
func (c *SocketClient) ListMethods(ctx context.Context) ([]string, error) {
|
||||
req := &Request{
|
||||
Version: "1.0",
|
||||
Method: "methods.list",
|
||||
ID: generateRequestID(),
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return nil, xerrors.Errorf("list methods error: %s", resp.Error.Message)
|
||||
}
|
||||
|
||||
var methods []string
|
||||
if err := json.Unmarshal(resp.Result, &methods); err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal methods response: %w", err)
|
||||
}
|
||||
|
||||
return methods, nil
|
||||
}
|
||||
|
||||
// sendRequest sends a request and returns the response
|
||||
func (c *SocketClient) sendRequest(_ context.Context, req *Request) (*Response, error) {
|
||||
// Set write deadline
|
||||
if err := c.conn.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||
return nil, xerrors.Errorf("set write deadline: %w", err)
|
||||
}
|
||||
|
||||
// Send request
|
||||
if err := json.NewEncoder(c.conn).Encode(req); err != nil {
|
||||
return nil, xerrors.Errorf("send request: %w", err)
|
||||
}
|
||||
|
||||
// Set read deadline
|
||||
if err := c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||
return nil, xerrors.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
|
||||
// Read response
|
||||
var resp Response
|
||||
if err := json.NewDecoder(c.conn).Decode(&resp); err != nil {
|
||||
return nil, xerrors.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// discoverSocketPath discovers the agent socket path
|
||||
func discoverSocketPath() (string, error) {
|
||||
// Check environment variable first
|
||||
if path := os.Getenv("CODER_AGENT_SOCKET_PATH"); path != "" {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// Try common socket paths
|
||||
paths := []string{
|
||||
// XDG runtime directory
|
||||
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "coder-agent.sock"),
|
||||
// User-specific temp directory
|
||||
filepath.Join(os.TempDir(), fmt.Sprintf("coder-agent-%d.sock", os.Getuid())),
|
||||
// Fallback temp directory
|
||||
filepath.Join(os.TempDir(), "coder-agent.sock"),
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
if path == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", xerrors.New("agent socket not found")
|
||||
}
|
||||
|
||||
// generateRequestID generates a unique request ID
|
||||
func generateRequestID() string {
|
||||
return fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// Request represents a socket request
|
||||
type Request struct {
|
||||
Version string `json:"version"`
|
||||
Method string `json:"method"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// Response represents a socket response
|
||||
type Response struct {
|
||||
Version string `json:"version"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// Error represents a socket error
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// PingResponse represents a ping response
|
||||
type PingResponse struct {
|
||||
Message string `json:"message"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// HealthResponse represents a health check response
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Uptime string `json:"uptime"`
|
||||
}
|
||||
|
||||
// AgentInfo represents agent information
|
||||
type AgentInfo struct {
|
||||
ID string `json:"id"`
|
||||
Version string `json:"version"`
|
||||
Status string `json:"status"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
Uptime string `json:"uptime"`
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user