Compare commits
45 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4f6fc49a87 | |||
| dba9f68b11 | |||
| 245ce91199 | |||
| 5d0734e005 | |||
| 43a1af3cd6 | |||
| 3d5d58ec2b | |||
| 37d937554e | |||
| 796190d435 | |||
| c1474c7ee2 | |||
| a8e7cc10b6 | |||
| 82f965a0ae | |||
| acbfb90c30 | |||
| c344d7c00e | |||
| 53350377b3 | |||
| 147df5c971 | |||
| 9e4c283370 | |||
| 145817e8d3 | |||
| 956f6b2473 | |||
| d2afda8191 | |||
| c389c2bc5c | |||
| 4c9e37b659 | |||
| 3b268c95d3 | |||
| 138bc41563 | |||
| 80a172f932 | |||
| 86d8b6daee | |||
| 470e6c7217 | |||
| ed19a3a08e | |||
| 975373704f | |||
| 522288c9d5 | |||
| edd13482a0 | |||
| ef14654078 | |||
| ea37f1ff86 | |||
| c49170b6b3 | |||
| ee9b46fe08 | |||
| 1ad3c898a0 | |||
| b8e09d09b0 | |||
| 0900a44ff3 | |||
| 4537413315 | |||
| ab86ed0df8 | |||
| f2b9d5f8f7 | |||
| b73983e309 | |||
| c11cc0ba30 | |||
| 3163e74b77 | |||
| eca2257c26 | |||
| 75f5b60eb6 |
@@ -45,7 +45,7 @@ jobs:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
- name: check changed files
|
||||
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
|
||||
uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
|
||||
id: filter
|
||||
with:
|
||||
filters: |
|
||||
|
||||
@@ -135,7 +135,7 @@ jobs:
|
||||
PR_NUMBER: ${{ steps.pr_info.outputs.PR_NUMBER }}
|
||||
|
||||
- name: Check changed files
|
||||
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
|
||||
uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
|
||||
id: filter
|
||||
with:
|
||||
base: ${{ github.ref }}
|
||||
|
||||
@@ -63,116 +63,3 @@ jobs:
|
||||
--data "{\"content\": \"$msg\"}" \
|
||||
"${{ secrets.SLACK_SECURITY_FAILURE_WEBHOOK_URL }}"
|
||||
|
||||
trivy:
|
||||
permissions:
|
||||
security-events: write
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Go
|
||||
uses: ./.github/actions/setup-go
|
||||
|
||||
- name: Setup Node
|
||||
uses: ./.github/actions/setup-node
|
||||
|
||||
- name: Setup sqlc
|
||||
uses: ./.github/actions/setup-sqlc
|
||||
|
||||
- name: Install cosign
|
||||
uses: ./.github/actions/install-cosign
|
||||
|
||||
- name: Install syft
|
||||
uses: ./.github/actions/install-syft
|
||||
|
||||
- name: Install yq
|
||||
run: go run github.com/mikefarah/yq/v4@v4.44.3
|
||||
- name: Install mockgen
|
||||
run: ./.github/scripts/retry.sh -- go install go.uber.org/mock/mockgen@v0.6.0
|
||||
- name: Install protoc-gen-go
|
||||
run: ./.github/scripts/retry.sh -- go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.30
|
||||
- name: Install protoc-gen-go-drpc
|
||||
run: ./.github/scripts/retry.sh -- go install storj.io/drpc/cmd/protoc-gen-go-drpc@v0.0.34
|
||||
- name: Install Protoc
|
||||
run: |
|
||||
# protoc must be in lockstep with our dogfood Dockerfile or the
|
||||
# version in the comments will differ. This is also defined in
|
||||
# ci.yaml.
|
||||
set -euxo pipefail
|
||||
cd dogfood/coder
|
||||
mkdir -p /usr/local/bin
|
||||
mkdir -p /usr/local/include
|
||||
|
||||
DOCKER_BUILDKIT=1 docker build . --target proto -t protoc
|
||||
protoc_path=/usr/local/bin/protoc
|
||||
docker run --rm --entrypoint cat protoc /tmp/bin/protoc > $protoc_path
|
||||
chmod +x $protoc_path
|
||||
protoc --version
|
||||
# Copy the generated files to the include directory.
|
||||
docker run --rm -v /usr/local/include:/target protoc cp -r /tmp/include/google /target/
|
||||
ls -la /usr/local/include/google/protobuf/
|
||||
stat /usr/local/include/google/protobuf/timestamp.proto
|
||||
|
||||
- name: Build Coder linux amd64 Docker image
|
||||
id: build
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
version="$(./scripts/version.sh)"
|
||||
image_job="build/coder_${version}_linux_amd64.tag"
|
||||
|
||||
# This environment variable force make to not build packages and
|
||||
# archives (which the Docker image depends on due to technical reasons
|
||||
# related to concurrent FS writes).
|
||||
export DOCKER_IMAGE_NO_PREREQUISITES=true
|
||||
# This environment variables forces scripts/build_docker.sh to build
|
||||
# the base image tag locally instead of using the cached version from
|
||||
# the registry.
|
||||
CODER_IMAGE_BUILD_BASE_TAG="$(CODER_IMAGE_BASE=coder-base ./scripts/image_tag.sh --version "$version")"
|
||||
export CODER_IMAGE_BUILD_BASE_TAG
|
||||
|
||||
# We would like to use make -j here, but it doesn't work with the some recent additions
|
||||
# to our code generation.
|
||||
make "$image_job"
|
||||
echo "image=$(cat "$image_job")" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@57a97c7e7821a5776cebc9bb87c984fa69cba8f1 # v0.34.0
|
||||
with:
|
||||
image-ref: ${{ steps.build.outputs.image }}
|
||||
format: sarif
|
||||
output: trivy-results.sarif
|
||||
severity: "CRITICAL,HIGH"
|
||||
|
||||
- name: Upload Trivy scan results to GitHub Security tab
|
||||
uses: github/codeql-action/upload-sarif@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
|
||||
with:
|
||||
sarif_file: trivy-results.sarif
|
||||
category: "Trivy"
|
||||
|
||||
- name: Upload Trivy scan results as an artifact
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: trivy
|
||||
path: trivy-results.sarif
|
||||
retention-days: 7
|
||||
|
||||
- name: Send Slack notification on failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
msg="❌ Trivy Failed\n\nhttps://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
|
||||
curl \
|
||||
-qfsSL \
|
||||
-X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
--data "{\"content\": \"$msg\"}" \
|
||||
"${{ secrets.SLACK_SECURITY_FAILURE_WEBHOOK_URL }}"
|
||||
|
||||
@@ -1343,6 +1343,7 @@ test-js: site/node_modules/.installed
|
||||
|
||||
test-storybook: site/node_modules/.installed
|
||||
cd site/
|
||||
pnpm playwright:install
|
||||
pnpm exec vitest run --project=storybook
|
||||
.PHONY: test-storybook
|
||||
|
||||
|
||||
+1
-2
@@ -16,7 +16,6 @@ import (
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -1877,7 +1876,7 @@ func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connect
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
sort.Float64s(durations)
|
||||
slices.Sort(durations)
|
||||
durationsLength := len(durations)
|
||||
switch {
|
||||
case durationsLength == 0:
|
||||
|
||||
@@ -433,7 +433,7 @@ func convertDockerInspect(raw []byte) ([]codersdk.WorkspaceAgentContainer, []str
|
||||
}
|
||||
portKeys := maps.Keys(in.NetworkSettings.Ports)
|
||||
// Sort the ports for deterministic output.
|
||||
sort.Strings(portKeys)
|
||||
slices.Sort(portKeys)
|
||||
// If we see the same port bound to both ipv4 and ipv6 loopback or unspecified
|
||||
// interfaces to the same container port, there is no point in adding it multiple times.
|
||||
loopbackHostPortContainerPorts := make(map[int]uint16, 0)
|
||||
|
||||
+31
-46
@@ -2,7 +2,6 @@ package agentdesktop
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -13,6 +12,7 @@ import (
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
@@ -26,9 +26,9 @@ type DesktopAction struct {
|
||||
Duration *int `json:"duration,omitempty"`
|
||||
ScrollAmount *int `json:"scroll_amount,omitempty"`
|
||||
ScrollDirection *string `json:"scroll_direction,omitempty"`
|
||||
// ScaledWidth and ScaledHeight are the coordinate space the
|
||||
// model is using. When provided, coordinates are linearly
|
||||
// mapped from scaled → native before dispatching.
|
||||
// ScaledWidth and ScaledHeight describe the declared model-facing desktop
|
||||
// geometry. When provided, input coordinates are mapped from declared space
|
||||
// to native desktop pixels before dispatching.
|
||||
ScaledWidth *int `json:"scaled_width,omitempty"`
|
||||
ScaledHeight *int `json:"scaled_height,omitempty"`
|
||||
}
|
||||
@@ -144,17 +144,8 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
|
||||
)
|
||||
|
||||
// Helper to scale a coordinate pair from the model's space to
|
||||
// native display pixels.
|
||||
scaleXY := func(x, y int) (int, int) {
|
||||
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
|
||||
x = scaleCoordinate(x, *action.ScaledWidth, cfg.Width)
|
||||
}
|
||||
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
|
||||
y = scaleCoordinate(y, *action.ScaledHeight, cfg.Height)
|
||||
}
|
||||
return x, y
|
||||
}
|
||||
geometry := desktopGeometryForAction(cfg, action)
|
||||
scaleXY := geometry.DeclaredPointToNative
|
||||
|
||||
var resp DesktopActionResponse
|
||||
|
||||
@@ -192,7 +183,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
resp.Output = "type action performed"
|
||||
|
||||
case "cursor_position":
|
||||
x, y, err := a.desktop.CursorPosition(ctx)
|
||||
nativeX, nativeY, err := a.desktop.CursorPosition(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Cursor position failed.",
|
||||
@@ -200,6 +191,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y := geometry.NativePointToDeclared(nativeX, nativeY)
|
||||
resp.Output = "x=" + strconv.Itoa(x) + ",y=" + strconv.Itoa(y)
|
||||
|
||||
case "mouse_move":
|
||||
@@ -447,14 +439,10 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
resp.Output = "hold_key action performed"
|
||||
|
||||
case "screenshot":
|
||||
var opts ScreenshotOptions
|
||||
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
|
||||
opts.TargetWidth = *action.ScaledWidth
|
||||
}
|
||||
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
|
||||
opts.TargetHeight = *action.ScaledHeight
|
||||
}
|
||||
result, err := a.desktop.Screenshot(ctx, opts)
|
||||
result, err := a.desktop.Screenshot(ctx, ScreenshotOptions{
|
||||
TargetWidth: geometry.DeclaredWidth,
|
||||
TargetHeight: geometry.DeclaredHeight,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Screenshot failed.",
|
||||
@@ -464,16 +452,8 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
resp.Output = "screenshot"
|
||||
resp.ScreenshotData = result.Data
|
||||
if action.ScaledWidth != nil && *action.ScaledWidth > 0 && *action.ScaledWidth != cfg.Width {
|
||||
resp.ScreenshotWidth = *action.ScaledWidth
|
||||
} else {
|
||||
resp.ScreenshotWidth = cfg.Width
|
||||
}
|
||||
if action.ScaledHeight != nil && *action.ScaledHeight > 0 && *action.ScaledHeight != cfg.Height {
|
||||
resp.ScreenshotHeight = *action.ScaledHeight
|
||||
} else {
|
||||
resp.ScreenshotHeight = cfg.Height
|
||||
}
|
||||
resp.ScreenshotWidth = geometry.DeclaredWidth
|
||||
resp.ScreenshotHeight = geometry.DeclaredHeight
|
||||
|
||||
default:
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
@@ -512,6 +492,23 @@ func coordFromAction(action DesktopAction) (x, y int, err error) {
|
||||
return action.Coordinate[0], action.Coordinate[1], nil
|
||||
}
|
||||
|
||||
func desktopGeometryForAction(cfg DisplayConfig, action DesktopAction) workspacesdk.DesktopGeometry {
|
||||
declaredWidth := cfg.Width
|
||||
declaredHeight := cfg.Height
|
||||
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
|
||||
declaredWidth = *action.ScaledWidth
|
||||
}
|
||||
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
|
||||
declaredHeight = *action.ScaledHeight
|
||||
}
|
||||
return workspacesdk.NewDesktopGeometryWithDeclared(
|
||||
cfg.Width,
|
||||
cfg.Height,
|
||||
declaredWidth,
|
||||
declaredHeight,
|
||||
)
|
||||
}
|
||||
|
||||
// missingFieldError is returned when a required field is absent from
|
||||
// a DesktopAction.
|
||||
type missingFieldError struct {
|
||||
@@ -522,15 +519,3 @@ type missingFieldError struct {
|
||||
func (e *missingFieldError) Error() string {
|
||||
return "Missing \"" + e.field + "\" for " + e.action + " action."
|
||||
}
|
||||
|
||||
// scaleCoordinate maps a coordinate from scaled → native space.
|
||||
func scaleCoordinate(scaled, scaledDim, nativeDim int) int {
|
||||
if scaledDim == 0 || scaledDim == nativeDim {
|
||||
return scaled
|
||||
}
|
||||
native := (float64(scaled)+0.5)*float64(nativeDim)/float64(scaledDim) - 0.5
|
||||
// Clamp to valid range.
|
||||
native = math.Max(native, 0)
|
||||
native = math.Min(native, float64(nativeDim-1))
|
||||
return int(native)
|
||||
}
|
||||
|
||||
+125
-16
@@ -27,10 +27,12 @@ var _ agentdesktop.Desktop = (*fakeDesktop)(nil)
|
||||
// fakeDesktop is a minimal Desktop implementation for unit tests.
|
||||
type fakeDesktop struct {
|
||||
startErr error
|
||||
cursorPos [2]int
|
||||
startCfg agentdesktop.DisplayConfig
|
||||
vncConnErr error
|
||||
screenshotErr error
|
||||
screenshotRes agentdesktop.ScreenshotResult
|
||||
lastShotOpts agentdesktop.ScreenshotOptions
|
||||
closed bool
|
||||
|
||||
// Track calls for assertions.
|
||||
@@ -51,7 +53,8 @@ func (f *fakeDesktop) VNCConn(context.Context) (net.Conn, error) {
|
||||
return nil, f.vncConnErr
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Screenshot(_ context.Context, _ agentdesktop.ScreenshotOptions) (agentdesktop.ScreenshotResult, error) {
|
||||
func (f *fakeDesktop) Screenshot(_ context.Context, opts agentdesktop.ScreenshotOptions) (agentdesktop.ScreenshotResult, error) {
|
||||
f.lastShotOpts = opts
|
||||
return f.screenshotRes, f.screenshotErr
|
||||
}
|
||||
|
||||
@@ -100,8 +103,8 @@ func (f *fakeDesktop) Type(_ context.Context, text string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*fakeDesktop) CursorPosition(context.Context) (x int, y int, err error) {
|
||||
return 10, 20, nil
|
||||
func (f *fakeDesktop) CursorPosition(context.Context) (x int, y int, err error) {
|
||||
return f.cursorPos[0], f.cursorPos[1], nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Close() error {
|
||||
@@ -135,8 +138,12 @@ func TestHandleAction_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
geometry := workspacesdk.DefaultDesktopGeometry()
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: workspacesdk.DesktopDisplayWidth, Height: workspacesdk.DesktopDisplayHeight},
|
||||
startCfg: agentdesktop.DisplayConfig{
|
||||
Width: geometry.NativeWidth,
|
||||
Height: geometry.NativeHeight,
|
||||
},
|
||||
screenshotRes: agentdesktop.ScreenshotResult{Data: "base64data"},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
@@ -158,11 +165,52 @@ func TestHandleAction_Screenshot(t *testing.T) {
|
||||
var result agentdesktop.DesktopActionResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&result)
|
||||
require.NoError(t, err)
|
||||
// Dimensions come from DisplayConfig, not the screenshot CLI.
|
||||
assert.Equal(t, "screenshot", result.Output)
|
||||
assert.Equal(t, "base64data", result.ScreenshotData)
|
||||
assert.Equal(t, workspacesdk.DesktopDisplayWidth, result.ScreenshotWidth)
|
||||
assert.Equal(t, workspacesdk.DesktopDisplayHeight, result.ScreenshotHeight)
|
||||
assert.Equal(t, geometry.NativeWidth, result.ScreenshotWidth)
|
||||
assert.Equal(t, geometry.NativeHeight, result.ScreenshotHeight)
|
||||
assert.Equal(t, agentdesktop.ScreenshotOptions{
|
||||
TargetWidth: geometry.NativeWidth,
|
||||
TargetHeight: geometry.NativeHeight,
|
||||
}, fake.lastShotOpts)
|
||||
}
|
||||
|
||||
func TestHandleAction_ScreenshotUsesDeclaredDimensionsFromRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
screenshotRes: agentdesktop.ScreenshotResult{Data: "base64data"},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
sw := 1280
|
||||
sh := 720
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "screenshot",
|
||||
ScaledWidth: &sw,
|
||||
ScaledHeight: &sh,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, agentdesktop.ScreenshotOptions{TargetWidth: 1280, TargetHeight: 720}, fake.lastShotOpts)
|
||||
|
||||
var result agentdesktop.DesktopActionResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1280, result.ScreenshotWidth)
|
||||
assert.Equal(t, 720, result.ScreenshotHeight)
|
||||
}
|
||||
|
||||
func TestHandleAction_LeftClick(t *testing.T) {
|
||||
@@ -315,7 +363,6 @@ func TestHandleAction_HoldKey(t *testing.T) {
|
||||
handler.ServeHTTP(rr, req)
|
||||
}()
|
||||
|
||||
// Wait for the timer to be created, then advance past it.
|
||||
trap.MustWait(req.Context()).MustRelease(req.Context())
|
||||
mClk.Advance(time.Duration(dur) * time.Millisecond).MustWait(req.Context())
|
||||
|
||||
@@ -389,7 +436,6 @@ func TestHandleAction_ScrollDown(t *testing.T) {
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
// dy should be positive 5 for "down".
|
||||
assert.Equal(t, [4]int{500, 400, 0, 5}, fake.lastScroll)
|
||||
}
|
||||
|
||||
@@ -398,13 +444,11 @@ func TestHandleAction_CoordinateScaling(t *testing.T) {
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
// Native display is 1920x1080.
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
// Model is working in a 1280x720 coordinate space.
|
||||
sw := 1280
|
||||
sh := 720
|
||||
body := agentdesktop.DesktopAction{
|
||||
@@ -424,12 +468,43 @@ func TestHandleAction_CoordinateScaling(t *testing.T) {
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
// 640 in 1280-space → 960 in 1920-space (midpoint maps to
|
||||
// midpoint).
|
||||
assert.Equal(t, 960, fake.lastMove[0])
|
||||
assert.Equal(t, 540, fake.lastMove[1])
|
||||
}
|
||||
|
||||
func TestHandleAction_CoordinateScalingClampsToLastPixel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
sw := 1366
|
||||
sh := 768
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "mouse_move",
|
||||
Coordinate: &[2]int{1365, 767},
|
||||
ScaledWidth: &sw,
|
||||
ScaledHeight: &sh,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, 1919, fake.lastMove[0])
|
||||
assert.Equal(t, 1079, fake.lastMove[1])
|
||||
}
|
||||
|
||||
func TestClose_DelegatesToDesktop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -446,15 +521,12 @@ func TestClose_PreventsNewSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// After Close(), Start() will return an error because the
|
||||
// underlying Desktop is closed.
|
||||
fake := &fakeDesktop{}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
|
||||
err := api.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate the closed desktop returning an error on Start().
|
||||
fake.startErr = xerrors.New("desktop is closed")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
@@ -465,3 +537,40 @@ func TestClose_PreventsNewSessions(t *testing.T) {
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
}
|
||||
|
||||
func TestHandleAction_CursorPositionReturnsDeclaredCoordinates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
cursorPos: [2]int{960, 540},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
sw := 1280
|
||||
sh := 720
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "cursor_position",
|
||||
ScaledWidth: &sw,
|
||||
ScaledHeight: &sh,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var resp agentdesktop.DesktopActionResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
// Native (960,540) in 1920x1080 should map to declared space in 1280x720.
|
||||
assert.Equal(t, "x=640,y=360", resp.Output)
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
|
||||
|
||||
//nolint:gosec // portabledesktop is a trusted binary resolved via ensureBinary.
|
||||
cmd := p.execer.CommandContext(sessionCtx, p.binPath, "up", "--json",
|
||||
"--geometry", fmt.Sprintf("%dx%d", workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight))
|
||||
"--geometry", fmt.Sprintf("%dx%d", workspacesdk.DesktopNativeWidth, workspacesdk.DesktopNativeHeight))
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
sessionCancel()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package agentgit
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -99,7 +99,7 @@ func (ps *PathStore) GetPaths(chatID uuid.UUID) []string {
|
||||
for p := range m {
|
||||
out = append(out, p)
|
||||
}
|
||||
sort.Strings(out)
|
||||
slices.Sort(out)
|
||||
return out
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -228,6 +228,6 @@ func resultPaths(results []filefinder.Result) []string {
|
||||
for i, r := range results {
|
||||
paths[i] = r.Path
|
||||
}
|
||||
sort.Strings(paths)
|
||||
slices.Sort(paths)
|
||||
return paths
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -376,8 +376,8 @@ func Test_sshConfigOptions_addOption(t *testing.T) {
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
sort.Strings(tt.Expect)
|
||||
sort.Strings(o.sshOptions)
|
||||
slices.Sort(tt.Expect)
|
||||
slices.Sort(o.sshOptions)
|
||||
require.Equal(t, tt.Expect, o.sshOptions)
|
||||
})
|
||||
}
|
||||
|
||||
+2
-2
@@ -24,7 +24,7 @@ import (
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -2825,7 +2825,7 @@ func ReadExternalAuthProvidersFromEnv(environ []string) ([]codersdk.ExternalAuth
|
||||
// parsing of `GITAUTH` environment variables.
|
||||
func parseExternalAuthProvidersFromEnv(prefix string, environ []string) ([]codersdk.ExternalAuthConfig, error) {
|
||||
// The index numbers must be in-order.
|
||||
sort.Strings(environ)
|
||||
slices.Sort(environ)
|
||||
|
||||
var providers []codersdk.ExternalAuthConfig
|
||||
for _, v := range serpent.ParseEnviron(environ, prefix) {
|
||||
|
||||
+3
-3
@@ -7,7 +7,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"slices"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/xerrors"
|
||||
@@ -31,7 +31,7 @@ func (*RootCmd) templateInit() *serpent.Command {
|
||||
for _, ex := range exampleList {
|
||||
templateIDs = append(templateIDs, ex.ID)
|
||||
}
|
||||
sort.Strings(templateIDs)
|
||||
slices.Sort(templateIDs)
|
||||
cmd := &serpent.Command{
|
||||
Use: "init [directory]",
|
||||
Short: "Get started with a templated template.",
|
||||
@@ -50,7 +50,7 @@ func (*RootCmd) templateInit() *serpent.Command {
|
||||
optsToID[name] = example.ID
|
||||
}
|
||||
opts := maps.Keys(optsToID)
|
||||
sort.Strings(opts)
|
||||
slices.Sort(opts)
|
||||
_, _ = fmt.Fprintln(
|
||||
inv.Stdout,
|
||||
pretty.Sprint(
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -47,7 +47,7 @@ func TestTemplateList(t *testing.T) {
|
||||
|
||||
// expect that templates are listed alphabetically
|
||||
templatesList := []string{firstTemplate.Name, secondTemplate.Name}
|
||||
sort.Strings(templatesList)
|
||||
slices.Sort(templatesList)
|
||||
|
||||
require.NoError(t, <-errC)
|
||||
|
||||
|
||||
+2
-3
@@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -194,7 +193,7 @@ func joinScopes(scopes []codersdk.APIKeyScope) string {
|
||||
return ""
|
||||
}
|
||||
vals := slice.ToStrings(scopes)
|
||||
sort.Strings(vals)
|
||||
slices.Sort(vals)
|
||||
return strings.Join(vals, ", ")
|
||||
}
|
||||
|
||||
@@ -206,7 +205,7 @@ func joinAllowList(entries []codersdk.APIAllowListTarget) string {
|
||||
for i, entry := range entries {
|
||||
vals[i] = entry.String()
|
||||
}
|
||||
sort.Strings(vals)
|
||||
slices.Sort(vals)
|
||||
return strings.Join(vals, ", ")
|
||||
}
|
||||
|
||||
|
||||
+1
-1
@@ -773,7 +773,7 @@ func (api *API) taskSend(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if statusResp.Status != agentapisdk.StatusStable {
|
||||
return httperror.NewResponseError(http.StatusBadGateway, codersdk.Response{
|
||||
return httperror.NewResponseError(http.StatusConflict, codersdk.Response{
|
||||
Message: "Task app is not ready to accept input.",
|
||||
Detail: fmt.Sprintf("Status: %s", statusResp.Status),
|
||||
})
|
||||
|
||||
@@ -789,6 +789,11 @@ func TestTasks(t *testing.T) {
|
||||
})
|
||||
require.Error(t, err, "wanted error due to bad status")
|
||||
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
require.Contains(t, sdkErr.Message, "not ready to accept input")
|
||||
|
||||
statusResponse = agentapisdk.StatusStable
|
||||
|
||||
//nolint:tparallel // Not intended to run in parallel.
|
||||
|
||||
Generated
+123
-23
@@ -163,6 +163,57 @@ const docTemplate = `{
|
||||
]
|
||||
}
|
||||
},
|
||||
"/aibridge/sessions": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"AI Bridge"
|
||||
],
|
||||
"summary": "List AI Bridge sessions",
|
||||
"operationId": "list-ai-bridge-sessions",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Search query in the format ` + "`" + `key:value` + "`" + `. Available keys are: initiator, provider, model, client, session_id, started_after, started_before.",
|
||||
"name": "q",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Page limit",
|
||||
"name": "limit",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Cursor pagination after session ID (cannot be used with offset)",
|
||||
"name": "after_session_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Offset pagination (cannot be used with after_session_id)",
|
||||
"name": "offset",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeListSessionsResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/appearance": {
|
||||
"get": {
|
||||
"produces": [
|
||||
@@ -7971,29 +8022,6 @@ const docTemplate = `{
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/me/session/token-to-cookie": {
|
||||
"post": {
|
||||
"description": "Converts the current session token into a Set-Cookie response.\nThis is used by embedded iframes (e.g. VS Code chat) that\nreceive a session token out-of-band via postMessage but need\ncookie-based auth for WebSocket connections.",
|
||||
"tags": [
|
||||
"Authorization"
|
||||
],
|
||||
"summary": "Set session token cookie",
|
||||
"operationId": "set-session-token-cookie",
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/users/oauth2/github/callback": {
|
||||
"get": {
|
||||
"tags": [
|
||||
@@ -12801,6 +12829,20 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeListSessionsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"sessions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSession"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeOpenAIConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -12853,6 +12895,64 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSession": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"client": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"initiator": {
|
||||
"$ref": "#/definitions/codersdk.MinimalUser"
|
||||
},
|
||||
"last_prompt": {
|
||||
"type": "string"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {}
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"started_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"threads": {
|
||||
"type": "integer"
|
||||
},
|
||||
"token_usage_summary": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSessionTokenUsageSummary"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSessionTokenUsageSummary": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"output_tokens": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeTokenUsage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
Generated
+119
-21
@@ -136,6 +136,53 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"/aibridge/sessions": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["AI Bridge"],
|
||||
"summary": "List AI Bridge sessions",
|
||||
"operationId": "list-ai-bridge-sessions",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Search query in the format `key:value`. Available keys are: initiator, provider, model, client, session_id, started_after, started_before.",
|
||||
"name": "q",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Page limit",
|
||||
"name": "limit",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Cursor pagination after session ID (cannot be used with offset)",
|
||||
"name": "after_session_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Offset pagination (cannot be used with after_session_id)",
|
||||
"name": "offset",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeListSessionsResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/appearance": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
@@ -7064,27 +7111,6 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/me/session/token-to-cookie": {
|
||||
"post": {
|
||||
"description": "Converts the current session token into a Set-Cookie response.\nThis is used by embedded iframes (e.g. VS Code chat) that\nreceive a session token out-of-band via postMessage but need\ncookie-based auth for WebSocket connections.",
|
||||
"tags": ["Authorization"],
|
||||
"summary": "Set session token cookie",
|
||||
"operationId": "set-session-token-cookie",
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/users/oauth2/github/callback": {
|
||||
"get": {
|
||||
"tags": ["Users"],
|
||||
@@ -11389,6 +11415,20 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeListSessionsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"sessions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSession"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeOpenAIConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -11441,6 +11481,64 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSession": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"client": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"initiator": {
|
||||
"$ref": "#/definitions/codersdk.MinimalUser"
|
||||
},
|
||||
"last_prompt": {
|
||||
"type": "string"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {}
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"started_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"threads": {
|
||||
"type": "integer"
|
||||
},
|
||||
"token_usage_summary": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSessionTokenUsageSummary"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSessionTokenUsageSummary": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"output_tokens": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeTokenUsage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@@ -1,186 +0,0 @@
|
||||
package chattool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestComputerUseTool_Info(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, nil, quartz.NewReal())
|
||||
info := tool.Info()
|
||||
assert.Equal(t, "computer", info.Name)
|
||||
assert.NotEmpty(t, info.Description)
|
||||
}
|
||||
|
||||
func TestComputerUseProviderTool(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
def := chattool.ComputerUseProviderTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight)
|
||||
pdt, ok := def.(fantasy.ProviderDefinedTool)
|
||||
require.True(t, ok, "ComputerUseProviderTool should return a ProviderDefinedTool")
|
||||
assert.Contains(t, pdt.ID, "computer")
|
||||
assert.Equal(t, "computer", pdt.Name)
|
||||
// Verify display dimensions are passed through.
|
||||
assert.Equal(t, int64(workspacesdk.DesktopDisplayWidth), pdt.Args["display_width_px"])
|
||||
assert.Equal(t, int64(workspacesdk.DesktopDisplayHeight), pdt.Args["display_height_px"])
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).Return(workspacesdk.DesktopActionResponse{
|
||||
Output: "screenshot",
|
||||
ScreenshotData: "base64png",
|
||||
ScreenshotWidth: 1024,
|
||||
ScreenshotHeight: 768,
|
||||
}, nil)
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return mockConn, nil
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-1",
|
||||
Name: "computer",
|
||||
Input: `{"action":"screenshot"}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image", resp.Type)
|
||||
assert.Equal(t, "image/png", resp.MediaType)
|
||||
assert.Equal(t, []byte("base64png"), resp.Data)
|
||||
assert.False(t, resp.IsError)
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_LeftClick(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
// Expect the action call first.
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).Return(workspacesdk.DesktopActionResponse{
|
||||
Output: "left_click performed",
|
||||
}, nil)
|
||||
|
||||
// Then expect a screenshot (auto-screenshot after action).
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).Return(workspacesdk.DesktopActionResponse{
|
||||
Output: "screenshot",
|
||||
ScreenshotData: "after-click",
|
||||
ScreenshotWidth: 1024,
|
||||
ScreenshotHeight: 768,
|
||||
}, nil)
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return mockConn, nil
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-2",
|
||||
Name: "computer",
|
||||
Input: `{"action":"left_click","coordinate":[100,200]}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image", resp.Type)
|
||||
assert.Equal(t, []byte("after-click"), resp.Data)
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_Wait(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
// Expect a screenshot after the wait completes.
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).Return(workspacesdk.DesktopActionResponse{
|
||||
Output: "screenshot",
|
||||
ScreenshotData: "after-wait",
|
||||
ScreenshotWidth: 1024,
|
||||
ScreenshotHeight: 768,
|
||||
}, nil)
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return mockConn, nil
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-3",
|
||||
Name: "computer",
|
||||
Input: `{"action":"wait","duration":10}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image", resp.Type)
|
||||
assert.Equal(t, "image/png", resp.MediaType)
|
||||
assert.Equal(t, []byte("after-wait"), resp.Data)
|
||||
assert.False(t, resp.IsError)
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_ConnError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return nil, xerrors.New("workspace not available")
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-4",
|
||||
Name: "computer",
|
||||
Input: `{"action":"screenshot"}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsError)
|
||||
assert.Contains(t, resp.Content, "workspace not available")
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_InvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return nil, xerrors.New("should not be called")
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-5",
|
||||
Name: "computer",
|
||||
Input: `{invalid json`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsError)
|
||||
assert.Contains(t, resp.Content, "invalid computer use input")
|
||||
}
|
||||
+47
-43
@@ -51,7 +51,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/awsidentity"
|
||||
"github.com/coder/coder/v2/coderd/boundaryusage"
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/connectionlog"
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -63,7 +62,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/files"
|
||||
"github.com/coder/coder/v2/coderd/gitsshkey"
|
||||
"github.com/coder/coder/v2/coderd/gitsync"
|
||||
"github.com/coder/coder/v2/coderd/healthcheck"
|
||||
"github.com/coder/coder/v2/coderd/healthcheck/derphealth"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
@@ -94,6 +92,8 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/coderd/wsbuilder"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd"
|
||||
"github.com/coder/coder/v2/coderd/x/gitsync"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
"github.com/coder/coder/v2/codersdk/healthsdk"
|
||||
@@ -767,43 +767,45 @@ func New(options *Options) *API {
|
||||
}
|
||||
api.agentProvider = stn
|
||||
|
||||
maxChatsPerAcquire := options.DeploymentValues.AI.Chat.AcquireBatchSize.Value()
|
||||
if maxChatsPerAcquire > math.MaxInt32 {
|
||||
maxChatsPerAcquire = math.MaxInt32
|
||||
}
|
||||
if maxChatsPerAcquire < math.MinInt32 {
|
||||
maxChatsPerAcquire = math.MinInt32
|
||||
}
|
||||
{ // Experimental: agents — chat daemon and git sync worker initialization.
|
||||
maxChatsPerAcquire := options.DeploymentValues.AI.Chat.AcquireBatchSize.Value()
|
||||
if maxChatsPerAcquire > math.MaxInt32 {
|
||||
maxChatsPerAcquire = math.MaxInt32
|
||||
}
|
||||
if maxChatsPerAcquire < math.MinInt32 {
|
||||
maxChatsPerAcquire = math.MinInt32
|
||||
}
|
||||
|
||||
api.chatDaemon = chatd.New(chatd.Config{
|
||||
Logger: options.Logger.Named("chatd"),
|
||||
Database: options.Database,
|
||||
ReplicaID: api.ID,
|
||||
SubscribeFn: options.ChatSubscribeFn,
|
||||
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
|
||||
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
StartWorkspace: api.chatStartWorkspace,
|
||||
Pubsub: options.Pubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
UsageTracker: options.WorkspaceUsageTracker,
|
||||
})
|
||||
gitSyncLogger := options.Logger.Named("gitsync")
|
||||
refresher := gitsync.NewRefresher(
|
||||
api.resolveGitProvider,
|
||||
api.resolveChatGitAccessToken,
|
||||
gitSyncLogger.Named("refresher"),
|
||||
quartz.NewReal(),
|
||||
)
|
||||
api.gitSyncWorker = gitsync.NewWorker(options.Database,
|
||||
refresher,
|
||||
api.chatDaemon.PublishDiffStatusChange,
|
||||
quartz.NewReal(),
|
||||
gitSyncLogger,
|
||||
)
|
||||
// nolint:gocritic // chat diff worker needs to be able to CRUD chats.
|
||||
go api.gitSyncWorker.Start(dbauthz.AsChatd(api.ctx))
|
||||
api.chatDaemon = chatd.New(chatd.Config{
|
||||
Logger: options.Logger.Named("chatd"),
|
||||
Database: options.Database,
|
||||
ReplicaID: api.ID,
|
||||
SubscribeFn: options.ChatSubscribeFn,
|
||||
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
|
||||
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
StartWorkspace: api.chatStartWorkspace,
|
||||
Pubsub: options.Pubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
UsageTracker: options.WorkspaceUsageTracker,
|
||||
})
|
||||
gitSyncLogger := options.Logger.Named("gitsync")
|
||||
refresher := gitsync.NewRefresher(
|
||||
api.resolveGitProvider,
|
||||
api.resolveChatGitAccessToken,
|
||||
gitSyncLogger.Named("refresher"),
|
||||
quartz.NewReal(),
|
||||
)
|
||||
api.gitSyncWorker = gitsync.NewWorker(options.Database,
|
||||
refresher,
|
||||
api.chatDaemon.PublishDiffStatusChange,
|
||||
quartz.NewReal(),
|
||||
gitSyncLogger,
|
||||
)
|
||||
// nolint:gocritic // chat diff worker needs to be able to CRUD chats.
|
||||
go api.gitSyncWorker.Start(dbauthz.AsChatd(api.ctx))
|
||||
}
|
||||
if options.DeploymentValues.Prometheus.Enable {
|
||||
options.PrometheusRegistry.MustRegister(stn)
|
||||
api.lifecycleMetrics = agentapi.NewLifecycleMetrics(options.PrometheusRegistry)
|
||||
@@ -1146,6 +1148,7 @@ func New(options *Options) *API {
|
||||
})
|
||||
})
|
||||
})
|
||||
// Experimental(agents): chat API routes gated by ExperimentAgents.
|
||||
r.Route("/chats", func(r chi.Router) {
|
||||
r.Use(
|
||||
apiKeyMiddleware,
|
||||
@@ -1177,6 +1180,9 @@ func New(options *Options) *API {
|
||||
r.Put("/desktop-enabled", api.putChatDesktopEnabled)
|
||||
r.Get("/user-prompt", api.getUserChatCustomPrompt)
|
||||
r.Put("/user-prompt", api.putUserChatCustomPrompt)
|
||||
r.Get("/user-compaction-thresholds", api.getUserChatCompactionThresholds)
|
||||
r.Put("/user-compaction-thresholds/{modelConfig}", api.putUserChatCompactionThreshold)
|
||||
r.Delete("/user-compaction-thresholds/{modelConfig}", api.deleteUserChatCompactionThreshold)
|
||||
r.Get("/workspace-ttl", api.getChatWorkspaceTTL)
|
||||
r.Put("/workspace-ttl", api.putChatWorkspaceTTL)
|
||||
})
|
||||
@@ -1517,7 +1523,6 @@ func New(options *Options) *API {
|
||||
r.Post("/", api.postUser)
|
||||
r.Get("/", api.users)
|
||||
r.Post("/logout", api.postLogout)
|
||||
r.Post("/me/session/token-to-cookie", api.postSessionTokenCookie)
|
||||
r.Get("/oidc-claims", api.userOIDCClaims)
|
||||
// These routes query information about site wide roles.
|
||||
r.Route("/roles", func(r chi.Router) {
|
||||
@@ -2087,13 +2092,12 @@ type API struct {
|
||||
// dbRolluper rolls up template usage stats from raw agent and app
|
||||
// stats. This is used to provide insights in the WebUI.
|
||||
dbRolluper *dbrollup.Rolluper
|
||||
// chatDaemon handles background processing of pending chats.
|
||||
// Experimental(agents): chatDaemon handles background processing of pending chats.
|
||||
chatDaemon *chatd.Server
|
||||
// Experimental(agents): gitSyncWorker refreshes stale chat diff statuses in the background.
|
||||
gitSyncWorker *gitsync.Worker
|
||||
// AISeatTracker records AI seat usage.
|
||||
AISeatTracker aiseats.SeatTracker
|
||||
// gitSyncWorker refreshes stale chat diff statuses in the
|
||||
// background.
|
||||
gitSyncWorker *gitsync.Worker
|
||||
|
||||
// ProfileCollector abstracts the runtime/pprof and runtime/trace
|
||||
// calls used by the /debug/profile endpoint. Tests override this
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
@@ -28,6 +27,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/provisionersdk/proto"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
@@ -223,6 +223,7 @@ func UserFromGroupMember(member database.GroupMember) database.User {
|
||||
QuietHoursSchedule: member.UserQuietHoursSchedule,
|
||||
Name: member.UserName,
|
||||
GithubComUserID: member.UserGithubComUserID,
|
||||
IsServiceAccount: member.UserIsServiceAccount,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,6 +252,7 @@ func UserFromGroupMemberRow(member database.GetGroupMembersByGroupIDPaginatedRow
|
||||
QuietHoursSchedule: member.UserQuietHoursSchedule,
|
||||
Name: member.UserName,
|
||||
GithubComUserID: member.UserGithubComUserID,
|
||||
IsServiceAccount: member.UserIsServiceAccount,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1019,6 +1021,44 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
|
||||
return intc
|
||||
}
|
||||
|
||||
func AIBridgeSession(row database.ListAIBridgeSessionsRow) codersdk.AIBridgeSession {
|
||||
session := codersdk.AIBridgeSession{
|
||||
ID: row.SessionID,
|
||||
Initiator: MinimalUserFromVisibleUser(database.VisibleUser{
|
||||
ID: row.UserID,
|
||||
Username: row.UserUsername,
|
||||
Name: row.UserName,
|
||||
AvatarURL: row.UserAvatarUrl,
|
||||
}),
|
||||
Providers: row.Providers,
|
||||
Models: row.Models,
|
||||
Metadata: jsonOrEmptyMap(pqtype.NullRawMessage{RawMessage: row.Metadata, Valid: len(row.Metadata) > 0}),
|
||||
StartedAt: row.StartedAt,
|
||||
Threads: row.Threads,
|
||||
TokenUsageSummary: codersdk.AIBridgeSessionTokenUsageSummary{
|
||||
InputTokens: row.InputTokens,
|
||||
OutputTokens: row.OutputTokens,
|
||||
},
|
||||
}
|
||||
// Ensure non-nil slices for JSON serialization.
|
||||
if session.Providers == nil {
|
||||
session.Providers = []string{}
|
||||
}
|
||||
if session.Models == nil {
|
||||
session.Models = []string{}
|
||||
}
|
||||
if row.Client != "" {
|
||||
session.Client = &row.Client
|
||||
}
|
||||
if !row.EndedAt.IsZero() {
|
||||
session.EndedAt = &row.EndedAt
|
||||
}
|
||||
if row.LastPrompt != "" {
|
||||
session.LastPrompt = &row.LastPrompt
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
func AIBridgeTokenUsage(usage database.AIBridgeTokenUsage) codersdk.AIBridgeTokenUsage {
|
||||
return codersdk.AIBridgeTokenUsage{
|
||||
ID: usage.ID,
|
||||
|
||||
@@ -1709,6 +1709,14 @@ func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.C
|
||||
return q.db.CountAuthorizedAIBridgeInterceptions(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
}
|
||||
return q.db.CountAuthorizedAIBridgeSessions(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
|
||||
// Shortcut if the user is an owner. The SQL filter is noticeable,
|
||||
// and this is an easy win for owners. Which is the common case.
|
||||
@@ -2118,6 +2126,17 @@ func (q *querier) DeleteTask(ctx context.Context, arg database.DeleteTaskParams)
|
||||
return q.db.DeleteTask(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserChatCompactionThreshold(ctx context.Context, arg database.DeleteUserChatCompactionThresholdParams) error {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteUserChatCompactionThreshold(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.GetUserSecret(ctx, id)
|
||||
@@ -3921,6 +3940,17 @@ func (q *querier) GetUserByID(ctx context.Context, id uuid.UUID) (database.User,
|
||||
return fetch(q.log, q.auth, q.db.GetUserByID)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatCompactionThreshold(ctx context.Context, arg database.GetUserChatCompactionThresholdParams) (string, error) {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return q.db.GetUserChatCompactionThreshold(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error) {
|
||||
u, err := q.db.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -5295,10 +5325,16 @@ func (q *querier) ListAIBridgeModels(ctx context.Context, arg database.ListAIBri
|
||||
return q.db.ListAuthorizedAIBridgeModels(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
}
|
||||
return q.db.ListAuthorizedAIBridgeSessions(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
// This function is a system function until we implement a join for aibridge interceptions.
|
||||
// Matches the behavior of the workspaces listing endpoint.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -5306,9 +5342,7 @@ func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context,
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeToolUsage, error) {
|
||||
// This function is a system function until we implement a join for aibridge interceptions.
|
||||
// Matches the behavior of the workspaces listing endpoint.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -5316,9 +5350,7 @@ func (q *querier) ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, i
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeUserPrompt, error) {
|
||||
// This function is a system function until we implement a join for aibridge interceptions.
|
||||
// Matches the behavior of the workspaces listing endpoint.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -5352,6 +5384,17 @@ func (q *querier) ListTasks(ctx context.Context, arg database.ListTasksParams) (
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListTasks)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]database.UserConfig, error) {
|
||||
u, err := q.db.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ListUserChatCompactionThresholds(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(userID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
|
||||
@@ -6212,6 +6255,17 @@ func (q *querier) UpdateUsageEventsPostPublish(ctx context.Context, arg database
|
||||
return q.db.UpdateUsageEventsPostPublish(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserChatCompactionThreshold(ctx context.Context, arg database.UpdateUserChatCompactionThresholdParams) (database.UserConfig, error) {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return database.UserConfig{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return database.UserConfig{}, err
|
||||
}
|
||||
return q.db.UpdateUserChatCompactionThreshold(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserChatCustomPrompt(ctx context.Context, arg database.UpdateUserChatCustomPromptParams) (database.UserConfig, error) {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
@@ -7084,6 +7138,14 @@ func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database
|
||||
return q.ListAIBridgeModels(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
return q.db.ListAuthorizedAIBridgeSessions(ctx, arg, prepared)
|
||||
}
|
||||
|
||||
func (q *querier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
return q.db.CountAuthorizedAIBridgeSessions(ctx, arg, prepared)
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
return q.GetChats(ctx, arg)
|
||||
}
|
||||
|
||||
@@ -2278,6 +2278,35 @@ func (s *MethodTestSuite) TestUser() {
|
||||
dbm.EXPECT().UpdateUserChatCustomPrompt(gomock.Any(), arg).Return(uc, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(uc)
|
||||
}))
|
||||
s.Run("ListUserChatCompactionThresholds", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
uc := database.UserConfig{UserID: u.ID, Key: codersdk.ChatCompactionThresholdKeyPrefix + "00000000-0000-0000-0000-000000000001", Value: "75"}
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().ListUserChatCompactionThresholds(gomock.Any(), u.ID).Return([]database.UserConfig{uc}, nil).AnyTimes()
|
||||
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns([]database.UserConfig{uc})
|
||||
}))
|
||||
s.Run("GetUserChatCompactionThreshold", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.GetUserChatCompactionThresholdParams{UserID: u.ID, Key: codersdk.ChatCompactionThresholdKeyPrefix + "00000000-0000-0000-0000-000000000001"}
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().GetUserChatCompactionThreshold(gomock.Any(), arg).Return("75", nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionReadPersonal).Returns("75")
|
||||
}))
|
||||
s.Run("UpdateUserChatCompactionThreshold", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
uc := database.UserConfig{UserID: u.ID, Key: codersdk.ChatCompactionThresholdKeyPrefix + "00000000-0000-0000-0000-000000000001", Value: "75"}
|
||||
arg := database.UpdateUserChatCompactionThresholdParams{UserID: u.ID, Key: uc.Key, ThresholdPercent: 75}
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateUserChatCompactionThreshold(gomock.Any(), arg).Return(uc, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(uc)
|
||||
}))
|
||||
s.Run("DeleteUserChatCompactionThreshold", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.DeleteUserChatCompactionThresholdParams{UserID: u.ID, Key: codersdk.ChatCompactionThresholdKeyPrefix + "00000000-0000-0000-0000-000000000001"}
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteUserChatCompactionThreshold(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal)
|
||||
}))
|
||||
s.Run("UpdateUserTaskNotificationAlertDismissed", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
userConfig := database.UserConfig{UserID: user.ID, Key: "task_notification_alert_dismissed", Value: "false"}
|
||||
@@ -5485,22 +5514,50 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.ListAIBridgeSessionsParams{}
|
||||
db.EXPECT().ListAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeSessionsRow{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAuthorizedAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.ListAIBridgeSessionsParams{}
|
||||
db.EXPECT().ListAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeSessionsRow{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("CountAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.CountAIBridgeSessionsParams{}
|
||||
db.EXPECT().CountAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return(int64(0), nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("CountAuthorizedAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.CountAIBridgeSessionsParams{}
|
||||
db.EXPECT().CountAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return(int64(0), nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeTokenUsagesByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ids := []uuid.UUID{{1}}
|
||||
db.EXPECT().ListAIBridgeTokenUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeTokenUsage{}, nil).AnyTimes()
|
||||
check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.AIBridgeTokenUsage{})
|
||||
check.Args(ids).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead).Returns([]database.AIBridgeTokenUsage{})
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeUserPromptsByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ids := []uuid.UUID{{1}}
|
||||
db.EXPECT().ListAIBridgeUserPromptsByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeUserPrompt{}, nil).AnyTimes()
|
||||
check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.AIBridgeUserPrompt{})
|
||||
check.Args(ids).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead).Returns([]database.AIBridgeUserPrompt{})
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeToolUsagesByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ids := []uuid.UUID{{1}}
|
||||
db.EXPECT().ListAIBridgeToolUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeToolUsage{}, nil).AnyTimes()
|
||||
check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.AIBridgeToolUsage{})
|
||||
check.Args(ids).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead).Returns([]database.AIBridgeToolUsage{})
|
||||
}))
|
||||
|
||||
s.Run("UpdateAIBridgeInterceptionEnded", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -97,7 +97,7 @@ func (s *MethodTestSuite) TearDownSuite() {
|
||||
notCalled = append(notCalled, m)
|
||||
}
|
||||
}
|
||||
sort.Strings(notCalled)
|
||||
slices.Sort(notCalled)
|
||||
for _, m := range notCalled {
|
||||
t.Errorf("Method never called: %q", m)
|
||||
}
|
||||
|
||||
@@ -280,6 +280,14 @@ func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAIBridgeSessions(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("CountAIBridgeSessions").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountAIBridgeSessions").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAuditLogs(ctx, arg)
|
||||
@@ -680,6 +688,14 @@ func (m queryMetricsStore) DeleteTask(ctx context.Context, arg database.DeleteTa
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserChatCompactionThreshold(ctx context.Context, arg database.DeleteUserChatCompactionThresholdParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserChatCompactionThreshold(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserChatCompactionThreshold").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserChatCompactionThreshold").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserSecret(ctx, id)
|
||||
@@ -2448,6 +2464,14 @@ func (m queryMetricsStore) GetUserByID(ctx context.Context, id uuid.UUID) (datab
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatCompactionThreshold(ctx context.Context, arg database.GetUserChatCompactionThresholdParams) (string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatCompactionThreshold(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetUserChatCompactionThreshold").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatCompactionThreshold").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatCustomPrompt(ctx, userID)
|
||||
@@ -3704,6 +3728,14 @@ func (m queryMetricsStore) ListAIBridgeModels(ctx context.Context, arg database.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeSessions(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("ListAIBridgeSessions").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAIBridgeSessions").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIds)
|
||||
@@ -3768,6 +3800,14 @@ func (m queryMetricsStore) ListTasks(ctx context.Context, arg database.ListTasks
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]database.UserConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListUserChatCompactionThresholds(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("ListUserChatCompactionThresholds").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListUserChatCompactionThresholds").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListUserSecrets(ctx, userID)
|
||||
@@ -4360,6 +4400,14 @@ func (m queryMetricsStore) UpdateUsageEventsPostPublish(ctx context.Context, arg
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserChatCompactionThreshold(ctx context.Context, arg database.UpdateUserChatCompactionThresholdParams) (database.UserConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateUserChatCompactionThreshold(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserChatCompactionThreshold").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserChatCompactionThreshold").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserChatCustomPrompt(ctx context.Context, arg database.UpdateUserChatCustomPromptParams) (database.UserConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateUserChatCustomPrompt(ctx, arg)
|
||||
@@ -5104,6 +5152,22 @@ func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAuthorizedAIBridgeSessions(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("ListAuthorizedAIBridgeSessions").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeSessions").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAuthorizedAIBridgeSessions(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("CountAuthorizedAIBridgeSessions").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountAuthorizedAIBridgeSessions").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAuthorizedChats(ctx, arg, prepared)
|
||||
|
||||
@@ -363,6 +363,21 @@ func (mr *MockStoreMockRecorder) CountAIBridgeInterceptions(ctx, arg any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).CountAIBridgeInterceptions), ctx, arg)
|
||||
}
|
||||
|
||||
// CountAIBridgeSessions mocks base method.
|
||||
func (m *MockStore) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountAIBridgeSessions", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountAIBridgeSessions indicates an expected call of CountAIBridgeSessions.
|
||||
func (mr *MockStoreMockRecorder) CountAIBridgeSessions(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).CountAIBridgeSessions), ctx, arg)
|
||||
}
|
||||
|
||||
// CountAuditLogs mocks base method.
|
||||
func (m *MockStore) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -393,6 +408,21 @@ func (mr *MockStoreMockRecorder) CountAuthorizedAIBridgeInterceptions(ctx, arg,
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).CountAuthorizedAIBridgeInterceptions), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// CountAuthorizedAIBridgeSessions mocks base method.
|
||||
func (m *MockStore) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountAuthorizedAIBridgeSessions", ctx, arg, prepared)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountAuthorizedAIBridgeSessions indicates an expected call of CountAuthorizedAIBridgeSessions.
|
||||
func (mr *MockStoreMockRecorder) CountAuthorizedAIBridgeSessions(ctx, arg, prepared any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).CountAuthorizedAIBridgeSessions), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// CountAuthorizedAuditLogs mocks base method.
|
||||
func (m *MockStore) CountAuthorizedAuditLogs(ctx context.Context, arg database.CountAuditLogsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1126,6 +1156,20 @@ func (mr *MockStoreMockRecorder) DeleteTask(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTask", reflect.TypeOf((*MockStore)(nil).DeleteTask), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserChatCompactionThreshold mocks base method.
|
||||
func (m *MockStore) DeleteUserChatCompactionThreshold(ctx context.Context, arg database.DeleteUserChatCompactionThresholdParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteUserChatCompactionThreshold", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteUserChatCompactionThreshold indicates an expected call of DeleteUserChatCompactionThreshold.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserChatCompactionThreshold(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).DeleteUserChatCompactionThreshold), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserSecret mocks base method.
|
||||
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4564,6 +4608,21 @@ func (mr *MockStoreMockRecorder) GetUserByID(ctx, id any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserByID", reflect.TypeOf((*MockStore)(nil).GetUserByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetUserChatCompactionThreshold mocks base method.
|
||||
func (m *MockStore) GetUserChatCompactionThreshold(ctx context.Context, arg database.GetUserChatCompactionThresholdParams) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserChatCompactionThreshold", ctx, arg)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserChatCompactionThreshold indicates an expected call of GetUserChatCompactionThreshold.
|
||||
func (mr *MockStoreMockRecorder) GetUserChatCompactionThreshold(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).GetUserChatCompactionThreshold), ctx, arg)
|
||||
}
|
||||
|
||||
// GetUserChatCustomPrompt mocks base method.
|
||||
func (m *MockStore) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6918,6 +6977,21 @@ func (mr *MockStoreMockRecorder) ListAIBridgeModels(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAIBridgeModels), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeSessions mocks base method.
|
||||
func (m *MockStore) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAIBridgeSessions", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.ListAIBridgeSessionsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAIBridgeSessions indicates an expected call of ListAIBridgeSessions.
|
||||
func (mr *MockStoreMockRecorder) ListAIBridgeSessions(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).ListAIBridgeSessions), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeTokenUsagesByInterceptionIDs mocks base method.
|
||||
func (m *MockStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6993,6 +7067,21 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// ListAuthorizedAIBridgeSessions mocks base method.
|
||||
func (m *MockStore) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAuthorizedAIBridgeSessions", ctx, arg, prepared)
|
||||
ret0, _ := ret[0].([]database.ListAIBridgeSessionsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAuthorizedAIBridgeSessions indicates an expected call of ListAuthorizedAIBridgeSessions.
|
||||
func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeSessions(ctx, arg, prepared any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeSessions), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// ListChatUsageLimitGroupOverrides mocks base method.
|
||||
func (m *MockStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7068,6 +7157,21 @@ func (mr *MockStoreMockRecorder) ListTasks(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTasks", reflect.TypeOf((*MockStore)(nil).ListTasks), ctx, arg)
|
||||
}
|
||||
|
||||
// ListUserChatCompactionThresholds mocks base method.
|
||||
func (m *MockStore) ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]database.UserConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListUserChatCompactionThresholds", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.UserConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListUserChatCompactionThresholds indicates an expected call of ListUserChatCompactionThresholds.
|
||||
func (mr *MockStoreMockRecorder) ListUserChatCompactionThresholds(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserChatCompactionThresholds", reflect.TypeOf((*MockStore)(nil).ListUserChatCompactionThresholds), ctx, userID)
|
||||
}
|
||||
|
||||
// ListUserSecrets mocks base method.
|
||||
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8173,6 +8277,21 @@ func (mr *MockStoreMockRecorder) UpdateUsageEventsPostPublish(ctx, arg any) *gom
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUsageEventsPostPublish", reflect.TypeOf((*MockStore)(nil).UpdateUsageEventsPostPublish), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserChatCompactionThreshold mocks base method.
|
||||
func (m *MockStore) UpdateUserChatCompactionThreshold(ctx context.Context, arg database.UpdateUserChatCompactionThresholdParams) (database.UserConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateUserChatCompactionThreshold", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UserConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateUserChatCompactionThreshold indicates an expected call of UpdateUserChatCompactionThreshold.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserChatCompactionThreshold(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).UpdateUserChatCompactionThreshold), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserChatCustomPrompt mocks base method.
|
||||
func (m *MockStore) UpdateUserChatCustomPrompt(ctx context.Context, arg database.UpdateUserChatCustomPromptParams) (database.UserConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+11
-3
@@ -1099,7 +1099,8 @@ CREATE TABLE aibridge_interceptions (
|
||||
client character varying(64) DEFAULT 'Unknown'::character varying,
|
||||
thread_parent_id uuid,
|
||||
thread_root_id uuid,
|
||||
client_session_id character varying(256)
|
||||
client_session_id character varying(256),
|
||||
session_id text GENERATED ALWAYS AS (COALESCE(client_session_id, ((thread_root_id)::text)::character varying, ((id)::text)::character varying)) STORED NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge';
|
||||
@@ -1112,6 +1113,8 @@ COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interceptio
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID supplied by the client (optional and not universally supported).';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.session_id IS 'Groups related interceptions into a logical session. Determined by a priority chain: (1) client_session_id — an explicit session identifier supplied by the calling client (e.g. Claude Code); (2) thread_root_id — the root of an agentic thread detected by Bridge through tool-call correlation, used when the client does not supply its own session ID; (3) id — the interception''s own ID, used as a last resort so every interception belongs to exactly one session even if it is standalone. This is a generated column stored on disk so it can be indexed and joined without recomputing the COALESCE on every query.';
|
||||
|
||||
CREATE TABLE aibridge_model_thoughts (
|
||||
interception_id uuid NOT NULL,
|
||||
content text NOT NULL,
|
||||
@@ -1619,6 +1622,7 @@ CREATE VIEW group_members_expanded AS
|
||||
users.name AS user_name,
|
||||
users.github_com_user_id AS user_github_com_user_id,
|
||||
users.is_system AS user_is_system,
|
||||
users.is_service_account AS user_is_service_account,
|
||||
groups.organization_id,
|
||||
groups.name AS group_name,
|
||||
all_members.group_id
|
||||
@@ -1627,8 +1631,6 @@ CREATE VIEW group_members_expanded AS
|
||||
JOIN groups ON ((groups.id = all_members.group_id)))
|
||||
WHERE (users.deleted = false);
|
||||
|
||||
COMMENT ON VIEW group_members_expanded IS 'Joins group members with user information, organization ID, group name. Includes both regular group members and organization members (as part of the "Everyone" group).';
|
||||
|
||||
CREATE TABLE inbox_notifications (
|
||||
id uuid NOT NULL,
|
||||
user_id uuid NOT NULL,
|
||||
@@ -3655,6 +3657,10 @@ CREATE INDEX idx_aibridge_interceptions_model ON aibridge_interceptions USING bt
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_provider ON aibridge_interceptions USING btree (provider);
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_session_id ON aibridge_interceptions USING btree (session_id) WHERE (ended_at IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_sessions_filter ON aibridge_interceptions USING btree (initiator_id, started_at DESC, id DESC) WHERE (ended_at IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_started_id_desc ON aibridge_interceptions USING btree (started_at DESC, id DESC);
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_thread_parent_id ON aibridge_interceptions USING btree (thread_parent_id);
|
||||
@@ -3673,6 +3679,8 @@ CREATE INDEX idx_aibridge_tool_usages_provider_tool_call_id ON aibridge_tool_usa
|
||||
|
||||
CREATE INDEX idx_aibridge_tool_usagesprovider_response_id ON aibridge_tool_usages USING btree (provider_response_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_user_prompts_interception_created ON aibridge_user_prompts USING btree (interception_id, created_at DESC, id DESC);
|
||||
|
||||
CREATE INDEX idx_aibridge_user_prompts_interception_id ON aibridge_user_prompts USING btree (interception_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_user_prompts_provider_response_id ON aibridge_user_prompts USING btree (provider_response_id);
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
DROP VIEW group_members_expanded;
|
||||
|
||||
CREATE VIEW group_members_expanded AS
|
||||
WITH all_members AS (
|
||||
SELECT group_members.user_id,
|
||||
group_members.group_id
|
||||
FROM group_members
|
||||
UNION
|
||||
SELECT organization_members.user_id,
|
||||
organization_members.organization_id AS group_id
|
||||
FROM organization_members
|
||||
)
|
||||
SELECT users.id AS user_id,
|
||||
users.email AS user_email,
|
||||
users.username AS user_username,
|
||||
users.hashed_password AS user_hashed_password,
|
||||
users.created_at AS user_created_at,
|
||||
users.updated_at AS user_updated_at,
|
||||
users.status AS user_status,
|
||||
users.rbac_roles AS user_rbac_roles,
|
||||
users.login_type AS user_login_type,
|
||||
users.avatar_url AS user_avatar_url,
|
||||
users.deleted AS user_deleted,
|
||||
users.last_seen_at AS user_last_seen_at,
|
||||
users.quiet_hours_schedule AS user_quiet_hours_schedule,
|
||||
users.name AS user_name,
|
||||
users.github_com_user_id AS user_github_com_user_id,
|
||||
users.is_system AS user_is_system,
|
||||
groups.organization_id,
|
||||
groups.name AS group_name,
|
||||
all_members.group_id
|
||||
FROM ((all_members
|
||||
JOIN users ON ((users.id = all_members.user_id)))
|
||||
JOIN groups ON ((groups.id = all_members.group_id)))
|
||||
WHERE (users.deleted = false);
|
||||
@@ -0,0 +1,36 @@
|
||||
DROP VIEW group_members_expanded;
|
||||
|
||||
CREATE VIEW group_members_expanded AS
|
||||
WITH all_members AS (
|
||||
SELECT group_members.user_id,
|
||||
group_members.group_id
|
||||
FROM group_members
|
||||
UNION
|
||||
SELECT organization_members.user_id,
|
||||
organization_members.organization_id AS group_id
|
||||
FROM organization_members
|
||||
)
|
||||
SELECT users.id AS user_id,
|
||||
users.email AS user_email,
|
||||
users.username AS user_username,
|
||||
users.hashed_password AS user_hashed_password,
|
||||
users.created_at AS user_created_at,
|
||||
users.updated_at AS user_updated_at,
|
||||
users.status AS user_status,
|
||||
users.rbac_roles AS user_rbac_roles,
|
||||
users.login_type AS user_login_type,
|
||||
users.avatar_url AS user_avatar_url,
|
||||
users.deleted AS user_deleted,
|
||||
users.last_seen_at AS user_last_seen_at,
|
||||
users.quiet_hours_schedule AS user_quiet_hours_schedule,
|
||||
users.name AS user_name,
|
||||
users.github_com_user_id AS user_github_com_user_id,
|
||||
users.is_system AS user_is_system,
|
||||
users.is_service_account as user_is_service_account,
|
||||
groups.organization_id,
|
||||
groups.name AS group_name,
|
||||
all_members.group_id
|
||||
FROM ((all_members
|
||||
JOIN users ON ((users.id = all_members.user_id)))
|
||||
JOIN groups ON ((groups.id = all_members.group_id)))
|
||||
WHERE (users.deleted = false);
|
||||
@@ -0,0 +1,5 @@
|
||||
DROP INDEX IF EXISTS idx_aibridge_interceptions_session_id;
|
||||
DROP INDEX IF EXISTS idx_aibridge_user_prompts_interception_created;
|
||||
DROP INDEX IF EXISTS idx_aibridge_interceptions_sessions_filter;
|
||||
|
||||
ALTER TABLE aibridge_interceptions DROP COLUMN IF EXISTS session_id;
|
||||
@@ -0,0 +1,40 @@
|
||||
-- A "session" groups related interceptions together. See the COMMENT ON
|
||||
-- COLUMN below for the full business-logic description.
|
||||
ALTER TABLE aibridge_interceptions
|
||||
ADD COLUMN session_id TEXT NOT NULL
|
||||
GENERATED ALWAYS AS (
|
||||
COALESCE(
|
||||
client_session_id,
|
||||
thread_root_id::text,
|
||||
id::text
|
||||
)
|
||||
) STORED;
|
||||
|
||||
-- Searching and grouping on the resolved session ID will be common.
|
||||
CREATE INDEX idx_aibridge_interceptions_session_id
|
||||
ON aibridge_interceptions (session_id)
|
||||
WHERE ended_at IS NOT NULL;
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.session_id IS
|
||||
'Groups related interceptions into a logical session. '
|
||||
'Determined by a priority chain: '
|
||||
'(1) client_session_id — an explicit session identifier supplied by the '
|
||||
'calling client (e.g. Claude Code); '
|
||||
'(2) thread_root_id — the root of an agentic thread detected by Bridge '
|
||||
'through tool-call correlation, used when the client does not supply its '
|
||||
'own session ID; '
|
||||
'(3) id — the interception''s own ID, used as a last resort so every '
|
||||
'interception belongs to exactly one session even if it is standalone. '
|
||||
'This is a generated column stored on disk so it can be indexed and '
|
||||
'joined without recomputing the COALESCE on every query.';
|
||||
|
||||
-- Composite index for the most common filter path used by
|
||||
-- ListAIBridgeSessions: initiator_id equality + started_at range,
|
||||
-- with ended_at IS NOT NULL as a partial filter.
|
||||
CREATE INDEX idx_aibridge_interceptions_sessions_filter
|
||||
ON aibridge_interceptions (initiator_id, started_at DESC, id DESC)
|
||||
WHERE ended_at IS NOT NULL;
|
||||
|
||||
-- Supports lateral prompt lookup by interception + recency.
|
||||
CREATE INDEX idx_aibridge_user_prompts_interception_created
|
||||
ON aibridge_user_prompts (interception_id, created_at DESC, id DESC);
|
||||
@@ -806,6 +806,8 @@ type aibridgeQuerier interface {
|
||||
ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error)
|
||||
CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error)
|
||||
ListAuthorizedAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error)
|
||||
ListAuthorizedAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionsRow, error)
|
||||
CountAuthorizedAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error)
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error) {
|
||||
@@ -852,6 +854,7 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
|
||||
&i.AIBridgeInterception.ThreadParentID,
|
||||
&i.AIBridgeInterception.ThreadRootID,
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.AIBridgeInterception.SessionID,
|
||||
&i.VisibleUser.ID,
|
||||
&i.VisibleUser.Username,
|
||||
&i.VisibleUser.Name,
|
||||
@@ -939,6 +942,109 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeModels(ctx context.Context, arg ListA
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionsRow, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
|
||||
VariableConverter: regosql.AIBridgeInterceptionConverter(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
filtered, err := insertAuthorizedFilter(listAIBridgeSessions, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("-- name: ListAuthorizedAIBridgeSessions :many\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.AfterSessionID,
|
||||
arg.Offset,
|
||||
arg.Limit,
|
||||
arg.StartedAfter,
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.SessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ListAIBridgeSessionsRow
|
||||
for rows.Next() {
|
||||
var i ListAIBridgeSessionsRow
|
||||
if err := rows.Scan(
|
||||
&i.SessionID,
|
||||
&i.UserID,
|
||||
&i.UserUsername,
|
||||
&i.UserName,
|
||||
&i.UserAvatarUrl,
|
||||
pq.Array(&i.Providers),
|
||||
pq.Array(&i.Models),
|
||||
&i.Client,
|
||||
&i.Metadata,
|
||||
&i.StartedAt,
|
||||
&i.EndedAt,
|
||||
&i.Threads,
|
||||
&i.InputTokens,
|
||||
&i.OutputTokens,
|
||||
&i.LastPrompt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
|
||||
VariableConverter: regosql.AIBridgeInterceptionConverter(),
|
||||
})
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
filtered, err := insertAuthorizedFilter(countAIBridgeSessions, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("-- name: CountAuthorizedAIBridgeSessions :one\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.StartedAfter,
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.SessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var count int64
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
|
||||
if !strings.Contains(query, authorizedQueryPlaceholder) {
|
||||
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")
|
||||
|
||||
@@ -4036,6 +4036,8 @@ type AIBridgeInterception struct {
|
||||
ThreadRootID uuid.NullUUID `db:"thread_root_id" json:"thread_root_id"`
|
||||
// The session ID supplied by the client (optional and not universally supported).
|
||||
ClientSessionID sql.NullString `db:"client_session_id" json:"client_session_id"`
|
||||
// Groups related interceptions into a logical session. Determined by a priority chain: (1) client_session_id — an explicit session identifier supplied by the calling client (e.g. Claude Code); (2) thread_root_id — the root of an agentic thread detected by Bridge through tool-call correlation, used when the client does not supply its own session ID; (3) id — the interception's own ID, used as a last resort so every interception belongs to exactly one session even if it is standalone. This is a generated column stored on disk so it can be indexed and joined without recomputing the COALESCE on every query.
|
||||
SessionID string `db:"session_id" json:"session_id"`
|
||||
}
|
||||
|
||||
// Audit log of model thinking in intercepted requests in AI Bridge
|
||||
@@ -4394,7 +4396,6 @@ type Group struct {
|
||||
ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"`
|
||||
}
|
||||
|
||||
// Joins group members with user information, organization ID, group name. Includes both regular group members and organization members (as part of the "Everyone" group).
|
||||
type GroupMember struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
UserEmail string `db:"user_email" json:"user_email"`
|
||||
@@ -4412,6 +4413,7 @@ type GroupMember struct {
|
||||
UserName string `db:"user_name" json:"user_name"`
|
||||
UserGithubComUserID sql.NullInt64 `db:"user_github_com_user_id" json:"user_github_com_user_id"`
|
||||
UserIsSystem bool `db:"user_is_system" json:"user_is_system"`
|
||||
UserIsServiceAccount bool `db:"user_is_service_account" json:"user_is_service_account"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
GroupName string `db:"group_name" json:"group_name"`
|
||||
GroupID uuid.UUID `db:"group_id" json:"group_id"`
|
||||
|
||||
@@ -76,6 +76,7 @@ type sqlcQuerier interface {
|
||||
CleanTailnetTunnels(ctx context.Context) error
|
||||
CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
|
||||
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
|
||||
CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error)
|
||||
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
|
||||
CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error)
|
||||
// Counts enabled, non-deleted model configs that lack both input and
|
||||
@@ -148,6 +149,7 @@ type sqlcQuerier interface {
|
||||
DeleteTailnetPeer(ctx context.Context, arg DeleteTailnetPeerParams) (DeleteTailnetPeerRow, error)
|
||||
DeleteTailnetTunnel(ctx context.Context, arg DeleteTailnetTunnelParams) (DeleteTailnetTunnelRow, error)
|
||||
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
|
||||
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
|
||||
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
@@ -553,6 +555,7 @@ type sqlcQuerier interface {
|
||||
GetUserActivityInsights(ctx context.Context, arg GetUserActivityInsightsParams) ([]GetUserActivityInsightsRow, error)
|
||||
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
|
||||
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
|
||||
GetUserChatCompactionThreshold(ctx context.Context, arg GetUserChatCompactionThresholdParams) (string, error)
|
||||
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
|
||||
GetUserChatSpendInPeriod(ctx context.Context, arg GetUserChatSpendInPeriodParams) (int64, error)
|
||||
GetUserCount(ctx context.Context, includeSystem bool) (int64, error)
|
||||
@@ -757,6 +760,10 @@ type sqlcQuerier interface {
|
||||
// (provider, model, client) in the given timeframe for telemetry reporting.
|
||||
ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error)
|
||||
ListAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams) ([]string, error)
|
||||
// Returns paginated sessions with aggregated metadata, token counts, and
|
||||
// the most recent user prompt. A "session" is a logical grouping of
|
||||
// interceptions that share the same session_id (set by the client).
|
||||
ListAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams) ([]ListAIBridgeSessionsRow, error)
|
||||
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
|
||||
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
|
||||
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
|
||||
@@ -765,6 +772,7 @@ type sqlcQuerier interface {
|
||||
ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
|
||||
ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]UserConfig, error)
|
||||
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
|
||||
ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error)
|
||||
MarkAllInboxNotificationsAsRead(ctx context.Context, arg MarkAllInboxNotificationsAsReadParams) error
|
||||
@@ -868,6 +876,7 @@ type sqlcQuerier interface {
|
||||
UpdateTemplateVersionFlagsByJobID(ctx context.Context, arg UpdateTemplateVersionFlagsByJobIDParams) error
|
||||
UpdateTemplateWorkspacesLastUsedAt(ctx context.Context, arg UpdateTemplateWorkspacesLastUsedAtParams) error
|
||||
UpdateUsageEventsPostPublish(ctx context.Context, arg UpdateUsageEventsPostPublishParams) error
|
||||
UpdateUserChatCompactionThreshold(ctx context.Context, arg UpdateUserChatCompactionThresholdParams) (UserConfig, error)
|
||||
UpdateUserChatCustomPrompt(ctx context.Context, arg UpdateUserChatCustomPromptParams) (UserConfig, error)
|
||||
UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error
|
||||
UpdateUserGithubComUserID(ctx context.Context, arg UpdateUserGithubComUserIDParams) error
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
@@ -35,6 +34,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
|
||||
@@ -332,6 +332,77 @@ func (q *sqlQuerier) CountAIBridgeInterceptions(ctx context.Context, arg CountAI
|
||||
return count, err
|
||||
}
|
||||
|
||||
const countAIBridgeSessions = `-- name: CountAIBridgeSessions :one
|
||||
SELECT
|
||||
COUNT(DISTINCT (aibridge_interceptions.session_id, aibridge_interceptions.initiator_id))
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
-- Remove inflight interceptions (ones which lack an ended_at value).
|
||||
aibridge_interceptions.ended_at IS NOT NULL
|
||||
-- Filter by time frame
|
||||
AND CASE
|
||||
WHEN $1::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= $1::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
AND CASE
|
||||
WHEN $2::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= $2::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
-- Filter initiator_id
|
||||
AND CASE
|
||||
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = $3::uuid
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider
|
||||
AND CASE
|
||||
WHEN $4::text != '' THEN aibridge_interceptions.provider = $4::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN aibridge_interceptions.model = $5::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter client
|
||||
AND CASE
|
||||
WHEN $6::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $6::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter session_id
|
||||
AND CASE
|
||||
WHEN $7::text != '' THEN aibridge_interceptions.session_id = $7::text
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAIBridgeSessions
|
||||
-- @authorize_filter
|
||||
`
|
||||
|
||||
type CountAIBridgeSessionsParams struct {
|
||||
StartedAfter time.Time `db:"started_after" json:"started_after"`
|
||||
StartedBefore time.Time `db:"started_before" json:"started_before"`
|
||||
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
Model string `db:"model" json:"model"`
|
||||
Client string `db:"client" json:"client"`
|
||||
SessionID string `db:"session_id" json:"session_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error) {
|
||||
row := q.db.QueryRowContext(ctx, countAIBridgeSessions,
|
||||
arg.StartedAfter,
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.SessionID,
|
||||
)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
const deleteOldAIBridgeRecords = `-- name: DeleteOldAIBridgeRecords :one
|
||||
WITH
|
||||
-- We don't have FK relationships between the dependent tables and aibridge_interceptions, so we can't rely on DELETE CASCADE.
|
||||
@@ -384,7 +455,7 @@ func (q *sqlQuerier) DeleteOldAIBridgeRecords(ctx context.Context, beforeTime ti
|
||||
|
||||
const getAIBridgeInterceptionByID = `-- name: GetAIBridgeInterceptionByID :one
|
||||
SELECT
|
||||
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id
|
||||
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
@@ -407,6 +478,7 @@ func (q *sqlQuerier) GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UU
|
||||
&i.ThreadParentID,
|
||||
&i.ThreadRootID,
|
||||
&i.ClientSessionID,
|
||||
&i.SessionID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -441,7 +513,7 @@ func (q *sqlQuerier) GetAIBridgeInterceptionLineageByToolCallID(ctx context.Cont
|
||||
|
||||
const getAIBridgeInterceptions = `-- name: GetAIBridgeInterceptions :many
|
||||
SELECT
|
||||
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id
|
||||
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
`
|
||||
@@ -468,6 +540,7 @@ func (q *sqlQuerier) GetAIBridgeInterceptions(ctx context.Context) ([]AIBridgeIn
|
||||
&i.ThreadParentID,
|
||||
&i.ThreadRootID,
|
||||
&i.ClientSessionID,
|
||||
&i.SessionID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -618,7 +691,7 @@ INSERT INTO aibridge_interceptions (
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, COALESCE($6::jsonb, '{}'::jsonb), $7, $8, $9, $10::uuid, $11::uuid
|
||||
)
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id
|
||||
`
|
||||
|
||||
type InsertAIBridgeInterceptionParams struct {
|
||||
@@ -663,6 +736,7 @@ func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertA
|
||||
&i.ThreadParentID,
|
||||
&i.ThreadRootID,
|
||||
&i.ClientSessionID,
|
||||
&i.SessionID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -837,7 +911,7 @@ func (q *sqlQuerier) InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIB
|
||||
|
||||
const listAIBridgeInterceptions = `-- name: ListAIBridgeInterceptions :many
|
||||
SELECT
|
||||
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id,
|
||||
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id, aibridge_interceptions.session_id,
|
||||
visible_users.id, visible_users.username, visible_users.name, visible_users.avatar_url
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
@@ -949,6 +1023,7 @@ func (q *sqlQuerier) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBr
|
||||
&i.AIBridgeInterception.ThreadParentID,
|
||||
&i.AIBridgeInterception.ThreadRootID,
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.AIBridgeInterception.SessionID,
|
||||
&i.VisibleUser.ID,
|
||||
&i.VisibleUser.Username,
|
||||
&i.VisibleUser.Name,
|
||||
@@ -1071,6 +1146,229 @@ func (q *sqlQuerier) ListAIBridgeModels(ctx context.Context, arg ListAIBridgeMod
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listAIBridgeSessions = `-- name: ListAIBridgeSessions :many
|
||||
WITH filtered_interceptions AS (
|
||||
SELECT
|
||||
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id, aibridge_interceptions.session_id
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
-- Remove inflight interceptions (ones which lack an ended_at value).
|
||||
aibridge_interceptions.ended_at IS NOT NULL
|
||||
-- Filter by time frame
|
||||
AND CASE
|
||||
WHEN $4::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= $4::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
AND CASE
|
||||
WHEN $5::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= $5::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
-- Filter initiator_id
|
||||
AND CASE
|
||||
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = $6::uuid
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider
|
||||
AND CASE
|
||||
WHEN $7::text != '' THEN aibridge_interceptions.provider = $7::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN $8::text != '' THEN aibridge_interceptions.model = $8::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter client
|
||||
AND CASE
|
||||
WHEN $9::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $9::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter session_id
|
||||
AND CASE
|
||||
WHEN $10::text != '' THEN aibridge_interceptions.session_id = $10::text
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeSessions
|
||||
-- @authorize_filter
|
||||
),
|
||||
session_tokens AS (
|
||||
-- Aggregate token usage across all interceptions in each session.
|
||||
-- Group by (session_id, initiator_id) to avoid merging sessions from
|
||||
-- different users who happen to share the same client_session_id.
|
||||
SELECT
|
||||
fi.session_id,
|
||||
fi.initiator_id,
|
||||
COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens,
|
||||
COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens
|
||||
-- TODO: add extra token types once https://github.com/coder/aibridge/issues/150 lands.
|
||||
FROM
|
||||
filtered_interceptions fi
|
||||
LEFT JOIN
|
||||
aibridge_token_usages tu ON fi.id = tu.interception_id
|
||||
GROUP BY
|
||||
fi.session_id, fi.initiator_id
|
||||
),
|
||||
session_root AS (
|
||||
-- Build one summary row per session. Group by (session_id, initiator_id)
|
||||
-- to avoid merging sessions from different users who happen to share the
|
||||
-- same client_session_id. The ARRAY_AGG with ORDER BY picks values from
|
||||
-- the chronologically first interception for fields that should represent
|
||||
-- the session as a whole (client, metadata). Threads are counted as
|
||||
-- distinct root interception IDs: an interception with a NULL
|
||||
-- thread_root_id is itself a thread root.
|
||||
SELECT
|
||||
fi.session_id,
|
||||
fi.initiator_id,
|
||||
(ARRAY_AGG(fi.client ORDER BY fi.started_at, fi.id))[1] AS client,
|
||||
(ARRAY_AGG(fi.metadata ORDER BY fi.started_at, fi.id))[1] AS metadata,
|
||||
ARRAY_AGG(DISTINCT fi.provider ORDER BY fi.provider) AS providers,
|
||||
ARRAY_AGG(DISTINCT fi.model ORDER BY fi.model) AS models,
|
||||
MIN(fi.started_at) AS started_at,
|
||||
MAX(fi.ended_at) AS ended_at,
|
||||
COUNT(DISTINCT COALESCE(fi.thread_root_id, fi.id)) AS threads,
|
||||
-- Collect IDs for lateral prompt lookup.
|
||||
ARRAY_AGG(fi.id) AS interception_ids
|
||||
FROM
|
||||
filtered_interceptions fi
|
||||
GROUP BY
|
||||
fi.session_id, fi.initiator_id
|
||||
)
|
||||
SELECT
|
||||
sr.session_id,
|
||||
visible_users.id AS user_id,
|
||||
visible_users.username AS user_username,
|
||||
visible_users.name AS user_name,
|
||||
visible_users.avatar_url AS user_avatar_url,
|
||||
sr.providers::text[] AS providers,
|
||||
sr.models::text[] AS models,
|
||||
COALESCE(sr.client, '')::varchar(64) AS client,
|
||||
sr.metadata::jsonb AS metadata,
|
||||
sr.started_at::timestamptz AS started_at,
|
||||
sr.ended_at::timestamptz AS ended_at,
|
||||
sr.threads,
|
||||
COALESCE(st.input_tokens, 0)::bigint AS input_tokens,
|
||||
COALESCE(st.output_tokens, 0)::bigint AS output_tokens,
|
||||
COALESCE(slp.prompt, '') AS last_prompt
|
||||
FROM
|
||||
session_root sr
|
||||
JOIN
|
||||
visible_users ON visible_users.id = sr.initiator_id
|
||||
LEFT JOIN
|
||||
session_tokens st ON st.session_id = sr.session_id AND st.initiator_id = sr.initiator_id
|
||||
LEFT JOIN LATERAL (
|
||||
-- Lateral join to efficiently fetch only the most recent user prompt
|
||||
-- across all interceptions in the session, avoiding a full aggregation.
|
||||
SELECT up.prompt
|
||||
FROM aibridge_user_prompts up
|
||||
WHERE up.interception_id = ANY(sr.interception_ids)
|
||||
ORDER BY up.created_at DESC, up.id DESC
|
||||
LIMIT 1
|
||||
) slp ON true
|
||||
WHERE
|
||||
-- Cursor pagination: uses a composite (started_at, session_id) cursor
|
||||
-- to support keyset pagination. The less-than comparison matches the
|
||||
-- DESC sort order so that rows after the cursor come later in results.
|
||||
CASE
|
||||
WHEN $1::text != '' THEN (
|
||||
(sr.started_at, sr.session_id) < (
|
||||
(SELECT started_at FROM session_root WHERE session_id = $1),
|
||||
$1::text
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
ORDER BY
|
||||
sr.started_at DESC,
|
||||
sr.session_id DESC
|
||||
LIMIT COALESCE(NULLIF($3::integer, 0), 100)
|
||||
OFFSET $2
|
||||
`
|
||||
|
||||
type ListAIBridgeSessionsParams struct {
|
||||
AfterSessionID string `db:"after_session_id" json:"after_session_id"`
|
||||
Offset int32 `db:"offset_" json:"offset_"`
|
||||
Limit int32 `db:"limit_" json:"limit_"`
|
||||
StartedAfter time.Time `db:"started_after" json:"started_after"`
|
||||
StartedBefore time.Time `db:"started_before" json:"started_before"`
|
||||
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
Model string `db:"model" json:"model"`
|
||||
Client string `db:"client" json:"client"`
|
||||
SessionID string `db:"session_id" json:"session_id"`
|
||||
}
|
||||
|
||||
type ListAIBridgeSessionsRow struct {
|
||||
SessionID string `db:"session_id" json:"session_id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
UserUsername string `db:"user_username" json:"user_username"`
|
||||
UserName string `db:"user_name" json:"user_name"`
|
||||
UserAvatarUrl string `db:"user_avatar_url" json:"user_avatar_url"`
|
||||
Providers []string `db:"providers" json:"providers"`
|
||||
Models []string `db:"models" json:"models"`
|
||||
Client string `db:"client" json:"client"`
|
||||
Metadata json.RawMessage `db:"metadata" json:"metadata"`
|
||||
StartedAt time.Time `db:"started_at" json:"started_at"`
|
||||
EndedAt time.Time `db:"ended_at" json:"ended_at"`
|
||||
Threads int64 `db:"threads" json:"threads"`
|
||||
InputTokens int64 `db:"input_tokens" json:"input_tokens"`
|
||||
OutputTokens int64 `db:"output_tokens" json:"output_tokens"`
|
||||
LastPrompt string `db:"last_prompt" json:"last_prompt"`
|
||||
}
|
||||
|
||||
// Returns paginated sessions with aggregated metadata, token counts, and
|
||||
// the most recent user prompt. A "session" is a logical grouping of
|
||||
// interceptions that share the same session_id (set by the client).
|
||||
func (q *sqlQuerier) ListAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams) ([]ListAIBridgeSessionsRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, listAIBridgeSessions,
|
||||
arg.AfterSessionID,
|
||||
arg.Offset,
|
||||
arg.Limit,
|
||||
arg.StartedAfter,
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.SessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ListAIBridgeSessionsRow
|
||||
for rows.Next() {
|
||||
var i ListAIBridgeSessionsRow
|
||||
if err := rows.Scan(
|
||||
&i.SessionID,
|
||||
&i.UserID,
|
||||
&i.UserUsername,
|
||||
&i.UserName,
|
||||
&i.UserAvatarUrl,
|
||||
pq.Array(&i.Providers),
|
||||
pq.Array(&i.Models),
|
||||
&i.Client,
|
||||
&i.Metadata,
|
||||
&i.StartedAt,
|
||||
&i.EndedAt,
|
||||
&i.Threads,
|
||||
&i.InputTokens,
|
||||
&i.OutputTokens,
|
||||
&i.LastPrompt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listAIBridgeTokenUsagesByInterceptionIDs = `-- name: ListAIBridgeTokenUsagesByInterceptionIDs :many
|
||||
SELECT
|
||||
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
|
||||
@@ -1209,7 +1507,7 @@ UPDATE aibridge_interceptions
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
AND ended_at IS NULL
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id
|
||||
`
|
||||
|
||||
type UpdateAIBridgeInterceptionEndedParams struct {
|
||||
@@ -1233,6 +1531,7 @@ func (q *sqlQuerier) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg Up
|
||||
&i.ThreadParentID,
|
||||
&i.ThreadRootID,
|
||||
&i.ClientSessionID,
|
||||
&i.SessionID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -7293,7 +7592,7 @@ func (q *sqlQuerier) DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteG
|
||||
}
|
||||
|
||||
const getGroupMembers = `-- name: GetGroupMembers :many
|
||||
SELECT user_id, user_email, user_username, user_hashed_password, user_created_at, user_updated_at, user_status, user_rbac_roles, user_login_type, user_avatar_url, user_deleted, user_last_seen_at, user_quiet_hours_schedule, user_name, user_github_com_user_id, user_is_system, organization_id, group_name, group_id FROM group_members_expanded
|
||||
SELECT user_id, user_email, user_username, user_hashed_password, user_created_at, user_updated_at, user_status, user_rbac_roles, user_login_type, user_avatar_url, user_deleted, user_last_seen_at, user_quiet_hours_schedule, user_name, user_github_com_user_id, user_is_system, user_is_service_account, organization_id, group_name, group_id FROM group_members_expanded
|
||||
WHERE CASE
|
||||
WHEN $1::bool THEN TRUE
|
||||
ELSE
|
||||
@@ -7327,6 +7626,7 @@ func (q *sqlQuerier) GetGroupMembers(ctx context.Context, includeSystem bool) ([
|
||||
&i.UserName,
|
||||
&i.UserGithubComUserID,
|
||||
&i.UserIsSystem,
|
||||
&i.UserIsServiceAccount,
|
||||
&i.OrganizationID,
|
||||
&i.GroupName,
|
||||
&i.GroupID,
|
||||
@@ -7345,7 +7645,7 @@ func (q *sqlQuerier) GetGroupMembers(ctx context.Context, includeSystem bool) ([
|
||||
}
|
||||
|
||||
const getGroupMembersByGroupID = `-- name: GetGroupMembersByGroupID :many
|
||||
SELECT user_id, user_email, user_username, user_hashed_password, user_created_at, user_updated_at, user_status, user_rbac_roles, user_login_type, user_avatar_url, user_deleted, user_last_seen_at, user_quiet_hours_schedule, user_name, user_github_com_user_id, user_is_system, organization_id, group_name, group_id
|
||||
SELECT user_id, user_email, user_username, user_hashed_password, user_created_at, user_updated_at, user_status, user_rbac_roles, user_login_type, user_avatar_url, user_deleted, user_last_seen_at, user_quiet_hours_schedule, user_name, user_github_com_user_id, user_is_system, user_is_service_account, organization_id, group_name, group_id
|
||||
FROM group_members_expanded
|
||||
WHERE group_id = $1
|
||||
-- Filter by system type
|
||||
@@ -7387,6 +7687,7 @@ func (q *sqlQuerier) GetGroupMembersByGroupID(ctx context.Context, arg GetGroupM
|
||||
&i.UserName,
|
||||
&i.UserGithubComUserID,
|
||||
&i.UserIsSystem,
|
||||
&i.UserIsServiceAccount,
|
||||
&i.OrganizationID,
|
||||
&i.GroupName,
|
||||
&i.GroupID,
|
||||
@@ -7406,7 +7707,7 @@ func (q *sqlQuerier) GetGroupMembersByGroupID(ctx context.Context, arg GetGroupM
|
||||
|
||||
const getGroupMembersByGroupIDPaginated = `-- name: GetGroupMembersByGroupIDPaginated :many
|
||||
SELECT
|
||||
user_id, user_email, user_username, user_hashed_password, user_created_at, user_updated_at, user_status, user_rbac_roles, user_login_type, user_avatar_url, user_deleted, user_last_seen_at, user_quiet_hours_schedule, user_name, user_github_com_user_id, user_is_system, organization_id, group_name, group_id, COUNT(*) OVER() AS count
|
||||
user_id, user_email, user_username, user_hashed_password, user_created_at, user_updated_at, user_status, user_rbac_roles, user_login_type, user_avatar_url, user_deleted, user_last_seen_at, user_quiet_hours_schedule, user_name, user_github_com_user_id, user_is_system, user_is_service_account, organization_id, group_name, group_id, COUNT(*) OVER() AS count
|
||||
FROM
|
||||
group_members_expanded
|
||||
WHERE
|
||||
@@ -7544,6 +7845,7 @@ type GetGroupMembersByGroupIDPaginatedRow struct {
|
||||
UserName string `db:"user_name" json:"user_name"`
|
||||
UserGithubComUserID sql.NullInt64 `db:"user_github_com_user_id" json:"user_github_com_user_id"`
|
||||
UserIsSystem bool `db:"user_is_system" json:"user_is_system"`
|
||||
UserIsServiceAccount bool `db:"user_is_service_account" json:"user_is_service_account"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
GroupName string `db:"group_name" json:"group_name"`
|
||||
GroupID uuid.UUID `db:"group_id" json:"group_id"`
|
||||
@@ -7592,6 +7894,7 @@ func (q *sqlQuerier) GetGroupMembersByGroupIDPaginated(ctx context.Context, arg
|
||||
&i.UserName,
|
||||
&i.UserGithubComUserID,
|
||||
&i.UserIsSystem,
|
||||
&i.UserIsServiceAccount,
|
||||
&i.OrganizationID,
|
||||
&i.GroupName,
|
||||
&i.GroupID,
|
||||
@@ -16634,7 +16937,7 @@ FROM
|
||||
(
|
||||
-- Select all groups this user is a member of. This will also include
|
||||
-- the "Everyone" group for organizations the user is a member of.
|
||||
SELECT user_id, user_email, user_username, user_hashed_password, user_created_at, user_updated_at, user_status, user_rbac_roles, user_login_type, user_avatar_url, user_deleted, user_last_seen_at, user_quiet_hours_schedule, user_name, user_github_com_user_id, user_is_system, organization_id, group_name, group_id FROM group_members_expanded
|
||||
SELECT user_id, user_email, user_username, user_hashed_password, user_created_at, user_updated_at, user_status, user_rbac_roles, user_login_type, user_avatar_url, user_deleted, user_last_seen_at, user_quiet_hours_schedule, user_name, user_github_com_user_id, user_is_system, user_is_service_account, organization_id, group_name, group_id FROM group_members_expanded
|
||||
WHERE
|
||||
$1 = user_id AND
|
||||
$2 = group_members_expanded.organization_id
|
||||
@@ -21153,6 +21456,20 @@ func (q *sqlQuerier) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const deleteUserChatCompactionThreshold = `-- name: DeleteUserChatCompactionThreshold :exec
|
||||
DELETE FROM user_configs WHERE user_id = $1 AND key = $2
|
||||
`
|
||||
|
||||
type DeleteUserChatCompactionThresholdParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Key string `db:"key" json:"key"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteUserChatCompactionThreshold, arg.UserID, arg.Key)
|
||||
return err
|
||||
}
|
||||
|
||||
const getActiveUserCount = `-- name: GetActiveUserCount :one
|
||||
SELECT
|
||||
COUNT(*)
|
||||
@@ -21333,6 +21650,23 @@ func (q *sqlQuerier) GetUserByID(ctx context.Context, id uuid.UUID) (User, error
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserChatCompactionThreshold = `-- name: GetUserChatCompactionThreshold :one
|
||||
SELECT value AS threshold_percent FROM user_configs
|
||||
WHERE user_id = $1 AND key = $2
|
||||
`
|
||||
|
||||
type GetUserChatCompactionThresholdParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Key string `db:"key" json:"key"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetUserChatCompactionThreshold(ctx context.Context, arg GetUserChatCompactionThresholdParams) (string, error) {
|
||||
row := q.db.QueryRowContext(ctx, getUserChatCompactionThreshold, arg.UserID, arg.Key)
|
||||
var threshold_percent string
|
||||
err := row.Scan(&threshold_percent)
|
||||
return threshold_percent, err
|
||||
}
|
||||
|
||||
const getUserChatCustomPrompt = `-- name: GetUserChatCustomPrompt :one
|
||||
SELECT
|
||||
value as chat_custom_prompt
|
||||
@@ -21756,6 +22090,36 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listUserChatCompactionThresholds = `-- name: ListUserChatCompactionThresholds :many
|
||||
SELECT user_id, key, value FROM user_configs
|
||||
WHERE user_id = $1
|
||||
AND key LIKE 'chat\_compaction\_threshold\_pct:%'
|
||||
ORDER BY key
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]UserConfig, error) {
|
||||
rows, err := q.db.QueryContext(ctx, listUserChatCompactionThresholds, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []UserConfig
|
||||
for rows.Next() {
|
||||
var i UserConfig
|
||||
if err := rows.Scan(&i.UserID, &i.Key, &i.Value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateInactiveUsersToDormant = `-- name: UpdateInactiveUsersToDormant :many
|
||||
UPDATE
|
||||
users
|
||||
@@ -21809,6 +22173,27 @@ func (q *sqlQuerier) UpdateInactiveUsersToDormant(ctx context.Context, arg Updat
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateUserChatCompactionThreshold = `-- name: UpdateUserChatCompactionThreshold :one
|
||||
INSERT INTO user_configs (user_id, key, value)
|
||||
VALUES ($1, $2, ($3::int)::text)
|
||||
ON CONFLICT ON CONSTRAINT user_configs_pkey
|
||||
DO UPDATE SET value = ($3::int)::text
|
||||
RETURNING user_id, key, value
|
||||
`
|
||||
|
||||
type UpdateUserChatCompactionThresholdParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Key string `db:"key" json:"key"`
|
||||
ThresholdPercent int32 `db:"threshold_percent" json:"threshold_percent"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateUserChatCompactionThreshold(ctx context.Context, arg UpdateUserChatCompactionThresholdParams) (UserConfig, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateUserChatCompactionThreshold, arg.UserID, arg.Key, arg.ThresholdPercent)
|
||||
var i UserConfig
|
||||
err := row.Scan(&i.UserID, &i.Key, &i.Value)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateUserChatCustomPrompt = `-- name: UpdateUserChatCustomPrompt :one
|
||||
INSERT INTO
|
||||
user_configs (user_id, key, value)
|
||||
|
||||
@@ -404,6 +404,194 @@ SELECT (
|
||||
(SELECT COUNT(*) FROM interceptions)
|
||||
)::bigint as total_deleted;
|
||||
|
||||
-- name: CountAIBridgeSessions :one
|
||||
SELECT
|
||||
COUNT(DISTINCT (aibridge_interceptions.session_id, aibridge_interceptions.initiator_id))
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
-- Remove inflight interceptions (ones which lack an ended_at value).
|
||||
aibridge_interceptions.ended_at IS NOT NULL
|
||||
-- Filter by time frame
|
||||
AND CASE
|
||||
WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
AND CASE
|
||||
WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= @started_before::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
-- Filter initiator_id
|
||||
AND CASE
|
||||
WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = @initiator_id::uuid
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider
|
||||
AND CASE
|
||||
WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter client
|
||||
AND CASE
|
||||
WHEN @client::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = @client::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter session_id
|
||||
AND CASE
|
||||
WHEN @session_id::text != '' THEN aibridge_interceptions.session_id = @session_id::text
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAIBridgeSessions
|
||||
-- @authorize_filter
|
||||
;
|
||||
|
||||
-- name: ListAIBridgeSessions :many
|
||||
-- Returns paginated sessions with aggregated metadata, token counts, and
|
||||
-- the most recent user prompt. A "session" is a logical grouping of
|
||||
-- interceptions that share the same session_id (set by the client).
|
||||
WITH filtered_interceptions AS (
|
||||
SELECT
|
||||
aibridge_interceptions.*
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
-- Remove inflight interceptions (ones which lack an ended_at value).
|
||||
aibridge_interceptions.ended_at IS NOT NULL
|
||||
-- Filter by time frame
|
||||
AND CASE
|
||||
WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
AND CASE
|
||||
WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= @started_before::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
-- Filter initiator_id
|
||||
AND CASE
|
||||
WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = @initiator_id::uuid
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider
|
||||
AND CASE
|
||||
WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter client
|
||||
AND CASE
|
||||
WHEN @client::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = @client::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter session_id
|
||||
AND CASE
|
||||
WHEN @session_id::text != '' THEN aibridge_interceptions.session_id = @session_id::text
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeSessions
|
||||
-- @authorize_filter
|
||||
),
|
||||
session_tokens AS (
|
||||
-- Aggregate token usage across all interceptions in each session.
|
||||
-- Group by (session_id, initiator_id) to avoid merging sessions from
|
||||
-- different users who happen to share the same client_session_id.
|
||||
SELECT
|
||||
fi.session_id,
|
||||
fi.initiator_id,
|
||||
COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens,
|
||||
COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens
|
||||
-- TODO: add extra token types once https://github.com/coder/aibridge/issues/150 lands.
|
||||
FROM
|
||||
filtered_interceptions fi
|
||||
LEFT JOIN
|
||||
aibridge_token_usages tu ON fi.id = tu.interception_id
|
||||
GROUP BY
|
||||
fi.session_id, fi.initiator_id
|
||||
),
|
||||
session_root AS (
|
||||
-- Build one summary row per session. Group by (session_id, initiator_id)
|
||||
-- to avoid merging sessions from different users who happen to share the
|
||||
-- same client_session_id. The ARRAY_AGG with ORDER BY picks values from
|
||||
-- the chronologically first interception for fields that should represent
|
||||
-- the session as a whole (client, metadata). Threads are counted as
|
||||
-- distinct root interception IDs: an interception with a NULL
|
||||
-- thread_root_id is itself a thread root.
|
||||
SELECT
|
||||
fi.session_id,
|
||||
fi.initiator_id,
|
||||
(ARRAY_AGG(fi.client ORDER BY fi.started_at, fi.id))[1] AS client,
|
||||
(ARRAY_AGG(fi.metadata ORDER BY fi.started_at, fi.id))[1] AS metadata,
|
||||
ARRAY_AGG(DISTINCT fi.provider ORDER BY fi.provider) AS providers,
|
||||
ARRAY_AGG(DISTINCT fi.model ORDER BY fi.model) AS models,
|
||||
MIN(fi.started_at) AS started_at,
|
||||
MAX(fi.ended_at) AS ended_at,
|
||||
COUNT(DISTINCT COALESCE(fi.thread_root_id, fi.id)) AS threads,
|
||||
-- Collect IDs for lateral prompt lookup.
|
||||
ARRAY_AGG(fi.id) AS interception_ids
|
||||
FROM
|
||||
filtered_interceptions fi
|
||||
GROUP BY
|
||||
fi.session_id, fi.initiator_id
|
||||
)
|
||||
SELECT
|
||||
sr.session_id,
|
||||
visible_users.id AS user_id,
|
||||
visible_users.username AS user_username,
|
||||
visible_users.name AS user_name,
|
||||
visible_users.avatar_url AS user_avatar_url,
|
||||
sr.providers::text[] AS providers,
|
||||
sr.models::text[] AS models,
|
||||
COALESCE(sr.client, '')::varchar(64) AS client,
|
||||
sr.metadata::jsonb AS metadata,
|
||||
sr.started_at::timestamptz AS started_at,
|
||||
sr.ended_at::timestamptz AS ended_at,
|
||||
sr.threads,
|
||||
COALESCE(st.input_tokens, 0)::bigint AS input_tokens,
|
||||
COALESCE(st.output_tokens, 0)::bigint AS output_tokens,
|
||||
COALESCE(slp.prompt, '') AS last_prompt
|
||||
FROM
|
||||
session_root sr
|
||||
JOIN
|
||||
visible_users ON visible_users.id = sr.initiator_id
|
||||
LEFT JOIN
|
||||
session_tokens st ON st.session_id = sr.session_id AND st.initiator_id = sr.initiator_id
|
||||
LEFT JOIN LATERAL (
|
||||
-- Lateral join to efficiently fetch only the most recent user prompt
|
||||
-- across all interceptions in the session, avoiding a full aggregation.
|
||||
SELECT up.prompt
|
||||
FROM aibridge_user_prompts up
|
||||
WHERE up.interception_id = ANY(sr.interception_ids)
|
||||
ORDER BY up.created_at DESC, up.id DESC
|
||||
LIMIT 1
|
||||
) slp ON true
|
||||
WHERE
|
||||
-- Cursor pagination: uses a composite (started_at, session_id) cursor
|
||||
-- to support keyset pagination. The less-than comparison matches the
|
||||
-- DESC sort order so that rows after the cursor come later in results.
|
||||
CASE
|
||||
WHEN @after_session_id::text != '' THEN (
|
||||
(sr.started_at, sr.session_id) < (
|
||||
(SELECT started_at FROM session_root WHERE session_id = @after_session_id),
|
||||
@after_session_id::text
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
ORDER BY
|
||||
sr.started_at DESC,
|
||||
sr.session_id DESC
|
||||
LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100)
|
||||
OFFSET @offset_
|
||||
;
|
||||
|
||||
-- name: ListAIBridgeModels :many
|
||||
SELECT
|
||||
model
|
||||
|
||||
@@ -193,6 +193,26 @@ WHERE user_configs.user_id = @user_id
|
||||
AND user_configs.key = 'chat_custom_prompt'
|
||||
RETURNING *;
|
||||
|
||||
-- name: ListUserChatCompactionThresholds :many
|
||||
SELECT user_id, key, value FROM user_configs
|
||||
WHERE user_id = @user_id
|
||||
AND key LIKE 'chat\_compaction\_threshold\_pct:%'
|
||||
ORDER BY key;
|
||||
|
||||
-- name: GetUserChatCompactionThreshold :one
|
||||
SELECT value AS threshold_percent FROM user_configs
|
||||
WHERE user_id = @user_id AND key = @key;
|
||||
|
||||
-- name: UpdateUserChatCompactionThreshold :one
|
||||
INSERT INTO user_configs (user_id, key, value)
|
||||
VALUES (@user_id, @key, (@threshold_percent::int)::text)
|
||||
ON CONFLICT ON CONSTRAINT user_configs_pkey
|
||||
DO UPDATE SET value = (@threshold_percent::int)::text
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteUserChatCompactionThreshold :exec
|
||||
DELETE FROM user_configs WHERE user_id = @user_id AND key = @key;
|
||||
|
||||
-- name: GetUserTaskNotificationAlertDismissed :one
|
||||
SELECT
|
||||
value::boolean as task_notification_alert_dismissed
|
||||
|
||||
@@ -3,7 +3,7 @@ package dynamicparameters
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"slices"
|
||||
|
||||
"github.com/hashicorp/hcl/v2"
|
||||
|
||||
@@ -94,7 +94,7 @@ func (e *DiagnosticError) Response() (int, codersdk.Response) {
|
||||
for name := range e.KeyedDiagnostics {
|
||||
sortedNames = append(sortedNames, name)
|
||||
}
|
||||
sort.Strings(sortedNames)
|
||||
slices.Sort(sortedNames)
|
||||
|
||||
for _, name := range sortedNames {
|
||||
diag := e.KeyedDiagnostics[name]
|
||||
|
||||
@@ -28,14 +28,11 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
|
||||
"github.com/coder/coder/v2/coderd/gitsync"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpapi/httperror"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
@@ -46,6 +43,9 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/gitsync"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/wsjson"
|
||||
"github.com/coder/websocket"
|
||||
@@ -2542,6 +2542,17 @@ func normalizeChatCompressionThreshold(
|
||||
return threshold, nil
|
||||
}
|
||||
|
||||
func parseCompactionThresholdKey(key string) (uuid.UUID, error) {
|
||||
if !strings.HasPrefix(key, codersdk.ChatCompactionThresholdKeyPrefix) {
|
||||
return uuid.Nil, xerrors.Errorf("invalid compaction threshold key: %q", key)
|
||||
}
|
||||
id, err := uuid.Parse(key[len(codersdk.ChatCompactionThresholdKeyPrefix):])
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("invalid model config ID in key %q: %w", key, err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
const (
|
||||
// maxChatFileSize is the maximum size of a chat file upload (10 MB).
|
||||
maxChatFileSize = 10 << 20
|
||||
@@ -2816,6 +2827,170 @@ func (api *API) putUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request)
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Get user chat compaction thresholds
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
func (api *API) getUserChatCompactionThresholds(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
apiKey = httpmw.APIKey(r)
|
||||
)
|
||||
|
||||
rows, err := api.Database.ListUserChatCompactionThresholds(ctx, apiKey.UserID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Error listing user chat compaction thresholds.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp := codersdk.UserChatCompactionThresholds{
|
||||
Thresholds: make([]codersdk.UserChatCompactionThreshold, 0, len(rows)),
|
||||
}
|
||||
for _, row := range rows {
|
||||
modelConfigID, err := parseCompactionThresholdKey(row.Key)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "skipping malformed user chat compaction threshold key",
|
||||
slog.F("key", row.Key),
|
||||
slog.F("value", row.Value),
|
||||
slog.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
thresholdPercent, err := strconv.ParseInt(row.Value, 10, 32)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "skipping malformed user chat compaction threshold value",
|
||||
slog.F("key", row.Key),
|
||||
slog.F("value", row.Value),
|
||||
slog.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if thresholdPercent < int64(minChatContextCompressionThreshold) ||
|
||||
thresholdPercent > int64(maxChatContextCompressionThreshold) {
|
||||
api.Logger.Warn(ctx, "skipping out-of-range user chat compaction threshold",
|
||||
slog.F("key", row.Key),
|
||||
slog.F("value", row.Value),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
resp.Thresholds = append(resp.Thresholds, codersdk.UserChatCompactionThreshold{
|
||||
ModelConfigID: modelConfigID,
|
||||
ThresholdPercent: int32(thresholdPercent),
|
||||
})
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// @Summary Set user chat compaction threshold for a model config
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) putUserChatCompactionThreshold(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
apiKey = httpmw.APIKey(r)
|
||||
)
|
||||
|
||||
modelConfigID, ok := parseChatModelConfigID(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.UpdateUserChatCompactionThresholdRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
if req.ThresholdPercent < minChatContextCompressionThreshold ||
|
||||
req.ThresholdPercent > maxChatContextCompressionThreshold {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "threshold_percent is out of range.",
|
||||
Detail: fmt.Sprintf(
|
||||
"threshold_percent must be between %d and %d, got %d.",
|
||||
minChatContextCompressionThreshold,
|
||||
maxChatContextCompressionThreshold,
|
||||
req.ThresholdPercent,
|
||||
),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Use system context because GetChatModelConfigByID requires
|
||||
// deployment-config read access, which non-admin users lack.
|
||||
// The user is only checking if the model exists and is enabled
|
||||
// before writing their own personal preference.
|
||||
//nolint:gocritic // Non-admin users need this lookup to save their own setting.
|
||||
modelConfig, err := api.Database.GetChatModelConfigByID(dbauthz.AsSystemRestricted(ctx), modelConfigID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) || httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get chat model config.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !modelConfig.Enabled {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Model config is disabled.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
_, err = api.Database.UpdateUserChatCompactionThreshold(ctx, database.UpdateUserChatCompactionThresholdParams{
|
||||
UserID: apiKey.UserID,
|
||||
Key: codersdk.CompactionThresholdKey(modelConfigID),
|
||||
ThresholdPercent: req.ThresholdPercent,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Error updating user chat compaction threshold.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCompactionThreshold{
|
||||
ModelConfigID: modelConfigID,
|
||||
ThresholdPercent: req.ThresholdPercent,
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Delete user chat compaction threshold for a model config
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) deleteUserChatCompactionThreshold(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
apiKey = httpmw.APIKey(r)
|
||||
)
|
||||
|
||||
modelConfigID, ok := parseChatModelConfigID(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := api.Database.DeleteUserChatCompactionThreshold(ctx, database.DeleteUserChatCompactionThresholdParams{
|
||||
UserID: apiKey.UserID,
|
||||
Key: codersdk.CompactionThresholdKey(modelConfigID),
|
||||
}); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Error deleting user chat compaction threshold.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (api *API) resolvedChatSystemPrompt(ctx context.Context) string {
|
||||
custom, err := api.Database.GetChatSystemPrompt(ctx)
|
||||
if err != nil {
|
||||
File diff suppressed because it is too large
Load Diff
+10
-27
@@ -21,11 +21,14 @@ import (
|
||||
|
||||
func TestPostFiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each sub-test
|
||||
// creates independent resources with unique IDs so parallel
|
||||
// execution is safe.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
t.Run("BadContentType", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
@@ -35,9 +38,6 @@ func TestPostFiles(t *testing.T) {
|
||||
|
||||
t.Run("Insert", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
@@ -47,9 +47,6 @@ func TestPostFiles(t *testing.T) {
|
||||
|
||||
t.Run("InsertWindowsZip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
@@ -59,9 +56,6 @@ func TestPostFiles(t *testing.T) {
|
||||
|
||||
t.Run("InsertAlreadyExists", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
@@ -73,9 +67,6 @@ func TestPostFiles(t *testing.T) {
|
||||
})
|
||||
t.Run("InsertConcurrent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
@@ -99,11 +90,12 @@ func TestPostFiles(t *testing.T) {
|
||||
|
||||
func TestDownload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Shared instance — see TestPostFiles for rationale.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
@@ -115,9 +107,6 @@ func TestDownload(t *testing.T) {
|
||||
|
||||
t.Run("InsertTar_DownloadTar", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// given
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
@@ -139,9 +128,6 @@ func TestDownload(t *testing.T) {
|
||||
|
||||
t.Run("InsertZip_DownloadTar", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// given
|
||||
zipContent := archivetest.TestZipFileBytes()
|
||||
|
||||
@@ -164,9 +150,6 @@ func TestDownload(t *testing.T) {
|
||||
|
||||
t.Run("InsertTar_DownloadZip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// given
|
||||
tarball := archivetest.TestTarFileBytes()
|
||||
|
||||
|
||||
+50
-28
@@ -248,12 +248,9 @@ func PrecheckAPIKey(cfg ValidateAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
//
|
||||
// Returns (result, nil) on success or (nil, error) on failure.
|
||||
func ValidateAPIKey(ctx context.Context, cfg ValidateAPIKeyConfig, r *http.Request) (*ValidateAPIKeyResult, *ValidateAPIKeyError) {
|
||||
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
|
||||
if !ok {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: resp,
|
||||
}
|
||||
key, valErr := apiKeyFromRequestValidate(ctx, cfg.DB, cfg.SessionTokenFunc, r)
|
||||
if valErr != nil {
|
||||
return nil, valErr
|
||||
}
|
||||
|
||||
// Log the API key ID for all requests that have a valid key
|
||||
@@ -475,7 +472,7 @@ func ValidateAPIKey(ctx context.Context, cfg ValidateAPIKeyConfig, r *http.Reque
|
||||
actor, userStatus, err := UserRBACSubject(ctx, cfg.DB, key.UserID, key.ScopeSet())
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Internal error fetching user's roles. %s", err.Error()),
|
||||
@@ -492,6 +489,15 @@ func ValidateAPIKey(ctx context.Context, cfg ValidateAPIKeyConfig, r *http.Reque
|
||||
}
|
||||
|
||||
func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, codersdk.Response, bool) {
|
||||
key, valErr := apiKeyFromRequestValidate(ctx, db, sessionTokenFunc, r)
|
||||
if valErr != nil {
|
||||
return nil, valErr.Response, false
|
||||
}
|
||||
|
||||
return key, codersdk.Response{}, true
|
||||
}
|
||||
|
||||
func apiKeyFromRequestValidate(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, *ValidateAPIKeyError) {
|
||||
tokenFunc := APITokenFromRequest
|
||||
if sessionTokenFunc != nil {
|
||||
tokenFunc = sessionTokenFunc
|
||||
@@ -499,45 +505,61 @@ func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc
|
||||
|
||||
token := tokenFunc(r)
|
||||
if token == "" {
|
||||
return nil, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie),
|
||||
}, false
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
keyID, keySecret, err := SplitAPIToken(token)
|
||||
if err != nil {
|
||||
return nil, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "Invalid API key format: " + err.Error(),
|
||||
}, false
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "Invalid API key format: " + err.Error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:gocritic // System needs to fetch API key to check if it's valid.
|
||||
key, err := db.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "API key is invalid.",
|
||||
}, false
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "API key is invalid.",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return nil, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()),
|
||||
}, false
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Checking to see if the secret is valid.
|
||||
if !apikey.ValidateHash(key.HashedSecret, keySecret) {
|
||||
return nil, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "API key secret is invalid.",
|
||||
}, false
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "API key secret is invalid.",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return &key, codersdk.Response{}, true
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
// ExtractAPIKey requires authentication using a valid API key. It handles
|
||||
|
||||
@@ -19,12 +19,14 @@ import (
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/apikey"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
@@ -192,6 +194,31 @@ func TestAPIKey(t *testing.T) {
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("GetAPIKeyByIDInternalError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
id, secret, _ := randomAPIKeyParts()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret))
|
||||
|
||||
db.EXPECT().GetAPIKeyByID(gomock.Any(), id).Return(database.APIKey{}, xerrors.New("db unavailable"))
|
||||
|
||||
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusInternalServerError, res.StatusCode)
|
||||
|
||||
var resp codersdk.Response
|
||||
require.NoError(t, json.NewDecoder(res.Body).Decode(&resp))
|
||||
require.NotEqual(t, httpmw.SignedOutErrorMessage, resp.Message)
|
||||
require.Contains(t, resp.Detail, "Internal error fetching API key by id")
|
||||
})
|
||||
|
||||
t.Run("UserLinkNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
|
||||
@@ -14,9 +14,13 @@ import (
|
||||
func TestInitScript(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. All operations
|
||||
// are read-only (fetching init scripts) so parallel execution
|
||||
// is safe.
|
||||
client := coderdtest.New(t, nil)
|
||||
|
||||
t.Run("OK Windows amd64", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
script, err := client.InitScript(context.Background(), "windows", "amd64")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, script)
|
||||
@@ -26,7 +30,6 @@ func TestInitScript(t *testing.T) {
|
||||
|
||||
t.Run("OK Windows arm64", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
script, err := client.InitScript(context.Background(), "windows", "arm64")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, script)
|
||||
@@ -36,7 +39,6 @@ func TestInitScript(t *testing.T) {
|
||||
|
||||
t.Run("OK Linux amd64", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
script, err := client.InitScript(context.Background(), "linux", "amd64")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, script)
|
||||
@@ -46,7 +48,6 @@ func TestInitScript(t *testing.T) {
|
||||
|
||||
t.Run("OK Linux arm64", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
script, err := client.InitScript(context.Background(), "linux", "arm64")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, script)
|
||||
@@ -56,7 +57,6 @@ func TestInitScript(t *testing.T) {
|
||||
|
||||
t.Run("BadRequest", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_, err := client.InitScript(context.Background(), "darwin", "armv7")
|
||||
require.Error(t, err)
|
||||
var apiErr *codersdk.Error
|
||||
|
||||
+97
-2
@@ -6,12 +6,16 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
@@ -107,9 +111,37 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
// Validate auth-type-dependent fields.
|
||||
switch req.AuthType {
|
||||
case "oauth2":
|
||||
if req.OAuth2ClientID == "" || req.OAuth2AuthURL == "" || req.OAuth2TokenURL == "" {
|
||||
// When the admin does not provide OAuth2 credentials, attempt
|
||||
// automatic discovery and Dynamic Client Registration (RFC 7591)
|
||||
// using the MCP server URL. This follows the MCP authorization
|
||||
// spec: discover the authorization server via Protected Resource
|
||||
// Metadata (RFC 9728) and Authorization Server Metadata
|
||||
// (RFC 8414), then register a client dynamically.
|
||||
if req.OAuth2ClientID == "" && req.OAuth2AuthURL == "" && req.OAuth2TokenURL == "" {
|
||||
callbackURL := fmt.Sprintf("%s/api/experimental/mcp/servers/{id}/oauth2/callback", api.AccessURL.String())
|
||||
result, err := discoverAndRegisterMCPOAuth2(ctx, strings.TrimSpace(req.URL), callbackURL)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "mcp oauth2 auto-discovery failed",
|
||||
slog.F("url", req.URL),
|
||||
slog.Error(err),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "OAuth2 auto-discovery failed. Provide oauth2_client_id, oauth2_auth_url, and oauth2_token_url manually, or ensure the MCP server supports RFC 9728 (Protected Resource Metadata) and RFC 7591 (Dynamic Client Registration).",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
req.OAuth2ClientID = result.clientID
|
||||
req.OAuth2ClientSecret = result.clientSecret
|
||||
req.OAuth2AuthURL = result.authURL
|
||||
req.OAuth2TokenURL = result.tokenURL
|
||||
if req.OAuth2Scopes == "" {
|
||||
req.OAuth2Scopes = result.scopes
|
||||
}
|
||||
} else if req.OAuth2ClientID == "" || req.OAuth2AuthURL == "" || req.OAuth2TokenURL == "" {
|
||||
// Partial manual config: all three fields are required together.
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "OAuth2 auth type requires oauth2_client_id, oauth2_auth_url, and oauth2_token_url.",
|
||||
Message: "OAuth2 auth type requires either all of oauth2_client_id, oauth2_auth_url, and oauth2_token_url (manual configuration), or none of them (automatic discovery via RFC 7591).",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -919,3 +951,66 @@ func coalesceStringSlice(ss []string) []string {
|
||||
}
|
||||
return ss
|
||||
}
|
||||
|
||||
// mcpOAuth2Discovery holds the result of MCP OAuth2 auto-discovery
|
||||
// and Dynamic Client Registration.
|
||||
type mcpOAuth2Discovery struct {
|
||||
clientID string
|
||||
clientSecret string
|
||||
authURL string
|
||||
tokenURL string
|
||||
scopes string // space-separated
|
||||
}
|
||||
|
||||
// discoverAndRegisterMCPOAuth2 uses the mcp-go library's OAuthHandler to
|
||||
// perform the MCP OAuth2 discovery and Dynamic Client Registration flow:
|
||||
//
|
||||
// 1. Discover the authorization server via Protected Resource Metadata
|
||||
// (RFC 9728) and Authorization Server Metadata (RFC 8414).
|
||||
// 2. Register a client via Dynamic Client Registration (RFC 7591).
|
||||
// 3. Return the discovered endpoints and generated credentials.
|
||||
func discoverAndRegisterMCPOAuth2(ctx context.Context, mcpServerURL, callbackURL string) (*mcpOAuth2Discovery, error) {
|
||||
// Per the MCP spec, the authorization base URL is the MCP server
|
||||
// URL with the path component discarded (scheme + host only).
|
||||
parsed, err := url.Parse(mcpServerURL)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse MCP server URL: %w", err)
|
||||
}
|
||||
origin := fmt.Sprintf("%s://%s", parsed.Scheme, parsed.Host)
|
||||
|
||||
oauthHandler := transport.NewOAuthHandler(transport.OAuthConfig{
|
||||
RedirectURI: callbackURL,
|
||||
TokenStore: transport.NewMemoryTokenStore(),
|
||||
})
|
||||
oauthHandler.SetBaseURL(origin)
|
||||
|
||||
// Step 1: Discover authorization server metadata (RFC 9728 + RFC 8414).
|
||||
metadata, err := oauthHandler.GetServerMetadata(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("discover authorization server: %w", err)
|
||||
}
|
||||
if metadata.AuthorizationEndpoint == "" {
|
||||
return nil, xerrors.New("authorization server metadata missing authorization_endpoint")
|
||||
}
|
||||
if metadata.TokenEndpoint == "" {
|
||||
return nil, xerrors.New("authorization server metadata missing token_endpoint")
|
||||
}
|
||||
if metadata.RegistrationEndpoint == "" {
|
||||
return nil, xerrors.New("authorization server does not advertise a registration_endpoint (dynamic client registration may not be supported)")
|
||||
}
|
||||
|
||||
// Step 2: Register a client via Dynamic Client Registration (RFC 7591).
|
||||
if err := oauthHandler.RegisterClient(ctx, "Coder"); err != nil {
|
||||
return nil, xerrors.Errorf("dynamic client registration: %w", err)
|
||||
}
|
||||
|
||||
scopes := strings.Join(metadata.ScopesSupported, " ")
|
||||
|
||||
return &mcpOAuth2Discovery{
|
||||
clientID: oauthHandler.GetClientID(),
|
||||
clientSecret: oauthHandler.GetClientSecret(),
|
||||
authURL: metadata.AuthorizationEndpoint,
|
||||
tokenURL: metadata.TokenEndpoint,
|
||||
scopes: scopes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
+175
-4
@@ -3,6 +3,7 @@ package coderd_test
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -430,6 +431,174 @@ func TestMCPServerConfigsOAuth2Disconnect(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestMCPServerConfigsOAuth2AutoDiscovery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Stand up a mock auth server that serves RFC 8414 metadata and
|
||||
// a RFC 7591 dynamic client registration endpoint.
|
||||
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/.well-known/oauth-authorization-server":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"issuer": "` + r.Host + `",
|
||||
"authorization_endpoint": "` + "http://" + r.Host + `/authorize",
|
||||
"token_endpoint": "` + "http://" + r.Host + `/token",
|
||||
"registration_endpoint": "` + "http://" + r.Host + `/register",
|
||||
"response_types_supported": ["code"],
|
||||
"scopes_supported": ["read", "write"]
|
||||
}`))
|
||||
case "/register":
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_, _ = w.Write([]byte(`{
|
||||
"client_id": "auto-discovered-client-id",
|
||||
"client_secret": "auto-discovered-client-secret"
|
||||
}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(authServer.Close)
|
||||
|
||||
// Stand up a mock MCP server that serves RFC 9728 Protected
|
||||
// Resource Metadata pointing to the auth server above.
|
||||
mcpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/.well-known/oauth-protected-resource" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"resource": "` + "http://" + r.Host + `",
|
||||
"authorization_servers": ["` + authServer.URL + `"]
|
||||
}`))
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
t.Cleanup(mcpServer.Close)
|
||||
|
||||
client := newMCPClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create config with auth_type=oauth2 but no OAuth2 fields —
|
||||
// the server should auto-discover them.
|
||||
created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "Auto-Discovery Server",
|
||||
Slug: "auto-discovery",
|
||||
Transport: "streamable_http",
|
||||
URL: mcpServer.URL + "/v1/mcp",
|
||||
AuthType: "oauth2",
|
||||
Availability: "default_on",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "auto-discovered-client-id", created.OAuth2ClientID)
|
||||
require.True(t, created.HasOAuth2Secret)
|
||||
require.Equal(t, authServer.URL+"/authorize", created.OAuth2AuthURL)
|
||||
require.Equal(t, authServer.URL+"/token", created.OAuth2TokenURL)
|
||||
require.Equal(t, "read write", created.OAuth2Scopes)
|
||||
})
|
||||
|
||||
t.Run("PartialOAuth2FieldsRejected", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newMCPClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Provide client_id but omit auth_url and token_url.
|
||||
_, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "Partial Fields",
|
||||
Slug: "partial-oauth2",
|
||||
Transport: "streamable_http",
|
||||
URL: "https://mcp.example.com/partial",
|
||||
AuthType: "oauth2",
|
||||
OAuth2ClientID: "only-client-id",
|
||||
Availability: "default_on",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
require.Contains(t, sdkErr.Message, "automatic discovery")
|
||||
})
|
||||
|
||||
t.Run("DiscoveryFailure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// MCP server that returns 404 for the well-known endpoint and
|
||||
// a non-401 status for the root — discovery has nothing to latch
|
||||
// onto.
|
||||
mcpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}))
|
||||
t.Cleanup(mcpServer.Close)
|
||||
|
||||
client := newMCPClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "Will Fail",
|
||||
Slug: "discovery-fail",
|
||||
Transport: "streamable_http",
|
||||
URL: mcpServer.URL + "/v1/mcp",
|
||||
AuthType: "oauth2",
|
||||
Availability: "default_on",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
require.Contains(t, sdkErr.Message, "auto-discovery failed")
|
||||
})
|
||||
|
||||
t.Run("ManualConfigStillWorks", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newMCPClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Providing all three OAuth2 fields bypasses discovery entirely.
|
||||
created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "Manual Config",
|
||||
Slug: "manual-oauth2",
|
||||
Transport: "streamable_http",
|
||||
URL: "https://mcp.example.com/manual",
|
||||
AuthType: "oauth2",
|
||||
OAuth2ClientID: "manual-client-id",
|
||||
OAuth2AuthURL: "https://auth.example.com/authorize",
|
||||
OAuth2TokenURL: "https://auth.example.com/token",
|
||||
Availability: "default_on",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "manual-client-id", created.OAuth2ClientID)
|
||||
require.Equal(t, "https://auth.example.com/authorize", created.OAuth2AuthURL)
|
||||
require.Equal(t, "https://auth.example.com/token", created.OAuth2TokenURL)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatWithMCPServerIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -437,14 +606,16 @@ func TestChatWithMCPServerIDs(t *testing.T) {
|
||||
client := newMCPClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
// Create the chat model config required for creating a chat.
|
||||
_ = createChatModelConfigForMCP(t, client)
|
||||
_ = createChatModelConfigForMCP(t, expClient)
|
||||
|
||||
// Create an enabled MCP server config.
|
||||
mcpConfig := createMCPServerConfig(t, client, "chat-mcp-server", true)
|
||||
|
||||
// Create a chat referencing the MCP server.
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
@@ -458,7 +629,7 @@ func TestChatWithMCPServerIDs(t *testing.T) {
|
||||
require.Contains(t, chat.MCPServerIDs, mcpConfig.ID)
|
||||
|
||||
// Fetch the chat and verify the MCP server IDs persist.
|
||||
fetched, err := client.GetChat(ctx, chat.ID)
|
||||
fetched, err := expClient.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, fetched.MCPServerIDs, mcpConfig.ID)
|
||||
}
|
||||
@@ -466,7 +637,7 @@ func TestChatWithMCPServerIDs(t *testing.T) {
|
||||
// createChatModelConfigForMCP sets up a chat provider and model
|
||||
// config so that CreateChat succeeds. This mirrors the helper in
|
||||
// chats_test.go but is defined here to avoid coupling.
|
||||
func createChatModelConfigForMCP(t testing.TB, client *codersdk.Client) codersdk.ChatModelConfig {
|
||||
func createChatModelConfigForMCP(t testing.TB, client *codersdk.ExperimentalClient) codersdk.ChatModelConfig {
|
||||
t.Helper()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -549,8 +548,8 @@ func TestExpiredLeaseIsRequeued(t *testing.T) {
|
||||
leasedIDs = append(leasedIDs, msg.ID.String())
|
||||
}
|
||||
|
||||
sort.Strings(msgs)
|
||||
sort.Strings(leasedIDs)
|
||||
slices.Sort(msgs)
|
||||
slices.Sort(leasedIDs)
|
||||
require.EqualValues(t, msgs, leasedIDs)
|
||||
|
||||
// Wait out the lease period; all messages should be eligible to be re-acquired.
|
||||
|
||||
@@ -18,12 +18,13 @@ import (
|
||||
func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps with unique client names.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("RedirectURIValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
redirectURIs []string
|
||||
@@ -132,9 +133,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("ClientURIValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
clientURI string
|
||||
@@ -207,9 +205,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("LogoURIValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
logoURI string
|
||||
@@ -272,9 +267,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("GrantTypeValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
grantTypes []codersdk.OAuth2ProviderGrantType
|
||||
@@ -347,9 +339,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("ResponseTypeValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
responseTypes []codersdk.OAuth2ProviderResponseType
|
||||
@@ -407,9 +396,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("TokenEndpointAuthMethodValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authMethod codersdk.OAuth2TokenEndpointAuthMethod
|
||||
@@ -479,6 +465,10 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
func TestOAuth2ClientNameValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
clientName string
|
||||
@@ -530,8 +520,6 @@ func TestOAuth2ClientNameValidation(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
@@ -554,6 +542,10 @@ func TestOAuth2ClientNameValidation(t *testing.T) {
|
||||
func TestOAuth2ClientScopeValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
scope string
|
||||
@@ -615,8 +607,6 @@ func TestOAuth2ClientScopeValidation(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
@@ -682,11 +672,13 @@ func TestOAuth2ClientMetadataDefaults(t *testing.T) {
|
||||
func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps with unique client names.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("ExtremelyLongRedirectURI", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Create a very long but valid HTTPS URI
|
||||
@@ -709,8 +701,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Run("ManyRedirectURIs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Test with many redirect URIs
|
||||
@@ -732,8 +722,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Run("URIWithUnusualPort", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
@@ -748,8 +736,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Run("URIWithComplexPath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
@@ -764,8 +750,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Run("URIWithEncodedCharacters", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Test with URL-encoded characters
|
||||
|
||||
@@ -18,12 +18,13 @@ import (
|
||||
func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps with unique client names.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("RedirectURIValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
redirectURIs []string
|
||||
@@ -132,9 +133,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("ClientURIValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
clientURI string
|
||||
@@ -207,9 +205,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("LogoURIValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
logoURI string
|
||||
@@ -272,9 +267,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("GrantTypeValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
grantTypes []codersdk.OAuth2ProviderGrantType
|
||||
@@ -347,9 +339,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("ResponseTypeValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
responseTypes []codersdk.OAuth2ProviderResponseType
|
||||
@@ -407,9 +396,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("TokenEndpointAuthMethodValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authMethod codersdk.OAuth2TokenEndpointAuthMethod
|
||||
@@ -479,6 +465,10 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
func TestOAuth2ClientNameValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
clientName string
|
||||
@@ -530,8 +520,6 @@ func TestOAuth2ClientNameValidation(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
@@ -554,6 +542,10 @@ func TestOAuth2ClientNameValidation(t *testing.T) {
|
||||
func TestOAuth2ClientScopeValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
scope string
|
||||
@@ -615,8 +607,6 @@ func TestOAuth2ClientScopeValidation(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
@@ -682,11 +672,13 @@ func TestOAuth2ClientMetadataDefaults(t *testing.T) {
|
||||
func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps with unique client names.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("ExtremelyLongRedirectURI", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Create a very long but valid HTTPS URI
|
||||
@@ -709,8 +701,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Run("ManyRedirectURIs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Test with many redirect URIs
|
||||
@@ -732,8 +722,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Run("URIWithUnusualPort", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
@@ -748,8 +736,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Run("URIWithComplexPath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
@@ -764,8 +750,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Run("URIWithEncodedCharacters", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Test with URL-encoded characters
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package prometheusmetrics_test
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
@@ -134,7 +135,7 @@ func collectAndSortMetrics(t *testing.T, collector prometheus.Collector, count i
|
||||
|
||||
// Ensure always the same order of metrics
|
||||
sort.Slice(metrics, func(i, j int) bool {
|
||||
return sort.StringsAreSorted([]string{metrics[i].Label[0].GetValue(), metrics[j].Label[1].GetValue()})
|
||||
return slices.IsSorted([]string{metrics[i].Label[0].GetValue(), metrics[j].Label[1].GetValue()})
|
||||
})
|
||||
return metrics
|
||||
}
|
||||
|
||||
+19
-2
@@ -316,13 +316,16 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
denyPermissions...,
|
||||
),
|
||||
User: append(
|
||||
allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUser, ResourceOrganizationMember, ResourceOrganizationMember, ResourceBoundaryUsage),
|
||||
allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUser, ResourceOrganizationMember, ResourceOrganizationMember, ResourceBoundaryUsage, ResourceAibridgeInterception),
|
||||
Permissions(map[string][]policy.Action{
|
||||
// Users cannot do create/update/delete on themselves, but they
|
||||
// can read their own details.
|
||||
ResourceUser.Type: {policy.ActionRead, policy.ActionReadPersonal, policy.ActionUpdatePersonal},
|
||||
// Users can create provisioner daemons scoped to themselves.
|
||||
ResourceProvisionerDaemon.Type: {policy.ActionRead, policy.ActionCreate, policy.ActionRead, policy.ActionUpdate},
|
||||
// Members can create and update AI Bridge interceptions but
|
||||
// cannot read them back.
|
||||
ResourceAibridgeInterception.Type: {policy.ActionCreate, policy.ActionUpdate},
|
||||
})...,
|
||||
),
|
||||
ByOrgID: map[string]OrgPermissions{},
|
||||
@@ -345,7 +348,7 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
// Allow auditors to query deployment stats and insights.
|
||||
ResourceDeploymentStats.Type: {policy.ActionRead},
|
||||
ResourceDeploymentConfig.Type: {policy.ActionRead},
|
||||
// Allow auditors to query aibridge interceptions.
|
||||
// Allow auditors to query AI Bridge interceptions.
|
||||
ResourceAibridgeInterception.Type: {policy.ActionRead},
|
||||
}),
|
||||
User: []Permission{},
|
||||
@@ -998,6 +1001,7 @@ func OrgMemberPermissions(org OrgSettings) OrgRolePermissions {
|
||||
ResourcePrebuiltWorkspace,
|
||||
ResourceUser,
|
||||
ResourceOrganizationMember,
|
||||
ResourceAibridgeInterception,
|
||||
),
|
||||
Permissions(map[string][]policy.Action{
|
||||
// Reduced permission set on dormant workspaces. No build,
|
||||
@@ -1016,6 +1020,12 @@ func OrgMemberPermissions(org OrgSettings) OrgRolePermissions {
|
||||
ResourceOrganizationMember.Type: {
|
||||
policy.ActionRead,
|
||||
},
|
||||
// Members can create and update AI Bridge interceptions but
|
||||
// cannot read them back.
|
||||
ResourceAibridgeInterception.Type: {
|
||||
policy.ActionCreate,
|
||||
policy.ActionUpdate,
|
||||
},
|
||||
})...,
|
||||
)
|
||||
|
||||
@@ -1073,6 +1083,7 @@ func OrgServiceAccountPermissions(org OrgSettings) OrgRolePermissions {
|
||||
ResourcePrebuiltWorkspace,
|
||||
ResourceUser,
|
||||
ResourceOrganizationMember,
|
||||
ResourceAibridgeInterception,
|
||||
),
|
||||
Permissions(map[string][]policy.Action{
|
||||
// Reduced permission set on dormant workspaces. No build,
|
||||
@@ -1091,6 +1102,12 @@ func OrgServiceAccountPermissions(org OrgSettings) OrgRolePermissions {
|
||||
ResourceOrganizationMember.Type: {
|
||||
policy.ActionRead,
|
||||
},
|
||||
// Service accounts can create and update AI Bridge
|
||||
// interceptions but cannot read them back.
|
||||
ResourceAibridgeInterception.Type: {
|
||||
policy.ActionCreate,
|
||||
policy.ActionUpdate,
|
||||
},
|
||||
})...,
|
||||
)
|
||||
|
||||
|
||||
@@ -1023,8 +1023,9 @@ func TestRolePermissions(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "AIBridgeInterceptions",
|
||||
Actions: []policy.Action{policy.ActionCreate, policy.ActionRead, policy.ActionUpdate},
|
||||
// Members can create/update records but can't read them afterwards.
|
||||
Name: "AIBridgeInterceptionsCreateUpdate",
|
||||
Actions: []policy.Action{policy.ActionCreate, policy.ActionUpdate},
|
||||
Resource: rbac.ResourceAibridgeInterception.WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, memberMe},
|
||||
@@ -1036,6 +1037,22 @@ func TestRolePermissions(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
// Only owners and site-wide auditors can view interceptions and their sub-resources.
|
||||
Name: "AIBridgeInterceptionsRead",
|
||||
Actions: []policy.Action{policy.ActionRead},
|
||||
Resource: rbac.ResourceAibridgeInterception.WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, auditor},
|
||||
false: {
|
||||
memberMe,
|
||||
orgAdmin, otherOrgAdmin,
|
||||
orgAuditor, otherOrgAuditor,
|
||||
templateAdmin, orgTemplateAdmin, otherOrgTemplateAdmin,
|
||||
userAdmin, orgUserAdmin, otherOrgUserAdmin,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "BoundaryUsage",
|
||||
Actions: []policy.Action{policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
|
||||
@@ -3,7 +3,6 @@ package rbac
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -176,7 +175,7 @@ func CompositeScopeNames() []string {
|
||||
for k := range compositePerms {
|
||||
out = append(out, string(k))
|
||||
}
|
||||
sort.Strings(out)
|
||||
slices.Sort(out)
|
||||
return out
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package rbac
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -16,7 +16,7 @@ func TestExternalScopeNames(t *testing.T) {
|
||||
|
||||
// Ensure sorted ascending
|
||||
sorted := append([]string(nil), names...)
|
||||
sort.Strings(sorted)
|
||||
slices.Sort(sorted)
|
||||
require.Equal(t, sorted, names)
|
||||
|
||||
// Ensure each entry expands to site-only
|
||||
|
||||
@@ -401,6 +401,49 @@ func AIBridgeInterceptions(ctx context.Context, db database.Store, query string,
|
||||
return filter, parser.Errors
|
||||
}
|
||||
|
||||
func AIBridgeSessions(ctx context.Context, db database.Store, query string, page codersdk.Pagination, actorID uuid.UUID, afterSessionID string) (database.ListAIBridgeSessionsParams, []codersdk.ValidationError) {
|
||||
// nolint:exhaustruct // Empty values just means "don't filter by that field".
|
||||
filter := database.ListAIBridgeSessionsParams{
|
||||
AfterSessionID: afterSessionID,
|
||||
// #nosec G115 - Safe conversion for pagination limit which is expected to be within int32 range
|
||||
Limit: int32(page.Limit),
|
||||
// #nosec G115 - Safe conversion for pagination offset which is expected to be within int32 range
|
||||
Offset: int32(page.Offset),
|
||||
}
|
||||
|
||||
if query == "" {
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
values, errors := searchTerms(query, func(string, url.Values) error {
|
||||
// Do not specify a default search key; let's be explicit to prevent user confusion.
|
||||
return xerrors.New("no search key specified")
|
||||
})
|
||||
if len(errors) > 0 {
|
||||
return filter, errors
|
||||
}
|
||||
|
||||
parser := httpapi.NewQueryParamParser()
|
||||
filter.InitiatorID = parseUser(ctx, db, parser, values, "initiator", actorID)
|
||||
filter.Provider = parser.String(values, "", "provider")
|
||||
filter.Model = parser.String(values, "", "model")
|
||||
filter.Client = parser.String(values, "", "client")
|
||||
filter.SessionID = parser.String(values, "", "session_id")
|
||||
|
||||
// Time must be between started_after and started_before.
|
||||
filter.StartedAfter = parser.Time3339Nano(values, time.Time{}, "started_after")
|
||||
filter.StartedBefore = parser.Time3339Nano(values, time.Time{}, "started_before")
|
||||
if !filter.StartedBefore.IsZero() && !filter.StartedAfter.IsZero() && !filter.StartedBefore.After(filter.StartedAfter) {
|
||||
parser.Errors = append(parser.Errors, codersdk.ValidationError{
|
||||
Field: "started_before",
|
||||
Detail: `Query param "started_before" has invalid value: "started_before" must be after "started_after" if set`,
|
||||
})
|
||||
}
|
||||
|
||||
parser.ErrorExcessParams(values)
|
||||
return filter, parser.Errors
|
||||
}
|
||||
|
||||
func AIBridgeModels(query string, page codersdk.Pagination) (database.ListAIBridgeModelsParams, []codersdk.ValidationError) {
|
||||
// nolint:exhaustruct // Empty values just means "don't filter by that field".
|
||||
filter := database.ListAIBridgeModelsParams{
|
||||
|
||||
@@ -1272,10 +1272,14 @@ func TestTemplateVersionsByTemplate(t *testing.T) {
|
||||
|
||||
func TestTemplateVersionByName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each sub-test
|
||||
// creates its own template version and template with unique
|
||||
// IDs so parallel execution is safe.
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
@@ -1290,8 +1294,6 @@ func TestTemplateVersionByName(t *testing.T) {
|
||||
|
||||
t.Run("Found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
@@ -1935,10 +1937,12 @@ func TestPaginatedTemplateVersions(t *testing.T) {
|
||||
|
||||
func TestTemplateVersionByOrganizationTemplateAndName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Shared instance — see TestTemplateVersionByName for rationale.
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
@@ -1953,8 +1957,6 @@ func TestTemplateVersionByOrganizationTemplateAndName(t *testing.T) {
|
||||
|
||||
t.Run("Found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
@@ -2204,10 +2206,14 @@ func TestTemplateVersionVariables(t *testing.T) {
|
||||
|
||||
func TestTemplateVersionPatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all 9 sub-tests. Each sub-test
|
||||
// creates its own template version(s) and template(s) with
|
||||
// unique IDs so parallel execution is safe.
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
t.Run("Update the name", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
@@ -2226,8 +2232,6 @@ func TestTemplateVersionPatch(t *testing.T) {
|
||||
|
||||
t.Run("Update the message", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil, func(req *codersdk.CreateTemplateVersionRequest) {
|
||||
req.Message = "Example message"
|
||||
})
|
||||
@@ -2247,8 +2251,6 @@ func TestTemplateVersionPatch(t *testing.T) {
|
||||
|
||||
t.Run("Remove the message", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil, func(req *codersdk.CreateTemplateVersionRequest) {
|
||||
req.Message = "Example message"
|
||||
})
|
||||
@@ -2268,8 +2270,6 @@ func TestTemplateVersionPatch(t *testing.T) {
|
||||
|
||||
t.Run("Keep the message", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
wantMessage := "Example message"
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil, func(req *codersdk.CreateTemplateVersionRequest) {
|
||||
req.Message = wantMessage
|
||||
@@ -2291,8 +2291,6 @@ func TestTemplateVersionPatch(t *testing.T) {
|
||||
|
||||
t.Run("Use the same name if a new name is not passed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
@@ -2306,9 +2304,6 @@ func TestTemplateVersionPatch(t *testing.T) {
|
||||
|
||||
t.Run("Use the same name for two different templates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
version1 := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
coderdtest.CreateTemplate(t, client, user.OrganizationID, version1.ID)
|
||||
version2 := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
@@ -2334,8 +2329,6 @@ func TestTemplateVersionPatch(t *testing.T) {
|
||||
|
||||
t.Run("Use the same name for two versions for the same templates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version1 := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil, func(ctvr *codersdk.CreateTemplateVersionRequest) {
|
||||
ctvr.Name = "v1"
|
||||
})
|
||||
@@ -2356,8 +2349,6 @@ func TestTemplateVersionPatch(t *testing.T) {
|
||||
|
||||
t.Run("Rename the unassigned template", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version1 := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
@@ -2373,8 +2364,6 @@ func TestTemplateVersionPatch(t *testing.T) {
|
||||
|
||||
t.Run("Use incorrect template version name", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version1 := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
|
||||
+3
-40
@@ -7,7 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"sort"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -744,43 +744,6 @@ func (api *API) postLogout(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Set session token cookie
|
||||
// @Description Converts the current session token into a Set-Cookie response.
|
||||
// @Description This is used by embedded iframes (e.g. VS Code chat) that
|
||||
// @Description receive a session token out-of-band via postMessage but need
|
||||
// @Description cookie-based auth for WebSocket connections.
|
||||
// @ID set-session-token-cookie
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Authorization
|
||||
// @Success 204
|
||||
// @Router /users/me/session/token-to-cookie [post]
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) postSessionTokenCookie(rw http.ResponseWriter, r *http.Request) {
|
||||
// Only accept the token from the Coder-Session-Token header.
|
||||
// Other sources (query params, cookies) should not be allowed
|
||||
// to bootstrap a new cookie.
|
||||
token := r.Header.Get(codersdk.SessionTokenHeader)
|
||||
if token == "" {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Session token must be provided via the Coder-Session-Token header.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
cookie := api.DeploymentValues.HTTPCookies.Apply(&http.Cookie{
|
||||
Name: codersdk.SessionTokenCookie,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
// Expire the cookie when the underlying API key expires.
|
||||
Expires: apiKey.ExpiresAt,
|
||||
})
|
||||
http.SetCookie(rw, cookie)
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// GithubOAuth2Team represents a team scoped to an organization.
|
||||
type GithubOAuth2Team struct {
|
||||
Organization string
|
||||
@@ -1626,7 +1589,7 @@ func claimFields(claims map[string]interface{}) []string {
|
||||
for field := range claims {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
sort.Strings(fields)
|
||||
slices.Sort(fields)
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -1639,7 +1602,7 @@ func blankFields(claims map[string]interface{}) []string {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
}
|
||||
sort.Strings(fields)
|
||||
slices.Sort(fields)
|
||||
return fields
|
||||
}
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatcost"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatcost"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -20,12 +21,6 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatcost"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatloop"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/coderd/chatd/mcpclient"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
@@ -34,6 +29,12 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/coderd/webpush"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatcost"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/mcpclient"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/quartz"
|
||||
@@ -3150,8 +3151,14 @@ func (p *Server) runChat(
|
||||
// "Summarizing..." tool call with the "Summarized" tool
|
||||
// result.
|
||||
compactionToolCallID := "chat_summarized_" + uuid.NewString()
|
||||
effectiveThreshold := modelConfig.CompressionThreshold
|
||||
thresholdSource := "model_default"
|
||||
if override, ok := p.resolveUserCompactionThreshold(ctx, chat.OwnerID, modelConfig.ID); ok {
|
||||
effectiveThreshold = override
|
||||
thresholdSource = "user_override"
|
||||
}
|
||||
compactionOptions := &chatloop.CompactionOptions{
|
||||
ThresholdPercent: modelConfig.CompressionThreshold,
|
||||
ThresholdPercent: effectiveThreshold,
|
||||
ContextLimit: modelConfig.ContextLimit,
|
||||
Persist: func(
|
||||
persistCtx context.Context,
|
||||
@@ -3168,6 +3175,7 @@ func (p *Server) runChat(
|
||||
}
|
||||
logger.Info(persistCtx, "chat context summarized",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("threshold_source", thresholdSource),
|
||||
slog.F("threshold_percent", result.ThresholdPercent),
|
||||
slog.F("usage_percent", result.UsagePercent),
|
||||
slog.F("context_tokens", result.ContextTokens),
|
||||
@@ -3272,14 +3280,17 @@ func (p *Server) runChat(
|
||||
}
|
||||
|
||||
if isComputerUse {
|
||||
desktopGeometry := workspacesdk.DefaultDesktopGeometry()
|
||||
providerTools = append(providerTools, chatloop.ProviderTool{
|
||||
Definition: chattool.ComputerUseProviderTool(
|
||||
workspacesdk.DesktopDisplayWidth,
|
||||
workspacesdk.DesktopDisplayHeight),
|
||||
desktopGeometry.DeclaredWidth,
|
||||
desktopGeometry.DeclaredHeight,
|
||||
),
|
||||
Runner: chattool.NewComputerUseTool(
|
||||
workspacesdk.DesktopDisplayWidth,
|
||||
workspacesdk.DesktopDisplayHeight,
|
||||
workspaceCtx.getWorkspaceConn, quartz.NewReal(),
|
||||
desktopGeometry.DeclaredWidth,
|
||||
desktopGeometry.DeclaredHeight,
|
||||
workspaceCtx.getWorkspaceConn,
|
||||
quartz.NewReal(),
|
||||
),
|
||||
})
|
||||
}
|
||||
@@ -3715,6 +3726,34 @@ func (p *Server) resolveInstructions(
|
||||
return instruction
|
||||
}
|
||||
|
||||
// resolveUserCompactionThreshold looks up the user's per-model
|
||||
// compaction threshold override. Returns the override value and
|
||||
// true if one exists and is valid, or 0 and false otherwise.
|
||||
func (p *Server) resolveUserCompactionThreshold(ctx context.Context, userID uuid.UUID, modelConfigID uuid.UUID) (int32, bool) {
|
||||
raw, err := p.db.GetUserChatCompactionThreshold(ctx, database.GetUserChatCompactionThresholdParams{
|
||||
UserID: userID,
|
||||
Key: codersdk.CompactionThresholdKey(modelConfigID),
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, false
|
||||
}
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to fetch compaction threshold override",
|
||||
slog.F("user_id", userID),
|
||||
slog.F("model_config_id", modelConfigID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return 0, false
|
||||
}
|
||||
// Range 0..100 must stay in sync with handler validation in
|
||||
// coderd/chats.go.
|
||||
val, err := strconv.ParseInt(raw, 10, 32)
|
||||
if err != nil || val < 0 || val > 100 {
|
||||
return 0, false
|
||||
}
|
||||
return int32(val), true
|
||||
}
|
||||
|
||||
// resolveUserPrompt fetches the user's custom chat prompt from the
|
||||
// database and wraps it in <user-instructions> tags. Returns empty
|
||||
// string if no prompt is set.
|
||||
@@ -2,6 +2,7 @@ package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -606,6 +607,85 @@ func TestPublishToStream_DropWarnRateLimiting(t *testing.T) {
|
||||
requireFieldValue(t, subWarn[2], "dropped_count", int64(1))
|
||||
}
|
||||
|
||||
func TestResolveUserCompactionThreshold(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userID := uuid.New()
|
||||
modelConfigID := uuid.New()
|
||||
expectedKey := codersdk.CompactionThresholdKey(modelConfigID)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dbReturn string
|
||||
dbErr error
|
||||
wantVal int32
|
||||
wantOK bool
|
||||
wantWarnLog bool
|
||||
}{
|
||||
{
|
||||
name: "NoRowsReturnsDefault",
|
||||
dbErr: sql.ErrNoRows,
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "ValidOverride",
|
||||
dbReturn: "75",
|
||||
wantVal: 75,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "OutOfRangeValue",
|
||||
dbReturn: "101",
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "NonIntegerValue",
|
||||
dbReturn: "abc",
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "UnexpectedDBError",
|
||||
dbErr: xerrors.New("connection refused"),
|
||||
wantOK: false,
|
||||
wantWarnLog: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockDB := dbmock.NewMockStore(ctrl)
|
||||
sink := testutil.NewFakeSink(t)
|
||||
|
||||
srv := &Server{
|
||||
db: mockDB,
|
||||
logger: sink.Logger(),
|
||||
}
|
||||
|
||||
mockDB.EXPECT().GetUserChatCompactionThreshold(gomock.Any(), database.GetUserChatCompactionThresholdParams{
|
||||
UserID: userID,
|
||||
Key: expectedKey,
|
||||
}).Return(tc.dbReturn, tc.dbErr)
|
||||
|
||||
val, ok := srv.resolveUserCompactionThreshold(context.Background(), userID, modelConfigID)
|
||||
require.Equal(t, tc.wantVal, val)
|
||||
require.Equal(t, tc.wantOK, ok)
|
||||
|
||||
warns := sink.Entries(func(e slog.SinkEntry) bool {
|
||||
return e.Level == slog.LevelWarn
|
||||
})
|
||||
if tc.wantWarnLog {
|
||||
require.NotEmpty(t, warns, "expected a warning log entry")
|
||||
return
|
||||
}
|
||||
require.Empty(t, warns, "unexpected warning log entry")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// requireFieldValue asserts that a SinkEntry contains a field with
|
||||
// the given name and value.
|
||||
func requireFieldValue(t *testing.T, entry slog.SinkEntry, name string, expected interface{}) {
|
||||
@@ -26,10 +26,6 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattest"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
@@ -41,6 +37,10 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
@@ -111,6 +111,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
||||
IncludeProvisionerDaemon: true,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
agentToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
@@ -161,7 +162,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
||||
)
|
||||
})
|
||||
|
||||
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
APIKey: "test-api-key",
|
||||
BaseURL: openAIURL,
|
||||
@@ -170,7 +171,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
@@ -179,7 +180,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a root chat whose first model call will spawn a subagent.
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
@@ -193,7 +194,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
||||
// The root chat finishes first, then the chatd server
|
||||
// picks up and runs the child (subagent) chat.
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := client.GetChat(ctx, chat.ID)
|
||||
got, getErr := expClient.GetChat(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
@@ -901,15 +902,32 @@ func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
|
||||
acquireTrap := clock.Trap().NewTicker("chatd", "acquire")
|
||||
defer acquireTrap.Close()
|
||||
|
||||
assertPendingWithoutQueuedMessages := func(chatID uuid.UUID) {
|
||||
t.Helper()
|
||||
|
||||
queued, dbErr := db.GetChatQueuedMessages(ctx, chatID)
|
||||
require.NoError(t, dbErr)
|
||||
require.Empty(t, queued)
|
||||
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chatID)
|
||||
require.NoError(t, dbErr)
|
||||
require.Equal(t, database.ChatStatusPending, fromDB.Status)
|
||||
require.False(t, fromDB.WorkerID.Valid)
|
||||
}
|
||||
|
||||
streamStarted := make(chan struct{})
|
||||
interrupted := make(chan struct{})
|
||||
secondRequestStarted := make(chan struct{})
|
||||
thirdRequestStarted := make(chan struct{})
|
||||
allowFinish := make(chan struct{})
|
||||
var requestCount atomic.Int32
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("title")
|
||||
}
|
||||
if requestCount.Add(1) == 1 {
|
||||
|
||||
switch requestCount.Add(1) {
|
||||
case 1:
|
||||
chunks := make(chan chattest.OpenAIChunk, 1)
|
||||
go func() {
|
||||
defer close(chunks)
|
||||
@@ -928,7 +946,12 @@ func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
|
||||
<-allowFinish
|
||||
}()
|
||||
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
||||
case 2:
|
||||
close(secondRequestStarted)
|
||||
case 3:
|
||||
close(thirdRequestStarted)
|
||||
}
|
||||
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("done")...,
|
||||
)
|
||||
@@ -953,15 +976,7 @@ func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
clock.Advance(acquireInterval).MustWait(ctx)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-streamStarted:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
testutil.TryReceive(ctx, t, streamStarted)
|
||||
|
||||
queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
@@ -972,29 +987,11 @@ func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
|
||||
require.True(t, queuedResult.Queued)
|
||||
require.NotNil(t, queuedResult.QueuedMessage)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-interrupted:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
testutil.TryReceive(ctx, t, interrupted)
|
||||
|
||||
close(allowFinish)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
queued, dbErr := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
if dbErr != nil || len(queued) != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
return fromDB.Status == database.ChatStatusPending && !fromDB.WorkerID.Valid
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
chatd.WaitUntilIdleForTest(server)
|
||||
assertPendingWithoutQueuedMessages(chat.ID)
|
||||
|
||||
// Keep the acquire loop frozen here so "queued" stays pending.
|
||||
// That makes the later send queue because the chat is still busy,
|
||||
@@ -1045,63 +1042,41 @@ func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
clock.Advance(acquireInterval).MustWait(ctx)
|
||||
require.Eventually(t, func() bool {
|
||||
return requestCount.Load() >= 2
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
queued, dbErr := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
if dbErr != nil || len(queued) != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
return fromDB.Status == database.ChatStatusPending && !fromDB.WorkerID.Valid
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
testutil.TryReceive(ctx, t, secondRequestStarted)
|
||||
chatd.WaitUntilIdleForTest(server)
|
||||
assertPendingWithoutQueuedMessages(chat.ID)
|
||||
|
||||
clock.Advance(acquireInterval).MustWait(ctx)
|
||||
testutil.TryReceive(ctx, t, thirdRequestStarted)
|
||||
chatd.WaitUntilIdleForTest(server)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
queued, dbErr := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
if dbErr != nil || len(queued) != 0 {
|
||||
return false
|
||||
}
|
||||
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, queued)
|
||||
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil || fromDB.Status != database.ChatStatusWaiting {
|
||||
return false
|
||||
}
|
||||
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
|
||||
require.False(t, fromDB.WorkerID.Valid)
|
||||
|
||||
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
userTexts := make([]string, 0, 3)
|
||||
for _, message := range messages {
|
||||
if message.Role != database.ChatMessageRoleUser {
|
||||
continue
|
||||
}
|
||||
sdkMessage := db2sdk.ChatMessage(message)
|
||||
if len(sdkMessage.Content) != 1 {
|
||||
continue
|
||||
}
|
||||
userTexts = append(userTexts, sdkMessage.Content[0].Text)
|
||||
userTexts := make([]string, 0, 3)
|
||||
for _, message := range messages {
|
||||
if message.Role != database.ChatMessageRoleUser {
|
||||
continue
|
||||
}
|
||||
if len(userTexts) != 3 {
|
||||
return false
|
||||
sdkMessage := db2sdk.ChatMessage(message)
|
||||
if len(sdkMessage.Content) != 1 {
|
||||
continue
|
||||
}
|
||||
return requestCount.Load() >= 3 &&
|
||||
userTexts[0] == "hello" &&
|
||||
userTexts[1] == "queued" &&
|
||||
userTexts[2] == "later queued"
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
userTexts = append(userTexts, sdkMessage.Content[0].Text)
|
||||
}
|
||||
require.Equal(t, []string{"hello", "queued", "later queued"}, userTexts)
|
||||
}
|
||||
|
||||
func TestEditMessageRejectsWhenUsageLimitReached(t *testing.T) {
|
||||
@@ -1844,6 +1819,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
IncludeProvisionerDaemon: true,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
agentToken := uuid.NewString()
|
||||
// Add a startup script so the agent spends time in the
|
||||
@@ -1898,7 +1874,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
)
|
||||
})
|
||||
|
||||
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
APIKey: "test-api-key",
|
||||
BaseURL: openAIURL,
|
||||
@@ -1907,7 +1883,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
@@ -1915,7 +1891,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
@@ -1927,7 +1903,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
|
||||
var chatResult codersdk.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := client.GetChat(ctx, chat.ID)
|
||||
got, getErr := expClient.GetChat(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
@@ -1949,7 +1925,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, workspaceName, workspace.Name)
|
||||
|
||||
chatMsgs, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
chatMsgs, err := expClient.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
var foundCreateWorkspaceResult bool
|
||||
@@ -2023,6 +1999,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
IncludeProvisionerDaemon: true,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
@@ -2067,7 +2044,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
)
|
||||
})
|
||||
|
||||
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
APIKey: "test-api-key",
|
||||
BaseURL: openAIURL,
|
||||
@@ -2076,7 +2053,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
@@ -2085,7 +2062,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a chat with the stopped workspace pre-associated.
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
@@ -2098,7 +2075,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
|
||||
var chatResult codersdk.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := client.GetChat(ctx, chat.ID)
|
||||
got, getErr := expClient.GetChat(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
@@ -2120,7 +2097,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.WorkspaceTransitionStart, updatedWorkspace.LatestBuild.Transition)
|
||||
|
||||
chatMsgs, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
chatMsgs, err := expClient.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify start_workspace tool result exists in the chat messages.
|
||||
@@ -16,8 +16,8 @@ import (
|
||||
"charm.land/fantasy/schema"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatretry"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
+1
-1
@@ -14,11 +14,11 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
+1
-1
@@ -9,8 +9,8 @@ import (
|
||||
fantasyvercel "charm.land/fantasy/providers/vercel"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
+2
-2
@@ -12,8 +12,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattest"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
)
|
||||
|
||||
func TestUserAgent(t *testing.T) {
|
||||
+1
-1
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatretry"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
|
||||
)
|
||||
|
||||
func TestIsRetryable(t *testing.T) {
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattest"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
)
|
||||
|
||||
func TestAnthropic_Streaming(t *testing.T) {
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
fantasyopenai "charm.land/fantasy/providers/openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattest"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
)
|
||||
|
||||
func TestOpenAI_Streaming(t *testing.T) {
|
||||
@@ -3,7 +3,6 @@ package chattool
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
@@ -25,23 +24,25 @@ const (
|
||||
// computerUseTool implements fantasy.AgentTool and
|
||||
// chatloop.ToolDefiner for Anthropic computer use.
|
||||
type computerUseTool struct {
|
||||
displayWidth int
|
||||
displayHeight int
|
||||
declaredWidth int
|
||||
declaredHeight int
|
||||
getWorkspaceConn func(ctx context.Context) (workspacesdk.AgentConn, error)
|
||||
providerOptions fantasy.ProviderOptions
|
||||
clock quartz.Clock
|
||||
}
|
||||
|
||||
// NewComputerUseTool creates a computer use AgentTool that
|
||||
// delegates to the agent's desktop endpoints.
|
||||
// NewComputerUseTool creates a computer use AgentTool that delegates to the
|
||||
// agent's desktop endpoints. declaredWidth and declaredHeight are the
|
||||
// model-facing desktop dimensions advertised to Anthropic and requested for
|
||||
// screenshots.
|
||||
func NewComputerUseTool(
|
||||
displayWidth, displayHeight int,
|
||||
declaredWidth, declaredHeight int,
|
||||
getWorkspaceConn func(ctx context.Context) (workspacesdk.AgentConn, error),
|
||||
clock quartz.Clock,
|
||||
) fantasy.AgentTool {
|
||||
return &computerUseTool{
|
||||
displayWidth: displayWidth,
|
||||
displayHeight: displayHeight,
|
||||
declaredWidth: declaredWidth,
|
||||
declaredHeight: declaredHeight,
|
||||
getWorkspaceConn: getWorkspaceConn,
|
||||
clock: clock,
|
||||
}
|
||||
@@ -56,14 +57,13 @@ func (*computerUseTool) Info() fantasy.ToolInfo {
|
||||
}
|
||||
}
|
||||
|
||||
// ComputerUseProviderTool creates the provider-defined tool
|
||||
// definition for Anthropic computer use. This is passed via
|
||||
// ProviderTools so the API receives the correct wire format.
|
||||
func ComputerUseProviderTool(displayWidth, displayHeight int) fantasy.Tool {
|
||||
// ComputerUseProviderTool creates the provider-defined Anthropic computer-use
|
||||
// tool definition using the declared model-facing desktop geometry.
|
||||
func ComputerUseProviderTool(declaredWidth, declaredHeight int) fantasy.Tool {
|
||||
return fantasyanthropic.NewComputerUseTool(
|
||||
fantasyanthropic.ComputerUseToolOptions{
|
||||
DisplayWidthPx: int64(displayWidth),
|
||||
DisplayHeightPx: int64(displayHeight),
|
||||
DisplayWidthPx: int64(declaredWidth),
|
||||
DisplayHeightPx: int64(declaredHeight),
|
||||
ToolVersion: fantasyanthropic.ComputerUse20251124,
|
||||
},
|
||||
)
|
||||
@@ -92,10 +92,7 @@ func (t *computerUseTool) Run(ctx context.Context, call fantasy.ToolCall) (fanta
|
||||
), nil
|
||||
}
|
||||
|
||||
// Compute scaled screenshot size for Anthropic constraints.
|
||||
scaledW, scaledH := computeScaledScreenshotSize(
|
||||
t.displayWidth, t.displayHeight,
|
||||
)
|
||||
declaredWidth, declaredHeight := t.declaredActionDimensions()
|
||||
|
||||
// For wait actions, sleep then return a screenshot.
|
||||
if input.Action == fantasyanthropic.ActionWait {
|
||||
@@ -111,8 +108,8 @@ func (t *computerUseTool) Run(ctx context.Context, call fantasy.ToolCall) (fanta
|
||||
}
|
||||
screenshotAction := workspacesdk.DesktopAction{
|
||||
Action: "screenshot",
|
||||
ScaledWidth: &scaledW,
|
||||
ScaledHeight: &scaledH,
|
||||
ScaledWidth: &declaredWidth,
|
||||
ScaledHeight: &declaredHeight,
|
||||
}
|
||||
screenResp, sErr := conn.ExecuteDesktopAction(ctx, screenshotAction)
|
||||
if sErr != nil {
|
||||
@@ -129,8 +126,8 @@ func (t *computerUseTool) Run(ctx context.Context, call fantasy.ToolCall) (fanta
|
||||
if input.Action == fantasyanthropic.ActionScreenshot {
|
||||
screenshotAction := workspacesdk.DesktopAction{
|
||||
Action: "screenshot",
|
||||
ScaledWidth: &scaledW,
|
||||
ScaledHeight: &scaledH,
|
||||
ScaledWidth: &declaredWidth,
|
||||
ScaledHeight: &declaredHeight,
|
||||
}
|
||||
screenResp, sErr := conn.ExecuteDesktopAction(ctx, screenshotAction)
|
||||
if sErr != nil {
|
||||
@@ -146,8 +143,8 @@ func (t *computerUseTool) Run(ctx context.Context, call fantasy.ToolCall) (fanta
|
||||
// Build the action request.
|
||||
action := workspacesdk.DesktopAction{
|
||||
Action: string(input.Action),
|
||||
ScaledWidth: &scaledW,
|
||||
ScaledHeight: &scaledH,
|
||||
ScaledWidth: &declaredWidth,
|
||||
ScaledHeight: &declaredHeight,
|
||||
}
|
||||
if input.Coordinate != ([2]int64{}) {
|
||||
coord := [2]int{int(input.Coordinate[0]), int(input.Coordinate[1])}
|
||||
@@ -183,8 +180,8 @@ func (t *computerUseTool) Run(ctx context.Context, call fantasy.ToolCall) (fanta
|
||||
// Take a screenshot after every action (Anthropic pattern).
|
||||
screenshotAction := workspacesdk.DesktopAction{
|
||||
Action: "screenshot",
|
||||
ScaledWidth: &scaledW,
|
||||
ScaledHeight: &scaledH,
|
||||
ScaledWidth: &declaredWidth,
|
||||
ScaledHeight: &declaredHeight,
|
||||
}
|
||||
screenResp, sErr := conn.ExecuteDesktopAction(ctx, screenshotAction)
|
||||
if sErr != nil {
|
||||
@@ -198,23 +195,17 @@ func (t *computerUseTool) Run(ctx context.Context, call fantasy.ToolCall) (fanta
|
||||
), nil
|
||||
}
|
||||
|
||||
// computeScaledScreenshotSize computes the target screenshot
|
||||
// dimensions to fit within Anthropic's constraints.
|
||||
func computeScaledScreenshotSize(width, height int) (scaledWidth int, scaledHeight int) {
|
||||
const maxLongEdge = 1568
|
||||
const maxTotalPixels = 1_150_000
|
||||
|
||||
longEdge := max(width, height)
|
||||
totalPixels := width * height
|
||||
longEdgeScale := float64(maxLongEdge) / float64(longEdge)
|
||||
totalPixelsScale := math.Sqrt(
|
||||
float64(maxTotalPixels) / float64(totalPixels),
|
||||
)
|
||||
scale := min(1.0, longEdgeScale, totalPixelsScale)
|
||||
|
||||
if scale >= 1.0 {
|
||||
return width, height
|
||||
func (t *computerUseTool) declaredActionDimensions() (declaredWidth, declaredHeight int) {
|
||||
if t.declaredWidth <= 0 || t.declaredHeight <= 0 {
|
||||
geometry := workspacesdk.DefaultDesktopGeometry()
|
||||
return geometry.DeclaredWidth, geometry.DeclaredHeight
|
||||
}
|
||||
return max(1, int(float64(width)*scale)),
|
||||
max(1, int(float64(height)*scale))
|
||||
return t.declaredWidth, t.declaredHeight
|
||||
}
|
||||
|
||||
// computeScaledScreenshotSize preserves the historical helper name while using
|
||||
// the shared declared-geometry selection logic.
|
||||
func computeScaledScreenshotSize(width, height int) (scaledWidth int, scaledHeight int) {
|
||||
geometry := workspacesdk.NewDesktopGeometry(width, height)
|
||||
return geometry.DeclaredWidth, geometry.DeclaredHeight
|
||||
}
|
||||
+20
-16
@@ -15,11 +15,11 @@ func TestComputeScaledScreenshotSize(t *testing.T) {
|
||||
wantW, wantH int
|
||||
}{
|
||||
{
|
||||
name: "1920x1080_scales_down",
|
||||
name: "1920x1080_prefers_standard_1280x720",
|
||||
width: 1920,
|
||||
height: 1080,
|
||||
wantW: 1429,
|
||||
wantH: 804,
|
||||
wantW: 1280,
|
||||
wantH: 720,
|
||||
},
|
||||
{
|
||||
name: "1280x800_no_scaling",
|
||||
@@ -29,18 +29,18 @@ func TestComputeScaledScreenshotSize(t *testing.T) {
|
||||
wantH: 800,
|
||||
},
|
||||
{
|
||||
name: "3840x2160_large_display",
|
||||
name: "3840x2160_prefers_standard_1280x720",
|
||||
width: 3840,
|
||||
height: 2160,
|
||||
wantW: 1429,
|
||||
wantH: 804,
|
||||
wantW: 1280,
|
||||
wantH: 720,
|
||||
},
|
||||
{
|
||||
name: "1568x1000_pixel_cap_applies",
|
||||
name: "1568x1000_prefers_standard_1280x816",
|
||||
width: 1568,
|
||||
height: 1000,
|
||||
wantW: 1342,
|
||||
wantH: 856,
|
||||
wantW: 1280,
|
||||
wantH: 816,
|
||||
},
|
||||
{
|
||||
name: "100x100_small_display",
|
||||
@@ -50,14 +50,18 @@ func TestComputeScaledScreenshotSize(t *testing.T) {
|
||||
wantH: 100,
|
||||
},
|
||||
{
|
||||
name: "4000x3000_stays_within_limits",
|
||||
width: 4000,
|
||||
// Both constraints apply. The function should keep
|
||||
// the result within maxLongEdge=1568 and
|
||||
// totalPixels<=1,150,000.
|
||||
name: "4000x3000_prefers_standard_1024x768",
|
||||
width: 4000,
|
||||
height: 3000,
|
||||
wantW: 1238,
|
||||
wantH: 928,
|
||||
wantW: 1024,
|
||||
wantH: 768,
|
||||
},
|
||||
{
|
||||
name: "1920x1200_prefers_standard_1280x800",
|
||||
width: 1920,
|
||||
height: 1200,
|
||||
wantW: 1280,
|
||||
wantH: 800,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
package chattool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestComputerUseTool_Info(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
geometry := workspacesdk.DefaultDesktopGeometry()
|
||||
tool := chattool.NewComputerUseTool(geometry.DeclaredWidth, geometry.DeclaredHeight, nil, quartz.NewReal())
|
||||
info := tool.Info()
|
||||
assert.Equal(t, "computer", info.Name)
|
||||
assert.NotEmpty(t, info.Description)
|
||||
}
|
||||
|
||||
func TestComputerUseProviderTool(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
geometry := workspacesdk.DefaultDesktopGeometry()
|
||||
def := chattool.ComputerUseProviderTool(geometry.DeclaredWidth, geometry.DeclaredHeight)
|
||||
pdt, ok := def.(fantasy.ProviderDefinedTool)
|
||||
require.True(t, ok, "ComputerUseProviderTool should return a ProviderDefinedTool")
|
||||
assert.Contains(t, pdt.ID, "computer")
|
||||
assert.Equal(t, "computer", pdt.Name)
|
||||
assert.Equal(t, int64(geometry.DeclaredWidth), pdt.Args["display_width_px"])
|
||||
assert.Equal(t, int64(geometry.DeclaredHeight), pdt.Args["display_height_px"])
|
||||
}
|
||||
|
||||
func TestComputerUseProviderTool_PrefersDeclaredGeometry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
geometry := workspacesdk.NewDesktopGeometry(1920, 1080)
|
||||
def := chattool.ComputerUseProviderTool(geometry.DeclaredWidth, geometry.DeclaredHeight)
|
||||
pdt, ok := def.(fantasy.ProviderDefinedTool)
|
||||
require.True(t, ok, "ComputerUseProviderTool should return a ProviderDefinedTool")
|
||||
assert.Equal(t, int64(1280), pdt.Args["display_width_px"])
|
||||
assert.Equal(t, int64(720), pdt.Args["display_height_px"])
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
geometry := workspacesdk.DefaultDesktopGeometry()
|
||||
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.AssignableToTypeOf(workspacesdk.DesktopAction{}),
|
||||
).DoAndReturn(func(_ context.Context, action workspacesdk.DesktopAction) (workspacesdk.DesktopActionResponse, error) {
|
||||
require.NotNil(t, action.ScaledWidth)
|
||||
require.NotNil(t, action.ScaledHeight)
|
||||
assert.Equal(t, geometry.DeclaredWidth, *action.ScaledWidth)
|
||||
assert.Equal(t, geometry.DeclaredHeight, *action.ScaledHeight)
|
||||
return workspacesdk.DesktopActionResponse{
|
||||
Output: "screenshot",
|
||||
ScreenshotData: "base64png",
|
||||
ScreenshotWidth: geometry.DeclaredWidth,
|
||||
ScreenshotHeight: geometry.DeclaredHeight,
|
||||
}, nil
|
||||
})
|
||||
|
||||
tool := chattool.NewComputerUseTool(geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return mockConn, nil
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-1",
|
||||
Name: "computer",
|
||||
Input: `{"action":"screenshot"}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image", resp.Type)
|
||||
assert.Equal(t, "image/png", resp.MediaType)
|
||||
assert.Equal(t, []byte("base64png"), resp.Data)
|
||||
assert.False(t, resp.IsError)
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_LeftClick(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
geometry := workspacesdk.DefaultDesktopGeometry()
|
||||
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.AssignableToTypeOf(workspacesdk.DesktopAction{}),
|
||||
).DoAndReturn(func(_ context.Context, action workspacesdk.DesktopAction) (workspacesdk.DesktopActionResponse, error) {
|
||||
require.NotNil(t, action.Coordinate)
|
||||
assert.Equal(t, [2]int{100, 200}, *action.Coordinate)
|
||||
require.NotNil(t, action.ScaledWidth)
|
||||
require.NotNil(t, action.ScaledHeight)
|
||||
assert.Equal(t, geometry.DeclaredWidth, *action.ScaledWidth)
|
||||
assert.Equal(t, geometry.DeclaredHeight, *action.ScaledHeight)
|
||||
return workspacesdk.DesktopActionResponse{Output: "left_click performed"}, nil
|
||||
})
|
||||
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.AssignableToTypeOf(workspacesdk.DesktopAction{}),
|
||||
).DoAndReturn(func(_ context.Context, action workspacesdk.DesktopAction) (workspacesdk.DesktopActionResponse, error) {
|
||||
assert.Equal(t, "screenshot", action.Action)
|
||||
require.NotNil(t, action.ScaledWidth)
|
||||
require.NotNil(t, action.ScaledHeight)
|
||||
assert.Equal(t, geometry.DeclaredWidth, *action.ScaledWidth)
|
||||
assert.Equal(t, geometry.DeclaredHeight, *action.ScaledHeight)
|
||||
return workspacesdk.DesktopActionResponse{
|
||||
Output: "screenshot",
|
||||
ScreenshotData: "after-click",
|
||||
ScreenshotWidth: geometry.DeclaredWidth,
|
||||
ScreenshotHeight: geometry.DeclaredHeight,
|
||||
}, nil
|
||||
})
|
||||
|
||||
tool := chattool.NewComputerUseTool(geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return mockConn, nil
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-2",
|
||||
Name: "computer",
|
||||
Input: `{"action":"left_click","coordinate":[100,200]}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image", resp.Type)
|
||||
assert.Equal(t, []byte("after-click"), resp.Data)
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_Wait(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
geometry := workspacesdk.DefaultDesktopGeometry()
|
||||
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.AssignableToTypeOf(workspacesdk.DesktopAction{}),
|
||||
).DoAndReturn(func(_ context.Context, action workspacesdk.DesktopAction) (workspacesdk.DesktopActionResponse, error) {
|
||||
require.NotNil(t, action.ScaledWidth)
|
||||
require.NotNil(t, action.ScaledHeight)
|
||||
assert.Equal(t, geometry.DeclaredWidth, *action.ScaledWidth)
|
||||
assert.Equal(t, geometry.DeclaredHeight, *action.ScaledHeight)
|
||||
return workspacesdk.DesktopActionResponse{
|
||||
Output: "screenshot",
|
||||
ScreenshotData: "after-wait",
|
||||
ScreenshotWidth: geometry.DeclaredWidth,
|
||||
ScreenshotHeight: geometry.DeclaredHeight,
|
||||
}, nil
|
||||
})
|
||||
|
||||
tool := chattool.NewComputerUseTool(geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return mockConn, nil
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-3",
|
||||
Name: "computer",
|
||||
Input: `{"action":"wait","duration":10}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image", resp.Type)
|
||||
assert.Equal(t, "image/png", resp.MediaType)
|
||||
assert.Equal(t, []byte("after-wait"), resp.Data)
|
||||
assert.False(t, resp.IsError)
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_ConnError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
geometry := workspacesdk.DefaultDesktopGeometry()
|
||||
tool := chattool.NewComputerUseTool(geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return nil, xerrors.New("workspace not available")
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-4",
|
||||
Name: "computer",
|
||||
Input: `{"action":"screenshot"}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsError)
|
||||
assert.Contains(t, resp.Content, "workspace not available")
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_InvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
geometry := workspacesdk.DefaultDesktopGeometry()
|
||||
tool := chattool.NewComputerUseTool(geometry.DeclaredWidth, geometry.DeclaredHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return nil, xerrors.New("should not be called")
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-5",
|
||||
Name: "computer",
|
||||
Input: `{invalid json`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsError)
|
||||
assert.Contains(t, resp.Content, "invalid computer use input")
|
||||
}
|
||||
@@ -78,9 +78,9 @@ type ProcessToolOptions struct {
|
||||
// ExecuteArgs are the parameters accepted by the execute tool.
|
||||
type ExecuteArgs struct {
|
||||
Command string `json:"command" description:"The shell command to execute."`
|
||||
Timeout *string `json:"timeout,omitempty" description:"Timeout duration (e.g. '30s', '5m'). Default is 10s. Only applies to foreground commands."`
|
||||
Timeout *string `json:"timeout,omitempty" description:"How long to wait for completion (e.g. '30s', '5m'). Default is 10s. The process keeps running if this expires and you get a background_process_id to re-attach. Only applies to foreground commands."`
|
||||
WorkDir *string `json:"workdir,omitempty" description:"Working directory for the command."`
|
||||
RunInBackground *bool `json:"run_in_background,omitempty" description:"Run this command in the background without blocking. Use for long-running processes like dev servers, file watchers, or builds that run longer than 5 seconds. Do NOT use shell & to background processes — it will not work correctly. Always use this parameter instead."`
|
||||
RunInBackground *bool `json:"run_in_background,omitempty" description:"Run without blocking. Use for persistent processes (dev servers, file watchers) or when you want to continue working while a command runs and check the result later with process_output. For commands whose result you need before continuing, prefer foreground with a longer timeout. Do NOT use shell & to background processes — it will not work correctly. Always use this parameter instead."`
|
||||
}
|
||||
|
||||
// Execute returns an AgentTool that runs a shell command in the
|
||||
@@ -88,7 +88,7 @@ type ExecuteArgs struct {
|
||||
func Execute(options ExecuteOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"execute",
|
||||
"Execute a shell command in the workspace. Use run_in_background=true for long-running processes (dev servers, file watchers, builds). Never use shell '&' for backgrounding. If the command fails or times out, the response may include a background_process_id; use process_output with that ID to retrieve the result.",
|
||||
"Execute a shell command in the workspace. Runs the command and waits for completion up to the timeout (default 10s, override with the timeout parameter e.g. '30s', '5m'). If the command exceeds the timeout, the response includes a background_process_id; use process_output with that ID to re-attach and wait for the result. Use run_in_background=true for persistent processes (dev servers, file watchers) or when you want to continue other work while the command runs. Never use shell '&' for backgrounding.",
|
||||
func(ctx context.Context, args ExecuteArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
@@ -389,11 +389,11 @@ type ProcessOutputArgs struct {
|
||||
}
|
||||
|
||||
// ProcessOutput returns an AgentTool that retrieves the output
|
||||
// of a background process by its ID.
|
||||
// of a tracked process by its ID.
|
||||
func ProcessOutput(options ProcessToolOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"process_output",
|
||||
"Retrieve output from a background process. "+
|
||||
"Retrieve output from a tracked process by ID. "+
|
||||
"Use the process_id returned by execute with "+
|
||||
"run_in_background=true or from a timed-out "+
|
||||
"execute's background_process_id. Blocks up to "+
|
||||
@@ -401,7 +401,8 @@ func ProcessOutput(options ProcessToolOptions) fantasy.AgentTool {
|
||||
"output and exit_code. If still running after "+
|
||||
"the timeout, returns the output so far. Use "+
|
||||
"wait_timeout to override the default 10s wait "+
|
||||
"(e.g. '30s', or '0s' for an immediate snapshot).",
|
||||
"(e.g. '30s', or '0s' for an immediate snapshot "+
|
||||
"without waiting).",
|
||||
func(ctx context.Context, args ProcessOutputArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
@@ -486,8 +487,7 @@ func ProcessList(options ProcessToolOptions) fantasy.AgentTool {
|
||||
"List all tracked processes in the workspace. "+
|
||||
"Returns process IDs, commands, status (running or "+
|
||||
"exited), and exit codes. Use this to discover "+
|
||||
"background processes or check which processes are "+
|
||||
"still running.",
|
||||
"processes or check which are still running.",
|
||||
func(ctx context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
@@ -517,11 +517,11 @@ type ProcessSignalArgs struct {
|
||||
}
|
||||
|
||||
// ProcessSignal returns an AgentTool that sends a signal to a
|
||||
// tracked process on the workspace agent.
|
||||
// tracked process on the workspace agent by its ID.
|
||||
func ProcessSignal(options ProcessToolOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"process_signal",
|
||||
"Send a signal to a background process. "+
|
||||
"Send a signal to a tracked process. "+
|
||||
"Use \"terminate\" (SIGTERM) for graceful shutdown "+
|
||||
"or \"kill\" (SIGKILL) to force stop. Use the "+
|
||||
"process_id returned by execute with "+
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
+1
-1
@@ -11,11 +11,11 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
@@ -0,0 +1,9 @@
|
||||
package chatd
|
||||
|
||||
// WaitUntilIdleForTest waits for background chat work tracked by the server to
|
||||
// finish without shutting the server down. Tests use this to assert final
|
||||
// database state only after asynchronous chat processing has completed.
|
||||
// Close waits for the same tracked work, but also stops the server.
|
||||
func WaitUntilIdleForTest(server *Server) {
|
||||
server.inflight.Wait()
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user