Compare commits
75 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2574576426 | |||
| fd1e2f0dd9 | |||
| be5e080de6 | |||
| 19e86628da | |||
| 02356c61f6 | |||
| b9f0c479ac | |||
| 803cfeb882 | |||
| 08577006c6 | |||
| 13241a58ba | |||
| 631e4449bb | |||
| 76eac82e5b | |||
| 405d81be09 | |||
| 1c0442c247 | |||
| 16edcbdd5b | |||
| f62f2ffe6a | |||
| 2dc3466f07 | |||
| cbd56d33d4 | |||
| b23aed034f | |||
| 56e80b0a27 | |||
| dba9f68b11 | |||
| 245ce91199 | |||
| 5d0734e005 | |||
| 43a1af3cd6 | |||
| 3d5d58ec2b | |||
| 37d937554e | |||
| 796190d435 | |||
| c1474c7ee2 | |||
| a8e7cc10b6 | |||
| 82f965a0ae | |||
| acbfb90c30 | |||
| c344d7c00e | |||
| 53350377b3 | |||
| 147df5c971 | |||
| 9e4c283370 | |||
| 145817e8d3 | |||
| 956f6b2473 | |||
| d2afda8191 | |||
| c389c2bc5c | |||
| 4c9e37b659 | |||
| 3b268c95d3 | |||
| 138bc41563 | |||
| 80a172f932 | |||
| 86d8b6daee | |||
| 470e6c7217 | |||
| ed19a3a08e | |||
| 975373704f | |||
| 522288c9d5 | |||
| edd13482a0 | |||
| ef14654078 | |||
| ea37f1ff86 | |||
| c49170b6b3 | |||
| ee9b46fe08 | |||
| 1ad3c898a0 | |||
| b8e09d09b0 | |||
| 0900a44ff3 | |||
| 4537413315 | |||
| ab86ed0df8 | |||
| f2b9d5f8f7 | |||
| b73983e309 | |||
| c11cc0ba30 | |||
| 3163e74b77 | |||
| eca2257c26 | |||
| 75f5b60eb6 | |||
| 69d430f51b | |||
| 0f3d40b97f | |||
| 3729ff46fb | |||
| b87171086c | |||
| b763b72b53 | |||
| a08b6848f2 | |||
| bf702cc3b9 | |||
| 47daca6eea | |||
| 4b707515c0 | |||
| ecc28a6650 | |||
| cf24c59b56 | |||
| a85800c90b |
@@ -0,0 +1,140 @@
|
||||
---
|
||||
name: refine-plan
|
||||
description: Iteratively refine development plans using TDD methodology. Ensures plans are clear, actionable, and include red-green-refactor cycles with proper test coverage.
|
||||
---
|
||||
|
||||
# Refine Development Plan
|
||||
|
||||
## Overview
|
||||
|
||||
Good plans eliminate ambiguity through clear requirements, break work into clear phases, and always include refactoring to capture implementation insights.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
| Symptom | Example |
|
||||
|-----------------------------|----------------------------------------|
|
||||
| Unclear acceptance criteria | No definition of "done" |
|
||||
| Vague implementation | Missing concrete steps or file changes |
|
||||
| Missing/undefined tests | Tests mentioned only as afterthought |
|
||||
| Absent refactor phase | No plan to improve code after it works |
|
||||
| Ambiguous requirements | Multiple interpretations possible |
|
||||
| Missing verification | No way to confirm the change works |
|
||||
|
||||
## Planning Principles
|
||||
|
||||
### 1. Plans Must Be Actionable and Unambiguous
|
||||
|
||||
Every step should be concrete enough that another agent could execute it without guessing.
|
||||
|
||||
- ❌ "Improve error handling" → ✓ "Add try-catch to API calls in user-service.ts, return 400 with error message"
|
||||
- ❌ "Update tests" → ✓ "Add test case to auth.test.ts: 'should reject expired tokens with 401'"
|
||||
|
||||
NEVER include thinking output or other stream-of-consciousness prose mid-plan.
|
||||
|
||||
### 2. Push Back on Unclear Requirements
|
||||
|
||||
When requirements are ambiguous, ask questions before proceeding.
|
||||
|
||||
### 3. Tests Define Requirements
|
||||
|
||||
Writing test cases forces disambiguation. Use test definition as a requirements clarification tool.
|
||||
|
||||
### 4. TDD is Non-Negotiable
|
||||
|
||||
All plans follow: **Red → Green → Refactor**. The refactor phase is MANDATORY.
|
||||
|
||||
## The TDD Workflow
|
||||
|
||||
### Red Phase: Write Failing Tests First
|
||||
|
||||
**Purpose:** Define success criteria through concrete test cases.
|
||||
|
||||
**What to test:**
|
||||
|
||||
- Happy path (normal usage), edge cases (boundaries, empty/null), error conditions (invalid input, failures), integration points
|
||||
|
||||
**Test types:**
|
||||
|
||||
- Unit tests: Individual functions in isolation (most tests should be these - fast, focused)
|
||||
- Integration tests: Component interactions (use for critical paths)
|
||||
- E2E tests: Complete workflows (use sparingly)
|
||||
|
||||
**Write descriptive test cases:**
|
||||
|
||||
**If you can't write the test, you don't understand the requirement and MUST ask for clarification.**
|
||||
|
||||
### Green Phase: Make Tests Pass
|
||||
|
||||
**Purpose:** Implement minimal working solution.
|
||||
|
||||
Focus on correctness first. Hardcode if needed. Add just enough logic. Resist urge to "improve" code. Run tests frequently.
|
||||
|
||||
### Refactor Phase: Improve the Implementation
|
||||
|
||||
**Purpose:** Apply insights gained during implementation.
|
||||
|
||||
**This phase is MANDATORY.** During implementation you'll discover better structure, repeated patterns, and simplification opportunities.
|
||||
|
||||
**When to Extract vs Keep Duplication:**
|
||||
|
||||
This is highly subjective, so use the following rules of thumb combined with good judgement:
|
||||
|
||||
1) Follow the "rule of three": if the exact 10+ lines are repeated verbatim 3+ times, extract it.
|
||||
2) The "wrong abstraction" is harder to fix than duplication.
|
||||
3) If extraction would harm readability, prefer duplication.
|
||||
|
||||
**Common refactorings:**
|
||||
|
||||
- Rename for clarity
|
||||
- Simplify complex conditionals
|
||||
- Extract repeated code (if meets criteria above)
|
||||
- Apply design patterns
|
||||
|
||||
**Constraints:**
|
||||
|
||||
- All tests must still pass after refactoring
|
||||
- Don't add new features (that's a new Red phase)
|
||||
|
||||
## Plan Refinement Process
|
||||
|
||||
### Step 1: Review Current Plan for Completeness
|
||||
|
||||
- [ ] Clear context explaining why
|
||||
- [ ] Specific, unambiguous requirements
|
||||
- [ ] Test cases defined before implementation
|
||||
- [ ] Step-by-step implementation approach
|
||||
- [ ] Explicit refactor phase
|
||||
- [ ] Verification steps
|
||||
|
||||
### Step 2: Identify Gaps
|
||||
|
||||
Look for missing tests, vague steps, no refactor phase, ambiguous requirements, missing verification.
|
||||
|
||||
### Step 3: Handle Unclear Requirements
|
||||
|
||||
If you can't write the plan without this information, ask the user. Otherwise, make reasonable assumptions and note them in the plan.
|
||||
|
||||
### Step 4: Define Test Cases
|
||||
|
||||
For each requirement, write concrete test cases. If you struggle to write test cases, you need more clarification.
|
||||
|
||||
### Step 5: Structure with Red-Green-Refactor
|
||||
|
||||
Organize the plan into three explicit phases.
|
||||
|
||||
### Step 6: Add Verification Steps
|
||||
|
||||
Specify how to confirm the change works (automated tests + manual checks).
|
||||
|
||||
## Tips for Success
|
||||
|
||||
1. **Start with tests:** If you can't write the test, you don't understand the requirement.
|
||||
2. **Be specific:** "Update API" is not a step. "Add error handling to POST /users endpoint" is.
|
||||
3. **Always refactor:** Even if code looks good, ask "How could this be clearer?"
|
||||
4. **Question everything:** Ambiguity is the enemy.
|
||||
5. **Think in phases:** Red → Green → Refactor.
|
||||
6. **Keep plans manageable:** If plan exceeds ~10 files or >5 phases, consider splitting.
|
||||
|
||||
---
|
||||
|
||||
**Remember:** A good plan makes implementation straightforward. A vague plan leads to confusion, rework, and bugs.
|
||||
@@ -177,16 +177,6 @@ Dependabot PRs are auto-generated - don't try to match their verbose style for m
|
||||
Changes from https://github.com/upstream/repo/pull/XXX/
|
||||
```
|
||||
|
||||
## Attribution Footer
|
||||
|
||||
For AI-generated PRs, end with:
|
||||
|
||||
```markdown
|
||||
🤖 Generated with [Claude Code](https://claude.com/claude-code)
|
||||
|
||||
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
|
||||
```
|
||||
|
||||
## Creating PRs as Draft
|
||||
|
||||
**IMPORTANT**: Unless explicitly told otherwise, always create PRs as drafts using the `--draft` flag:
|
||||
@@ -197,11 +187,12 @@ gh pr create --draft --title "..." --body "..."
|
||||
|
||||
After creating the PR, encourage the user to review it before marking as ready:
|
||||
|
||||
```
|
||||
```text
|
||||
I've created draft PR #XXXX. Please review the changes and mark it as ready for review when you're satisfied.
|
||||
```
|
||||
|
||||
This allows the user to:
|
||||
|
||||
- Review the code changes before requesting reviews from maintainers
|
||||
- Make additional adjustments if needed
|
||||
- Ensure CI passes before notifying reviewers
|
||||
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
- name: check changed files
|
||||
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
|
||||
uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
|
||||
id: filter
|
||||
with:
|
||||
filters: |
|
||||
@@ -1119,6 +1119,8 @@ jobs:
|
||||
|
||||
- name: Setup Go
|
||||
uses: ./.github/actions/setup-go
|
||||
with:
|
||||
use-cache: false
|
||||
|
||||
- name: Install rcodesign
|
||||
run: |
|
||||
|
||||
@@ -135,7 +135,7 @@ jobs:
|
||||
PR_NUMBER: ${{ steps.pr_info.outputs.PR_NUMBER }}
|
||||
|
||||
- name: Check changed files
|
||||
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
|
||||
uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
|
||||
id: filter
|
||||
with:
|
||||
base: ${{ github.ref }}
|
||||
|
||||
@@ -163,6 +163,8 @@ jobs:
|
||||
|
||||
- name: Setup Go
|
||||
uses: ./.github/actions/setup-go
|
||||
with:
|
||||
use-cache: false
|
||||
|
||||
- name: Setup Node
|
||||
uses: ./.github/actions/setup-node
|
||||
|
||||
@@ -63,116 +63,3 @@ jobs:
|
||||
--data "{\"content\": \"$msg\"}" \
|
||||
"${{ secrets.SLACK_SECURITY_FAILURE_WEBHOOK_URL }}"
|
||||
|
||||
trivy:
|
||||
permissions:
|
||||
security-events: write
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Go
|
||||
uses: ./.github/actions/setup-go
|
||||
|
||||
- name: Setup Node
|
||||
uses: ./.github/actions/setup-node
|
||||
|
||||
- name: Setup sqlc
|
||||
uses: ./.github/actions/setup-sqlc
|
||||
|
||||
- name: Install cosign
|
||||
uses: ./.github/actions/install-cosign
|
||||
|
||||
- name: Install syft
|
||||
uses: ./.github/actions/install-syft
|
||||
|
||||
- name: Install yq
|
||||
run: go run github.com/mikefarah/yq/v4@v4.44.3
|
||||
- name: Install mockgen
|
||||
run: ./.github/scripts/retry.sh -- go install go.uber.org/mock/mockgen@v0.6.0
|
||||
- name: Install protoc-gen-go
|
||||
run: ./.github/scripts/retry.sh -- go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.30
|
||||
- name: Install protoc-gen-go-drpc
|
||||
run: ./.github/scripts/retry.sh -- go install storj.io/drpc/cmd/protoc-gen-go-drpc@v0.0.34
|
||||
- name: Install Protoc
|
||||
run: |
|
||||
# protoc must be in lockstep with our dogfood Dockerfile or the
|
||||
# version in the comments will differ. This is also defined in
|
||||
# ci.yaml.
|
||||
set -euxo pipefail
|
||||
cd dogfood/coder
|
||||
mkdir -p /usr/local/bin
|
||||
mkdir -p /usr/local/include
|
||||
|
||||
DOCKER_BUILDKIT=1 docker build . --target proto -t protoc
|
||||
protoc_path=/usr/local/bin/protoc
|
||||
docker run --rm --entrypoint cat protoc /tmp/bin/protoc > $protoc_path
|
||||
chmod +x $protoc_path
|
||||
protoc --version
|
||||
# Copy the generated files to the include directory.
|
||||
docker run --rm -v /usr/local/include:/target protoc cp -r /tmp/include/google /target/
|
||||
ls -la /usr/local/include/google/protobuf/
|
||||
stat /usr/local/include/google/protobuf/timestamp.proto
|
||||
|
||||
- name: Build Coder linux amd64 Docker image
|
||||
id: build
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
version="$(./scripts/version.sh)"
|
||||
image_job="build/coder_${version}_linux_amd64.tag"
|
||||
|
||||
# This environment variable force make to not build packages and
|
||||
# archives (which the Docker image depends on due to technical reasons
|
||||
# related to concurrent FS writes).
|
||||
export DOCKER_IMAGE_NO_PREREQUISITES=true
|
||||
# This environment variables forces scripts/build_docker.sh to build
|
||||
# the base image tag locally instead of using the cached version from
|
||||
# the registry.
|
||||
CODER_IMAGE_BUILD_BASE_TAG="$(CODER_IMAGE_BASE=coder-base ./scripts/image_tag.sh --version "$version")"
|
||||
export CODER_IMAGE_BUILD_BASE_TAG
|
||||
|
||||
# We would like to use make -j here, but it doesn't work with the some recent additions
|
||||
# to our code generation.
|
||||
make "$image_job"
|
||||
echo "image=$(cat "$image_job")" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@57a97c7e7821a5776cebc9bb87c984fa69cba8f1 # v0.34.0
|
||||
with:
|
||||
image-ref: ${{ steps.build.outputs.image }}
|
||||
format: sarif
|
||||
output: trivy-results.sarif
|
||||
severity: "CRITICAL,HIGH"
|
||||
|
||||
- name: Upload Trivy scan results to GitHub Security tab
|
||||
uses: github/codeql-action/upload-sarif@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
|
||||
with:
|
||||
sarif_file: trivy-results.sarif
|
||||
category: "Trivy"
|
||||
|
||||
- name: Upload Trivy scan results as an artifact
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: trivy
|
||||
path: trivy-results.sarif
|
||||
retention-days: 7
|
||||
|
||||
- name: Send Slack notification on failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
msg="❌ Trivy Failed\n\nhttps://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
|
||||
curl \
|
||||
-qfsSL \
|
||||
-X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
--data "{\"content\": \"$msg\"}" \
|
||||
"${{ secrets.SLACK_SECURITY_FAILURE_WEBHOOK_URL }}"
|
||||
|
||||
@@ -1255,7 +1255,7 @@ coderd/notifications/.gen-golden: $(wildcard coderd/notifications/testdata/*/*.g
|
||||
TZ=UTC go test ./coderd/notifications -run="Test.*Golden$$" -update
|
||||
touch "$@"
|
||||
|
||||
provisioner/terraform/testdata/.gen-golden: $(wildcard provisioner/terraform/testdata/*/*.golden) $(GO_SRC_FILES) $(wildcard provisioner/terraform/*_test.go)
|
||||
provisioner/terraform/testdata/.gen-golden: $(wildcard provisioner/terraform/testdata/*/*.golden) $(wildcard provisioner/terraform/testdata/*/*/*.golden) $(GO_SRC_FILES) $(wildcard provisioner/terraform/*_test.go)
|
||||
TZ=UTC go test ./provisioner/terraform -run="Test.*Golden$$" -update
|
||||
touch "$@"
|
||||
|
||||
@@ -1343,6 +1343,7 @@ test-js: site/node_modules/.installed
|
||||
|
||||
test-storybook: site/node_modules/.installed
|
||||
cd site/
|
||||
pnpm playwright:install
|
||||
pnpm exec vitest run --project=storybook
|
||||
.PHONY: test-storybook
|
||||
|
||||
|
||||
+2
-5
@@ -16,7 +16,6 @@ import (
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -463,9 +462,7 @@ func (a *agent) runLoop() {
|
||||
// messages.
|
||||
ctx := a.hardCtx
|
||||
defer a.logger.Info(ctx, "agent main loop exited")
|
||||
retrier := retry.New(100*time.Millisecond, 10*time.Second)
|
||||
retrier.Jitter = 0.5
|
||||
for ; retrier.Wait(ctx); {
|
||||
for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
|
||||
a.logger.Info(ctx, "connecting to coderd")
|
||||
err := a.run()
|
||||
if err == nil {
|
||||
@@ -1879,7 +1876,7 @@ func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connect
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
sort.Float64s(durations)
|
||||
slices.Sort(durations)
|
||||
durationsLength := len(durations)
|
||||
switch {
|
||||
case durationsLength == 0:
|
||||
|
||||
@@ -433,7 +433,7 @@ func convertDockerInspect(raw []byte) ([]codersdk.WorkspaceAgentContainer, []str
|
||||
}
|
||||
portKeys := maps.Keys(in.NetworkSettings.Ports)
|
||||
// Sort the ports for deterministic output.
|
||||
sort.Strings(portKeys)
|
||||
slices.Sort(portKeys)
|
||||
// If we see the same port bound to both ipv4 and ipv6 loopback or unspecified
|
||||
// interfaces to the same container port, there is no point in adding it multiple times.
|
||||
loopbackHostPortContainerPorts := make(map[int]uint16, 0)
|
||||
|
||||
+31
-46
@@ -2,7 +2,6 @@ package agentdesktop
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -13,6 +12,7 @@ import (
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
@@ -26,9 +26,9 @@ type DesktopAction struct {
|
||||
Duration *int `json:"duration,omitempty"`
|
||||
ScrollAmount *int `json:"scroll_amount,omitempty"`
|
||||
ScrollDirection *string `json:"scroll_direction,omitempty"`
|
||||
// ScaledWidth and ScaledHeight are the coordinate space the
|
||||
// model is using. When provided, coordinates are linearly
|
||||
// mapped from scaled → native before dispatching.
|
||||
// ScaledWidth and ScaledHeight describe the declared model-facing desktop
|
||||
// geometry. When provided, input coordinates are mapped from declared space
|
||||
// to native desktop pixels before dispatching.
|
||||
ScaledWidth *int `json:"scaled_width,omitempty"`
|
||||
ScaledHeight *int `json:"scaled_height,omitempty"`
|
||||
}
|
||||
@@ -144,17 +144,8 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
|
||||
)
|
||||
|
||||
// Helper to scale a coordinate pair from the model's space to
|
||||
// native display pixels.
|
||||
scaleXY := func(x, y int) (int, int) {
|
||||
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
|
||||
x = scaleCoordinate(x, *action.ScaledWidth, cfg.Width)
|
||||
}
|
||||
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
|
||||
y = scaleCoordinate(y, *action.ScaledHeight, cfg.Height)
|
||||
}
|
||||
return x, y
|
||||
}
|
||||
geometry := desktopGeometryForAction(cfg, action)
|
||||
scaleXY := geometry.DeclaredPointToNative
|
||||
|
||||
var resp DesktopActionResponse
|
||||
|
||||
@@ -192,7 +183,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
resp.Output = "type action performed"
|
||||
|
||||
case "cursor_position":
|
||||
x, y, err := a.desktop.CursorPosition(ctx)
|
||||
nativeX, nativeY, err := a.desktop.CursorPosition(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Cursor position failed.",
|
||||
@@ -200,6 +191,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y := geometry.NativePointToDeclared(nativeX, nativeY)
|
||||
resp.Output = "x=" + strconv.Itoa(x) + ",y=" + strconv.Itoa(y)
|
||||
|
||||
case "mouse_move":
|
||||
@@ -447,14 +439,10 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
resp.Output = "hold_key action performed"
|
||||
|
||||
case "screenshot":
|
||||
var opts ScreenshotOptions
|
||||
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
|
||||
opts.TargetWidth = *action.ScaledWidth
|
||||
}
|
||||
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
|
||||
opts.TargetHeight = *action.ScaledHeight
|
||||
}
|
||||
result, err := a.desktop.Screenshot(ctx, opts)
|
||||
result, err := a.desktop.Screenshot(ctx, ScreenshotOptions{
|
||||
TargetWidth: geometry.DeclaredWidth,
|
||||
TargetHeight: geometry.DeclaredHeight,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Screenshot failed.",
|
||||
@@ -464,16 +452,8 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
resp.Output = "screenshot"
|
||||
resp.ScreenshotData = result.Data
|
||||
if action.ScaledWidth != nil && *action.ScaledWidth > 0 && *action.ScaledWidth != cfg.Width {
|
||||
resp.ScreenshotWidth = *action.ScaledWidth
|
||||
} else {
|
||||
resp.ScreenshotWidth = cfg.Width
|
||||
}
|
||||
if action.ScaledHeight != nil && *action.ScaledHeight > 0 && *action.ScaledHeight != cfg.Height {
|
||||
resp.ScreenshotHeight = *action.ScaledHeight
|
||||
} else {
|
||||
resp.ScreenshotHeight = cfg.Height
|
||||
}
|
||||
resp.ScreenshotWidth = geometry.DeclaredWidth
|
||||
resp.ScreenshotHeight = geometry.DeclaredHeight
|
||||
|
||||
default:
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
@@ -512,6 +492,23 @@ func coordFromAction(action DesktopAction) (x, y int, err error) {
|
||||
return action.Coordinate[0], action.Coordinate[1], nil
|
||||
}
|
||||
|
||||
func desktopGeometryForAction(cfg DisplayConfig, action DesktopAction) workspacesdk.DesktopGeometry {
|
||||
declaredWidth := cfg.Width
|
||||
declaredHeight := cfg.Height
|
||||
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
|
||||
declaredWidth = *action.ScaledWidth
|
||||
}
|
||||
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
|
||||
declaredHeight = *action.ScaledHeight
|
||||
}
|
||||
return workspacesdk.NewDesktopGeometryWithDeclared(
|
||||
cfg.Width,
|
||||
cfg.Height,
|
||||
declaredWidth,
|
||||
declaredHeight,
|
||||
)
|
||||
}
|
||||
|
||||
// missingFieldError is returned when a required field is absent from
|
||||
// a DesktopAction.
|
||||
type missingFieldError struct {
|
||||
@@ -522,15 +519,3 @@ type missingFieldError struct {
|
||||
func (e *missingFieldError) Error() string {
|
||||
return "Missing \"" + e.field + "\" for " + e.action + " action."
|
||||
}
|
||||
|
||||
// scaleCoordinate maps a coordinate from scaled → native space.
|
||||
func scaleCoordinate(scaled, scaledDim, nativeDim int) int {
|
||||
if scaledDim == 0 || scaledDim == nativeDim {
|
||||
return scaled
|
||||
}
|
||||
native := (float64(scaled)+0.5)*float64(nativeDim)/float64(scaledDim) - 0.5
|
||||
// Clamp to valid range.
|
||||
native = math.Max(native, 0)
|
||||
native = math.Min(native, float64(nativeDim-1))
|
||||
return int(native)
|
||||
}
|
||||
|
||||
+125
-16
@@ -27,10 +27,12 @@ var _ agentdesktop.Desktop = (*fakeDesktop)(nil)
|
||||
// fakeDesktop is a minimal Desktop implementation for unit tests.
|
||||
type fakeDesktop struct {
|
||||
startErr error
|
||||
cursorPos [2]int
|
||||
startCfg agentdesktop.DisplayConfig
|
||||
vncConnErr error
|
||||
screenshotErr error
|
||||
screenshotRes agentdesktop.ScreenshotResult
|
||||
lastShotOpts agentdesktop.ScreenshotOptions
|
||||
closed bool
|
||||
|
||||
// Track calls for assertions.
|
||||
@@ -51,7 +53,8 @@ func (f *fakeDesktop) VNCConn(context.Context) (net.Conn, error) {
|
||||
return nil, f.vncConnErr
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Screenshot(_ context.Context, _ agentdesktop.ScreenshotOptions) (agentdesktop.ScreenshotResult, error) {
|
||||
func (f *fakeDesktop) Screenshot(_ context.Context, opts agentdesktop.ScreenshotOptions) (agentdesktop.ScreenshotResult, error) {
|
||||
f.lastShotOpts = opts
|
||||
return f.screenshotRes, f.screenshotErr
|
||||
}
|
||||
|
||||
@@ -100,8 +103,8 @@ func (f *fakeDesktop) Type(_ context.Context, text string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*fakeDesktop) CursorPosition(context.Context) (x int, y int, err error) {
|
||||
return 10, 20, nil
|
||||
func (f *fakeDesktop) CursorPosition(context.Context) (x int, y int, err error) {
|
||||
return f.cursorPos[0], f.cursorPos[1], nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Close() error {
|
||||
@@ -135,8 +138,12 @@ func TestHandleAction_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
geometry := workspacesdk.DefaultDesktopGeometry()
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: workspacesdk.DesktopDisplayWidth, Height: workspacesdk.DesktopDisplayHeight},
|
||||
startCfg: agentdesktop.DisplayConfig{
|
||||
Width: geometry.NativeWidth,
|
||||
Height: geometry.NativeHeight,
|
||||
},
|
||||
screenshotRes: agentdesktop.ScreenshotResult{Data: "base64data"},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
@@ -158,11 +165,52 @@ func TestHandleAction_Screenshot(t *testing.T) {
|
||||
var result agentdesktop.DesktopActionResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&result)
|
||||
require.NoError(t, err)
|
||||
// Dimensions come from DisplayConfig, not the screenshot CLI.
|
||||
assert.Equal(t, "screenshot", result.Output)
|
||||
assert.Equal(t, "base64data", result.ScreenshotData)
|
||||
assert.Equal(t, workspacesdk.DesktopDisplayWidth, result.ScreenshotWidth)
|
||||
assert.Equal(t, workspacesdk.DesktopDisplayHeight, result.ScreenshotHeight)
|
||||
assert.Equal(t, geometry.NativeWidth, result.ScreenshotWidth)
|
||||
assert.Equal(t, geometry.NativeHeight, result.ScreenshotHeight)
|
||||
assert.Equal(t, agentdesktop.ScreenshotOptions{
|
||||
TargetWidth: geometry.NativeWidth,
|
||||
TargetHeight: geometry.NativeHeight,
|
||||
}, fake.lastShotOpts)
|
||||
}
|
||||
|
||||
func TestHandleAction_ScreenshotUsesDeclaredDimensionsFromRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
screenshotRes: agentdesktop.ScreenshotResult{Data: "base64data"},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
sw := 1280
|
||||
sh := 720
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "screenshot",
|
||||
ScaledWidth: &sw,
|
||||
ScaledHeight: &sh,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, agentdesktop.ScreenshotOptions{TargetWidth: 1280, TargetHeight: 720}, fake.lastShotOpts)
|
||||
|
||||
var result agentdesktop.DesktopActionResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1280, result.ScreenshotWidth)
|
||||
assert.Equal(t, 720, result.ScreenshotHeight)
|
||||
}
|
||||
|
||||
func TestHandleAction_LeftClick(t *testing.T) {
|
||||
@@ -315,7 +363,6 @@ func TestHandleAction_HoldKey(t *testing.T) {
|
||||
handler.ServeHTTP(rr, req)
|
||||
}()
|
||||
|
||||
// Wait for the timer to be created, then advance past it.
|
||||
trap.MustWait(req.Context()).MustRelease(req.Context())
|
||||
mClk.Advance(time.Duration(dur) * time.Millisecond).MustWait(req.Context())
|
||||
|
||||
@@ -389,7 +436,6 @@ func TestHandleAction_ScrollDown(t *testing.T) {
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
// dy should be positive 5 for "down".
|
||||
assert.Equal(t, [4]int{500, 400, 0, 5}, fake.lastScroll)
|
||||
}
|
||||
|
||||
@@ -398,13 +444,11 @@ func TestHandleAction_CoordinateScaling(t *testing.T) {
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
// Native display is 1920x1080.
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
// Model is working in a 1280x720 coordinate space.
|
||||
sw := 1280
|
||||
sh := 720
|
||||
body := agentdesktop.DesktopAction{
|
||||
@@ -424,12 +468,43 @@ func TestHandleAction_CoordinateScaling(t *testing.T) {
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
// 640 in 1280-space → 960 in 1920-space (midpoint maps to
|
||||
// midpoint).
|
||||
assert.Equal(t, 960, fake.lastMove[0])
|
||||
assert.Equal(t, 540, fake.lastMove[1])
|
||||
}
|
||||
|
||||
func TestHandleAction_CoordinateScalingClampsToLastPixel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
sw := 1366
|
||||
sh := 768
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "mouse_move",
|
||||
Coordinate: &[2]int{1365, 767},
|
||||
ScaledWidth: &sw,
|
||||
ScaledHeight: &sh,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, 1919, fake.lastMove[0])
|
||||
assert.Equal(t, 1079, fake.lastMove[1])
|
||||
}
|
||||
|
||||
func TestClose_DelegatesToDesktop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -446,15 +521,12 @@ func TestClose_PreventsNewSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// After Close(), Start() will return an error because the
|
||||
// underlying Desktop is closed.
|
||||
fake := &fakeDesktop{}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
|
||||
err := api.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate the closed desktop returning an error on Start().
|
||||
fake.startErr = xerrors.New("desktop is closed")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
@@ -465,3 +537,40 @@ func TestClose_PreventsNewSessions(t *testing.T) {
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
}
|
||||
|
||||
func TestHandleAction_CursorPositionReturnsDeclaredCoordinates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
cursorPos: [2]int{960, 540},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
sw := 1280
|
||||
sh := 720
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "cursor_position",
|
||||
ScaledWidth: &sw,
|
||||
ScaledHeight: &sh,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var resp agentdesktop.DesktopActionResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
// Native (960,540) in 1920x1080 should map to declared space in 1280x720.
|
||||
assert.Equal(t, "x=640,y=360", resp.Output)
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
|
||||
|
||||
//nolint:gosec // portabledesktop is a trusted binary resolved via ensureBinary.
|
||||
cmd := p.execer.CommandContext(sessionCtx, p.binPath, "up", "--json",
|
||||
"--geometry", fmt.Sprintf("%dx%d", workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight))
|
||||
"--geometry", fmt.Sprintf("%dx%d", workspacesdk.DesktopNativeWidth, workspacesdk.DesktopNativeHeight))
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
sessionCancel()
|
||||
|
||||
+135
-22
@@ -14,6 +14,7 @@ import (
|
||||
"syscall"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/afero"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
@@ -319,8 +320,14 @@ func (api *API) writeFile(ctx context.Context, r *http.Request, path string) (HT
|
||||
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
|
||||
}
|
||||
|
||||
resolved, err := api.resolveSymlink(path)
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, xerrors.Errorf("resolve symlink %q: %w", path, err)
|
||||
}
|
||||
path = resolved
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
err := api.filesystem.MkdirAll(dir, 0o755)
|
||||
err = api.filesystem.MkdirAll(dir, 0o755)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
switch {
|
||||
@@ -410,6 +417,12 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
|
||||
return http.StatusBadRequest, xerrors.New("must specify at least one edit")
|
||||
}
|
||||
|
||||
resolved, err := api.resolveSymlink(path)
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, xerrors.Errorf("resolve symlink %q: %w", path, err)
|
||||
}
|
||||
path = resolved
|
||||
|
||||
f, err := api.filesystem.Open(path)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
@@ -510,6 +523,52 @@ func (api *API) atomicWrite(ctx context.Context, path string, mode *os.FileMode,
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// resolveSymlink resolves a path through any symlinks so that
|
||||
// subsequent operations (such as atomic rename) target the real
|
||||
// file instead of replacing the symlink itself.
|
||||
//
|
||||
// The filesystem must implement afero.Lstater and afero.LinkReader
|
||||
// for resolution to occur; if it does not (e.g. MemMapFs), the
|
||||
// path is returned unchanged.
|
||||
func (api *API) resolveSymlink(path string) (string, error) {
|
||||
const maxDepth = 10
|
||||
|
||||
lstater, hasLstat := api.filesystem.(afero.Lstater)
|
||||
if !hasLstat {
|
||||
return path, nil
|
||||
}
|
||||
reader, hasReadlink := api.filesystem.(afero.LinkReader)
|
||||
if !hasReadlink {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
for range maxDepth {
|
||||
info, _, err := lstater.LstatIfPossible(path)
|
||||
if err != nil {
|
||||
// If the file does not exist yet (new file write),
|
||||
// there is nothing to resolve.
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return path, nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if info.Mode()&os.ModeSymlink == 0 {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
target, err := reader.ReadlinkIfPossible(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !filepath.IsAbs(target) {
|
||||
target = filepath.Join(filepath.Dir(path), target)
|
||||
}
|
||||
path = target
|
||||
}
|
||||
|
||||
return "", xerrors.Errorf("too many levels of symlinks resolving %q", path)
|
||||
}
|
||||
|
||||
// fuzzyReplace attempts to find `search` inside `content` and replace it
|
||||
// with `replace`. It uses a cascading match strategy inspired by
|
||||
// openai/codex's apply_patch:
|
||||
@@ -567,30 +626,15 @@ func fuzzyReplace(content string, edit workspacesdk.FileEdit) (string, error) {
|
||||
}
|
||||
|
||||
// Pass 2 – trim trailing whitespace on each line.
|
||||
if start, end, ok := seekLines(contentLines, searchLines, trimRight); ok {
|
||||
if !edit.ReplaceAll {
|
||||
if count := countLineMatches(contentLines, searchLines, trimRight); count > 1 {
|
||||
return "", xerrors.Errorf("search string matches %d occurrences "+
|
||||
"(expected exactly 1). Include more surrounding "+
|
||||
"context to make the match unique, or set "+
|
||||
"replace_all to true", count)
|
||||
}
|
||||
}
|
||||
return spliceLines(contentLines, start, end, replace), nil
|
||||
if result, matched, err := fuzzyReplaceLines(contentLines, searchLines, replace, trimRight, edit.ReplaceAll); matched {
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Pass 3 – trim all leading and trailing whitespace
|
||||
// (indentation-tolerant).
|
||||
if start, end, ok := seekLines(contentLines, searchLines, trimAll); ok {
|
||||
if !edit.ReplaceAll {
|
||||
if count := countLineMatches(contentLines, searchLines, trimAll); count > 1 {
|
||||
return "", xerrors.Errorf("search string matches %d occurrences "+
|
||||
"(expected exactly 1). Include more surrounding "+
|
||||
"context to make the match unique, or set "+
|
||||
"replace_all to true", count)
|
||||
}
|
||||
}
|
||||
return spliceLines(contentLines, start, end, replace), nil
|
||||
// (indentation-tolerant). The replacement is inserted verbatim;
|
||||
// callers must provide correctly indented replacement text.
|
||||
if result, matched, err := fuzzyReplaceLines(contentLines, searchLines, replace, trimAll, edit.ReplaceAll); matched {
|
||||
return result, err
|
||||
}
|
||||
|
||||
return "", xerrors.New("search string not found in file. Verify the search " +
|
||||
@@ -653,3 +697,72 @@ func spliceLines(contentLines []string, start, end int, replacement string) stri
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// fuzzyReplaceLines handles fuzzy matching passes (2 and 3) for
|
||||
// fuzzyReplace. When replaceAll is false and there are multiple
|
||||
// matches, an error is returned. When replaceAll is true, all
|
||||
// non-overlapping matches are replaced.
|
||||
//
|
||||
// Returns (result, true, nil) on success, ("", false, nil) when
|
||||
// searchLines don't match at all, or ("", true, err) when the match
|
||||
// is ambiguous.
|
||||
//
|
||||
//nolint:revive // replaceAll is a direct pass-through of the user's flag, not a control coupling.
|
||||
func fuzzyReplaceLines(
|
||||
contentLines, searchLines []string,
|
||||
replace string,
|
||||
eq func(a, b string) bool,
|
||||
replaceAll bool,
|
||||
) (string, bool, error) {
|
||||
start, end, ok := seekLines(contentLines, searchLines, eq)
|
||||
if !ok {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
if !replaceAll {
|
||||
if count := countLineMatches(contentLines, searchLines, eq); count > 1 {
|
||||
return "", true, xerrors.Errorf("search string matches %d occurrences "+
|
||||
"(expected exactly 1). Include more surrounding "+
|
||||
"context to make the match unique, or set "+
|
||||
"replace_all to true", count)
|
||||
}
|
||||
return spliceLines(contentLines, start, end, replace), true, nil
|
||||
}
|
||||
|
||||
// Replace all: collect all match positions, then apply from last
|
||||
// to first to preserve indices.
|
||||
type lineMatch struct{ start, end int }
|
||||
var matches []lineMatch
|
||||
for i := 0; i <= len(contentLines)-len(searchLines); {
|
||||
found := true
|
||||
for j, sLine := range searchLines {
|
||||
if !eq(contentLines[i+j], sLine) {
|
||||
found = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
matches = append(matches, lineMatch{i, i + len(searchLines)})
|
||||
i += len(searchLines) // skip past this match
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
// Apply replacements from last to first.
|
||||
repLines := strings.SplitAfter(replace, "\n")
|
||||
for i := len(matches) - 1; i >= 0; i-- {
|
||||
m := matches[i]
|
||||
newLines := make([]string, 0, m.start+len(repLines)+(len(contentLines)-m.end))
|
||||
newLines = append(newLines, contentLines[:m.start]...)
|
||||
newLines = append(newLines, repLines...)
|
||||
newLines = append(newLines, contentLines[m.end:]...)
|
||||
contentLines = newLines
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
for _, l := range contentLines {
|
||||
_, _ = b.WriteString(l)
|
||||
}
|
||||
return b.String(), true, nil
|
||||
}
|
||||
|
||||
@@ -881,6 +881,43 @@ func TestEditFiles(t *testing.T) {
|
||||
},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "ra-exact"): "qux bar qux baz qux"},
|
||||
},
|
||||
{
|
||||
// replace_all with fuzzy trailing-whitespace match.
|
||||
name: "ReplaceAllFuzzyTrailing",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "ra-fuzzy-trail"): "hello \nworld\nhello \nagain"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "ra-fuzzy-trail"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "hello\n",
|
||||
Replace: "bye\n",
|
||||
ReplaceAll: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "ra-fuzzy-trail"): "bye\nworld\nbye\nagain"},
|
||||
},
|
||||
{
|
||||
// replace_all with fuzzy indent match (pass 3).
|
||||
name: "ReplaceAllFuzzyIndent",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "ra-fuzzy-indent"): "\t\talpha\n\t\tbeta\n\t\talpha\n\t\tgamma"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "ra-fuzzy-indent"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
// Search uses different indentation (spaces instead of tabs).
|
||||
Search: " alpha\n",
|
||||
Replace: "\t\tREPLACED\n",
|
||||
ReplaceAll: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "ra-fuzzy-indent"): "\t\tREPLACED\n\t\tbeta\n\t\tREPLACED\n\t\tgamma"},
|
||||
},
|
||||
{
|
||||
name: "MixedWhitespaceMultiline",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "mixed-ws"): "func main() {\n\tresult := compute()\n\tfmt.Println(result)\n}"},
|
||||
@@ -1395,3 +1432,105 @@ func TestReadFileLines(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteFile_FollowsSymlinks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("symlinks are not reliably supported on Windows")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
osFs := afero.NewOsFs()
|
||||
api := agentfiles.NewAPI(logger, osFs, nil)
|
||||
|
||||
// Create a real file and a symlink pointing to it.
|
||||
realPath := filepath.Join(dir, "real.txt")
|
||||
err := afero.WriteFile(osFs, realPath, []byte("original"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
linkPath := filepath.Join(dir, "link.txt")
|
||||
err = os.Symlink(realPath, linkPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
// Write through the symlink.
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodPost,
|
||||
fmt.Sprintf("/write-file?path=%s", linkPath),
|
||||
bytes.NewReader([]byte("updated")))
|
||||
api.Routes().ServeHTTP(w, r)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// The symlink must still be a symlink.
|
||||
fi, err := os.Lstat(linkPath)
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, fi.Mode()&os.ModeSymlink, "symlink was replaced")
|
||||
|
||||
// The real file must have the new content.
|
||||
data, err := os.ReadFile(realPath)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "updated", string(data))
|
||||
}
|
||||
|
||||
func TestEditFiles_FollowsSymlinks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("symlinks are not reliably supported on Windows")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
osFs := afero.NewOsFs()
|
||||
api := agentfiles.NewAPI(logger, osFs, nil)
|
||||
|
||||
// Create a real file and a symlink pointing to it.
|
||||
realPath := filepath.Join(dir, "real.txt")
|
||||
err := afero.WriteFile(osFs, realPath, []byte("hello world"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
linkPath := filepath.Join(dir, "link.txt")
|
||||
err = os.Symlink(realPath, linkPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
body := workspacesdk.FileEditRequest{
|
||||
Files: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: linkPath,
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "hello",
|
||||
Replace: "goodbye",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
buf := bytes.NewBuffer(nil)
|
||||
enc := json.NewEncoder(buf)
|
||||
enc.SetEscapeHTML(false)
|
||||
err = enc.Encode(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/edit-files", buf)
|
||||
api.Routes().ServeHTTP(w, r)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// The symlink must still be a symlink.
|
||||
fi, err := os.Lstat(linkPath)
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, fi.Mode()&os.ModeSymlink, "symlink was replaced")
|
||||
|
||||
// The real file must have the edited content.
|
||||
data, err := os.ReadFile(realPath)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "goodbye world", string(data))
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package agentgit
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -99,7 +99,7 @@ func (ps *PathStore) GetPaths(chatID uuid.UUID) []string {
|
||||
for p := range m {
|
||||
out = append(out, p)
|
||||
}
|
||||
sort.Strings(out)
|
||||
slices.Sort(out)
|
||||
return out
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -228,6 +228,6 @@ func resultPaths(results []filefinder.Result) []string {
|
||||
for i, r := range results {
|
||||
paths[i] = r.Path
|
||||
}
|
||||
sort.Strings(paths)
|
||||
slices.Sort(paths)
|
||||
return paths
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ func (b *Builder) Build(inv *serpent.Invocation) (log slog.Logger, closeLog func
|
||||
|
||||
addSinkIfProvided := func(sinkFn func(io.Writer) slog.Sink, loc string) error {
|
||||
switch loc {
|
||||
case "", "/dev/null":
|
||||
case "":
|
||||
case "/dev/stdout":
|
||||
sinks = append(sinks, sinkFn(inv.Stdout))
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -376,8 +376,8 @@ func Test_sshConfigOptions_addOption(t *testing.T) {
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
sort.Strings(tt.Expect)
|
||||
sort.Strings(o.sshOptions)
|
||||
slices.Sort(tt.Expect)
|
||||
slices.Sort(o.sshOptions)
|
||||
require.Equal(t, tt.Expect, o.sshOptions)
|
||||
})
|
||||
}
|
||||
|
||||
+12
-31
@@ -732,7 +732,6 @@ func (r *RootCmd) scaletestCreateWorkspaces() *serpent.Command {
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create tracer provider: %w", err)
|
||||
}
|
||||
client.Trace = tracingFlags.tracePropagate
|
||||
defer func() {
|
||||
// Allow time for traces to flush even if command context is
|
||||
// canceled. This is a no-op if tracing is not enabled.
|
||||
@@ -1080,7 +1079,6 @@ func (r *RootCmd) scaletestWorkspaceUpdates() *serpent.Command {
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create tracer provider: %w", err)
|
||||
}
|
||||
client.Trace = tracingFlags.tracePropagate
|
||||
tracer := tracerProvider.Tracer(scaletestTracerName)
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
@@ -1339,7 +1337,6 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create tracer provider: %w", err)
|
||||
}
|
||||
client.Trace = tracingFlags.tracePropagate
|
||||
defer func() {
|
||||
// Allow time for traces to flush even if command context is
|
||||
// canceled. This is a no-op if tracing is not enabled.
|
||||
@@ -1404,9 +1401,6 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {
|
||||
// Setup our workspace agent connection.
|
||||
config := workspacetraffic.Config{
|
||||
AgentID: agent.ID,
|
||||
WorkspaceID: ws.ID,
|
||||
WorkspaceName: ws.Name,
|
||||
AgentName: agent.Name,
|
||||
BytesPerTick: bytesPerTick,
|
||||
Duration: strategy.timeout,
|
||||
TickInterval: tickInterval,
|
||||
@@ -1446,35 +1440,24 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Running load test...")
|
||||
testCtx, testCancel := strategy.toContext(ctx)
|
||||
defer testCancel()
|
||||
runErr := th.Run(testCtx)
|
||||
|
||||
res := th.Results()
|
||||
|
||||
// Write full results to the configured output destination
|
||||
// (default: text to stdout via --output flag).
|
||||
// for _, o := range outputs {
|
||||
_ = outputs
|
||||
// if writeErr := o.write(res, os.Stdout); writeErr != nil {
|
||||
// _, _ = fmt.Fprintf(os.Stderr, "Failed to write output %q to %q: %v\n", o.format, o.path, writeErr)
|
||||
// }
|
||||
// }
|
||||
|
||||
// Always write a summary to stderr for visibility in
|
||||
// container logs. Full output goes to --output above.
|
||||
// Limit to 10 failures to avoid exceeding kubelet log
|
||||
// rotation limits.
|
||||
res.PrintSummary(os.Stderr, 10)
|
||||
|
||||
if runErr != nil {
|
||||
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", runErr)
|
||||
err = th.Run(testCtx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
|
||||
}
|
||||
|
||||
// Check for interrupt after printing results so we always
|
||||
// have visibility into what happened.
|
||||
// If the command was interrupted, skip stats.
|
||||
if notifyCtx.Err() != nil {
|
||||
return notifyCtx.Err()
|
||||
}
|
||||
|
||||
res := th.Results()
|
||||
for _, o := range outputs {
|
||||
err = o.write(res, inv.Stdout)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
|
||||
}
|
||||
}
|
||||
|
||||
if res.TotalFail > 0 {
|
||||
return xerrors.New("load test failed, see above for more details")
|
||||
}
|
||||
@@ -1580,7 +1563,6 @@ func (r *RootCmd) scaletestDashboard() *serpent.Command {
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create tracer provider: %w", err)
|
||||
}
|
||||
client.Trace = tracingFlags.tracePropagate
|
||||
tracer := tracerProvider.Tracer(scaletestTracerName)
|
||||
outputs, err := output.parse()
|
||||
if err != nil {
|
||||
@@ -1818,7 +1800,6 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create tracer provider: %w", err)
|
||||
}
|
||||
client.Trace = tracingFlags.tracePropagate
|
||||
tracer := tracerProvider.Tracer(scaletestTracerName)
|
||||
|
||||
setupBarrier := new(sync.WaitGroup)
|
||||
|
||||
+3
-3
@@ -24,7 +24,7 @@ import (
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -2291,7 +2291,7 @@ func startBuiltinPostgres(ctx context.Context, cfg config.Root, logger slog.Logg
|
||||
|
||||
ep := embeddedpostgres.NewDatabase(
|
||||
embeddedpostgres.DefaultConfig().
|
||||
Version(embeddedpostgres.V13).
|
||||
Version(embeddedpostgres.V16).
|
||||
BinariesPath(filepath.Join(cfg.PostgresPath(), "bin")).
|
||||
// Default BinaryRepositoryURL repo1.maven.org is flaky.
|
||||
BinaryRepositoryURL("https://repo.maven.apache.org/maven2").
|
||||
@@ -2825,7 +2825,7 @@ func ReadExternalAuthProvidersFromEnv(environ []string) ([]codersdk.ExternalAuth
|
||||
// parsing of `GITAUTH` environment variables.
|
||||
func parseExternalAuthProvidersFromEnv(prefix string, environ []string) ([]codersdk.ExternalAuthConfig, error) {
|
||||
// The index numbers must be in-order.
|
||||
sort.Strings(environ)
|
||||
slices.Sort(environ)
|
||||
|
||||
var providers []codersdk.ExternalAuthConfig
|
||||
for _, v := range serpent.ParseEnviron(environ, prefix) {
|
||||
|
||||
+3
-3
@@ -7,7 +7,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"slices"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/xerrors"
|
||||
@@ -31,7 +31,7 @@ func (*RootCmd) templateInit() *serpent.Command {
|
||||
for _, ex := range exampleList {
|
||||
templateIDs = append(templateIDs, ex.ID)
|
||||
}
|
||||
sort.Strings(templateIDs)
|
||||
slices.Sort(templateIDs)
|
||||
cmd := &serpent.Command{
|
||||
Use: "init [directory]",
|
||||
Short: "Get started with a templated template.",
|
||||
@@ -50,7 +50,7 @@ func (*RootCmd) templateInit() *serpent.Command {
|
||||
optsToID[name] = example.ID
|
||||
}
|
||||
opts := maps.Keys(optsToID)
|
||||
sort.Strings(opts)
|
||||
slices.Sort(opts)
|
||||
_, _ = fmt.Fprintln(
|
||||
inv.Stdout,
|
||||
pretty.Sprint(
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -47,7 +47,7 @@ func TestTemplateList(t *testing.T) {
|
||||
|
||||
// expect that templates are listed alphabetically
|
||||
templatesList := []string{firstTemplate.Name, secondTemplate.Name}
|
||||
sort.Strings(templatesList)
|
||||
slices.Sort(templatesList)
|
||||
|
||||
require.NoError(t, <-errC)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ USAGE:
|
||||
List all organization members
|
||||
|
||||
OPTIONS:
|
||||
-c, --column [username|name|user id|organization id|created at|updated at|organization roles] (default: username,organization roles)
|
||||
-c, --column [username|name|last seen at|user created at|user updated at|user id|organization id|created at|updated at|organization roles] (default: username,organization roles)
|
||||
Columns to display in table output.
|
||||
|
||||
-o, --output table|json (default: table)
|
||||
|
||||
-12
@@ -195,18 +195,6 @@ autobuildPollInterval: 1m0s
|
||||
# Interval to poll for hung and pending jobs and automatically terminate them.
|
||||
# (default: 1m0s, type: duration)
|
||||
jobHangDetectorInterval: 1m0s
|
||||
# Number of querier workers for the PG coordinator. 0 uses the default.
|
||||
# (default: 0, type: int)
|
||||
tailnetQuerierWorkers: 0
|
||||
# Number of binder workers for the PG coordinator. 0 uses the default.
|
||||
# (default: 0, type: int)
|
||||
tailnetBinderWorkers: 0
|
||||
# Number of tunneler workers for the PG coordinator. 0 uses the default.
|
||||
# (default: 0, type: int)
|
||||
tailnetTunnelerWorkers: 0
|
||||
# Number of handshaker workers for the PG coordinator. 0 uses the default.
|
||||
# (default: 0, type: int)
|
||||
tailnetHandshakerWorkers: 0
|
||||
introspection:
|
||||
statsCollection:
|
||||
usageStats:
|
||||
|
||||
+2
-3
@@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -194,7 +193,7 @@ func joinScopes(scopes []codersdk.APIKeyScope) string {
|
||||
return ""
|
||||
}
|
||||
vals := slice.ToStrings(scopes)
|
||||
sort.Strings(vals)
|
||||
slices.Sort(vals)
|
||||
return strings.Join(vals, ", ")
|
||||
}
|
||||
|
||||
@@ -206,7 +205,7 @@ func joinAllowList(entries []codersdk.APIAllowListTarget) string {
|
||||
for i, entry := range entries {
|
||||
vals[i] = entry.String()
|
||||
}
|
||||
sort.Strings(vals)
|
||||
slices.Sort(vals)
|
||||
return strings.Join(vals, ", ")
|
||||
}
|
||||
|
||||
|
||||
@@ -1,263 +0,0 @@
|
||||
package agentconnectionbatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultBatchSize is the maximum number of agent connection updates
|
||||
// to batch before forcing a flush. With one entry per agent, this
|
||||
// accommodates 500 concurrently connected agents per batch.
|
||||
defaultBatchSize = 500
|
||||
|
||||
// defaultChannelBufferMultiplier is the multiplier for the channel
|
||||
// buffer size relative to the batch size. A 5x multiplier provides
|
||||
// significant headroom for bursts while the batch is being flushed.
|
||||
defaultChannelBufferMultiplier = 5
|
||||
|
||||
// defaultFlushInterval is how frequently to flush batched connection
|
||||
// updates to the database. 5 seconds provides a good balance between
|
||||
// reducing database load and keeping connection state reasonably
|
||||
// current.
|
||||
defaultFlushInterval = 5 * time.Second
|
||||
|
||||
// finalFlushTimeout is the timeout for the final flush when the
|
||||
// batcher is shutting down.
|
||||
finalFlushTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
// Update represents a single agent connection state update to be batched.
|
||||
type Update struct {
|
||||
ID uuid.UUID
|
||||
FirstConnectedAt sql.NullTime
|
||||
LastConnectedAt sql.NullTime
|
||||
LastConnectedReplicaID uuid.NullUUID
|
||||
DisconnectedAt sql.NullTime
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// Batcher accumulates agent connection updates and periodically flushes
|
||||
// them to the database in a single batch query. This reduces per-heartbeat
|
||||
// database write pressure from O(n) queries to O(1).
|
||||
type Batcher struct {
|
||||
store database.Store
|
||||
log slog.Logger
|
||||
|
||||
updateCh chan Update
|
||||
batch map[uuid.UUID]Update
|
||||
maxBatchSize int
|
||||
|
||||
clock quartz.Clock
|
||||
timer *quartz.Timer
|
||||
interval time.Duration
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// Option is a functional option for configuring a Batcher.
|
||||
type Option func(b *Batcher)
|
||||
|
||||
// WithBatchSize sets the maximum number of updates to accumulate before
|
||||
// forcing a flush.
|
||||
func WithBatchSize(size int) Option {
|
||||
return func(b *Batcher) {
|
||||
b.maxBatchSize = size
|
||||
}
|
||||
}
|
||||
|
||||
// WithInterval sets how frequently the batcher flushes to the database.
|
||||
func WithInterval(d time.Duration) Option {
|
||||
return func(b *Batcher) {
|
||||
b.interval = d
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger sets the logger for the batcher.
|
||||
func WithLogger(log slog.Logger) Option {
|
||||
return func(b *Batcher) {
|
||||
b.log = log
|
||||
}
|
||||
}
|
||||
|
||||
// WithClock sets the clock for the batcher, useful for testing.
|
||||
func WithClock(clock quartz.Clock) Option {
|
||||
return func(b *Batcher) {
|
||||
b.clock = clock
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new Batcher and starts its background processing loop.
|
||||
// The provided context controls the lifetime of the batcher.
|
||||
func New(ctx context.Context, store database.Store, opts ...Option) *Batcher {
|
||||
b := &Batcher{
|
||||
store: store,
|
||||
done: make(chan struct{}),
|
||||
log: slog.Logger{},
|
||||
clock: quartz.NewReal(),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(b)
|
||||
}
|
||||
|
||||
if b.interval == 0 {
|
||||
b.interval = defaultFlushInterval
|
||||
}
|
||||
if b.maxBatchSize == 0 {
|
||||
b.maxBatchSize = defaultBatchSize
|
||||
}
|
||||
|
||||
b.timer = b.clock.NewTimer(b.interval)
|
||||
channelSize := b.maxBatchSize * defaultChannelBufferMultiplier
|
||||
b.updateCh = make(chan Update, channelSize)
|
||||
b.batch = make(map[uuid.UUID]Update)
|
||||
|
||||
b.ctx, b.cancel = context.WithCancel(ctx)
|
||||
go func() {
|
||||
b.run(b.ctx)
|
||||
close(b.done)
|
||||
}()
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Close cancels the batcher context and waits for the final flush to
|
||||
// complete.
|
||||
func (b *Batcher) Close() {
|
||||
b.cancel()
|
||||
if b.timer != nil {
|
||||
b.timer.Stop()
|
||||
}
|
||||
<-b.done
|
||||
}
|
||||
|
||||
// Add enqueues an agent connection update for batching. If the internal
|
||||
// channel is full, the update is dropped and a warning is logged.
|
||||
func (b *Batcher) Add(u Update) {
|
||||
select {
|
||||
case b.updateCh <- u:
|
||||
default:
|
||||
b.log.Warn(context.Background(), "connection batcher channel full, dropping update",
|
||||
slog.F("agent_id", u.ID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Batcher) processUpdate(u Update) {
|
||||
existing, exists := b.batch[u.ID]
|
||||
if exists && u.UpdatedAt.Before(existing.UpdatedAt) {
|
||||
return
|
||||
}
|
||||
b.batch[u.ID] = u
|
||||
}
|
||||
|
||||
func (b *Batcher) run(ctx context.Context) {
|
||||
//nolint:gocritic // System-level batch operation for agent connections.
|
||||
authCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
for {
|
||||
select {
|
||||
case u := <-b.updateCh:
|
||||
b.processUpdate(u)
|
||||
|
||||
if len(b.batch) >= b.maxBatchSize {
|
||||
b.flush(authCtx)
|
||||
b.timer.Reset(b.interval, "connectionBatcher", "capacityFlush")
|
||||
}
|
||||
|
||||
case <-b.timer.C:
|
||||
b.flush(authCtx)
|
||||
b.timer.Reset(b.interval, "connectionBatcher", "scheduledFlush")
|
||||
|
||||
case <-ctx.Done():
|
||||
b.log.Debug(ctx, "context done, flushing before exit")
|
||||
|
||||
ctxTimeout, cancel := context.WithTimeout(context.Background(), finalFlushTimeout)
|
||||
defer cancel() //nolint:revive // Returning after this.
|
||||
|
||||
//nolint:gocritic // System-level batch operation for agent connections.
|
||||
b.flush(dbauthz.AsSystemRestricted(ctxTimeout))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Batcher) flush(ctx context.Context) {
|
||||
count := len(b.batch)
|
||||
if count == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
b.log.Debug(ctx, "flushing connection batch", slog.F("count", count))
|
||||
|
||||
var (
|
||||
ids = make([]uuid.UUID, 0, count)
|
||||
firstConnectedAt = make([]time.Time, 0, count)
|
||||
lastConnectedAt = make([]time.Time, 0, count)
|
||||
lastConnectedReplicaID = make([]uuid.UUID, 0, count)
|
||||
disconnectedAt = make([]time.Time, 0, count)
|
||||
updatedAt = make([]time.Time, 0, count)
|
||||
)
|
||||
|
||||
for _, u := range b.batch {
|
||||
ids = append(ids, u.ID)
|
||||
firstConnectedAt = append(firstConnectedAt, nullTimeToTime(u.FirstConnectedAt))
|
||||
lastConnectedAt = append(lastConnectedAt, nullTimeToTime(u.LastConnectedAt))
|
||||
lastConnectedReplicaID = append(lastConnectedReplicaID, nullUUIDToUUID(u.LastConnectedReplicaID))
|
||||
disconnectedAt = append(disconnectedAt, nullTimeToTime(u.DisconnectedAt))
|
||||
updatedAt = append(updatedAt, u.UpdatedAt)
|
||||
}
|
||||
|
||||
// Clear batch before the DB call. Losing a batch of heartbeat
|
||||
// timestamps is acceptable; the next heartbeat will update them.
|
||||
b.batch = make(map[uuid.UUID]Update)
|
||||
|
||||
err := b.store.BatchUpdateWorkspaceAgentConnections(ctx, database.BatchUpdateWorkspaceAgentConnectionsParams{
|
||||
ID: ids,
|
||||
FirstConnectedAt: firstConnectedAt,
|
||||
LastConnectedAt: lastConnectedAt,
|
||||
LastConnectedReplicaID: lastConnectedReplicaID,
|
||||
DisconnectedAt: disconnectedAt,
|
||||
UpdatedAt: updatedAt,
|
||||
})
|
||||
if err != nil {
|
||||
if database.IsQueryCanceledError(err) {
|
||||
b.log.Debug(ctx, "query canceled, skipping connection batch update")
|
||||
return
|
||||
}
|
||||
b.log.Error(ctx, "failed to batch update agent connections", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
b.log.Debug(ctx, "connection batch flush complete", slog.F("count", count))
|
||||
}
|
||||
|
||||
// nullTimeToTime converts a sql.NullTime to a time.Time. When the
|
||||
// NullTime is not valid, the zero time is returned which PostgreSQL
|
||||
// will store as the epoch. The batch query uses unnest over plain
|
||||
// time arrays, so we cannot pass NULL directly.
|
||||
func nullTimeToTime(nt sql.NullTime) time.Time {
|
||||
if nt.Valid {
|
||||
return nt.Time
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// nullUUIDToUUID converts a uuid.NullUUID to a uuid.UUID. When the
|
||||
// NullUUID is not valid, uuid.Nil is returned.
|
||||
func nullUUIDToUUID(nu uuid.NullUUID) uuid.UUID {
|
||||
if nu.Valid {
|
||||
return nu.UUID
|
||||
}
|
||||
return uuid.Nil
|
||||
}
|
||||
+1
-1
@@ -773,7 +773,7 @@ func (api *API) taskSend(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if statusResp.Status != agentapisdk.StatusStable {
|
||||
return httperror.NewResponseError(http.StatusBadGateway, codersdk.Response{
|
||||
return httperror.NewResponseError(http.StatusConflict, codersdk.Response{
|
||||
Message: "Task app is not ready to accept input.",
|
||||
Detail: fmt.Sprintf("Status: %s", statusResp.Status),
|
||||
})
|
||||
|
||||
@@ -789,6 +789,11 @@ func TestTasks(t *testing.T) {
|
||||
})
|
||||
require.Error(t, err, "wanted error due to bad status")
|
||||
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
require.Contains(t, sdkErr.Message, "not ready to accept input")
|
||||
|
||||
statusResponse = agentapisdk.StatusStable
|
||||
|
||||
//nolint:tparallel // Not intended to run in parallel.
|
||||
|
||||
Generated
+162
-35
@@ -163,6 +163,57 @@ const docTemplate = `{
|
||||
]
|
||||
}
|
||||
},
|
||||
"/aibridge/sessions": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"AI Bridge"
|
||||
],
|
||||
"summary": "List AI Bridge sessions",
|
||||
"operationId": "list-ai-bridge-sessions",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Search query in the format ` + "`" + `key:value` + "`" + `. Available keys are: initiator, provider, model, client, session_id, started_after, started_before.",
|
||||
"name": "q",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Page limit",
|
||||
"name": "limit",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Cursor pagination after session ID (cannot be used with offset)",
|
||||
"name": "after_session_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Offset pagination (cannot be used with after_session_id)",
|
||||
"name": "offset",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeListSessionsResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/appearance": {
|
||||
"get": {
|
||||
"produces": [
|
||||
@@ -4082,6 +4133,19 @@ const docTemplate = `{
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Member search query",
|
||||
"name": "q",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "After ID",
|
||||
"name": "after_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Page limit, if 0 returns all members",
|
||||
@@ -7958,29 +8022,6 @@ const docTemplate = `{
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/me/session/token-to-cookie": {
|
||||
"post": {
|
||||
"description": "Converts the current session token into a Set-Cookie response.\nThis is used by embedded iframes (e.g. VS Code chat) that\nreceive a session token out-of-band via postMessage but need\ncookie-based auth for WebSocket connections.",
|
||||
"tags": [
|
||||
"Authorization"
|
||||
],
|
||||
"summary": "Set session token cookie",
|
||||
"operationId": "set-session-token-cookie",
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/users/oauth2/github/callback": {
|
||||
"get": {
|
||||
"tags": [
|
||||
@@ -12788,6 +12829,20 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeListSessionsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"sessions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSession"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeOpenAIConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -12840,6 +12895,64 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSession": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"client": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"initiator": {
|
||||
"$ref": "#/definitions/codersdk.MinimalUser"
|
||||
},
|
||||
"last_prompt": {
|
||||
"type": "string"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {}
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"started_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"threads": {
|
||||
"type": "integer"
|
||||
},
|
||||
"token_usage_summary": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSessionTokenUsageSummary"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSessionTokenUsageSummary": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"output_tokens": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeTokenUsage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -15273,18 +15386,6 @@ const docTemplate = `{
|
||||
"swagger": {
|
||||
"$ref": "#/definitions/codersdk.SwaggerConfig"
|
||||
},
|
||||
"tailnet_binder_workers": {
|
||||
"type": "integer"
|
||||
},
|
||||
"tailnet_handshaker_workers": {
|
||||
"type": "integer"
|
||||
},
|
||||
"tailnet_querier_workers": {
|
||||
"type": "integer"
|
||||
},
|
||||
"tailnet_tunneler_workers": {
|
||||
"type": "integer"
|
||||
},
|
||||
"telemetry": {
|
||||
"$ref": "#/definitions/codersdk.TelemetryConfig"
|
||||
},
|
||||
@@ -17325,6 +17426,13 @@ const docTemplate = `{
|
||||
"$ref": "#/definitions/codersdk.SlimRole"
|
||||
}
|
||||
},
|
||||
"last_seen_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"login_type": {
|
||||
"$ref": "#/definitions/codersdk.LoginType"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
@@ -17338,14 +17446,33 @@ const docTemplate = `{
|
||||
"$ref": "#/definitions/codersdk.SlimRole"
|
||||
}
|
||||
},
|
||||
"status": {
|
||||
"enum": [
|
||||
"active",
|
||||
"suspended"
|
||||
],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.UserStatus"
|
||||
}
|
||||
]
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"user_created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"user_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"user_updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"username": {
|
||||
"type": "string"
|
||||
}
|
||||
|
||||
Generated
+155
-33
@@ -136,6 +136,53 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"/aibridge/sessions": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["AI Bridge"],
|
||||
"summary": "List AI Bridge sessions",
|
||||
"operationId": "list-ai-bridge-sessions",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Search query in the format `key:value`. Available keys are: initiator, provider, model, client, session_id, started_after, started_before.",
|
||||
"name": "q",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Page limit",
|
||||
"name": "limit",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Cursor pagination after session ID (cannot be used with offset)",
|
||||
"name": "after_session_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Offset pagination (cannot be used with after_session_id)",
|
||||
"name": "offset",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeListSessionsResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/appearance": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
@@ -3603,6 +3650,19 @@
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Member search query",
|
||||
"name": "q",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "After ID",
|
||||
"name": "after_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Page limit, if 0 returns all members",
|
||||
@@ -7051,27 +7111,6 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/me/session/token-to-cookie": {
|
||||
"post": {
|
||||
"description": "Converts the current session token into a Set-Cookie response.\nThis is used by embedded iframes (e.g. VS Code chat) that\nreceive a session token out-of-band via postMessage but need\ncookie-based auth for WebSocket connections.",
|
||||
"tags": ["Authorization"],
|
||||
"summary": "Set session token cookie",
|
||||
"operationId": "set-session-token-cookie",
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/users/oauth2/github/callback": {
|
||||
"get": {
|
||||
"tags": ["Users"],
|
||||
@@ -11376,6 +11415,20 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeListSessionsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"sessions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSession"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeOpenAIConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -11428,6 +11481,64 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSession": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"client": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"initiator": {
|
||||
"$ref": "#/definitions/codersdk.MinimalUser"
|
||||
},
|
||||
"last_prompt": {
|
||||
"type": "string"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {}
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"started_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"threads": {
|
||||
"type": "integer"
|
||||
},
|
||||
"token_usage_summary": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSessionTokenUsageSummary"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSessionTokenUsageSummary": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"output_tokens": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeTokenUsage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -13780,18 +13891,6 @@
|
||||
"swagger": {
|
||||
"$ref": "#/definitions/codersdk.SwaggerConfig"
|
||||
},
|
||||
"tailnet_binder_workers": {
|
||||
"type": "integer"
|
||||
},
|
||||
"tailnet_handshaker_workers": {
|
||||
"type": "integer"
|
||||
},
|
||||
"tailnet_querier_workers": {
|
||||
"type": "integer"
|
||||
},
|
||||
"tailnet_tunneler_workers": {
|
||||
"type": "integer"
|
||||
},
|
||||
"telemetry": {
|
||||
"$ref": "#/definitions/codersdk.TelemetryConfig"
|
||||
},
|
||||
@@ -15752,6 +15851,13 @@
|
||||
"$ref": "#/definitions/codersdk.SlimRole"
|
||||
}
|
||||
},
|
||||
"last_seen_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"login_type": {
|
||||
"$ref": "#/definitions/codersdk.LoginType"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
@@ -15765,14 +15871,30 @@
|
||||
"$ref": "#/definitions/codersdk.SlimRole"
|
||||
}
|
||||
},
|
||||
"status": {
|
||||
"enum": ["active", "suspended"],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.UserStatus"
|
||||
}
|
||||
]
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"user_created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"user_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"user_updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"username": {
|
||||
"type": "string"
|
||||
}
|
||||
|
||||
@@ -1,186 +0,0 @@
|
||||
package chattool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestComputerUseTool_Info(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, nil, quartz.NewReal())
|
||||
info := tool.Info()
|
||||
assert.Equal(t, "computer", info.Name)
|
||||
assert.NotEmpty(t, info.Description)
|
||||
}
|
||||
|
||||
func TestComputerUseProviderTool(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
def := chattool.ComputerUseProviderTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight)
|
||||
pdt, ok := def.(fantasy.ProviderDefinedTool)
|
||||
require.True(t, ok, "ComputerUseProviderTool should return a ProviderDefinedTool")
|
||||
assert.Contains(t, pdt.ID, "computer")
|
||||
assert.Equal(t, "computer", pdt.Name)
|
||||
// Verify display dimensions are passed through.
|
||||
assert.Equal(t, int64(workspacesdk.DesktopDisplayWidth), pdt.Args["display_width_px"])
|
||||
assert.Equal(t, int64(workspacesdk.DesktopDisplayHeight), pdt.Args["display_height_px"])
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).Return(workspacesdk.DesktopActionResponse{
|
||||
Output: "screenshot",
|
||||
ScreenshotData: "base64png",
|
||||
ScreenshotWidth: 1024,
|
||||
ScreenshotHeight: 768,
|
||||
}, nil)
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return mockConn, nil
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-1",
|
||||
Name: "computer",
|
||||
Input: `{"action":"screenshot"}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image", resp.Type)
|
||||
assert.Equal(t, "image/png", resp.MediaType)
|
||||
assert.Equal(t, []byte("base64png"), resp.Data)
|
||||
assert.False(t, resp.IsError)
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_LeftClick(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
// Expect the action call first.
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).Return(workspacesdk.DesktopActionResponse{
|
||||
Output: "left_click performed",
|
||||
}, nil)
|
||||
|
||||
// Then expect a screenshot (auto-screenshot after action).
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).Return(workspacesdk.DesktopActionResponse{
|
||||
Output: "screenshot",
|
||||
ScreenshotData: "after-click",
|
||||
ScreenshotWidth: 1024,
|
||||
ScreenshotHeight: 768,
|
||||
}, nil)
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return mockConn, nil
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-2",
|
||||
Name: "computer",
|
||||
Input: `{"action":"left_click","coordinate":[100,200]}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image", resp.Type)
|
||||
assert.Equal(t, []byte("after-click"), resp.Data)
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_Wait(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
// Expect a screenshot after the wait completes.
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).Return(workspacesdk.DesktopActionResponse{
|
||||
Output: "screenshot",
|
||||
ScreenshotData: "after-wait",
|
||||
ScreenshotWidth: 1024,
|
||||
ScreenshotHeight: 768,
|
||||
}, nil)
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return mockConn, nil
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-3",
|
||||
Name: "computer",
|
||||
Input: `{"action":"wait","duration":10}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image", resp.Type)
|
||||
assert.Equal(t, "image/png", resp.MediaType)
|
||||
assert.Equal(t, []byte("after-wait"), resp.Data)
|
||||
assert.False(t, resp.IsError)
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_ConnError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return nil, xerrors.New("workspace not available")
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-4",
|
||||
Name: "computer",
|
||||
Input: `{"action":"screenshot"}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsError)
|
||||
assert.Contains(t, resp.Content, "workspace not available")
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_InvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return nil, xerrors.New("should not be called")
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-5",
|
||||
Name: "computer",
|
||||
Input: `{invalid json`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsError)
|
||||
assert.Contains(t, resp.Content, "invalid computer use input")
|
||||
}
|
||||
+49
-62
@@ -45,14 +45,12 @@ import (
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
|
||||
"github.com/coder/coder/v2/coderd/agentconnectionbatcher"
|
||||
"github.com/coder/coder/v2/coderd/aiseats"
|
||||
_ "github.com/coder/coder/v2/coderd/apidoc" // Used for swagger docs.
|
||||
"github.com/coder/coder/v2/coderd/appearance"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/awsidentity"
|
||||
"github.com/coder/coder/v2/coderd/boundaryusage"
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/connectionlog"
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -64,7 +62,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/files"
|
||||
"github.com/coder/coder/v2/coderd/gitsshkey"
|
||||
"github.com/coder/coder/v2/coderd/gitsync"
|
||||
"github.com/coder/coder/v2/coderd/healthcheck"
|
||||
"github.com/coder/coder/v2/coderd/healthcheck/derphealth"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
@@ -95,6 +92,8 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/coderd/wsbuilder"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd"
|
||||
"github.com/coder/coder/v2/coderd/x/gitsync"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
"github.com/coder/coder/v2/codersdk/healthsdk"
|
||||
@@ -251,8 +250,7 @@ type Options struct {
|
||||
UpdateAgentMetrics func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric)
|
||||
StatsBatcher workspacestats.Batcher
|
||||
|
||||
MetadataBatcherOptions []metadatabatcher.Option
|
||||
ConnectionBatcherOptions []agentconnectionbatcher.Option
|
||||
MetadataBatcherOptions []metadatabatcher.Option
|
||||
|
||||
ProvisionerdServerMetrics *provisionerdserver.Metrics
|
||||
WorkspaceBuilderMetrics *wsbuilder.Metrics
|
||||
@@ -769,43 +767,45 @@ func New(options *Options) *API {
|
||||
}
|
||||
api.agentProvider = stn
|
||||
|
||||
maxChatsPerAcquire := options.DeploymentValues.AI.Chat.AcquireBatchSize.Value()
|
||||
if maxChatsPerAcquire > math.MaxInt32 {
|
||||
maxChatsPerAcquire = math.MaxInt32
|
||||
}
|
||||
if maxChatsPerAcquire < math.MinInt32 {
|
||||
maxChatsPerAcquire = math.MinInt32
|
||||
}
|
||||
{ // Experimental: agents — chat daemon and git sync worker initialization.
|
||||
maxChatsPerAcquire := options.DeploymentValues.AI.Chat.AcquireBatchSize.Value()
|
||||
if maxChatsPerAcquire > math.MaxInt32 {
|
||||
maxChatsPerAcquire = math.MaxInt32
|
||||
}
|
||||
if maxChatsPerAcquire < math.MinInt32 {
|
||||
maxChatsPerAcquire = math.MinInt32
|
||||
}
|
||||
|
||||
api.chatDaemon = chatd.New(chatd.Config{
|
||||
Logger: options.Logger.Named("chatd"),
|
||||
Database: options.Database,
|
||||
ReplicaID: api.ID,
|
||||
SubscribeFn: options.ChatSubscribeFn,
|
||||
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
|
||||
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
StartWorkspace: api.chatStartWorkspace,
|
||||
Pubsub: options.Pubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
UsageTracker: options.WorkspaceUsageTracker,
|
||||
})
|
||||
gitSyncLogger := options.Logger.Named("gitsync")
|
||||
refresher := gitsync.NewRefresher(
|
||||
api.resolveGitProvider,
|
||||
api.resolveChatGitAccessToken,
|
||||
gitSyncLogger.Named("refresher"),
|
||||
quartz.NewReal(),
|
||||
)
|
||||
api.gitSyncWorker = gitsync.NewWorker(options.Database,
|
||||
refresher,
|
||||
api.chatDaemon.PublishDiffStatusChange,
|
||||
quartz.NewReal(),
|
||||
gitSyncLogger,
|
||||
)
|
||||
// nolint:gocritic // chat diff worker needs to be able to CRUD chats.
|
||||
go api.gitSyncWorker.Start(dbauthz.AsChatd(api.ctx))
|
||||
api.chatDaemon = chatd.New(chatd.Config{
|
||||
Logger: options.Logger.Named("chatd"),
|
||||
Database: options.Database,
|
||||
ReplicaID: api.ID,
|
||||
SubscribeFn: options.ChatSubscribeFn,
|
||||
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
|
||||
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
StartWorkspace: api.chatStartWorkspace,
|
||||
Pubsub: options.Pubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
UsageTracker: options.WorkspaceUsageTracker,
|
||||
})
|
||||
gitSyncLogger := options.Logger.Named("gitsync")
|
||||
refresher := gitsync.NewRefresher(
|
||||
api.resolveGitProvider,
|
||||
api.resolveChatGitAccessToken,
|
||||
gitSyncLogger.Named("refresher"),
|
||||
quartz.NewReal(),
|
||||
)
|
||||
api.gitSyncWorker = gitsync.NewWorker(options.Database,
|
||||
refresher,
|
||||
api.chatDaemon.PublishDiffStatusChange,
|
||||
quartz.NewReal(),
|
||||
gitSyncLogger,
|
||||
)
|
||||
// nolint:gocritic // chat diff worker needs to be able to CRUD chats.
|
||||
go api.gitSyncWorker.Start(dbauthz.AsChatd(api.ctx))
|
||||
}
|
||||
if options.DeploymentValues.Prometheus.Enable {
|
||||
options.PrometheusRegistry.MustRegister(stn)
|
||||
api.lifecycleMetrics = agentapi.NewLifecycleMetrics(options.PrometheusRegistry)
|
||||
@@ -860,17 +860,6 @@ func New(options *Options) *API {
|
||||
api.Logger.Fatal(context.Background(), "failed to initialize metadata batcher", slog.Error(err))
|
||||
}
|
||||
|
||||
// Initialize the connection batcher for batching agent heartbeat writes.
|
||||
connBatcherOpts := []agentconnectionbatcher.Option{
|
||||
agentconnectionbatcher.WithLogger(options.Logger.Named("connection_batcher")),
|
||||
}
|
||||
connBatcherOpts = append(connBatcherOpts, options.ConnectionBatcherOptions...)
|
||||
api.connectionBatcher = agentconnectionbatcher.New(
|
||||
api.ctx,
|
||||
options.Database,
|
||||
connBatcherOpts...,
|
||||
)
|
||||
|
||||
workspaceAppsLogger := options.Logger.Named("workspaceapps")
|
||||
if options.WorkspaceAppsStatsCollectorOptions.Logger == nil {
|
||||
named := workspaceAppsLogger.Named("stats_collector")
|
||||
@@ -1159,6 +1148,7 @@ func New(options *Options) *API {
|
||||
})
|
||||
})
|
||||
})
|
||||
// Experimental(agents): chat API routes gated by ExperimentAgents.
|
||||
r.Route("/chats", func(r chi.Router) {
|
||||
r.Use(
|
||||
apiKeyMiddleware,
|
||||
@@ -1190,6 +1180,9 @@ func New(options *Options) *API {
|
||||
r.Put("/desktop-enabled", api.putChatDesktopEnabled)
|
||||
r.Get("/user-prompt", api.getUserChatCustomPrompt)
|
||||
r.Put("/user-prompt", api.putUserChatCustomPrompt)
|
||||
r.Get("/user-compaction-thresholds", api.getUserChatCompactionThresholds)
|
||||
r.Put("/user-compaction-thresholds/{modelConfig}", api.putUserChatCompactionThreshold)
|
||||
r.Delete("/user-compaction-thresholds/{modelConfig}", api.deleteUserChatCompactionThreshold)
|
||||
r.Get("/workspace-ttl", api.getChatWorkspaceTTL)
|
||||
r.Put("/workspace-ttl", api.putChatWorkspaceTTL)
|
||||
})
|
||||
@@ -1530,7 +1523,6 @@ func New(options *Options) *API {
|
||||
r.Post("/", api.postUser)
|
||||
r.Get("/", api.users)
|
||||
r.Post("/logout", api.postLogout)
|
||||
r.Post("/me/session/token-to-cookie", api.postSessionTokenCookie)
|
||||
r.Get("/oidc-claims", api.userOIDCClaims)
|
||||
// These routes query information about site wide roles.
|
||||
r.Route("/roles", func(r chi.Router) {
|
||||
@@ -2093,21 +2085,19 @@ type API struct {
|
||||
healthCheckProgress healthcheck.Progress
|
||||
|
||||
statsReporter *workspacestats.Reporter
|
||||
metadataBatcher *metadatabatcher.Batcher
|
||||
connectionBatcher *agentconnectionbatcher.Batcher
|
||||
metadataBatcher *metadatabatcher.Batcher
|
||||
lifecycleMetrics *agentapi.LifecycleMetrics
|
||||
|
||||
Acquirer *provisionerdserver.Acquirer
|
||||
// dbRolluper rolls up template usage stats from raw agent and app
|
||||
// stats. This is used to provide insights in the WebUI.
|
||||
dbRolluper *dbrollup.Rolluper
|
||||
// chatDaemon handles background processing of pending chats.
|
||||
// Experimental(agents): chatDaemon handles background processing of pending chats.
|
||||
chatDaemon *chatd.Server
|
||||
// Experimental(agents): gitSyncWorker refreshes stale chat diff statuses in the background.
|
||||
gitSyncWorker *gitsync.Worker
|
||||
// AISeatTracker records AI seat usage.
|
||||
AISeatTracker aiseats.SeatTracker
|
||||
// gitSyncWorker refreshes stale chat diff statuses in the
|
||||
// background.
|
||||
gitSyncWorker *gitsync.Worker
|
||||
|
||||
// ProfileCollector abstracts the runtime/pprof and runtime/trace
|
||||
// calls used by the /debug/profile endpoint. Tests override this
|
||||
@@ -2175,9 +2165,6 @@ func (api *API) Close() error {
|
||||
if api.metadataBatcher != nil {
|
||||
api.metadataBatcher.Close()
|
||||
}
|
||||
if api.connectionBatcher != nil {
|
||||
api.connectionBatcher.Close()
|
||||
}
|
||||
_ = api.NetworkTelemetryBatcher.Close()
|
||||
_ = api.OIDCConvertKeyCache.Close()
|
||||
_ = api.AppSigningKeyCache.Close()
|
||||
|
||||
@@ -1,344 +0,0 @@
|
||||
package connectionlogbatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultBatchSize is the maximum number of connection log entries
|
||||
// to batch before forcing a flush.
|
||||
defaultBatchSize = 500
|
||||
|
||||
// defaultChannelBufferMultiplier is the multiplier for the channel
|
||||
// buffer size relative to the batch size. A 5x multiplier provides
|
||||
// significant headroom for bursts while the batch is being flushed.
|
||||
defaultChannelBufferMultiplier = 5
|
||||
|
||||
// defaultFlushInterval is how frequently to flush batched connection
|
||||
// log entries to the database. 1 second keeps audit logs near
|
||||
// real-time.
|
||||
defaultFlushInterval = time.Second
|
||||
|
||||
// finalFlushTimeout is the timeout for the final flush when the
|
||||
// batcher is shutting down.
|
||||
finalFlushTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
// Batcher accumulates connection log upserts and periodically flushes
|
||||
// them to the database in a single batch query. This reduces per-event
|
||||
// database write pressure from O(n) queries to O(1).
|
||||
type Batcher struct {
|
||||
store database.Store
|
||||
log slog.Logger
|
||||
|
||||
itemCh chan database.UpsertConnectionLogParams
|
||||
batch []database.UpsertConnectionLogParams
|
||||
maxBatchSize int
|
||||
|
||||
clock quartz.Clock
|
||||
timer *quartz.Timer
|
||||
interval time.Duration
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// Option is a functional option for configuring a Batcher.
|
||||
type Option func(b *Batcher)
|
||||
|
||||
// WithBatchSize sets the maximum number of entries to accumulate before
|
||||
// forcing a flush.
|
||||
func WithBatchSize(size int) Option {
|
||||
return func(b *Batcher) {
|
||||
b.maxBatchSize = size
|
||||
}
|
||||
}
|
||||
|
||||
// WithInterval sets how frequently the batcher flushes to the database.
|
||||
func WithInterval(d time.Duration) Option {
|
||||
return func(b *Batcher) {
|
||||
b.interval = d
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger sets the logger for the batcher.
|
||||
func WithLogger(log slog.Logger) Option {
|
||||
return func(b *Batcher) {
|
||||
b.log = log
|
||||
}
|
||||
}
|
||||
|
||||
// WithClock sets the clock for the batcher, useful for testing.
|
||||
func WithClock(clock quartz.Clock) Option {
|
||||
return func(b *Batcher) {
|
||||
b.clock = clock
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new Batcher and starts its background processing loop.
|
||||
// The provided context controls the lifetime of the batcher.
|
||||
func New(ctx context.Context, store database.Store, opts ...Option) *Batcher {
|
||||
b := &Batcher{
|
||||
store: store,
|
||||
done: make(chan struct{}),
|
||||
log: slog.Logger{},
|
||||
clock: quartz.NewReal(),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(b)
|
||||
}
|
||||
|
||||
if b.interval == 0 {
|
||||
b.interval = defaultFlushInterval
|
||||
}
|
||||
if b.maxBatchSize == 0 {
|
||||
b.maxBatchSize = defaultBatchSize
|
||||
}
|
||||
|
||||
b.timer = b.clock.NewTimer(b.interval)
|
||||
channelSize := b.maxBatchSize * defaultChannelBufferMultiplier
|
||||
b.itemCh = make(chan database.UpsertConnectionLogParams, channelSize)
|
||||
b.batch = make([]database.UpsertConnectionLogParams, 0, b.maxBatchSize)
|
||||
|
||||
b.ctx, b.cancel = context.WithCancel(ctx)
|
||||
go func() {
|
||||
b.run(b.ctx)
|
||||
close(b.done)
|
||||
}()
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Close cancels the batcher context and waits for the final flush to
|
||||
// complete.
|
||||
func (b *Batcher) Close() {
|
||||
b.cancel()
|
||||
if b.timer != nil {
|
||||
b.timer.Stop()
|
||||
}
|
||||
<-b.done
|
||||
}
|
||||
|
||||
// Add enqueues a connection log upsert for batching. If the internal
|
||||
// channel is full, the entry is dropped and a warning is logged.
|
||||
func (b *Batcher) Add(item database.UpsertConnectionLogParams) {
|
||||
select {
|
||||
case b.itemCh <- item:
|
||||
default:
|
||||
b.log.Warn(context.Background(), "connection log batcher channel full, dropping entry",
|
||||
slog.F("connection_id", item.ConnectionID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Batcher) run(ctx context.Context) {
|
||||
//nolint:gocritic // System-level batch operation for connection logs.
|
||||
authCtx := dbauthz.AsConnectionLogger(ctx)
|
||||
for {
|
||||
select {
|
||||
case item := <-b.itemCh:
|
||||
b.batch = append(b.batch, item)
|
||||
|
||||
if len(b.batch) >= b.maxBatchSize {
|
||||
b.flush(authCtx)
|
||||
b.timer.Reset(b.interval, "connectionLogBatcher", "capacityFlush")
|
||||
}
|
||||
|
||||
case <-b.timer.C:
|
||||
b.flush(authCtx)
|
||||
b.timer.Reset(b.interval, "connectionLogBatcher", "scheduledFlush")
|
||||
|
||||
case <-ctx.Done():
|
||||
b.log.Debug(ctx, "context done, flushing before exit")
|
||||
|
||||
ctxTimeout, cancel := context.WithTimeout(context.Background(), finalFlushTimeout)
|
||||
defer cancel() //nolint:revive // Returning after this.
|
||||
|
||||
//nolint:gocritic // System-level batch operation for connection logs.
|
||||
b.flush(dbauthz.AsConnectionLogger(ctxTimeout))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// conflictKey represents the unique constraint columns used by
|
||||
// the upsert query. Entries sharing the same key cannot appear
|
||||
// in a single INSERT … ON CONFLICT DO UPDATE statement.
|
||||
type conflictKey struct {
|
||||
ConnectionID uuid.UUID
|
||||
WorkspaceID uuid.UUID
|
||||
AgentName string
|
||||
}
|
||||
|
||||
func (b *Batcher) flush(ctx context.Context) {
|
||||
count := len(b.batch)
|
||||
if count == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
b.log.Debug(ctx, "flushing connection log batch", slog.F("count", count))
|
||||
|
||||
// Deduplicate by conflict key so PostgreSQL never sees the
|
||||
// same row twice in one INSERT … ON CONFLICT DO UPDATE.
|
||||
// Entries with a NULL connection_id (web events) are exempt
|
||||
// because NULL != NULL in SQL unique constraints.
|
||||
deduped := make(map[conflictKey]database.UpsertConnectionLogParams, count)
|
||||
var nullConnIDEntries []database.UpsertConnectionLogParams
|
||||
|
||||
for _, item := range b.batch {
|
||||
if !item.ConnectionID.Valid {
|
||||
nullConnIDEntries = append(nullConnIDEntries, item)
|
||||
continue
|
||||
}
|
||||
key := conflictKey{
|
||||
ConnectionID: item.ConnectionID.UUID,
|
||||
WorkspaceID: item.WorkspaceID,
|
||||
AgentName: item.AgentName,
|
||||
}
|
||||
existing, ok := deduped[key]
|
||||
if !ok {
|
||||
deduped[key] = item
|
||||
continue
|
||||
}
|
||||
// Prefer disconnect over connect (superset of info).
|
||||
// If same status, prefer the later event.
|
||||
if item.ConnectionStatus == database.ConnectionStatusDisconnected &&
|
||||
existing.ConnectionStatus != database.ConnectionStatusDisconnected {
|
||||
deduped[key] = item
|
||||
} else if item.Time.After(existing.Time) {
|
||||
deduped[key] = item
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild batch from deduplicated entries.
|
||||
items := make([]database.UpsertConnectionLogParams, 0, len(deduped)+len(nullConnIDEntries))
|
||||
for _, item := range deduped {
|
||||
items = append(items, item)
|
||||
}
|
||||
items = append(items, nullConnIDEntries...)
|
||||
|
||||
dedupedCount := len(items)
|
||||
if dedupedCount < count {
|
||||
b.log.Debug(ctx, "deduplicated connection log batch",
|
||||
slog.F("original", count),
|
||||
slog.F("deduped", dedupedCount),
|
||||
)
|
||||
}
|
||||
|
||||
var (
|
||||
ids = make([]uuid.UUID, 0, dedupedCount)
|
||||
connectTime = make([]time.Time, 0, dedupedCount)
|
||||
organizationID = make([]uuid.UUID, 0, dedupedCount)
|
||||
workspaceOwnerID = make([]uuid.UUID, 0, dedupedCount)
|
||||
workspaceID = make([]uuid.UUID, 0, dedupedCount)
|
||||
workspaceName = make([]string, 0, dedupedCount)
|
||||
agentName = make([]string, 0, dedupedCount)
|
||||
connType = make([]database.ConnectionType, 0, dedupedCount)
|
||||
code = make([]int32, 0, dedupedCount)
|
||||
ip = make([]pqtype.Inet, 0, dedupedCount)
|
||||
userAgent = make([]string, 0, dedupedCount)
|
||||
userID = make([]uuid.UUID, 0, dedupedCount)
|
||||
slugOrPort = make([]string, 0, dedupedCount)
|
||||
connectionID = make([]uuid.UUID, 0, dedupedCount)
|
||||
disconnectReason = make([]string, 0, dedupedCount)
|
||||
disconnectTime = make([]time.Time, 0, dedupedCount)
|
||||
)
|
||||
|
||||
for _, item := range items {
|
||||
ids = append(ids, item.ID)
|
||||
connectTime = append(connectTime, item.Time)
|
||||
organizationID = append(organizationID, item.OrganizationID)
|
||||
workspaceOwnerID = append(workspaceOwnerID, item.WorkspaceOwnerID)
|
||||
workspaceID = append(workspaceID, item.WorkspaceID)
|
||||
workspaceName = append(workspaceName, item.WorkspaceName)
|
||||
agentName = append(agentName, item.AgentName)
|
||||
connType = append(connType, item.Type)
|
||||
code = append(code, nullInt32ToInt32(item.Code))
|
||||
ip = append(ip, item.Ip)
|
||||
userAgent = append(userAgent, nullStringToString(item.UserAgent))
|
||||
userID = append(userID, nullUUIDToUUID(item.UserID))
|
||||
slugOrPort = append(slugOrPort, nullStringToString(item.SlugOrPort))
|
||||
connectionID = append(connectionID, nullUUIDToUUID(item.ConnectionID))
|
||||
disconnectReason = append(disconnectReason, nullStringToString(item.DisconnectReason))
|
||||
// Pre-compute disconnect_time: if status is "disconnected",
|
||||
// use the event time; otherwise use zero time (epoch) which
|
||||
// the SQL CASE will treat as no disconnect.
|
||||
if item.ConnectionStatus == database.ConnectionStatusDisconnected {
|
||||
disconnectTime = append(disconnectTime, item.Time)
|
||||
} else {
|
||||
disconnectTime = append(disconnectTime, time.Time{})
|
||||
}
|
||||
}
|
||||
|
||||
// Clear batch before the DB call. Losing a batch of connection
|
||||
// log entries is acceptable; the next event will be recorded.
|
||||
b.batch = make([]database.UpsertConnectionLogParams, 0, b.maxBatchSize)
|
||||
|
||||
err := b.store.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
||||
ID: ids,
|
||||
ConnectTime: connectTime,
|
||||
OrganizationID: organizationID,
|
||||
WorkspaceOwnerID: workspaceOwnerID,
|
||||
WorkspaceID: workspaceID,
|
||||
WorkspaceName: workspaceName,
|
||||
AgentName: agentName,
|
||||
Type: connType,
|
||||
Code: code,
|
||||
Ip: ip,
|
||||
UserAgent: userAgent,
|
||||
UserID: userID,
|
||||
SlugOrPort: slugOrPort,
|
||||
ConnectionID: connectionID,
|
||||
DisconnectReason: disconnectReason,
|
||||
DisconnectTime: disconnectTime,
|
||||
})
|
||||
if err != nil {
|
||||
if database.IsQueryCanceledError(err) {
|
||||
b.log.Debug(ctx, "query canceled, skipping connection log batch update")
|
||||
return
|
||||
}
|
||||
b.log.Error(ctx, "failed to batch upsert connection logs", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
b.log.Debug(ctx, "connection log batch flush complete", slog.F("count", count))
|
||||
}
|
||||
|
||||
// nullStringToString converts a sql.NullString to a string. When the
|
||||
// NullString is not valid, an empty string is returned.
|
||||
func nullStringToString(ns sql.NullString) string {
|
||||
if ns.Valid {
|
||||
return ns.String
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// nullInt32ToInt32 converts a sql.NullInt32 to an int32. When the
|
||||
// NullInt32 is not valid, zero is returned.
|
||||
func nullInt32ToInt32(ni sql.NullInt32) int32 {
|
||||
if ni.Valid {
|
||||
return ni.Int32
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// nullUUIDToUUID converts a uuid.NullUUID to a uuid.UUID. When the
|
||||
// NullUUID is not valid, uuid.Nil is returned.
|
||||
func nullUUIDToUUID(nu uuid.NullUUID) uuid.UUID {
|
||||
if nu.Valid {
|
||||
return nu.UUID
|
||||
}
|
||||
return uuid.Nil
|
||||
}
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
@@ -28,6 +27,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/provisionersdk/proto"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
@@ -223,6 +223,7 @@ func UserFromGroupMember(member database.GroupMember) database.User {
|
||||
QuietHoursSchedule: member.UserQuietHoursSchedule,
|
||||
Name: member.UserName,
|
||||
GithubComUserID: member.UserGithubComUserID,
|
||||
IsServiceAccount: member.UserIsServiceAccount,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,6 +252,7 @@ func UserFromGroupMemberRow(member database.GetGroupMembersByGroupIDPaginatedRow
|
||||
QuietHoursSchedule: member.UserQuietHoursSchedule,
|
||||
Name: member.UserName,
|
||||
GithubComUserID: member.UserGithubComUserID,
|
||||
IsServiceAccount: member.UserIsServiceAccount,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1019,6 +1021,44 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
|
||||
return intc
|
||||
}
|
||||
|
||||
func AIBridgeSession(row database.ListAIBridgeSessionsRow) codersdk.AIBridgeSession {
|
||||
session := codersdk.AIBridgeSession{
|
||||
ID: row.SessionID,
|
||||
Initiator: MinimalUserFromVisibleUser(database.VisibleUser{
|
||||
ID: row.UserID,
|
||||
Username: row.UserUsername,
|
||||
Name: row.UserName,
|
||||
AvatarURL: row.UserAvatarUrl,
|
||||
}),
|
||||
Providers: row.Providers,
|
||||
Models: row.Models,
|
||||
Metadata: jsonOrEmptyMap(pqtype.NullRawMessage{RawMessage: row.Metadata, Valid: len(row.Metadata) > 0}),
|
||||
StartedAt: row.StartedAt,
|
||||
Threads: row.Threads,
|
||||
TokenUsageSummary: codersdk.AIBridgeSessionTokenUsageSummary{
|
||||
InputTokens: row.InputTokens,
|
||||
OutputTokens: row.OutputTokens,
|
||||
},
|
||||
}
|
||||
// Ensure non-nil slices for JSON serialization.
|
||||
if session.Providers == nil {
|
||||
session.Providers = []string{}
|
||||
}
|
||||
if session.Models == nil {
|
||||
session.Models = []string{}
|
||||
}
|
||||
if row.Client != "" {
|
||||
session.Client = &row.Client
|
||||
}
|
||||
if !row.EndedAt.IsZero() {
|
||||
session.EndedAt = &row.EndedAt
|
||||
}
|
||||
if row.LastPrompt != "" {
|
||||
session.LastPrompt = &row.LastPrompt
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
func AIBridgeTokenUsage(usage database.AIBridgeTokenUsage) codersdk.AIBridgeTokenUsage {
|
||||
return codersdk.AIBridgeTokenUsage{
|
||||
ID: usage.ID,
|
||||
|
||||
@@ -1602,15 +1602,6 @@ func (q *querier) BackoffChatDiffStatus(ctx context.Context, arg database.Backof
|
||||
return q.db.BackoffChatDiffStatus(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) BatchUpdateWorkspaceAgentConnections(ctx context.Context, arg database.BatchUpdateWorkspaceAgentConnectionsParams) error {
|
||||
// Could be any workspace agent and checking auth to each workspace
|
||||
// agent is overkill for the purpose of this function.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceWorkspace.All()); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.BatchUpdateWorkspaceAgentConnections(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
|
||||
// Could be any workspace agent and checking auth to each workspace agent is overkill for
|
||||
// the purpose of this function.
|
||||
@@ -1636,13 +1627,6 @@ func (q *querier) BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg datab
|
||||
return q.db.BatchUpdateWorkspaceNextStartAt(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.BatchUpsertConnectionLogs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceNotificationMessage); err != nil {
|
||||
return 0, err
|
||||
@@ -1725,6 +1709,14 @@ func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.C
|
||||
return q.db.CountAuthorizedAIBridgeInterceptions(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
}
|
||||
return q.db.CountAuthorizedAIBridgeSessions(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
|
||||
// Shortcut if the user is an owner. The SQL filter is noticeable,
|
||||
// and this is an easy win for owners. Which is the common case.
|
||||
@@ -2134,6 +2126,17 @@ func (q *querier) DeleteTask(ctx context.Context, arg database.DeleteTaskParams)
|
||||
return q.db.DeleteTask(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserChatCompactionThreshold(ctx context.Context, arg database.DeleteUserChatCompactionThresholdParams) error {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteUserChatCompactionThreshold(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.GetUserSecret(ctx, id)
|
||||
@@ -3584,13 +3587,6 @@ func (q *querier) GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.U
|
||||
return q.db.GetTailnetTunnelPeerBindings(ctx, srcID)
|
||||
}
|
||||
|
||||
func (q *querier) GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsBatchRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetTailnetTunnelPeerBindingsBatch(ctx, ids)
|
||||
}
|
||||
|
||||
func (q *querier) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
|
||||
return nil, err
|
||||
@@ -3598,13 +3594,6 @@ func (q *querier) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID)
|
||||
return q.db.GetTailnetTunnelPeerIDs(ctx, srcID)
|
||||
}
|
||||
|
||||
func (q *querier) GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerIDsBatchRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetTailnetTunnelPeerIDsBatch(ctx, ids)
|
||||
}
|
||||
|
||||
func (q *querier) GetTaskByID(ctx context.Context, id uuid.UUID) (database.Task, error) {
|
||||
return fetch(q.log, q.auth, q.db.GetTaskByID)(ctx, id)
|
||||
}
|
||||
@@ -3951,6 +3940,17 @@ func (q *querier) GetUserByID(ctx context.Context, id uuid.UUID) (database.User,
|
||||
return fetch(q.log, q.auth, q.db.GetUserByID)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatCompactionThreshold(ctx context.Context, arg database.GetUserChatCompactionThresholdParams) (string, error) {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return q.db.GetUserChatCompactionThreshold(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error) {
|
||||
u, err := q.db.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -5325,10 +5325,16 @@ func (q *querier) ListAIBridgeModels(ctx context.Context, arg database.ListAIBri
|
||||
return q.db.ListAuthorizedAIBridgeModels(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
}
|
||||
return q.db.ListAuthorizedAIBridgeSessions(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
// This function is a system function until we implement a join for aibridge interceptions.
|
||||
// Matches the behavior of the workspaces listing endpoint.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -5336,9 +5342,7 @@ func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context,
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeToolUsage, error) {
|
||||
// This function is a system function until we implement a join for aibridge interceptions.
|
||||
// Matches the behavior of the workspaces listing endpoint.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -5346,9 +5350,7 @@ func (q *querier) ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, i
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeUserPrompt, error) {
|
||||
// This function is a system function until we implement a join for aibridge interceptions.
|
||||
// Matches the behavior of the workspaces listing endpoint.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -5382,6 +5384,17 @@ func (q *querier) ListTasks(ctx context.Context, arg database.ListTasksParams) (
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListTasks)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]database.UserConfig, error) {
|
||||
u, err := q.db.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ListUserChatCompactionThresholds(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(userID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
|
||||
@@ -6242,6 +6255,17 @@ func (q *querier) UpdateUsageEventsPostPublish(ctx context.Context, arg database
|
||||
return q.db.UpdateUsageEventsPostPublish(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserChatCompactionThreshold(ctx context.Context, arg database.UpdateUserChatCompactionThresholdParams) (database.UserConfig, error) {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return database.UserConfig{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return database.UserConfig{}, err
|
||||
}
|
||||
return q.db.UpdateUserChatCompactionThreshold(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserChatCustomPrompt(ctx context.Context, arg database.UpdateUserChatCustomPromptParams) (database.UserConfig, error) {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
@@ -7114,6 +7138,14 @@ func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database
|
||||
return q.ListAIBridgeModels(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
return q.db.ListAuthorizedAIBridgeSessions(ctx, arg, prepared)
|
||||
}
|
||||
|
||||
func (q *querier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
return q.db.CountAuthorizedAIBridgeSessions(ctx, arg, prepared)
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
return q.GetChats(ctx, arg)
|
||||
}
|
||||
|
||||
@@ -2278,6 +2278,35 @@ func (s *MethodTestSuite) TestUser() {
|
||||
dbm.EXPECT().UpdateUserChatCustomPrompt(gomock.Any(), arg).Return(uc, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(uc)
|
||||
}))
|
||||
s.Run("ListUserChatCompactionThresholds", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
uc := database.UserConfig{UserID: u.ID, Key: codersdk.ChatCompactionThresholdKeyPrefix + "00000000-0000-0000-0000-000000000001", Value: "75"}
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().ListUserChatCompactionThresholds(gomock.Any(), u.ID).Return([]database.UserConfig{uc}, nil).AnyTimes()
|
||||
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns([]database.UserConfig{uc})
|
||||
}))
|
||||
s.Run("GetUserChatCompactionThreshold", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.GetUserChatCompactionThresholdParams{UserID: u.ID, Key: codersdk.ChatCompactionThresholdKeyPrefix + "00000000-0000-0000-0000-000000000001"}
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().GetUserChatCompactionThreshold(gomock.Any(), arg).Return("75", nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionReadPersonal).Returns("75")
|
||||
}))
|
||||
s.Run("UpdateUserChatCompactionThreshold", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
uc := database.UserConfig{UserID: u.ID, Key: codersdk.ChatCompactionThresholdKeyPrefix + "00000000-0000-0000-0000-000000000001", Value: "75"}
|
||||
arg := database.UpdateUserChatCompactionThresholdParams{UserID: u.ID, Key: uc.Key, ThresholdPercent: 75}
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateUserChatCompactionThreshold(gomock.Any(), arg).Return(uc, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(uc)
|
||||
}))
|
||||
s.Run("DeleteUserChatCompactionThreshold", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.DeleteUserChatCompactionThresholdParams{UserID: u.ID, Key: codersdk.ChatCompactionThresholdKeyPrefix + "00000000-0000-0000-0000-000000000001"}
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteUserChatCompactionThreshold(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal)
|
||||
}))
|
||||
s.Run("UpdateUserTaskNotificationAlertDismissed", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
userConfig := database.UserConfig{UserID: user.ID, Key: "task_notification_alert_dismissed", Value: "false"}
|
||||
@@ -2623,20 +2652,6 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
dbm.EXPECT().GetWorkspaceAgentMetadata(gomock.Any(), arg).Return([]database.WorkspaceAgentMetadatum{dt}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(w, policy.ActionRead).Returns([]database.WorkspaceAgentMetadatum{dt})
|
||||
}))
|
||||
s.Run("BatchUpdateWorkspaceAgentConnections", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
|
||||
now := dbtime.Now()
|
||||
arg := database.BatchUpdateWorkspaceAgentConnectionsParams{
|
||||
ID: []uuid.UUID{agt.ID},
|
||||
FirstConnectedAt: []time.Time{now},
|
||||
LastConnectedAt: []time.Time{now},
|
||||
LastConnectedReplicaID: []uuid.UUID{uuid.New()},
|
||||
DisconnectedAt: []time.Time{{}},
|
||||
UpdatedAt: []time.Time{now},
|
||||
}
|
||||
dbm.EXPECT().BatchUpdateWorkspaceAgentConnections(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceWorkspace.All(), policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("BatchUpdateWorkspaceAgentMetadata", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
|
||||
arg := database.BatchUpdateWorkspaceAgentMetadataParams{
|
||||
@@ -3174,109 +3189,59 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestWorkspacePortSharing() {
|
||||
s.Run("UpsertWorkspaceAgentPortShare", s.Subtest(func(db database.Store, check *expects) {
|
||||
u := dbgen.User(s.T(), db, database.User{})
|
||||
org := dbgen.Organization(s.T(), db, database.Organization{})
|
||||
tpl := dbgen.Template(s.T(), db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: u.ID,
|
||||
})
|
||||
ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{
|
||||
OwnerID: u.ID,
|
||||
OrganizationID: org.ID,
|
||||
TemplateID: tpl.ID,
|
||||
})
|
||||
ps := dbgen.WorkspaceAgentPortShare(s.T(), db, database.WorkspaceAgentPortShare{WorkspaceID: ws.ID})
|
||||
//nolint:gosimple // casting is not a simplification
|
||||
check.Args(database.UpsertWorkspaceAgentPortShareParams{
|
||||
s.Run("UpsertWorkspaceAgentPortShare", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ws := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
ps := testutil.Fake(s.T(), faker, database.WorkspaceAgentPortShare{})
|
||||
ps.WorkspaceID = ws.ID
|
||||
arg := database.UpsertWorkspaceAgentPortShareParams(ps)
|
||||
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
|
||||
dbm.EXPECT().UpsertWorkspaceAgentPortShare(gomock.Any(), arg).Return(ps, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(ws, policy.ActionUpdate).Returns(ps)
|
||||
}))
|
||||
s.Run("GetWorkspaceAgentPortShare", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ws := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
ps := testutil.Fake(s.T(), faker, database.WorkspaceAgentPortShare{})
|
||||
ps.WorkspaceID = ws.ID
|
||||
arg := database.GetWorkspaceAgentPortShareParams{
|
||||
WorkspaceID: ps.WorkspaceID,
|
||||
AgentName: ps.AgentName,
|
||||
Port: ps.Port,
|
||||
ShareLevel: ps.ShareLevel,
|
||||
Protocol: ps.Protocol,
|
||||
}).Asserts(ws, policy.ActionUpdate).Returns(ps)
|
||||
}
|
||||
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
|
||||
dbm.EXPECT().GetWorkspaceAgentPortShare(gomock.Any(), arg).Return(ps, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(ws, policy.ActionRead).Returns(ps)
|
||||
}))
|
||||
s.Run("GetWorkspaceAgentPortShare", s.Subtest(func(db database.Store, check *expects) {
|
||||
u := dbgen.User(s.T(), db, database.User{})
|
||||
org := dbgen.Organization(s.T(), db, database.Organization{})
|
||||
tpl := dbgen.Template(s.T(), db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: u.ID,
|
||||
})
|
||||
ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{
|
||||
OwnerID: u.ID,
|
||||
OrganizationID: org.ID,
|
||||
TemplateID: tpl.ID,
|
||||
})
|
||||
ps := dbgen.WorkspaceAgentPortShare(s.T(), db, database.WorkspaceAgentPortShare{WorkspaceID: ws.ID})
|
||||
check.Args(database.GetWorkspaceAgentPortShareParams{
|
||||
WorkspaceID: ps.WorkspaceID,
|
||||
AgentName: ps.AgentName,
|
||||
Port: ps.Port,
|
||||
}).Asserts(ws, policy.ActionRead).Returns(ps)
|
||||
}))
|
||||
s.Run("ListWorkspaceAgentPortShares", s.Subtest(func(db database.Store, check *expects) {
|
||||
u := dbgen.User(s.T(), db, database.User{})
|
||||
org := dbgen.Organization(s.T(), db, database.Organization{})
|
||||
tpl := dbgen.Template(s.T(), db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: u.ID,
|
||||
})
|
||||
ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{
|
||||
OwnerID: u.ID,
|
||||
OrganizationID: org.ID,
|
||||
TemplateID: tpl.ID,
|
||||
})
|
||||
ps := dbgen.WorkspaceAgentPortShare(s.T(), db, database.WorkspaceAgentPortShare{WorkspaceID: ws.ID})
|
||||
s.Run("ListWorkspaceAgentPortShares", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ws := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
ps := testutil.Fake(s.T(), faker, database.WorkspaceAgentPortShare{})
|
||||
ps.WorkspaceID = ws.ID
|
||||
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
|
||||
dbm.EXPECT().ListWorkspaceAgentPortShares(gomock.Any(), ws.ID).Return([]database.WorkspaceAgentPortShare{ps}, nil).AnyTimes()
|
||||
check.Args(ws.ID).Asserts(ws, policy.ActionRead).Returns([]database.WorkspaceAgentPortShare{ps})
|
||||
}))
|
||||
s.Run("DeleteWorkspaceAgentPortShare", s.Subtest(func(db database.Store, check *expects) {
|
||||
u := dbgen.User(s.T(), db, database.User{})
|
||||
org := dbgen.Organization(s.T(), db, database.Organization{})
|
||||
tpl := dbgen.Template(s.T(), db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: u.ID,
|
||||
})
|
||||
ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{
|
||||
OwnerID: u.ID,
|
||||
OrganizationID: org.ID,
|
||||
TemplateID: tpl.ID,
|
||||
})
|
||||
ps := dbgen.WorkspaceAgentPortShare(s.T(), db, database.WorkspaceAgentPortShare{WorkspaceID: ws.ID})
|
||||
check.Args(database.DeleteWorkspaceAgentPortShareParams{
|
||||
s.Run("DeleteWorkspaceAgentPortShare", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ws := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
ps := testutil.Fake(s.T(), faker, database.WorkspaceAgentPortShare{})
|
||||
ps.WorkspaceID = ws.ID
|
||||
arg := database.DeleteWorkspaceAgentPortShareParams{
|
||||
WorkspaceID: ps.WorkspaceID,
|
||||
AgentName: ps.AgentName,
|
||||
Port: ps.Port,
|
||||
}).Asserts(ws, policy.ActionUpdate).Returns()
|
||||
}
|
||||
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteWorkspaceAgentPortShare(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(ws, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("DeleteWorkspaceAgentPortSharesByTemplate", s.Subtest(func(db database.Store, check *expects) {
|
||||
u := dbgen.User(s.T(), db, database.User{})
|
||||
org := dbgen.Organization(s.T(), db, database.Organization{})
|
||||
tpl := dbgen.Template(s.T(), db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: u.ID,
|
||||
})
|
||||
ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{
|
||||
OwnerID: u.ID,
|
||||
OrganizationID: org.ID,
|
||||
TemplateID: tpl.ID,
|
||||
})
|
||||
_ = dbgen.WorkspaceAgentPortShare(s.T(), db, database.WorkspaceAgentPortShare{WorkspaceID: ws.ID})
|
||||
s.Run("DeleteWorkspaceAgentPortSharesByTemplate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
tpl := testutil.Fake(s.T(), faker, database.Template{})
|
||||
dbm.EXPECT().GetTemplateByID(gomock.Any(), tpl.ID).Return(tpl, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteWorkspaceAgentPortSharesByTemplate(gomock.Any(), tpl.ID).Return(nil).AnyTimes()
|
||||
check.Args(tpl.ID).Asserts(tpl, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate", s.Subtest(func(db database.Store, check *expects) {
|
||||
u := dbgen.User(s.T(), db, database.User{})
|
||||
org := dbgen.Organization(s.T(), db, database.Organization{})
|
||||
tpl := dbgen.Template(s.T(), db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: u.ID,
|
||||
})
|
||||
ws := dbgen.Workspace(s.T(), db, database.WorkspaceTable{
|
||||
OwnerID: u.ID,
|
||||
OrganizationID: org.ID,
|
||||
TemplateID: tpl.ID,
|
||||
})
|
||||
_ = dbgen.WorkspaceAgentPortShare(s.T(), db, database.WorkspaceAgentPortShare{WorkspaceID: ws.ID})
|
||||
s.Run("ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
tpl := testutil.Fake(s.T(), faker, database.Template{})
|
||||
dbm.EXPECT().GetTemplateByID(gomock.Any(), tpl.ID).Return(tpl, nil).AnyTimes()
|
||||
dbm.EXPECT().ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(gomock.Any(), tpl.ID).Return(nil).AnyTimes()
|
||||
check.Args(tpl.ID).Asserts(tpl, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
}
|
||||
@@ -4993,113 +4958,69 @@ func (s *MethodTestSuite) TestOAuth2ProviderAppTokens() {
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestResourcesMonitor() {
|
||||
createAgent := func(t *testing.T, db database.Store) (database.WorkspaceAgent, database.WorkspaceTable) {
|
||||
t.Helper()
|
||||
|
||||
u := dbgen.User(t, db, database.User{})
|
||||
o := dbgen.Organization(t, db, database.Organization{})
|
||||
tpl := dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: o.ID,
|
||||
CreatedBy: u.ID,
|
||||
})
|
||||
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
||||
TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true},
|
||||
OrganizationID: o.ID,
|
||||
CreatedBy: u.ID,
|
||||
})
|
||||
w := dbgen.Workspace(t, db, database.WorkspaceTable{
|
||||
TemplateID: tpl.ID,
|
||||
OrganizationID: o.ID,
|
||||
OwnerID: u.ID,
|
||||
})
|
||||
j := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
||||
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
||||
})
|
||||
b := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
JobID: j.ID,
|
||||
WorkspaceID: w.ID,
|
||||
TemplateVersionID: tv.ID,
|
||||
})
|
||||
res := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{JobID: b.JobID})
|
||||
agt := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ResourceID: res.ID})
|
||||
|
||||
return agt, w
|
||||
}
|
||||
|
||||
s.Run("InsertMemoryResourceMonitor", s.Subtest(func(db database.Store, check *expects) {
|
||||
agt, _ := createAgent(s.T(), db)
|
||||
|
||||
check.Args(database.InsertMemoryResourceMonitorParams{
|
||||
AgentID: agt.ID,
|
||||
s.Run("InsertMemoryResourceMonitor", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.InsertMemoryResourceMonitorParams{
|
||||
AgentID: uuid.New(),
|
||||
State: database.WorkspaceAgentMonitorStateOK,
|
||||
}).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionCreate)
|
||||
}
|
||||
dbm.EXPECT().InsertMemoryResourceMonitor(gomock.Any(), arg).Return(database.WorkspaceAgentMemoryResourceMonitor{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionCreate)
|
||||
}))
|
||||
|
||||
s.Run("InsertVolumeResourceMonitor", s.Subtest(func(db database.Store, check *expects) {
|
||||
agt, _ := createAgent(s.T(), db)
|
||||
|
||||
check.Args(database.InsertVolumeResourceMonitorParams{
|
||||
AgentID: agt.ID,
|
||||
s.Run("InsertVolumeResourceMonitor", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.InsertVolumeResourceMonitorParams{
|
||||
AgentID: uuid.New(),
|
||||
State: database.WorkspaceAgentMonitorStateOK,
|
||||
}).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionCreate)
|
||||
}
|
||||
dbm.EXPECT().InsertVolumeResourceMonitor(gomock.Any(), arg).Return(database.WorkspaceAgentVolumeResourceMonitor{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionCreate)
|
||||
}))
|
||||
|
||||
s.Run("UpdateMemoryResourceMonitor", s.Subtest(func(db database.Store, check *expects) {
|
||||
agt, _ := createAgent(s.T(), db)
|
||||
|
||||
check.Args(database.UpdateMemoryResourceMonitorParams{
|
||||
AgentID: agt.ID,
|
||||
s.Run("UpdateMemoryResourceMonitor", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.UpdateMemoryResourceMonitorParams{
|
||||
AgentID: uuid.New(),
|
||||
State: database.WorkspaceAgentMonitorStateOK,
|
||||
}).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionUpdate)
|
||||
}
|
||||
dbm.EXPECT().UpdateMemoryResourceMonitor(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionUpdate)
|
||||
}))
|
||||
|
||||
s.Run("UpdateVolumeResourceMonitor", s.Subtest(func(db database.Store, check *expects) {
|
||||
agt, _ := createAgent(s.T(), db)
|
||||
|
||||
check.Args(database.UpdateVolumeResourceMonitorParams{
|
||||
AgentID: agt.ID,
|
||||
s.Run("UpdateVolumeResourceMonitor", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.UpdateVolumeResourceMonitorParams{
|
||||
AgentID: uuid.New(),
|
||||
State: database.WorkspaceAgentMonitorStateOK,
|
||||
}).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionUpdate)
|
||||
}
|
||||
dbm.EXPECT().UpdateVolumeResourceMonitor(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionUpdate)
|
||||
}))
|
||||
|
||||
s.Run("FetchMemoryResourceMonitorsUpdatedAfter", s.Subtest(func(db database.Store, check *expects) {
|
||||
s.Run("FetchMemoryResourceMonitorsUpdatedAfter", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().FetchMemoryResourceMonitorsUpdatedAfter(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
check.Args(dbtime.Now()).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionRead)
|
||||
}))
|
||||
|
||||
s.Run("FetchVolumesResourceMonitorsUpdatedAfter", s.Subtest(func(db database.Store, check *expects) {
|
||||
s.Run("FetchVolumesResourceMonitorsUpdatedAfter", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().FetchVolumesResourceMonitorsUpdatedAfter(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
check.Args(dbtime.Now()).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionRead)
|
||||
}))
|
||||
|
||||
s.Run("FetchMemoryResourceMonitorsByAgentID", s.Subtest(func(db database.Store, check *expects) {
|
||||
agt, w := createAgent(s.T(), db)
|
||||
|
||||
dbgen.WorkspaceAgentMemoryResourceMonitor(s.T(), db, database.WorkspaceAgentMemoryResourceMonitor{
|
||||
AgentID: agt.ID,
|
||||
Enabled: true,
|
||||
Threshold: 80,
|
||||
CreatedAt: dbtime.Now(),
|
||||
})
|
||||
|
||||
monitor, err := db.FetchMemoryResourceMonitorsByAgentID(context.Background(), agt.ID)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
s.Run("FetchMemoryResourceMonitorsByAgentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
|
||||
monitor := testutil.Fake(s.T(), faker, database.WorkspaceAgentMemoryResourceMonitor{})
|
||||
dbm.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agt.ID).Return(w, nil).AnyTimes()
|
||||
dbm.EXPECT().FetchMemoryResourceMonitorsByAgentID(gomock.Any(), agt.ID).Return(monitor, nil).AnyTimes()
|
||||
check.Args(agt.ID).Asserts(w, policy.ActionRead).Returns(monitor)
|
||||
}))
|
||||
|
||||
s.Run("FetchVolumesResourceMonitorsByAgentID", s.Subtest(func(db database.Store, check *expects) {
|
||||
agt, w := createAgent(s.T(), db)
|
||||
|
||||
dbgen.WorkspaceAgentVolumeResourceMonitor(s.T(), db, database.WorkspaceAgentVolumeResourceMonitor{
|
||||
AgentID: agt.ID,
|
||||
Path: "/var/lib",
|
||||
Enabled: true,
|
||||
Threshold: 80,
|
||||
CreatedAt: dbtime.Now(),
|
||||
})
|
||||
|
||||
monitors, err := db.FetchVolumesResourceMonitorsByAgentID(context.Background(), agt.ID)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
s.Run("FetchVolumesResourceMonitorsByAgentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
|
||||
monitors := []database.WorkspaceAgentVolumeResourceMonitor{
|
||||
testutil.Fake(s.T(), faker, database.WorkspaceAgentVolumeResourceMonitor{}),
|
||||
}
|
||||
dbm.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agt.ID).Return(w, nil).AnyTimes()
|
||||
dbm.EXPECT().FetchVolumesResourceMonitorsByAgentID(gomock.Any(), agt.ID).Return(monitors, nil).AnyTimes()
|
||||
check.Args(agt.ID).Asserts(w, policy.ActionRead).Returns(monitors)
|
||||
}))
|
||||
}
|
||||
@@ -5499,22 +5420,50 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.ListAIBridgeSessionsParams{}
|
||||
db.EXPECT().ListAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeSessionsRow{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAuthorizedAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.ListAIBridgeSessionsParams{}
|
||||
db.EXPECT().ListAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeSessionsRow{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("CountAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.CountAIBridgeSessionsParams{}
|
||||
db.EXPECT().CountAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return(int64(0), nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("CountAuthorizedAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.CountAIBridgeSessionsParams{}
|
||||
db.EXPECT().CountAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return(int64(0), nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeTokenUsagesByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ids := []uuid.UUID{{1}}
|
||||
db.EXPECT().ListAIBridgeTokenUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeTokenUsage{}, nil).AnyTimes()
|
||||
check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.AIBridgeTokenUsage{})
|
||||
check.Args(ids).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead).Returns([]database.AIBridgeTokenUsage{})
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeUserPromptsByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ids := []uuid.UUID{{1}}
|
||||
db.EXPECT().ListAIBridgeUserPromptsByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeUserPrompt{}, nil).AnyTimes()
|
||||
check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.AIBridgeUserPrompt{})
|
||||
check.Args(ids).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead).Returns([]database.AIBridgeUserPrompt{})
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeToolUsagesByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ids := []uuid.UUID{{1}}
|
||||
db.EXPECT().ListAIBridgeToolUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeToolUsage{}, nil).AnyTimes()
|
||||
check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.AIBridgeToolUsage{})
|
||||
check.Args(ids).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead).Returns([]database.AIBridgeToolUsage{})
|
||||
}))
|
||||
|
||||
s.Run("UpdateAIBridgeInterceptionEnded", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
|
||||
@@ -4,9 +4,10 @@ import (
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -90,6 +91,16 @@ func (s *MethodTestSuite) SetupSuite() {
|
||||
// TearDownSuite asserts that all methods were called at least once.
|
||||
func (s *MethodTestSuite) TearDownSuite() {
|
||||
s.Run("Accounting", func() {
|
||||
// testify/suite's -testify.m flag filters which suite methods
|
||||
// run, but TearDownSuite still executes. Skip the Accounting
|
||||
// check when filtering to avoid misleading "method never
|
||||
// called" errors for every method that was filtered out.
|
||||
if f := flag.Lookup("testify.m"); f != nil {
|
||||
if f.Value.String() != "" {
|
||||
s.T().Skip("Skipping Accounting check: -testify.m flag is set")
|
||||
}
|
||||
}
|
||||
|
||||
t := s.T()
|
||||
notCalled := []string{}
|
||||
for m, c := range s.methodAccounting {
|
||||
@@ -97,7 +108,7 @@ func (s *MethodTestSuite) TearDownSuite() {
|
||||
notCalled = append(notCalled, m)
|
||||
}
|
||||
}
|
||||
sort.Strings(notCalled)
|
||||
slices.Sort(notCalled)
|
||||
for _, m := range notCalled {
|
||||
t.Errorf("Method never called: %q", m)
|
||||
}
|
||||
|
||||
@@ -184,14 +184,6 @@ func (m queryMetricsStore) BackoffChatDiffStatus(ctx context.Context, arg databa
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) BatchUpdateWorkspaceAgentConnections(ctx context.Context, arg database.BatchUpdateWorkspaceAgentConnectionsParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.BatchUpdateWorkspaceAgentConnections(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("BatchUpdateWorkspaceAgentConnections").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BatchUpdateWorkspaceAgentConnections").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.BatchUpdateWorkspaceAgentMetadata(ctx, arg)
|
||||
@@ -216,14 +208,6 @@ func (m queryMetricsStore) BatchUpdateWorkspaceNextStartAt(ctx context.Context,
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.BatchUpsertConnectionLogs(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("BatchUpsertConnectionLogs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BatchUpsertConnectionLogs").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.BulkMarkNotificationMessagesFailed(ctx, arg)
|
||||
@@ -296,6 +280,14 @@ func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAIBridgeSessions(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("CountAIBridgeSessions").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountAIBridgeSessions").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAuditLogs(ctx, arg)
|
||||
@@ -696,6 +688,14 @@ func (m queryMetricsStore) DeleteTask(ctx context.Context, arg database.DeleteTa
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserChatCompactionThreshold(ctx context.Context, arg database.DeleteUserChatCompactionThresholdParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserChatCompactionThreshold(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserChatCompactionThreshold").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserChatCompactionThreshold").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserSecret(ctx, id)
|
||||
@@ -2176,14 +2176,6 @@ func (m queryMetricsStore) GetTailnetTunnelPeerBindings(ctx context.Context, src
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsBatchRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetTailnetTunnelPeerBindingsBatch(ctx, ids)
|
||||
m.queryLatencies.WithLabelValues("GetTailnetTunnelPeerBindingsBatch").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTailnetTunnelPeerBindingsBatch").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetTailnetTunnelPeerIDs(ctx, srcID)
|
||||
@@ -2192,14 +2184,6 @@ func (m queryMetricsStore) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uu
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerIDsBatchRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetTailnetTunnelPeerIDsBatch(ctx, ids)
|
||||
m.queryLatencies.WithLabelValues("GetTailnetTunnelPeerIDsBatch").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTailnetTunnelPeerIDsBatch").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetTaskByID(ctx context.Context, id uuid.UUID) (database.Task, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetTaskByID(ctx, id)
|
||||
@@ -2480,6 +2464,14 @@ func (m queryMetricsStore) GetUserByID(ctx context.Context, id uuid.UUID) (datab
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatCompactionThreshold(ctx context.Context, arg database.GetUserChatCompactionThresholdParams) (string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatCompactionThreshold(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetUserChatCompactionThreshold").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatCompactionThreshold").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatCustomPrompt(ctx, userID)
|
||||
@@ -3736,6 +3728,14 @@ func (m queryMetricsStore) ListAIBridgeModels(ctx context.Context, arg database.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeSessions(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("ListAIBridgeSessions").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAIBridgeSessions").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIds)
|
||||
@@ -3800,6 +3800,14 @@ func (m queryMetricsStore) ListTasks(ctx context.Context, arg database.ListTasks
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]database.UserConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListUserChatCompactionThresholds(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("ListUserChatCompactionThresholds").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListUserChatCompactionThresholds").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListUserSecrets(ctx, userID)
|
||||
@@ -4392,6 +4400,14 @@ func (m queryMetricsStore) UpdateUsageEventsPostPublish(ctx context.Context, arg
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserChatCompactionThreshold(ctx context.Context, arg database.UpdateUserChatCompactionThresholdParams) (database.UserConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateUserChatCompactionThreshold(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserChatCompactionThreshold").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserChatCompactionThreshold").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserChatCustomPrompt(ctx context.Context, arg database.UpdateUserChatCustomPromptParams) (database.UserConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateUserChatCustomPrompt(ctx, arg)
|
||||
@@ -5136,6 +5152,22 @@ func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAuthorizedAIBridgeSessions(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("ListAuthorizedAIBridgeSessions").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeSessions").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAuthorizedAIBridgeSessions(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("CountAuthorizedAIBridgeSessions").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountAuthorizedAIBridgeSessions").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAuthorizedChats(ctx, arg, prepared)
|
||||
|
||||
@@ -190,20 +190,6 @@ func (mr *MockStoreMockRecorder) BackoffChatDiffStatus(ctx, arg any) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackoffChatDiffStatus", reflect.TypeOf((*MockStore)(nil).BackoffChatDiffStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// BatchUpdateWorkspaceAgentConnections mocks base method.
|
||||
func (m *MockStore) BatchUpdateWorkspaceAgentConnections(ctx context.Context, arg database.BatchUpdateWorkspaceAgentConnectionsParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BatchUpdateWorkspaceAgentConnections", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BatchUpdateWorkspaceAgentConnections indicates an expected call of BatchUpdateWorkspaceAgentConnections.
|
||||
func (mr *MockStoreMockRecorder) BatchUpdateWorkspaceAgentConnections(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdateWorkspaceAgentConnections", reflect.TypeOf((*MockStore)(nil).BatchUpdateWorkspaceAgentConnections), ctx, arg)
|
||||
}
|
||||
|
||||
// BatchUpdateWorkspaceAgentMetadata mocks base method.
|
||||
func (m *MockStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -246,20 +232,6 @@ func (mr *MockStoreMockRecorder) BatchUpdateWorkspaceNextStartAt(ctx, arg any) *
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdateWorkspaceNextStartAt", reflect.TypeOf((*MockStore)(nil).BatchUpdateWorkspaceNextStartAt), ctx, arg)
|
||||
}
|
||||
|
||||
// BatchUpsertConnectionLogs mocks base method.
|
||||
func (m *MockStore) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BatchUpsertConnectionLogs", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BatchUpsertConnectionLogs indicates an expected call of BatchUpsertConnectionLogs.
|
||||
func (mr *MockStoreMockRecorder) BatchUpsertConnectionLogs(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpsertConnectionLogs", reflect.TypeOf((*MockStore)(nil).BatchUpsertConnectionLogs), ctx, arg)
|
||||
}
|
||||
|
||||
// BulkMarkNotificationMessagesFailed mocks base method.
|
||||
func (m *MockStore) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -391,6 +363,21 @@ func (mr *MockStoreMockRecorder) CountAIBridgeInterceptions(ctx, arg any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).CountAIBridgeInterceptions), ctx, arg)
|
||||
}
|
||||
|
||||
// CountAIBridgeSessions mocks base method.
|
||||
func (m *MockStore) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountAIBridgeSessions", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountAIBridgeSessions indicates an expected call of CountAIBridgeSessions.
|
||||
func (mr *MockStoreMockRecorder) CountAIBridgeSessions(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).CountAIBridgeSessions), ctx, arg)
|
||||
}
|
||||
|
||||
// CountAuditLogs mocks base method.
|
||||
func (m *MockStore) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -421,6 +408,21 @@ func (mr *MockStoreMockRecorder) CountAuthorizedAIBridgeInterceptions(ctx, arg,
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).CountAuthorizedAIBridgeInterceptions), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// CountAuthorizedAIBridgeSessions mocks base method.
|
||||
func (m *MockStore) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountAuthorizedAIBridgeSessions", ctx, arg, prepared)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountAuthorizedAIBridgeSessions indicates an expected call of CountAuthorizedAIBridgeSessions.
|
||||
func (mr *MockStoreMockRecorder) CountAuthorizedAIBridgeSessions(ctx, arg, prepared any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).CountAuthorizedAIBridgeSessions), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// CountAuthorizedAuditLogs mocks base method.
|
||||
func (m *MockStore) CountAuthorizedAuditLogs(ctx context.Context, arg database.CountAuditLogsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1154,6 +1156,20 @@ func (mr *MockStoreMockRecorder) DeleteTask(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTask", reflect.TypeOf((*MockStore)(nil).DeleteTask), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserChatCompactionThreshold mocks base method.
|
||||
func (m *MockStore) DeleteUserChatCompactionThreshold(ctx context.Context, arg database.DeleteUserChatCompactionThresholdParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteUserChatCompactionThreshold", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteUserChatCompactionThreshold indicates an expected call of DeleteUserChatCompactionThreshold.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserChatCompactionThreshold(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).DeleteUserChatCompactionThreshold), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserSecret mocks base method.
|
||||
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4022,21 +4038,6 @@ func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerBindings(ctx, srcID any) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerBindings", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerBindings), ctx, srcID)
|
||||
}
|
||||
|
||||
// GetTailnetTunnelPeerBindingsBatch mocks base method.
|
||||
func (m *MockStore) GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsBatchRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTailnetTunnelPeerBindingsBatch", ctx, ids)
|
||||
ret0, _ := ret[0].([]database.GetTailnetTunnelPeerBindingsBatchRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetTailnetTunnelPeerBindingsBatch indicates an expected call of GetTailnetTunnelPeerBindingsBatch.
|
||||
func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerBindingsBatch(ctx, ids any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerBindingsBatch", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerBindingsBatch), ctx, ids)
|
||||
}
|
||||
|
||||
// GetTailnetTunnelPeerIDs mocks base method.
|
||||
func (m *MockStore) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4052,21 +4053,6 @@ func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerIDs(ctx, srcID any) *gomock
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerIDs", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerIDs), ctx, srcID)
|
||||
}
|
||||
|
||||
// GetTailnetTunnelPeerIDsBatch mocks base method.
|
||||
func (m *MockStore) GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uuid.UUID) ([]database.GetTailnetTunnelPeerIDsBatchRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTailnetTunnelPeerIDsBatch", ctx, ids)
|
||||
ret0, _ := ret[0].([]database.GetTailnetTunnelPeerIDsBatchRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetTailnetTunnelPeerIDsBatch indicates an expected call of GetTailnetTunnelPeerIDsBatch.
|
||||
func (mr *MockStoreMockRecorder) GetTailnetTunnelPeerIDsBatch(ctx, ids any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetTunnelPeerIDsBatch", reflect.TypeOf((*MockStore)(nil).GetTailnetTunnelPeerIDsBatch), ctx, ids)
|
||||
}
|
||||
|
||||
// GetTaskByID mocks base method.
|
||||
func (m *MockStore) GetTaskByID(ctx context.Context, id uuid.UUID) (database.Task, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4622,6 +4608,21 @@ func (mr *MockStoreMockRecorder) GetUserByID(ctx, id any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserByID", reflect.TypeOf((*MockStore)(nil).GetUserByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetUserChatCompactionThreshold mocks base method.
|
||||
func (m *MockStore) GetUserChatCompactionThreshold(ctx context.Context, arg database.GetUserChatCompactionThresholdParams) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserChatCompactionThreshold", ctx, arg)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserChatCompactionThreshold indicates an expected call of GetUserChatCompactionThreshold.
|
||||
func (mr *MockStoreMockRecorder) GetUserChatCompactionThreshold(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).GetUserChatCompactionThreshold), ctx, arg)
|
||||
}
|
||||
|
||||
// GetUserChatCustomPrompt mocks base method.
|
||||
func (m *MockStore) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6976,6 +6977,21 @@ func (mr *MockStoreMockRecorder) ListAIBridgeModels(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAIBridgeModels), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeSessions mocks base method.
|
||||
func (m *MockStore) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAIBridgeSessions", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.ListAIBridgeSessionsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAIBridgeSessions indicates an expected call of ListAIBridgeSessions.
|
||||
func (mr *MockStoreMockRecorder) ListAIBridgeSessions(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).ListAIBridgeSessions), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeTokenUsagesByInterceptionIDs mocks base method.
|
||||
func (m *MockStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7051,6 +7067,21 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// ListAuthorizedAIBridgeSessions mocks base method.
|
||||
func (m *MockStore) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAuthorizedAIBridgeSessions", ctx, arg, prepared)
|
||||
ret0, _ := ret[0].([]database.ListAIBridgeSessionsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAuthorizedAIBridgeSessions indicates an expected call of ListAuthorizedAIBridgeSessions.
|
||||
func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeSessions(ctx, arg, prepared any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeSessions), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// ListChatUsageLimitGroupOverrides mocks base method.
|
||||
func (m *MockStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7126,6 +7157,21 @@ func (mr *MockStoreMockRecorder) ListTasks(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTasks", reflect.TypeOf((*MockStore)(nil).ListTasks), ctx, arg)
|
||||
}
|
||||
|
||||
// ListUserChatCompactionThresholds mocks base method.
|
||||
func (m *MockStore) ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]database.UserConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListUserChatCompactionThresholds", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.UserConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListUserChatCompactionThresholds indicates an expected call of ListUserChatCompactionThresholds.
|
||||
func (mr *MockStoreMockRecorder) ListUserChatCompactionThresholds(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserChatCompactionThresholds", reflect.TypeOf((*MockStore)(nil).ListUserChatCompactionThresholds), ctx, userID)
|
||||
}
|
||||
|
||||
// ListUserSecrets mocks base method.
|
||||
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8231,6 +8277,21 @@ func (mr *MockStoreMockRecorder) UpdateUsageEventsPostPublish(ctx, arg any) *gom
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUsageEventsPostPublish", reflect.TypeOf((*MockStore)(nil).UpdateUsageEventsPostPublish), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserChatCompactionThreshold mocks base method.
|
||||
func (m *MockStore) UpdateUserChatCompactionThreshold(ctx context.Context, arg database.UpdateUserChatCompactionThresholdParams) (database.UserConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateUserChatCompactionThreshold", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UserConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateUserChatCompactionThreshold indicates an expected call of UpdateUserChatCompactionThreshold.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserChatCompactionThreshold(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).UpdateUserChatCompactionThreshold), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserChatCustomPrompt mocks base method.
|
||||
func (m *MockStore) UpdateUserChatCustomPrompt(ctx context.Context, arg database.UpdateUserChatCustomPromptParams) (database.UserConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+13
-4
@@ -1099,7 +1099,8 @@ CREATE TABLE aibridge_interceptions (
|
||||
client character varying(64) DEFAULT 'Unknown'::character varying,
|
||||
thread_parent_id uuid,
|
||||
thread_root_id uuid,
|
||||
client_session_id character varying(256)
|
||||
client_session_id character varying(256),
|
||||
session_id text GENERATED ALWAYS AS (COALESCE(client_session_id, ((thread_root_id)::text)::character varying, ((id)::text)::character varying)) STORED NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge';
|
||||
@@ -1112,6 +1113,8 @@ COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interceptio
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID supplied by the client (optional and not universally supported).';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.session_id IS 'Groups related interceptions into a logical session. Determined by a priority chain: (1) client_session_id — an explicit session identifier supplied by the calling client (e.g. Claude Code); (2) thread_root_id — the root of an agentic thread detected by Bridge through tool-call correlation, used when the client does not supply its own session ID; (3) id — the interception''s own ID, used as a last resort so every interception belongs to exactly one session even if it is standalone. This is a generated column stored on disk so it can be indexed and joined without recomputing the COALESCE on every query.';
|
||||
|
||||
CREATE TABLE aibridge_model_thoughts (
|
||||
interception_id uuid NOT NULL,
|
||||
content text NOT NULL,
|
||||
@@ -1291,7 +1294,8 @@ CREATE TABLE chat_messages (
|
||||
content_version smallint NOT NULL,
|
||||
total_cost_micros bigint,
|
||||
runtime_ms bigint,
|
||||
deleted boolean DEFAULT false NOT NULL
|
||||
deleted boolean DEFAULT false NOT NULL,
|
||||
provider_response_id text
|
||||
);
|
||||
|
||||
CREATE SEQUENCE chat_messages_id_seq
|
||||
@@ -1619,6 +1623,7 @@ CREATE VIEW group_members_expanded AS
|
||||
users.name AS user_name,
|
||||
users.github_com_user_id AS user_github_com_user_id,
|
||||
users.is_system AS user_is_system,
|
||||
users.is_service_account AS user_is_service_account,
|
||||
groups.organization_id,
|
||||
groups.name AS group_name,
|
||||
all_members.group_id
|
||||
@@ -1627,8 +1632,6 @@ CREATE VIEW group_members_expanded AS
|
||||
JOIN groups ON ((groups.id = all_members.group_id)))
|
||||
WHERE (users.deleted = false);
|
||||
|
||||
COMMENT ON VIEW group_members_expanded IS 'Joins group members with user information, organization ID, group name. Includes both regular group members and organization members (as part of the "Everyone" group).';
|
||||
|
||||
CREATE TABLE inbox_notifications (
|
||||
id uuid NOT NULL,
|
||||
user_id uuid NOT NULL,
|
||||
@@ -3655,6 +3658,10 @@ CREATE INDEX idx_aibridge_interceptions_model ON aibridge_interceptions USING bt
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_provider ON aibridge_interceptions USING btree (provider);
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_session_id ON aibridge_interceptions USING btree (session_id) WHERE (ended_at IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_sessions_filter ON aibridge_interceptions USING btree (initiator_id, started_at DESC, id DESC) WHERE (ended_at IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_started_id_desc ON aibridge_interceptions USING btree (started_at DESC, id DESC);
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_thread_parent_id ON aibridge_interceptions USING btree (thread_parent_id);
|
||||
@@ -3673,6 +3680,8 @@ CREATE INDEX idx_aibridge_tool_usages_provider_tool_call_id ON aibridge_tool_usa
|
||||
|
||||
CREATE INDEX idx_aibridge_tool_usagesprovider_response_id ON aibridge_tool_usages USING btree (provider_response_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_user_prompts_interception_created ON aibridge_user_prompts USING btree (interception_id, created_at DESC, id DESC);
|
||||
|
||||
CREATE INDEX idx_aibridge_user_prompts_interception_id ON aibridge_user_prompts USING btree (interception_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_user_prompts_provider_response_id ON aibridge_user_prompts USING btree (provider_response_id);
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
DROP VIEW group_members_expanded;
|
||||
|
||||
CREATE VIEW group_members_expanded AS
|
||||
WITH all_members AS (
|
||||
SELECT group_members.user_id,
|
||||
group_members.group_id
|
||||
FROM group_members
|
||||
UNION
|
||||
SELECT organization_members.user_id,
|
||||
organization_members.organization_id AS group_id
|
||||
FROM organization_members
|
||||
)
|
||||
SELECT users.id AS user_id,
|
||||
users.email AS user_email,
|
||||
users.username AS user_username,
|
||||
users.hashed_password AS user_hashed_password,
|
||||
users.created_at AS user_created_at,
|
||||
users.updated_at AS user_updated_at,
|
||||
users.status AS user_status,
|
||||
users.rbac_roles AS user_rbac_roles,
|
||||
users.login_type AS user_login_type,
|
||||
users.avatar_url AS user_avatar_url,
|
||||
users.deleted AS user_deleted,
|
||||
users.last_seen_at AS user_last_seen_at,
|
||||
users.quiet_hours_schedule AS user_quiet_hours_schedule,
|
||||
users.name AS user_name,
|
||||
users.github_com_user_id AS user_github_com_user_id,
|
||||
users.is_system AS user_is_system,
|
||||
groups.organization_id,
|
||||
groups.name AS group_name,
|
||||
all_members.group_id
|
||||
FROM ((all_members
|
||||
JOIN users ON ((users.id = all_members.user_id)))
|
||||
JOIN groups ON ((groups.id = all_members.group_id)))
|
||||
WHERE (users.deleted = false);
|
||||
@@ -0,0 +1,36 @@
|
||||
DROP VIEW group_members_expanded;
|
||||
|
||||
CREATE VIEW group_members_expanded AS
|
||||
WITH all_members AS (
|
||||
SELECT group_members.user_id,
|
||||
group_members.group_id
|
||||
FROM group_members
|
||||
UNION
|
||||
SELECT organization_members.user_id,
|
||||
organization_members.organization_id AS group_id
|
||||
FROM organization_members
|
||||
)
|
||||
SELECT users.id AS user_id,
|
||||
users.email AS user_email,
|
||||
users.username AS user_username,
|
||||
users.hashed_password AS user_hashed_password,
|
||||
users.created_at AS user_created_at,
|
||||
users.updated_at AS user_updated_at,
|
||||
users.status AS user_status,
|
||||
users.rbac_roles AS user_rbac_roles,
|
||||
users.login_type AS user_login_type,
|
||||
users.avatar_url AS user_avatar_url,
|
||||
users.deleted AS user_deleted,
|
||||
users.last_seen_at AS user_last_seen_at,
|
||||
users.quiet_hours_schedule AS user_quiet_hours_schedule,
|
||||
users.name AS user_name,
|
||||
users.github_com_user_id AS user_github_com_user_id,
|
||||
users.is_system AS user_is_system,
|
||||
users.is_service_account as user_is_service_account,
|
||||
groups.organization_id,
|
||||
groups.name AS group_name,
|
||||
all_members.group_id
|
||||
FROM ((all_members
|
||||
JOIN users ON ((users.id = all_members.user_id)))
|
||||
JOIN groups ON ((groups.id = all_members.group_id)))
|
||||
WHERE (users.deleted = false);
|
||||
@@ -0,0 +1,5 @@
|
||||
DROP INDEX IF EXISTS idx_aibridge_interceptions_session_id;
|
||||
DROP INDEX IF EXISTS idx_aibridge_user_prompts_interception_created;
|
||||
DROP INDEX IF EXISTS idx_aibridge_interceptions_sessions_filter;
|
||||
|
||||
ALTER TABLE aibridge_interceptions DROP COLUMN IF EXISTS session_id;
|
||||
@@ -0,0 +1,40 @@
|
||||
-- A "session" groups related interceptions together. See the COMMENT ON
|
||||
-- COLUMN below for the full business-logic description.
|
||||
ALTER TABLE aibridge_interceptions
|
||||
ADD COLUMN session_id TEXT NOT NULL
|
||||
GENERATED ALWAYS AS (
|
||||
COALESCE(
|
||||
client_session_id,
|
||||
thread_root_id::text,
|
||||
id::text
|
||||
)
|
||||
) STORED;
|
||||
|
||||
-- Searching and grouping on the resolved session ID will be common.
|
||||
CREATE INDEX idx_aibridge_interceptions_session_id
|
||||
ON aibridge_interceptions (session_id)
|
||||
WHERE ended_at IS NOT NULL;
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.session_id IS
|
||||
'Groups related interceptions into a logical session. '
|
||||
'Determined by a priority chain: '
|
||||
'(1) client_session_id — an explicit session identifier supplied by the '
|
||||
'calling client (e.g. Claude Code); '
|
||||
'(2) thread_root_id — the root of an agentic thread detected by Bridge '
|
||||
'through tool-call correlation, used when the client does not supply its '
|
||||
'own session ID; '
|
||||
'(3) id — the interception''s own ID, used as a last resort so every '
|
||||
'interception belongs to exactly one session even if it is standalone. '
|
||||
'This is a generated column stored on disk so it can be indexed and '
|
||||
'joined without recomputing the COALESCE on every query.';
|
||||
|
||||
-- Composite index for the most common filter path used by
|
||||
-- ListAIBridgeSessions: initiator_id equality + started_at range,
|
||||
-- with ended_at IS NOT NULL as a partial filter.
|
||||
CREATE INDEX idx_aibridge_interceptions_sessions_filter
|
||||
ON aibridge_interceptions (initiator_id, started_at DESC, id DESC)
|
||||
WHERE ended_at IS NOT NULL;
|
||||
|
||||
-- Supports lateral prompt lookup by interception + recency.
|
||||
CREATE INDEX idx_aibridge_user_prompts_interception_created
|
||||
ON aibridge_user_prompts (interception_id, created_at DESC, id DESC);
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE chat_messages DROP COLUMN provider_response_id;
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE chat_messages ADD COLUMN provider_response_id TEXT;
|
||||
@@ -806,6 +806,8 @@ type aibridgeQuerier interface {
|
||||
ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error)
|
||||
CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error)
|
||||
ListAuthorizedAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error)
|
||||
ListAuthorizedAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionsRow, error)
|
||||
CountAuthorizedAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error)
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error) {
|
||||
@@ -852,6 +854,7 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
|
||||
&i.AIBridgeInterception.ThreadParentID,
|
||||
&i.AIBridgeInterception.ThreadRootID,
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.AIBridgeInterception.SessionID,
|
||||
&i.VisibleUser.ID,
|
||||
&i.VisibleUser.Username,
|
||||
&i.VisibleUser.Name,
|
||||
@@ -939,6 +942,109 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeModels(ctx context.Context, arg ListA
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionsRow, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
|
||||
VariableConverter: regosql.AIBridgeInterceptionConverter(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
filtered, err := insertAuthorizedFilter(listAIBridgeSessions, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("-- name: ListAuthorizedAIBridgeSessions :many\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.AfterSessionID,
|
||||
arg.Offset,
|
||||
arg.Limit,
|
||||
arg.StartedAfter,
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.SessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ListAIBridgeSessionsRow
|
||||
for rows.Next() {
|
||||
var i ListAIBridgeSessionsRow
|
||||
if err := rows.Scan(
|
||||
&i.SessionID,
|
||||
&i.UserID,
|
||||
&i.UserUsername,
|
||||
&i.UserName,
|
||||
&i.UserAvatarUrl,
|
||||
pq.Array(&i.Providers),
|
||||
pq.Array(&i.Models),
|
||||
&i.Client,
|
||||
&i.Metadata,
|
||||
&i.StartedAt,
|
||||
&i.EndedAt,
|
||||
&i.Threads,
|
||||
&i.InputTokens,
|
||||
&i.OutputTokens,
|
||||
&i.LastPrompt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
|
||||
VariableConverter: regosql.AIBridgeInterceptionConverter(),
|
||||
})
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
filtered, err := insertAuthorizedFilter(countAIBridgeSessions, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("-- name: CountAuthorizedAIBridgeSessions :one\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.StartedAfter,
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.SessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var count int64
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
|
||||
if !strings.Contains(query, authorizedQueryPlaceholder) {
|
||||
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")
|
||||
|
||||
@@ -4036,6 +4036,8 @@ type AIBridgeInterception struct {
|
||||
ThreadRootID uuid.NullUUID `db:"thread_root_id" json:"thread_root_id"`
|
||||
// The session ID supplied by the client (optional and not universally supported).
|
||||
ClientSessionID sql.NullString `db:"client_session_id" json:"client_session_id"`
|
||||
// Groups related interceptions into a logical session. Determined by a priority chain: (1) client_session_id — an explicit session identifier supplied by the calling client (e.g. Claude Code); (2) thread_root_id — the root of an agentic thread detected by Bridge through tool-call correlation, used when the client does not supply its own session ID; (3) id — the interception's own ID, used as a last resort so every interception belongs to exactly one session even if it is standalone. This is a generated column stored on disk so it can be indexed and joined without recomputing the COALESCE on every query.
|
||||
SessionID string `db:"session_id" json:"session_id"`
|
||||
}
|
||||
|
||||
// Audit log of model thinking in intercepted requests in AI Bridge
|
||||
@@ -4227,6 +4229,7 @@ type ChatMessage struct {
|
||||
TotalCostMicros sql.NullInt64 `db:"total_cost_micros" json:"total_cost_micros"`
|
||||
RuntimeMs sql.NullInt64 `db:"runtime_ms" json:"runtime_ms"`
|
||||
Deleted bool `db:"deleted" json:"deleted"`
|
||||
ProviderResponseID sql.NullString `db:"provider_response_id" json:"provider_response_id"`
|
||||
}
|
||||
|
||||
type ChatModelConfig struct {
|
||||
@@ -4394,7 +4397,6 @@ type Group struct {
|
||||
ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"`
|
||||
}
|
||||
|
||||
// Joins group members with user information, organization ID, group name. Includes both regular group members and organization members (as part of the "Everyone" group).
|
||||
type GroupMember struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
UserEmail string `db:"user_email" json:"user_email"`
|
||||
@@ -4412,6 +4414,7 @@ type GroupMember struct {
|
||||
UserName string `db:"user_name" json:"user_name"`
|
||||
UserGithubComUserID sql.NullInt64 `db:"user_github_com_user_id" json:"user_github_com_user_id"`
|
||||
UserIsSystem bool `db:"user_is_system" json:"user_is_system"`
|
||||
UserIsServiceAccount bool `db:"user_is_service_account" json:"user_is_service_account"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
GroupName string `db:"group_name" json:"group_name"`
|
||||
GroupID uuid.UUID `db:"group_id" json:"group_id"`
|
||||
|
||||
@@ -231,7 +231,7 @@ type PGPubsub struct {
|
||||
|
||||
// BufferSize is the maximum number of unhandled messages we will buffer
|
||||
// for a subscriber before dropping messages.
|
||||
const BufferSize = 8192
|
||||
const BufferSize = 2048
|
||||
|
||||
// Subscribe calls the listener when an event matching the name is received.
|
||||
func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
|
||||
|
||||
@@ -62,11 +62,9 @@ type sqlcQuerier interface {
|
||||
// referenced by the latest build of a workspace.
|
||||
ArchiveUnusedTemplateVersions(ctx context.Context, arg ArchiveUnusedTemplateVersionsParams) ([]uuid.UUID, error)
|
||||
BackoffChatDiffStatus(ctx context.Context, arg BackoffChatDiffStatusParams) error
|
||||
BatchUpdateWorkspaceAgentConnections(ctx context.Context, arg BatchUpdateWorkspaceAgentConnectionsParams) error
|
||||
BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg BatchUpdateWorkspaceAgentMetadataParams) error
|
||||
BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg BatchUpdateWorkspaceLastUsedAtParams) error
|
||||
BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error
|
||||
BatchUpsertConnectionLogs(ctx context.Context, arg BatchUpsertConnectionLogsParams) error
|
||||
BulkMarkNotificationMessagesFailed(ctx context.Context, arg BulkMarkNotificationMessagesFailedParams) (int64, error)
|
||||
BulkMarkNotificationMessagesSent(ctx context.Context, arg BulkMarkNotificationMessagesSentParams) (int64, error)
|
||||
// Calculates the telemetry summary for a given provider, model, and client
|
||||
@@ -78,6 +76,7 @@ type sqlcQuerier interface {
|
||||
CleanTailnetTunnels(ctx context.Context) error
|
||||
CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
|
||||
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
|
||||
CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error)
|
||||
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
|
||||
CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error)
|
||||
// Counts enabled, non-deleted model configs that lack both input and
|
||||
@@ -150,6 +149,7 @@ type sqlcQuerier interface {
|
||||
DeleteTailnetPeer(ctx context.Context, arg DeleteTailnetPeerParams) (DeleteTailnetPeerRow, error)
|
||||
DeleteTailnetTunnel(ctx context.Context, arg DeleteTailnetTunnelParams) (DeleteTailnetTunnelRow, error)
|
||||
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
|
||||
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
|
||||
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
@@ -461,9 +461,7 @@ type sqlcQuerier interface {
|
||||
GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error)
|
||||
GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]TailnetPeer, error)
|
||||
GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerBindingsRow, error)
|
||||
GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]GetTailnetTunnelPeerBindingsBatchRow, error)
|
||||
GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerIDsRow, error)
|
||||
GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uuid.UUID) ([]GetTailnetTunnelPeerIDsBatchRow, error)
|
||||
GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error)
|
||||
GetTaskByOwnerIDAndName(ctx context.Context, arg GetTaskByOwnerIDAndNameParams) (Task, error)
|
||||
GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (Task, error)
|
||||
@@ -557,6 +555,7 @@ type sqlcQuerier interface {
|
||||
GetUserActivityInsights(ctx context.Context, arg GetUserActivityInsightsParams) ([]GetUserActivityInsightsRow, error)
|
||||
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
|
||||
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
|
||||
GetUserChatCompactionThreshold(ctx context.Context, arg GetUserChatCompactionThresholdParams) (string, error)
|
||||
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
|
||||
GetUserChatSpendInPeriod(ctx context.Context, arg GetUserChatSpendInPeriodParams) (int64, error)
|
||||
GetUserCount(ctx context.Context, includeSystem bool) (int64, error)
|
||||
@@ -761,6 +760,10 @@ type sqlcQuerier interface {
|
||||
// (provider, model, client) in the given timeframe for telemetry reporting.
|
||||
ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error)
|
||||
ListAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams) ([]string, error)
|
||||
// Returns paginated sessions with aggregated metadata, token counts, and
|
||||
// the most recent user prompt. A "session" is a logical grouping of
|
||||
// interceptions that share the same session_id (set by the client).
|
||||
ListAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams) ([]ListAIBridgeSessionsRow, error)
|
||||
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
|
||||
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
|
||||
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
|
||||
@@ -769,6 +772,7 @@ type sqlcQuerier interface {
|
||||
ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
|
||||
ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]UserConfig, error)
|
||||
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
|
||||
ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error)
|
||||
MarkAllInboxNotificationsAsRead(ctx context.Context, arg MarkAllInboxNotificationsAsReadParams) error
|
||||
@@ -872,6 +876,7 @@ type sqlcQuerier interface {
|
||||
UpdateTemplateVersionFlagsByJobID(ctx context.Context, arg UpdateTemplateVersionFlagsByJobIDParams) error
|
||||
UpdateTemplateWorkspacesLastUsedAt(ctx context.Context, arg UpdateTemplateWorkspacesLastUsedAtParams) error
|
||||
UpdateUsageEventsPostPublish(ctx context.Context, arg UpdateUsageEventsPostPublishParams) error
|
||||
UpdateUserChatCompactionThreshold(ctx context.Context, arg UpdateUserChatCompactionThresholdParams) (UserConfig, error)
|
||||
UpdateUserChatCustomPrompt(ctx context.Context, arg UpdateUserChatCustomPromptParams) (UserConfig, error)
|
||||
UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error
|
||||
UpdateUserGithubComUserID(ctx context.Context, arg UpdateUserGithubComUserIDParams) error
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
@@ -35,6 +34,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
@@ -10417,6 +10417,49 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(0), recent[0].CostMicros)
|
||||
})
|
||||
|
||||
t.Run("BlankDisplayNameFallsBackToModel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, userID, _ := setupChatInfra(t)
|
||||
|
||||
const modelName = "claude-4.1"
|
||||
emptyDisplayModel, err := store.InsertChatModelConfig(context.Background(), database.InsertChatModelConfigParams{
|
||||
Provider: "anthropic",
|
||||
Model: modelName,
|
||||
DisplayName: "",
|
||||
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: false,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 80,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat := createChat(t, store, userID, emptyDisplayModel.ID, "chat-empty-display-name")
|
||||
insertCostMessage(t, store, chat.ID, userID, emptyDisplayModel.ID, 1_000_000)
|
||||
linkPR(t, store, chat.ID, "https://github.com/org/repo/pull/72", "merged", "fix: blank display name", 10, 2, 1)
|
||||
|
||||
byModel, err := store.GetPRInsightsPerModel(context.Background(), database.GetPRInsightsPerModelParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, byModel, 1)
|
||||
assert.Equal(t, modelName, byModel[0].DisplayName)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
assert.Equal(t, modelName, recent[0].ModelDisplayName)
|
||||
})
|
||||
|
||||
t.Run("MergedCostMicros_OnlyCountsMerged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, userID, mcID := setupChatInfra(t)
|
||||
|
||||
+567
-258
File diff suppressed because it is too large
Load Diff
@@ -404,6 +404,194 @@ SELECT (
|
||||
(SELECT COUNT(*) FROM interceptions)
|
||||
)::bigint as total_deleted;
|
||||
|
||||
-- name: CountAIBridgeSessions :one
|
||||
SELECT
|
||||
COUNT(DISTINCT (aibridge_interceptions.session_id, aibridge_interceptions.initiator_id))
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
-- Remove inflight interceptions (ones which lack an ended_at value).
|
||||
aibridge_interceptions.ended_at IS NOT NULL
|
||||
-- Filter by time frame
|
||||
AND CASE
|
||||
WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
AND CASE
|
||||
WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= @started_before::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
-- Filter initiator_id
|
||||
AND CASE
|
||||
WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = @initiator_id::uuid
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider
|
||||
AND CASE
|
||||
WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter client
|
||||
AND CASE
|
||||
WHEN @client::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = @client::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter session_id
|
||||
AND CASE
|
||||
WHEN @session_id::text != '' THEN aibridge_interceptions.session_id = @session_id::text
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAIBridgeSessions
|
||||
-- @authorize_filter
|
||||
;
|
||||
|
||||
-- name: ListAIBridgeSessions :many
|
||||
-- Returns paginated sessions with aggregated metadata, token counts, and
|
||||
-- the most recent user prompt. A "session" is a logical grouping of
|
||||
-- interceptions that share the same session_id (set by the client).
|
||||
WITH filtered_interceptions AS (
|
||||
SELECT
|
||||
aibridge_interceptions.*
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
-- Remove inflight interceptions (ones which lack an ended_at value).
|
||||
aibridge_interceptions.ended_at IS NOT NULL
|
||||
-- Filter by time frame
|
||||
AND CASE
|
||||
WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
AND CASE
|
||||
WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= @started_before::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
-- Filter initiator_id
|
||||
AND CASE
|
||||
WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = @initiator_id::uuid
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider
|
||||
AND CASE
|
||||
WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter client
|
||||
AND CASE
|
||||
WHEN @client::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = @client::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter session_id
|
||||
AND CASE
|
||||
WHEN @session_id::text != '' THEN aibridge_interceptions.session_id = @session_id::text
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeSessions
|
||||
-- @authorize_filter
|
||||
),
|
||||
session_tokens AS (
|
||||
-- Aggregate token usage across all interceptions in each session.
|
||||
-- Group by (session_id, initiator_id) to avoid merging sessions from
|
||||
-- different users who happen to share the same client_session_id.
|
||||
SELECT
|
||||
fi.session_id,
|
||||
fi.initiator_id,
|
||||
COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens,
|
||||
COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens
|
||||
-- TODO: add extra token types once https://github.com/coder/aibridge/issues/150 lands.
|
||||
FROM
|
||||
filtered_interceptions fi
|
||||
LEFT JOIN
|
||||
aibridge_token_usages tu ON fi.id = tu.interception_id
|
||||
GROUP BY
|
||||
fi.session_id, fi.initiator_id
|
||||
),
|
||||
session_root AS (
|
||||
-- Build one summary row per session. Group by (session_id, initiator_id)
|
||||
-- to avoid merging sessions from different users who happen to share the
|
||||
-- same client_session_id. The ARRAY_AGG with ORDER BY picks values from
|
||||
-- the chronologically first interception for fields that should represent
|
||||
-- the session as a whole (client, metadata). Threads are counted as
|
||||
-- distinct root interception IDs: an interception with a NULL
|
||||
-- thread_root_id is itself a thread root.
|
||||
SELECT
|
||||
fi.session_id,
|
||||
fi.initiator_id,
|
||||
(ARRAY_AGG(fi.client ORDER BY fi.started_at, fi.id))[1] AS client,
|
||||
(ARRAY_AGG(fi.metadata ORDER BY fi.started_at, fi.id))[1] AS metadata,
|
||||
ARRAY_AGG(DISTINCT fi.provider ORDER BY fi.provider) AS providers,
|
||||
ARRAY_AGG(DISTINCT fi.model ORDER BY fi.model) AS models,
|
||||
MIN(fi.started_at) AS started_at,
|
||||
MAX(fi.ended_at) AS ended_at,
|
||||
COUNT(DISTINCT COALESCE(fi.thread_root_id, fi.id)) AS threads,
|
||||
-- Collect IDs for lateral prompt lookup.
|
||||
ARRAY_AGG(fi.id) AS interception_ids
|
||||
FROM
|
||||
filtered_interceptions fi
|
||||
GROUP BY
|
||||
fi.session_id, fi.initiator_id
|
||||
)
|
||||
SELECT
|
||||
sr.session_id,
|
||||
visible_users.id AS user_id,
|
||||
visible_users.username AS user_username,
|
||||
visible_users.name AS user_name,
|
||||
visible_users.avatar_url AS user_avatar_url,
|
||||
sr.providers::text[] AS providers,
|
||||
sr.models::text[] AS models,
|
||||
COALESCE(sr.client, '')::varchar(64) AS client,
|
||||
sr.metadata::jsonb AS metadata,
|
||||
sr.started_at::timestamptz AS started_at,
|
||||
sr.ended_at::timestamptz AS ended_at,
|
||||
sr.threads,
|
||||
COALESCE(st.input_tokens, 0)::bigint AS input_tokens,
|
||||
COALESCE(st.output_tokens, 0)::bigint AS output_tokens,
|
||||
COALESCE(slp.prompt, '') AS last_prompt
|
||||
FROM
|
||||
session_root sr
|
||||
JOIN
|
||||
visible_users ON visible_users.id = sr.initiator_id
|
||||
LEFT JOIN
|
||||
session_tokens st ON st.session_id = sr.session_id AND st.initiator_id = sr.initiator_id
|
||||
LEFT JOIN LATERAL (
|
||||
-- Lateral join to efficiently fetch only the most recent user prompt
|
||||
-- across all interceptions in the session, avoiding a full aggregation.
|
||||
SELECT up.prompt
|
||||
FROM aibridge_user_prompts up
|
||||
WHERE up.interception_id = ANY(sr.interception_ids)
|
||||
ORDER BY up.created_at DESC, up.id DESC
|
||||
LIMIT 1
|
||||
) slp ON true
|
||||
WHERE
|
||||
-- Cursor pagination: uses a composite (started_at, session_id) cursor
|
||||
-- to support keyset pagination. The less-than comparison matches the
|
||||
-- DESC sort order so that rows after the cursor come later in results.
|
||||
CASE
|
||||
WHEN @after_session_id::text != '' THEN (
|
||||
(sr.started_at, sr.session_id) < (
|
||||
(SELECT started_at FROM session_root WHERE session_id = @after_session_id),
|
||||
@after_session_id::text
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
ORDER BY
|
||||
sr.started_at DESC,
|
||||
sr.session_id DESC
|
||||
LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100)
|
||||
OFFSET @offset_
|
||||
;
|
||||
|
||||
-- name: ListAIBridgeModels :many
|
||||
SELECT
|
||||
model
|
||||
|
||||
@@ -147,6 +147,7 @@ deduped AS (
|
||||
cds.deletions,
|
||||
cmc.id AS model_config_id,
|
||||
cmc.display_name,
|
||||
cmc.model,
|
||||
cmc.provider
|
||||
FROM chat_diff_statuses cds
|
||||
JOIN chats c ON c.id = cds.chat_id
|
||||
@@ -159,7 +160,7 @@ deduped AS (
|
||||
)
|
||||
SELECT
|
||||
d.model_config_id,
|
||||
COALESCE(d.display_name, 'Unknown')::text AS display_name,
|
||||
COALESCE(NULLIF(d.display_name, ''), NULLIF(d.model, ''), 'Unknown')::text AS display_name,
|
||||
COALESCE(d.provider, 'unknown')::text AS provider,
|
||||
COUNT(*)::bigint AS total_prs,
|
||||
COUNT(*) FILTER (WHERE d.pull_request_state = 'merged')::bigint AS merged_prs,
|
||||
@@ -169,7 +170,7 @@ SELECT
|
||||
COALESCE(SUM(pc.cost_micros) FILTER (WHERE d.pull_request_state = 'merged'), 0)::bigint AS merged_cost_micros
|
||||
FROM deduped d
|
||||
JOIN pr_costs pc ON pc.pr_key = d.pr_key
|
||||
GROUP BY d.model_config_id, d.display_name, d.provider
|
||||
GROUP BY d.model_config_id, d.display_name, d.model, d.provider
|
||||
ORDER BY total_prs DESC;
|
||||
|
||||
-- name: GetPRInsightsRecentPRs :many
|
||||
@@ -227,7 +228,7 @@ deduped AS (
|
||||
cds.author_login,
|
||||
cds.author_avatar_url,
|
||||
COALESCE(cds.base_branch, '')::text AS base_branch,
|
||||
COALESCE(cmc.display_name, cmc.model, 'Unknown')::text AS model_display_name,
|
||||
COALESCE(NULLIF(cmc.display_name, ''), NULLIF(cmc.model, ''), 'Unknown')::text AS model_display_name,
|
||||
c.created_at
|
||||
FROM chat_diff_statuses cds
|
||||
JOIN chats c ON c.id = cds.chat_id
|
||||
|
||||
@@ -241,7 +241,8 @@ INSERT INTO chat_messages (
|
||||
context_limit,
|
||||
compressed,
|
||||
total_cost_micros,
|
||||
runtime_ms
|
||||
runtime_ms,
|
||||
provider_response_id
|
||||
)
|
||||
SELECT
|
||||
@chat_id::uuid,
|
||||
@@ -260,7 +261,8 @@ SELECT
|
||||
NULLIF(UNNEST(@context_limit::bigint[]), 0),
|
||||
UNNEST(@compressed::boolean[]),
|
||||
NULLIF(UNNEST(@total_cost_micros::bigint[]), 0),
|
||||
NULLIF(UNNEST(@runtime_ms::bigint[]), 0)
|
||||
NULLIF(UNNEST(@runtime_ms::bigint[]), 0),
|
||||
NULLIF(UNNEST(@provider_response_id::text[]), '')
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
|
||||
@@ -303,44 +303,3 @@ DO UPDATE SET
|
||||
ELSE connection_logs.code
|
||||
END
|
||||
RETURNING *;
|
||||
|
||||
-- name: BatchUpsertConnectionLogs :exec
|
||||
INSERT INTO connection_logs (
|
||||
id, connect_time, organization_id, workspace_owner_id, workspace_id,
|
||||
workspace_name, agent_name, type, code, ip, user_agent, user_id,
|
||||
slug_or_port, connection_id, disconnect_reason, disconnect_time
|
||||
)
|
||||
SELECT
|
||||
unnest(sqlc.arg('id')::uuid[]),
|
||||
unnest(sqlc.arg('connect_time')::timestamptz[]),
|
||||
unnest(sqlc.arg('organization_id')::uuid[]),
|
||||
unnest(sqlc.arg('workspace_owner_id')::uuid[]),
|
||||
unnest(sqlc.arg('workspace_id')::uuid[]),
|
||||
unnest(sqlc.arg('workspace_name')::text[]),
|
||||
unnest(sqlc.arg('agent_name')::text[]),
|
||||
unnest(sqlc.arg('type')::connection_type[]),
|
||||
unnest(sqlc.arg('code')::int4[]),
|
||||
unnest(sqlc.arg('ip')::inet[]),
|
||||
unnest(sqlc.arg('user_agent')::text[]),
|
||||
unnest(sqlc.arg('user_id')::uuid[]),
|
||||
unnest(sqlc.arg('slug_or_port')::text[]),
|
||||
unnest(sqlc.arg('connection_id')::uuid[]),
|
||||
unnest(sqlc.arg('disconnect_reason')::text[]),
|
||||
unnest(sqlc.arg('disconnect_time')::timestamptz[])
|
||||
ON CONFLICT (connection_id, workspace_id, agent_name)
|
||||
DO UPDATE SET
|
||||
disconnect_time = CASE
|
||||
WHEN connection_logs.disconnect_time IS NULL
|
||||
THEN EXCLUDED.disconnect_time
|
||||
ELSE connection_logs.disconnect_time
|
||||
END,
|
||||
disconnect_reason = CASE
|
||||
WHEN connection_logs.disconnect_reason IS NULL
|
||||
THEN EXCLUDED.disconnect_reason
|
||||
ELSE connection_logs.disconnect_reason
|
||||
END,
|
||||
code = CASE
|
||||
WHEN connection_logs.code IS NULL
|
||||
THEN EXCLUDED.code
|
||||
ELSE connection_logs.code
|
||||
END;
|
||||
|
||||
@@ -5,7 +5,9 @@
|
||||
-- - Use both to get a specific org member row
|
||||
SELECT
|
||||
sqlc.embed(organization_members),
|
||||
users.username, users.avatar_url, users.name, users.email, users.rbac_roles as "global_roles"
|
||||
users.username, users.avatar_url, users.name, users.email, users.rbac_roles as "global_roles",
|
||||
users.last_seen_at, users.status, users.login_type,
|
||||
users.created_at as user_created_at, users.updated_at as user_updated_at
|
||||
FROM
|
||||
organization_members
|
||||
INNER JOIN
|
||||
@@ -83,23 +85,115 @@ RETURNING *;
|
||||
SELECT
|
||||
sqlc.embed(organization_members),
|
||||
users.username, users.avatar_url, users.name, users.email, users.rbac_roles as "global_roles",
|
||||
users.last_seen_at, users.status, users.login_type,
|
||||
users.created_at as user_created_at, users.updated_at as user_updated_at,
|
||||
COUNT(*) OVER() AS count
|
||||
FROM
|
||||
organization_members
|
||||
INNER JOIN
|
||||
INNER JOIN
|
||||
users ON organization_members.user_id = users.id AND users.deleted = false
|
||||
WHERE
|
||||
-- Filter by organization id
|
||||
CASE
|
||||
-- This allows using the last element on a page as effectively a cursor.
|
||||
-- This is an important option for scripts that need to paginate without
|
||||
-- duplicating or missing data.
|
||||
WHEN @after_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN (
|
||||
-- The pagination cursor is the last ID of the previous page.
|
||||
-- The query is ordered by the username field, so select all
|
||||
-- rows after the cursor.
|
||||
(LOWER(users.username)) > (
|
||||
SELECT
|
||||
LOWER(users.username)
|
||||
FROM
|
||||
organization_members
|
||||
INNER JOIN
|
||||
users ON organization_members.user_id = users.id
|
||||
WHERE
|
||||
organization_members.user_id = @after_id
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Start filters
|
||||
-- Filter by organization id
|
||||
AND CASE
|
||||
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by system type
|
||||
AND CASE WHEN @include_system::bool THEN TRUE ELSE is_system = false END
|
||||
-- Filter by email or username
|
||||
AND CASE
|
||||
WHEN @search :: text != '' THEN (
|
||||
users.email ILIKE concat('%', @search, '%')
|
||||
OR users.username ILIKE concat('%', @search, '%')
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by name (display name)
|
||||
AND CASE
|
||||
WHEN @name :: text != '' THEN
|
||||
users.name ILIKE concat('%', @name, '%')
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by status
|
||||
AND CASE
|
||||
-- @status needs to be a text because it can be empty, If it was
|
||||
-- user_status enum, it would not.
|
||||
WHEN cardinality(@status :: user_status[]) > 0 THEN
|
||||
users.status = ANY(@status :: user_status[])
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by global rbac_roles
|
||||
AND CASE
|
||||
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as
|
||||
-- everyone is a member.
|
||||
WHEN cardinality(@rbac_role :: text[]) > 0 AND 'member' != ANY(@rbac_role :: text[]) THEN
|
||||
users.rbac_roles && @rbac_role :: text[]
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by last_seen
|
||||
AND CASE
|
||||
WHEN @last_seen_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
users.last_seen_at <= @last_seen_before
|
||||
ELSE true
|
||||
END
|
||||
AND CASE
|
||||
WHEN @last_seen_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
users.last_seen_at >= @last_seen_after
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by created_at (user creation date, not date added to org)
|
||||
AND CASE
|
||||
WHEN @created_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
users.created_at <= @created_before
|
||||
ELSE true
|
||||
END
|
||||
AND CASE
|
||||
WHEN @created_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
users.created_at >= @created_after
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by system type
|
||||
AND CASE
|
||||
WHEN @include_system::bool THEN TRUE
|
||||
ELSE users.is_system = false
|
||||
END
|
||||
-- Filter by github.com user ID
|
||||
AND CASE
|
||||
WHEN @github_com_user_id :: bigint != 0 THEN
|
||||
users.github_com_user_id = @github_com_user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by login_type
|
||||
AND CASE
|
||||
WHEN cardinality(@login_type :: login_type[]) > 0 THEN
|
||||
users.login_type = ANY(@login_type :: login_type[])
|
||||
ELSE true
|
||||
END
|
||||
-- End of filters
|
||||
ORDER BY
|
||||
-- Deterministic and consistent ordering of all users. This is to ensure consistent pagination.
|
||||
LOWER(username) ASC OFFSET @offset_opt
|
||||
LOWER(users.username) ASC OFFSET @offset_opt
|
||||
LIMIT
|
||||
-- A null limit means "no limit", so 0 means return all
|
||||
NULLIF(@limit_opt :: int, 0);
|
||||
|
||||
@@ -118,26 +118,6 @@ WHERE id IN (
|
||||
WHERE tailnet_tunnels.dst_id = $1
|
||||
);
|
||||
|
||||
-- name: GetTailnetTunnelPeerIDsBatch :many
|
||||
SELECT src_id AS lookup_id, dst_id AS peer_id, coordinator_id, updated_at
|
||||
FROM tailnet_tunnels WHERE src_id = ANY(@ids :: uuid[])
|
||||
UNION ALL
|
||||
SELECT dst_id AS lookup_id, src_id AS peer_id, coordinator_id, updated_at
|
||||
FROM tailnet_tunnels WHERE dst_id = ANY(@ids :: uuid[]);
|
||||
|
||||
-- name: GetTailnetTunnelPeerBindingsBatch :many
|
||||
SELECT tp.id AS peer_id, tp.coordinator_id, tp.updated_at, tp.node, tp.status,
|
||||
tt.src_id AS lookup_id
|
||||
FROM tailnet_peers tp
|
||||
INNER JOIN tailnet_tunnels tt ON tp.id = tt.dst_id
|
||||
WHERE tt.src_id = ANY(@ids :: uuid[])
|
||||
UNION ALL
|
||||
SELECT tp.id AS peer_id, tp.coordinator_id, tp.updated_at, tp.node, tp.status,
|
||||
tt.dst_id AS lookup_id
|
||||
FROM tailnet_peers tp
|
||||
INNER JOIN tailnet_tunnels tt ON tp.id = tt.src_id
|
||||
WHERE tt.dst_id = ANY(@ids :: uuid[]);
|
||||
|
||||
-- For PG Coordinator HTMLDebug
|
||||
|
||||
-- name: GetAllTailnetCoordinators :many
|
||||
|
||||
@@ -193,6 +193,26 @@ WHERE user_configs.user_id = @user_id
|
||||
AND user_configs.key = 'chat_custom_prompt'
|
||||
RETURNING *;
|
||||
|
||||
-- name: ListUserChatCompactionThresholds :many
|
||||
SELECT user_id, key, value FROM user_configs
|
||||
WHERE user_id = @user_id
|
||||
AND key LIKE 'chat\_compaction\_threshold\_pct:%'
|
||||
ORDER BY key;
|
||||
|
||||
-- name: GetUserChatCompactionThreshold :one
|
||||
SELECT value AS threshold_percent FROM user_configs
|
||||
WHERE user_id = @user_id AND key = @key;
|
||||
|
||||
-- name: UpdateUserChatCompactionThreshold :one
|
||||
INSERT INTO user_configs (user_id, key, value)
|
||||
VALUES (@user_id, @key, (@threshold_percent::int)::text)
|
||||
ON CONFLICT ON CONSTRAINT user_configs_pkey
|
||||
DO UPDATE SET value = (@threshold_percent::int)::text
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteUserChatCompactionThreshold :exec
|
||||
DELETE FROM user_configs WHERE user_id = @user_id AND key = @key;
|
||||
|
||||
-- name: GetUserTaskNotificationAlertDismissed :one
|
||||
SELECT
|
||||
value::boolean as task_notification_alert_dismissed
|
||||
|
||||
@@ -78,29 +78,6 @@ SET
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: BatchUpdateWorkspaceAgentConnections :exec
|
||||
WITH agents AS (
|
||||
SELECT
|
||||
unnest(sqlc.arg('id')::uuid[]) AS id,
|
||||
unnest(sqlc.arg('first_connected_at')::timestamptz[]) AS first_connected_at,
|
||||
unnest(sqlc.arg('last_connected_at')::timestamptz[]) AS last_connected_at,
|
||||
unnest(sqlc.arg('last_connected_replica_id')::uuid[]) AS last_connected_replica_id,
|
||||
unnest(sqlc.arg('disconnected_at')::timestamptz[]) AS disconnected_at,
|
||||
unnest(sqlc.arg('updated_at')::timestamptz[]) AS updated_at
|
||||
)
|
||||
UPDATE
|
||||
workspace_agents wa
|
||||
SET
|
||||
first_connected_at = a.first_connected_at,
|
||||
last_connected_at = a.last_connected_at,
|
||||
last_connected_replica_id = a.last_connected_replica_id,
|
||||
disconnected_at = a.disconnected_at,
|
||||
updated_at = a.updated_at
|
||||
FROM
|
||||
agents a
|
||||
WHERE
|
||||
wa.id = a.id;
|
||||
|
||||
-- name: UpdateWorkspaceAgentStartupByID :exec
|
||||
UPDATE
|
||||
workspace_agents
|
||||
|
||||
@@ -3,7 +3,7 @@ package dynamicparameters
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"slices"
|
||||
|
||||
"github.com/hashicorp/hcl/v2"
|
||||
|
||||
@@ -94,7 +94,7 @@ func (e *DiagnosticError) Response() (int, codersdk.Response) {
|
||||
for name := range e.KeyedDiagnostics {
|
||||
sortedNames = append(sortedNames, name)
|
||||
}
|
||||
sort.Strings(sortedNames)
|
||||
slices.Sort(sortedNames)
|
||||
|
||||
for _, name := range sortedNames {
|
||||
diag := e.KeyedDiagnostics[name]
|
||||
|
||||
@@ -28,14 +28,11 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
|
||||
"github.com/coder/coder/v2/coderd/gitsync"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpapi/httperror"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
@@ -46,6 +43,9 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/gitsync"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/wsjson"
|
||||
"github.com/coder/websocket"
|
||||
@@ -2542,6 +2542,17 @@ func normalizeChatCompressionThreshold(
|
||||
return threshold, nil
|
||||
}
|
||||
|
||||
func parseCompactionThresholdKey(key string) (uuid.UUID, error) {
|
||||
if !strings.HasPrefix(key, codersdk.ChatCompactionThresholdKeyPrefix) {
|
||||
return uuid.Nil, xerrors.Errorf("invalid compaction threshold key: %q", key)
|
||||
}
|
||||
id, err := uuid.Parse(key[len(codersdk.ChatCompactionThresholdKeyPrefix):])
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("invalid model config ID in key %q: %w", key, err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
const (
|
||||
// maxChatFileSize is the maximum size of a chat file upload (10 MB).
|
||||
maxChatFileSize = 10 << 20
|
||||
@@ -2816,6 +2827,170 @@ func (api *API) putUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request)
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Get user chat compaction thresholds
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
func (api *API) getUserChatCompactionThresholds(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
apiKey = httpmw.APIKey(r)
|
||||
)
|
||||
|
||||
rows, err := api.Database.ListUserChatCompactionThresholds(ctx, apiKey.UserID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Error listing user chat compaction thresholds.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp := codersdk.UserChatCompactionThresholds{
|
||||
Thresholds: make([]codersdk.UserChatCompactionThreshold, 0, len(rows)),
|
||||
}
|
||||
for _, row := range rows {
|
||||
modelConfigID, err := parseCompactionThresholdKey(row.Key)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "skipping malformed user chat compaction threshold key",
|
||||
slog.F("key", row.Key),
|
||||
slog.F("value", row.Value),
|
||||
slog.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
thresholdPercent, err := strconv.ParseInt(row.Value, 10, 32)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "skipping malformed user chat compaction threshold value",
|
||||
slog.F("key", row.Key),
|
||||
slog.F("value", row.Value),
|
||||
slog.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if thresholdPercent < int64(minChatContextCompressionThreshold) ||
|
||||
thresholdPercent > int64(maxChatContextCompressionThreshold) {
|
||||
api.Logger.Warn(ctx, "skipping out-of-range user chat compaction threshold",
|
||||
slog.F("key", row.Key),
|
||||
slog.F("value", row.Value),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
resp.Thresholds = append(resp.Thresholds, codersdk.UserChatCompactionThreshold{
|
||||
ModelConfigID: modelConfigID,
|
||||
ThresholdPercent: int32(thresholdPercent),
|
||||
})
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// @Summary Set user chat compaction threshold for a model config
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) putUserChatCompactionThreshold(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
apiKey = httpmw.APIKey(r)
|
||||
)
|
||||
|
||||
modelConfigID, ok := parseChatModelConfigID(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.UpdateUserChatCompactionThresholdRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
if req.ThresholdPercent < minChatContextCompressionThreshold ||
|
||||
req.ThresholdPercent > maxChatContextCompressionThreshold {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "threshold_percent is out of range.",
|
||||
Detail: fmt.Sprintf(
|
||||
"threshold_percent must be between %d and %d, got %d.",
|
||||
minChatContextCompressionThreshold,
|
||||
maxChatContextCompressionThreshold,
|
||||
req.ThresholdPercent,
|
||||
),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Use system context because GetChatModelConfigByID requires
|
||||
// deployment-config read access, which non-admin users lack.
|
||||
// The user is only checking if the model exists and is enabled
|
||||
// before writing their own personal preference.
|
||||
//nolint:gocritic // Non-admin users need this lookup to save their own setting.
|
||||
modelConfig, err := api.Database.GetChatModelConfigByID(dbauthz.AsSystemRestricted(ctx), modelConfigID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) || httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get chat model config.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !modelConfig.Enabled {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Model config is disabled.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
_, err = api.Database.UpdateUserChatCompactionThreshold(ctx, database.UpdateUserChatCompactionThresholdParams{
|
||||
UserID: apiKey.UserID,
|
||||
Key: codersdk.CompactionThresholdKey(modelConfigID),
|
||||
ThresholdPercent: req.ThresholdPercent,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Error updating user chat compaction threshold.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCompactionThreshold{
|
||||
ModelConfigID: modelConfigID,
|
||||
ThresholdPercent: req.ThresholdPercent,
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Delete user chat compaction threshold for a model config
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) deleteUserChatCompactionThreshold(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
apiKey = httpmw.APIKey(r)
|
||||
)
|
||||
|
||||
modelConfigID, ok := parseChatModelConfigID(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := api.Database.DeleteUserChatCompactionThreshold(ctx, database.DeleteUserChatCompactionThresholdParams{
|
||||
UserID: apiKey.UserID,
|
||||
Key: codersdk.CompactionThresholdKey(modelConfigID),
|
||||
}); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Error deleting user chat compaction threshold.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (api *API) resolvedChatSystemPrompt(ctx context.Context) string {
|
||||
custom, err := api.Database.GetChatSystemPrompt(ctx)
|
||||
if err != nil {
|
||||
File diff suppressed because it is too large
Load Diff
+10
-27
@@ -21,11 +21,14 @@ import (
|
||||
|
||||
func TestPostFiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each sub-test
|
||||
// creates independent resources with unique IDs so parallel
|
||||
// execution is safe.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
t.Run("BadContentType", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
@@ -35,9 +38,6 @@ func TestPostFiles(t *testing.T) {
|
||||
|
||||
t.Run("Insert", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
@@ -47,9 +47,6 @@ func TestPostFiles(t *testing.T) {
|
||||
|
||||
t.Run("InsertWindowsZip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
@@ -59,9 +56,6 @@ func TestPostFiles(t *testing.T) {
|
||||
|
||||
t.Run("InsertAlreadyExists", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
@@ -73,9 +67,6 @@ func TestPostFiles(t *testing.T) {
|
||||
})
|
||||
t.Run("InsertConcurrent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
@@ -99,11 +90,12 @@ func TestPostFiles(t *testing.T) {
|
||||
|
||||
func TestDownload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Shared instance — see TestPostFiles for rationale.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
@@ -115,9 +107,6 @@ func TestDownload(t *testing.T) {
|
||||
|
||||
t.Run("InsertTar_DownloadTar", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// given
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
@@ -139,9 +128,6 @@ func TestDownload(t *testing.T) {
|
||||
|
||||
t.Run("InsertZip_DownloadTar", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// given
|
||||
zipContent := archivetest.TestZipFileBytes()
|
||||
|
||||
@@ -164,9 +150,6 @@ func TestDownload(t *testing.T) {
|
||||
|
||||
t.Run("InsertTar_DownloadZip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// given
|
||||
tarball := archivetest.TestTarFileBytes()
|
||||
|
||||
|
||||
+50
-28
@@ -248,12 +248,9 @@ func PrecheckAPIKey(cfg ValidateAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
//
|
||||
// Returns (result, nil) on success or (nil, error) on failure.
|
||||
func ValidateAPIKey(ctx context.Context, cfg ValidateAPIKeyConfig, r *http.Request) (*ValidateAPIKeyResult, *ValidateAPIKeyError) {
|
||||
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
|
||||
if !ok {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: resp,
|
||||
}
|
||||
key, valErr := apiKeyFromRequestValidate(ctx, cfg.DB, cfg.SessionTokenFunc, r)
|
||||
if valErr != nil {
|
||||
return nil, valErr
|
||||
}
|
||||
|
||||
// Log the API key ID for all requests that have a valid key
|
||||
@@ -475,7 +472,7 @@ func ValidateAPIKey(ctx context.Context, cfg ValidateAPIKeyConfig, r *http.Reque
|
||||
actor, userStatus, err := UserRBACSubject(ctx, cfg.DB, key.UserID, key.ScopeSet())
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Internal error fetching user's roles. %s", err.Error()),
|
||||
@@ -492,6 +489,15 @@ func ValidateAPIKey(ctx context.Context, cfg ValidateAPIKeyConfig, r *http.Reque
|
||||
}
|
||||
|
||||
func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, codersdk.Response, bool) {
|
||||
key, valErr := apiKeyFromRequestValidate(ctx, db, sessionTokenFunc, r)
|
||||
if valErr != nil {
|
||||
return nil, valErr.Response, false
|
||||
}
|
||||
|
||||
return key, codersdk.Response{}, true
|
||||
}
|
||||
|
||||
func apiKeyFromRequestValidate(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, *ValidateAPIKeyError) {
|
||||
tokenFunc := APITokenFromRequest
|
||||
if sessionTokenFunc != nil {
|
||||
tokenFunc = sessionTokenFunc
|
||||
@@ -499,45 +505,61 @@ func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc
|
||||
|
||||
token := tokenFunc(r)
|
||||
if token == "" {
|
||||
return nil, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie),
|
||||
}, false
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
keyID, keySecret, err := SplitAPIToken(token)
|
||||
if err != nil {
|
||||
return nil, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "Invalid API key format: " + err.Error(),
|
||||
}, false
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "Invalid API key format: " + err.Error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:gocritic // System needs to fetch API key to check if it's valid.
|
||||
key, err := db.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "API key is invalid.",
|
||||
}, false
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "API key is invalid.",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return nil, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()),
|
||||
}, false
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Checking to see if the secret is valid.
|
||||
if !apikey.ValidateHash(key.HashedSecret, keySecret) {
|
||||
return nil, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "API key secret is invalid.",
|
||||
}, false
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "API key secret is invalid.",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return &key, codersdk.Response{}, true
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
// ExtractAPIKey requires authentication using a valid API key. It handles
|
||||
|
||||
@@ -19,12 +19,14 @@ import (
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/apikey"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
@@ -192,6 +194,31 @@ func TestAPIKey(t *testing.T) {
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("GetAPIKeyByIDInternalError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
id, secret, _ := randomAPIKeyParts()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret))
|
||||
|
||||
db.EXPECT().GetAPIKeyByID(gomock.Any(), id).Return(database.APIKey{}, xerrors.New("db unavailable"))
|
||||
|
||||
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusInternalServerError, res.StatusCode)
|
||||
|
||||
var resp codersdk.Response
|
||||
require.NoError(t, json.NewDecoder(res.Body).Decode(&resp))
|
||||
require.NotEqual(t, httpmw.SignedOutErrorMessage, resp.Message)
|
||||
require.Contains(t, resp.Detail, "Internal error fetching API key by id")
|
||||
})
|
||||
|
||||
t.Run("UserLinkNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
|
||||
@@ -14,9 +14,13 @@ import (
|
||||
func TestInitScript(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. All operations
|
||||
// are read-only (fetching init scripts) so parallel execution
|
||||
// is safe.
|
||||
client := coderdtest.New(t, nil)
|
||||
|
||||
t.Run("OK Windows amd64", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
script, err := client.InitScript(context.Background(), "windows", "amd64")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, script)
|
||||
@@ -26,7 +30,6 @@ func TestInitScript(t *testing.T) {
|
||||
|
||||
t.Run("OK Windows arm64", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
script, err := client.InitScript(context.Background(), "windows", "arm64")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, script)
|
||||
@@ -36,7 +39,6 @@ func TestInitScript(t *testing.T) {
|
||||
|
||||
t.Run("OK Linux amd64", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
script, err := client.InitScript(context.Background(), "linux", "amd64")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, script)
|
||||
@@ -46,7 +48,6 @@ func TestInitScript(t *testing.T) {
|
||||
|
||||
t.Run("OK Linux arm64", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
script, err := client.InitScript(context.Background(), "linux", "arm64")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, script)
|
||||
@@ -56,7 +57,6 @@ func TestInitScript(t *testing.T) {
|
||||
|
||||
t.Run("BadRequest", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_, err := client.InitScript(context.Background(), "darwin", "armv7")
|
||||
require.Error(t, err)
|
||||
var apiErr *codersdk.Error
|
||||
|
||||
+207
-2
@@ -6,12 +6,16 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
@@ -107,9 +111,147 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
// Validate auth-type-dependent fields.
|
||||
switch req.AuthType {
|
||||
case "oauth2":
|
||||
if req.OAuth2ClientID == "" || req.OAuth2AuthURL == "" || req.OAuth2TokenURL == "" {
|
||||
// When the admin does not provide OAuth2 credentials, attempt
|
||||
// automatic discovery and Dynamic Client Registration (RFC 7591)
|
||||
// using the MCP server URL. This follows the MCP authorization
|
||||
// spec: discover the authorization server via Protected Resource
|
||||
// Metadata (RFC 9728) and Authorization Server Metadata
|
||||
// (RFC 8414), then register a client dynamically.
|
||||
if req.OAuth2ClientID == "" && req.OAuth2AuthURL == "" && req.OAuth2TokenURL == "" {
|
||||
// Auto-discovery flow: we need the config ID first to
|
||||
// build the correct callback URL. Insert the record
|
||||
// with empty OAuth2 fields, perform discovery, then
|
||||
// update.
|
||||
customHeadersJSON, err := marshalCustomHeaders(req.CustomHeaders)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid custom headers.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
inserted, err := api.Database.InsertMCPServerConfig(ctx, database.InsertMCPServerConfigParams{
|
||||
DisplayName: strings.TrimSpace(req.DisplayName),
|
||||
Slug: strings.TrimSpace(req.Slug),
|
||||
Description: strings.TrimSpace(req.Description),
|
||||
IconURL: strings.TrimSpace(req.IconURL),
|
||||
Transport: strings.TrimSpace(req.Transport),
|
||||
Url: strings.TrimSpace(req.URL),
|
||||
AuthType: strings.TrimSpace(req.AuthType),
|
||||
OAuth2ClientID: "",
|
||||
OAuth2ClientSecret: "",
|
||||
OAuth2ClientSecretKeyID: sql.NullString{},
|
||||
OAuth2AuthURL: "",
|
||||
OAuth2TokenURL: "",
|
||||
OAuth2Scopes: "",
|
||||
APIKeyHeader: strings.TrimSpace(req.APIKeyHeader),
|
||||
APIKeyValue: strings.TrimSpace(req.APIKeyValue),
|
||||
APIKeyValueKeyID: sql.NullString{},
|
||||
CustomHeaders: customHeadersJSON,
|
||||
CustomHeadersKeyID: sql.NullString{},
|
||||
ToolAllowList: coalesceStringSlice(trimStringSlice(req.ToolAllowList)),
|
||||
ToolDenyList: coalesceStringSlice(trimStringSlice(req.ToolDenyList)),
|
||||
Availability: strings.TrimSpace(req.Availability),
|
||||
Enabled: req.Enabled,
|
||||
CreatedBy: apiKey.UserID,
|
||||
UpdatedBy: apiKey.UserID,
|
||||
})
|
||||
if err != nil {
|
||||
switch {
|
||||
case database.IsUniqueViolation(err):
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "MCP server config already exists.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
case database.IsCheckViolation(err):
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid MCP server config.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
default:
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to create MCP server config.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Now build the callback URL with the actual ID.
|
||||
callbackURL := fmt.Sprintf("%s/api/experimental/mcp/servers/%s/oauth2/callback", api.AccessURL.String(), inserted.ID)
|
||||
result, err := discoverAndRegisterMCPOAuth2(ctx, strings.TrimSpace(req.URL), callbackURL)
|
||||
if err != nil {
|
||||
// Clean up: delete the partially created config.
|
||||
deleteErr := api.Database.DeleteMCPServerConfigByID(ctx, inserted.ID)
|
||||
if deleteErr != nil {
|
||||
api.Logger.Warn(ctx, "failed to clean up MCP server config after OAuth2 discovery failure",
|
||||
slog.F("config_id", inserted.ID),
|
||||
slog.Error(deleteErr),
|
||||
)
|
||||
}
|
||||
|
||||
api.Logger.Warn(ctx, "mcp oauth2 auto-discovery failed",
|
||||
slog.F("url", req.URL),
|
||||
slog.Error(err),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "OAuth2 auto-discovery failed. Provide oauth2_client_id, oauth2_auth_url, and oauth2_token_url manually, or ensure the MCP server supports RFC 9728 (Protected Resource Metadata) and RFC 7591 (Dynamic Client Registration).",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Determine scopes: use the request value if provided,
|
||||
// otherwise fall back to the discovered value.
|
||||
oauth2Scopes := strings.TrimSpace(req.OAuth2Scopes)
|
||||
if oauth2Scopes == "" {
|
||||
oauth2Scopes = result.scopes
|
||||
}
|
||||
|
||||
// Update the record with discovered OAuth2 credentials.
|
||||
updated, err := api.Database.UpdateMCPServerConfig(ctx, database.UpdateMCPServerConfigParams{
|
||||
ID: inserted.ID,
|
||||
DisplayName: inserted.DisplayName,
|
||||
Slug: inserted.Slug,
|
||||
Description: inserted.Description,
|
||||
IconURL: inserted.IconURL,
|
||||
Transport: inserted.Transport,
|
||||
Url: inserted.Url,
|
||||
AuthType: inserted.AuthType,
|
||||
OAuth2ClientID: result.clientID,
|
||||
OAuth2ClientSecret: result.clientSecret,
|
||||
OAuth2ClientSecretKeyID: sql.NullString{},
|
||||
OAuth2AuthURL: result.authURL,
|
||||
OAuth2TokenURL: result.tokenURL,
|
||||
OAuth2Scopes: oauth2Scopes,
|
||||
APIKeyHeader: inserted.APIKeyHeader,
|
||||
APIKeyValue: inserted.APIKeyValue,
|
||||
APIKeyValueKeyID: inserted.APIKeyValueKeyID,
|
||||
CustomHeaders: inserted.CustomHeaders,
|
||||
CustomHeadersKeyID: inserted.CustomHeadersKeyID,
|
||||
ToolAllowList: inserted.ToolAllowList,
|
||||
ToolDenyList: inserted.ToolDenyList,
|
||||
Availability: inserted.Availability,
|
||||
Enabled: inserted.Enabled,
|
||||
UpdatedBy: apiKey.UserID,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to update MCP server config with OAuth2 credentials.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, convertMCPServerConfig(updated))
|
||||
return
|
||||
} else if req.OAuth2ClientID == "" || req.OAuth2AuthURL == "" || req.OAuth2TokenURL == "" {
|
||||
// Partial manual config: all three fields are required together.
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "OAuth2 auth type requires oauth2_client_id, oauth2_auth_url, and oauth2_token_url.",
|
||||
Message: "OAuth2 auth type requires either all of oauth2_client_id, oauth2_auth_url, and oauth2_token_url (manual configuration), or none of them (automatic discovery via RFC 7591).",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -919,3 +1061,66 @@ func coalesceStringSlice(ss []string) []string {
|
||||
}
|
||||
return ss
|
||||
}
|
||||
|
||||
// mcpOAuth2Discovery holds the result of MCP OAuth2 auto-discovery
|
||||
// and Dynamic Client Registration.
|
||||
type mcpOAuth2Discovery struct {
|
||||
clientID string
|
||||
clientSecret string
|
||||
authURL string
|
||||
tokenURL string
|
||||
scopes string // space-separated
|
||||
}
|
||||
|
||||
// discoverAndRegisterMCPOAuth2 uses the mcp-go library's OAuthHandler to
|
||||
// perform the MCP OAuth2 discovery and Dynamic Client Registration flow:
|
||||
//
|
||||
// 1. Discover the authorization server via Protected Resource Metadata
|
||||
// (RFC 9728) and Authorization Server Metadata (RFC 8414).
|
||||
// 2. Register a client via Dynamic Client Registration (RFC 7591).
|
||||
// 3. Return the discovered endpoints and generated credentials.
|
||||
func discoverAndRegisterMCPOAuth2(ctx context.Context, mcpServerURL, callbackURL string) (*mcpOAuth2Discovery, error) {
|
||||
// Per the MCP spec, the authorization base URL is the MCP server
|
||||
// URL with the path component discarded (scheme + host only).
|
||||
parsed, err := url.Parse(mcpServerURL)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse MCP server URL: %w", err)
|
||||
}
|
||||
origin := fmt.Sprintf("%s://%s", parsed.Scheme, parsed.Host)
|
||||
|
||||
oauthHandler := transport.NewOAuthHandler(transport.OAuthConfig{
|
||||
RedirectURI: callbackURL,
|
||||
TokenStore: transport.NewMemoryTokenStore(),
|
||||
})
|
||||
oauthHandler.SetBaseURL(origin)
|
||||
|
||||
// Step 1: Discover authorization server metadata (RFC 9728 + RFC 8414).
|
||||
metadata, err := oauthHandler.GetServerMetadata(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("discover authorization server: %w", err)
|
||||
}
|
||||
if metadata.AuthorizationEndpoint == "" {
|
||||
return nil, xerrors.New("authorization server metadata missing authorization_endpoint")
|
||||
}
|
||||
if metadata.TokenEndpoint == "" {
|
||||
return nil, xerrors.New("authorization server metadata missing token_endpoint")
|
||||
}
|
||||
if metadata.RegistrationEndpoint == "" {
|
||||
return nil, xerrors.New("authorization server does not advertise a registration_endpoint (dynamic client registration may not be supported)")
|
||||
}
|
||||
|
||||
// Step 2: Register a client via Dynamic Client Registration (RFC 7591).
|
||||
if err := oauthHandler.RegisterClient(ctx, "Coder"); err != nil {
|
||||
return nil, xerrors.Errorf("dynamic client registration: %w", err)
|
||||
}
|
||||
|
||||
scopes := strings.Join(metadata.ScopesSupported, " ")
|
||||
|
||||
return &mcpOAuth2Discovery{
|
||||
clientID: oauthHandler.GetClientID(),
|
||||
clientSecret: oauthHandler.GetClientSecret(),
|
||||
authURL: metadata.AuthorizationEndpoint,
|
||||
tokenURL: metadata.TokenEndpoint,
|
||||
scopes: scopes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
+310
-4
@@ -3,6 +3,7 @@ package coderd_test
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -430,6 +431,309 @@ func TestMCPServerConfigsOAuth2Disconnect(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestMCPServerConfigsOAuth2AutoDiscovery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Stand up a mock auth server that serves RFC 8414 metadata and
|
||||
// a RFC 7591 dynamic client registration endpoint.
|
||||
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/.well-known/oauth-authorization-server":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"issuer": "` + r.Host + `",
|
||||
"authorization_endpoint": "` + "http://" + r.Host + `/authorize",
|
||||
"token_endpoint": "` + "http://" + r.Host + `/token",
|
||||
"registration_endpoint": "` + "http://" + r.Host + `/register",
|
||||
"response_types_supported": ["code"],
|
||||
"scopes_supported": ["read", "write"]
|
||||
}`))
|
||||
case "/register":
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_, _ = w.Write([]byte(`{
|
||||
"client_id": "auto-discovered-client-id",
|
||||
"client_secret": "auto-discovered-client-secret"
|
||||
}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(authServer.Close)
|
||||
|
||||
// Stand up a mock MCP server that serves RFC 9728 Protected
|
||||
// Resource Metadata pointing to the auth server above.
|
||||
mcpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/.well-known/oauth-protected-resource" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"resource": "` + "http://" + r.Host + `",
|
||||
"authorization_servers": ["` + authServer.URL + `"]
|
||||
}`))
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
t.Cleanup(mcpServer.Close)
|
||||
|
||||
client := newMCPClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create config with auth_type=oauth2 but no OAuth2 fields —
|
||||
// the server should auto-discover them.
|
||||
created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "Auto-Discovery Server",
|
||||
Slug: "auto-discovery",
|
||||
Transport: "streamable_http",
|
||||
URL: mcpServer.URL + "/v1/mcp",
|
||||
AuthType: "oauth2",
|
||||
Availability: "default_on",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "auto-discovered-client-id", created.OAuth2ClientID)
|
||||
require.True(t, created.HasOAuth2Secret)
|
||||
require.Equal(t, authServer.URL+"/authorize", created.OAuth2AuthURL)
|
||||
require.Equal(t, authServer.URL+"/token", created.OAuth2TokenURL)
|
||||
require.Equal(t, "read write", created.OAuth2Scopes)
|
||||
})
|
||||
|
||||
// Regression test: verify that during dynamic client registration
|
||||
// the redirect_uris sent to the authorization server contain the
|
||||
// real config UUID, NOT the literal string "{id}". Before the
|
||||
// fix, the callback URL was built before the config row existed,
|
||||
// so it contained "{id}" literally, which caused "redirect URIs
|
||||
// not approved" errors when the user later tried to connect.
|
||||
t.Run("RedirectURIContainsRealConfigID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Buffered channel so the handler never blocks.
|
||||
registeredRedirectURI := make(chan string, 1)
|
||||
|
||||
// Stand up a mock auth server that captures the redirect_uris
|
||||
// from the RFC 7591 Dynamic Client Registration request.
|
||||
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/.well-known/oauth-authorization-server":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"issuer": "` + "http://" + r.Host + `",
|
||||
"authorization_endpoint": "` + "http://" + r.Host + `/authorize",
|
||||
"token_endpoint": "` + "http://" + r.Host + `/token",
|
||||
"registration_endpoint": "` + "http://" + r.Host + `/register",
|
||||
"response_types_supported": ["code"],
|
||||
"scopes_supported": ["read", "write"]
|
||||
}`))
|
||||
case "/register":
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Decode the registration body and capture redirect_uris.
|
||||
var body map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, "bad json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if uris, ok := body["redirect_uris"].([]interface{}); ok && len(uris) > 0 {
|
||||
if uri, ok := uris[0].(string); ok {
|
||||
registeredRedirectURI <- uri
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_, _ = w.Write([]byte(`{
|
||||
"client_id": "test-client-id",
|
||||
"client_secret": "test-client-secret"
|
||||
}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(authServer.Close)
|
||||
|
||||
// Stand up a mock MCP server that returns RFC 9728 Protected
|
||||
// Resource Metadata pointing to the auth server.
|
||||
mcpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/.well-known/oauth-protected-resource" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"resource": "` + "http://" + r.Host + `",
|
||||
"authorization_servers": ["` + authServer.URL + `"]
|
||||
}`))
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
t.Cleanup(mcpServer.Close)
|
||||
|
||||
client := newMCPClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create config with auth_type=oauth2 but no OAuth2 fields to
|
||||
// trigger auto-discovery and dynamic client registration.
|
||||
created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "Redirect URI Test",
|
||||
Slug: "redirect-uri-test",
|
||||
Transport: "streamable_http",
|
||||
URL: mcpServer.URL + "/v1/mcp",
|
||||
AuthType: "oauth2",
|
||||
Availability: "default_on",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test-client-id", created.OAuth2ClientID)
|
||||
require.True(t, created.HasOAuth2Secret)
|
||||
|
||||
// The registration request has already completed by the time
|
||||
// CreateMCPServerConfig returns, so the URI is in the channel.
|
||||
var redirectURI string
|
||||
select {
|
||||
case redirectURI = <-registeredRedirectURI:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for registration redirect URI")
|
||||
}
|
||||
|
||||
// Core assertion: the redirect URI must NOT contain the
|
||||
// literal placeholder "{id}". Before the fix the callback
|
||||
// URL was built before the database insert, so it had
|
||||
// "{id}" where the UUID should be.
|
||||
require.NotContains(t, redirectURI, "{id}",
|
||||
"redirect URI sent during registration must not contain the literal \"{id}\" placeholder")
|
||||
|
||||
// Verify the redirect URI contains the real config UUID that
|
||||
// was assigned by the database.
|
||||
require.Contains(t, redirectURI, created.ID.String(),
|
||||
"redirect URI should contain the actual config UUID")
|
||||
|
||||
// Sanity-check the full path structure.
|
||||
require.Contains(t, redirectURI,
|
||||
"/api/experimental/mcp/servers/"+created.ID.String()+"/oauth2/callback",
|
||||
"redirect URI should have the expected callback path")
|
||||
|
||||
// Double-check that the ID segment is a valid UUID (not some
|
||||
// other placeholder or malformed value).
|
||||
pathParts := strings.Split(redirectURI, "/")
|
||||
var foundUUID bool
|
||||
for _, part := range pathParts {
|
||||
if _, err := uuid.Parse(part); err == nil {
|
||||
foundUUID = true
|
||||
require.Equal(t, created.ID.String(), part,
|
||||
"UUID in redirect URI path should match created config ID")
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundUUID,
|
||||
"redirect URI path should contain a valid UUID segment")
|
||||
})
|
||||
|
||||
t.Run("PartialOAuth2FieldsRejected", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newMCPClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Provide client_id but omit auth_url and token_url.
|
||||
_, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "Partial Fields",
|
||||
Slug: "partial-oauth2",
|
||||
Transport: "streamable_http",
|
||||
URL: "https://mcp.example.com/partial",
|
||||
AuthType: "oauth2",
|
||||
OAuth2ClientID: "only-client-id",
|
||||
Availability: "default_on",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
require.Contains(t, sdkErr.Message, "automatic discovery")
|
||||
})
|
||||
|
||||
t.Run("DiscoveryFailure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// MCP server that returns 404 for the well-known endpoint and
|
||||
// a non-401 status for the root — discovery has nothing to latch
|
||||
// onto.
|
||||
mcpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}))
|
||||
t.Cleanup(mcpServer.Close)
|
||||
|
||||
client := newMCPClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "Will Fail",
|
||||
Slug: "discovery-fail",
|
||||
Transport: "streamable_http",
|
||||
URL: mcpServer.URL + "/v1/mcp",
|
||||
AuthType: "oauth2",
|
||||
Availability: "default_on",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
require.Contains(t, sdkErr.Message, "auto-discovery failed")
|
||||
})
|
||||
|
||||
t.Run("ManualConfigStillWorks", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newMCPClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Providing all three OAuth2 fields bypasses discovery entirely.
|
||||
created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "Manual Config",
|
||||
Slug: "manual-oauth2",
|
||||
Transport: "streamable_http",
|
||||
URL: "https://mcp.example.com/manual",
|
||||
AuthType: "oauth2",
|
||||
OAuth2ClientID: "manual-client-id",
|
||||
OAuth2AuthURL: "https://auth.example.com/authorize",
|
||||
OAuth2TokenURL: "https://auth.example.com/token",
|
||||
Availability: "default_on",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "manual-client-id", created.OAuth2ClientID)
|
||||
require.Equal(t, "https://auth.example.com/authorize", created.OAuth2AuthURL)
|
||||
require.Equal(t, "https://auth.example.com/token", created.OAuth2TokenURL)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatWithMCPServerIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -437,14 +741,16 @@ func TestChatWithMCPServerIDs(t *testing.T) {
|
||||
client := newMCPClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
// Create the chat model config required for creating a chat.
|
||||
_ = createChatModelConfigForMCP(t, client)
|
||||
_ = createChatModelConfigForMCP(t, expClient)
|
||||
|
||||
// Create an enabled MCP server config.
|
||||
mcpConfig := createMCPServerConfig(t, client, "chat-mcp-server", true)
|
||||
|
||||
// Create a chat referencing the MCP server.
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
@@ -458,7 +764,7 @@ func TestChatWithMCPServerIDs(t *testing.T) {
|
||||
require.Contains(t, chat.MCPServerIDs, mcpConfig.ID)
|
||||
|
||||
// Fetch the chat and verify the MCP server IDs persist.
|
||||
fetched, err := client.GetChat(ctx, chat.ID)
|
||||
fetched, err := expClient.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, fetched.MCPServerIDs, mcpConfig.ID)
|
||||
}
|
||||
@@ -466,7 +772,7 @@ func TestChatWithMCPServerIDs(t *testing.T) {
|
||||
// createChatModelConfigForMCP sets up a chat provider and model
|
||||
// config so that CreateChat succeeds. This mirrors the helper in
|
||||
// chats_test.go but is defined here to avoid coupling.
|
||||
func createChatModelConfigForMCP(t testing.TB, client *codersdk.Client) codersdk.ChatModelConfig {
|
||||
func createChatModelConfigForMCP(t testing.TB, client *codersdk.ExperimentalClient) codersdk.ChatModelConfig {
|
||||
t.Helper()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
+44
-12
@@ -242,27 +242,51 @@ func (api *API) listMembers(rw http.ResponseWriter, r *http.Request) {
|
||||
// @Produce json
|
||||
// @Tags Members
|
||||
// @Param organization path string true "Organization ID"
|
||||
// @Param q query string false "Member search query"
|
||||
// @Param after_id query string false "After ID" format(uuid)
|
||||
// @Param limit query int false "Page limit, if 0 returns all members"
|
||||
// @Param offset query int false "Page offset"
|
||||
// @Success 200 {object} []codersdk.PaginatedMembersResponse
|
||||
// @Router /organizations/{organization}/paginated-members [get]
|
||||
func (api *API) paginatedMembers(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
organization = httpmw.OrganizationParam(r)
|
||||
paginationParams, ok = ParsePagination(rw, r)
|
||||
ctx = r.Context()
|
||||
organization = httpmw.OrganizationParam(r)
|
||||
)
|
||||
|
||||
filterQuery := r.URL.Query().Get("q")
|
||||
userFilterParams, filterErrs := searchquery.Users(filterQuery)
|
||||
if len(filterErrs) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid member search query.",
|
||||
Validations: filterErrs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
paginationParams, ok := ParsePagination(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
paginatedMemberRows, err := api.Database.PaginatedOrganizationMembers(ctx, database.PaginatedOrganizationMembersParams{
|
||||
OrganizationID: organization.ID,
|
||||
IncludeSystem: false,
|
||||
// #nosec G115 - Pagination limits are small and fit in int32
|
||||
LimitOpt: int32(paginationParams.Limit),
|
||||
AfterID: paginationParams.AfterID,
|
||||
OrganizationID: organization.ID,
|
||||
IncludeSystem: false,
|
||||
Search: userFilterParams.Search,
|
||||
Name: userFilterParams.Name,
|
||||
Status: userFilterParams.Status,
|
||||
RbacRole: userFilterParams.RbacRole,
|
||||
LastSeenBefore: userFilterParams.LastSeenBefore,
|
||||
LastSeenAfter: userFilterParams.LastSeenAfter,
|
||||
CreatedAfter: userFilterParams.CreatedAfter,
|
||||
CreatedBefore: userFilterParams.CreatedBefore,
|
||||
GithubComUserID: userFilterParams.GithubComUserID,
|
||||
LoginType: userFilterParams.LoginType,
|
||||
// #nosec G115 - Pagination offsets are small and fit in int32
|
||||
OffsetOpt: int32(paginationParams.Offset),
|
||||
// #nosec G115 - Pagination limits are small and fit in int32
|
||||
LimitOpt: int32(paginationParams.Limit),
|
||||
})
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
@@ -273,18 +297,21 @@ func (api *API) paginatedMembers(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
memberRows := make([]database.OrganizationMembersRow, 0)
|
||||
for _, pRow := range paginatedMemberRows {
|
||||
row := database.OrganizationMembersRow{
|
||||
memberRows := make([]database.OrganizationMembersRow, len(paginatedMemberRows))
|
||||
for i, pRow := range paginatedMemberRows {
|
||||
memberRows[i] = database.OrganizationMembersRow{
|
||||
OrganizationMember: pRow.OrganizationMember,
|
||||
Username: pRow.Username,
|
||||
AvatarURL: pRow.AvatarURL,
|
||||
Name: pRow.Name,
|
||||
Email: pRow.Email,
|
||||
GlobalRoles: pRow.GlobalRoles,
|
||||
LastSeenAt: pRow.LastSeenAt,
|
||||
Status: pRow.Status,
|
||||
LoginType: pRow.LoginType,
|
||||
UserCreatedAt: pRow.UserCreatedAt,
|
||||
UserUpdatedAt: pRow.UserUpdatedAt,
|
||||
}
|
||||
|
||||
memberRows = append(memberRows, row)
|
||||
}
|
||||
|
||||
if len(paginatedMemberRows) == 0 {
|
||||
@@ -501,6 +528,11 @@ func convertOrganizationMembersWithUserData(ctx context.Context, db database.Sto
|
||||
Name: rows[i].Name,
|
||||
Email: rows[i].Email,
|
||||
GlobalRoles: db2sdk.SlimRolesFromNames(rows[i].GlobalRoles),
|
||||
LastSeenAt: rows[i].LastSeenAt,
|
||||
Status: codersdk.UserStatus(rows[i].Status),
|
||||
LoginType: codersdk.LoginType(rows[i].LoginType),
|
||||
UserCreatedAt: rows[i].UserCreatedAt,
|
||||
UserUpdatedAt: rows[i].UserUpdatedAt,
|
||||
OrganizationMember: convertedMembers[i],
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
@@ -132,6 +134,67 @@ func TestListMembers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetOrgMembersFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
OIDCConfig: &coderd.OIDCConfig{
|
||||
AllowSignups: true,
|
||||
},
|
||||
})
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
setupCtx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
|
||||
res, err := client.OrganizationMembersPaginated(testCtx, first.OrganizationID, req)
|
||||
require.NoError(t, err)
|
||||
reduced := make([]codersdk.ReducedUser, len(res.Members))
|
||||
for i, user := range res.Members {
|
||||
reduced[i] = orgMemberToReducedUser(user)
|
||||
}
|
||||
return reduced
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetOrgMembersPagination(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
coderdtest.UsersPagination(ctx, t, client, nil, func(req codersdk.UsersRequest) ([]codersdk.ReducedUser, int) {
|
||||
res, err := client.OrganizationMembersPaginated(ctx, first.OrganizationID, req)
|
||||
require.NoError(t, err)
|
||||
reduced := make([]codersdk.ReducedUser, len(res.Members))
|
||||
for i, user := range res.Members {
|
||||
reduced[i] = orgMemberToReducedUser(user)
|
||||
}
|
||||
return reduced, res.Count
|
||||
})
|
||||
}
|
||||
|
||||
func onlyIDs(u codersdk.OrganizationMemberWithUserData) uuid.UUID {
|
||||
return u.UserID
|
||||
}
|
||||
|
||||
func orgMemberToReducedUser(user codersdk.OrganizationMemberWithUserData) codersdk.ReducedUser {
|
||||
return codersdk.ReducedUser{
|
||||
MinimalUser: codersdk.MinimalUser{
|
||||
ID: user.UserID,
|
||||
Username: user.Username,
|
||||
Name: user.Name,
|
||||
AvatarURL: user.AvatarURL,
|
||||
},
|
||||
Email: user.Email,
|
||||
CreatedAt: user.UserCreatedAt,
|
||||
UpdatedAt: user.UserUpdatedAt,
|
||||
LastSeenAt: user.LastSeenAt,
|
||||
Status: user.Status,
|
||||
LoginType: user.LoginType,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -549,8 +548,8 @@ func TestExpiredLeaseIsRequeued(t *testing.T) {
|
||||
leasedIDs = append(leasedIDs, msg.ID.String())
|
||||
}
|
||||
|
||||
sort.Strings(msgs)
|
||||
sort.Strings(leasedIDs)
|
||||
slices.Sort(msgs)
|
||||
slices.Sort(leasedIDs)
|
||||
require.EqualValues(t, msgs, leasedIDs)
|
||||
|
||||
// Wait out the lease period; all messages should be eligible to be re-acquired.
|
||||
|
||||
@@ -18,12 +18,13 @@ import (
|
||||
func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps with unique client names.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("RedirectURIValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
redirectURIs []string
|
||||
@@ -132,9 +133,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("ClientURIValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
clientURI string
|
||||
@@ -207,9 +205,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("LogoURIValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
logoURI string
|
||||
@@ -272,9 +267,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("GrantTypeValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
grantTypes []codersdk.OAuth2ProviderGrantType
|
||||
@@ -347,9 +339,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("ResponseTypeValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
responseTypes []codersdk.OAuth2ProviderResponseType
|
||||
@@ -407,9 +396,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("TokenEndpointAuthMethodValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authMethod codersdk.OAuth2TokenEndpointAuthMethod
|
||||
@@ -479,6 +465,10 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
func TestOAuth2ClientNameValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
clientName string
|
||||
@@ -530,8 +520,6 @@ func TestOAuth2ClientNameValidation(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
@@ -554,6 +542,10 @@ func TestOAuth2ClientNameValidation(t *testing.T) {
|
||||
func TestOAuth2ClientScopeValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
scope string
|
||||
@@ -615,8 +607,6 @@ func TestOAuth2ClientScopeValidation(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
@@ -682,11 +672,13 @@ func TestOAuth2ClientMetadataDefaults(t *testing.T) {
|
||||
func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps with unique client names.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("ExtremelyLongRedirectURI", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Create a very long but valid HTTPS URI
|
||||
@@ -709,8 +701,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Run("ManyRedirectURIs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Test with many redirect URIs
|
||||
@@ -732,8 +722,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Run("URIWithUnusualPort", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
@@ -748,8 +736,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Run("URIWithComplexPath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
@@ -764,8 +750,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Run("URIWithEncodedCharacters", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Test with URL-encoded characters
|
||||
|
||||
@@ -18,12 +18,13 @@ import (
|
||||
func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps with unique client names.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("RedirectURIValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
redirectURIs []string
|
||||
@@ -132,9 +133,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("ClientURIValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
clientURI string
|
||||
@@ -207,9 +205,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("LogoURIValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
logoURI string
|
||||
@@ -272,9 +267,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("GrantTypeValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
grantTypes []codersdk.OAuth2ProviderGrantType
|
||||
@@ -347,9 +339,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("ResponseTypeValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
responseTypes []codersdk.OAuth2ProviderResponseType
|
||||
@@ -407,9 +396,6 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
t.Run("TokenEndpointAuthMethodValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authMethod codersdk.OAuth2TokenEndpointAuthMethod
|
||||
@@ -479,6 +465,10 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
func TestOAuth2ClientNameValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
clientName string
|
||||
@@ -530,8 +520,6 @@ func TestOAuth2ClientNameValidation(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
@@ -554,6 +542,10 @@ func TestOAuth2ClientNameValidation(t *testing.T) {
|
||||
func TestOAuth2ClientScopeValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
scope string
|
||||
@@ -615,8 +607,6 @@ func TestOAuth2ClientScopeValidation(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
@@ -682,11 +672,13 @@ func TestOAuth2ClientMetadataDefaults(t *testing.T) {
|
||||
func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each registers independent OAuth2 apps with unique client names.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("ExtremelyLongRedirectURI", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Create a very long but valid HTTPS URI
|
||||
@@ -709,8 +701,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Run("ManyRedirectURIs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Test with many redirect URIs
|
||||
@@ -732,8 +722,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Run("URIWithUnusualPort", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
@@ -748,8 +736,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Run("URIWithComplexPath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
@@ -764,8 +750,6 @@ func TestOAuth2ClientMetadataEdgeCases(t *testing.T) {
|
||||
t.Run("URIWithEncodedCharacters", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Test with URL-encoded characters
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package prometheusmetrics_test
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
@@ -134,7 +135,7 @@ func collectAndSortMetrics(t *testing.T, collector prometheus.Collector, count i
|
||||
|
||||
// Ensure always the same order of metrics
|
||||
sort.Slice(metrics, func(i, j int) bool {
|
||||
return sort.StringsAreSorted([]string{metrics[i].Label[0].GetValue(), metrics[j].Label[1].GetValue()})
|
||||
return slices.IsSorted([]string{metrics[i].Label[0].GetValue(), metrics[j].Label[1].GetValue()})
|
||||
})
|
||||
return metrics
|
||||
}
|
||||
|
||||
+19
-2
@@ -316,13 +316,16 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
denyPermissions...,
|
||||
),
|
||||
User: append(
|
||||
allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUser, ResourceOrganizationMember, ResourceOrganizationMember, ResourceBoundaryUsage),
|
||||
allPermsExcept(ResourceWorkspaceDormant, ResourcePrebuiltWorkspace, ResourceWorkspace, ResourceUser, ResourceOrganizationMember, ResourceOrganizationMember, ResourceBoundaryUsage, ResourceAibridgeInterception),
|
||||
Permissions(map[string][]policy.Action{
|
||||
// Users cannot do create/update/delete on themselves, but they
|
||||
// can read their own details.
|
||||
ResourceUser.Type: {policy.ActionRead, policy.ActionReadPersonal, policy.ActionUpdatePersonal},
|
||||
// Users can create provisioner daemons scoped to themselves.
|
||||
ResourceProvisionerDaemon.Type: {policy.ActionRead, policy.ActionCreate, policy.ActionRead, policy.ActionUpdate},
|
||||
// Members can create and update AI Bridge interceptions but
|
||||
// cannot read them back.
|
||||
ResourceAibridgeInterception.Type: {policy.ActionCreate, policy.ActionUpdate},
|
||||
})...,
|
||||
),
|
||||
ByOrgID: map[string]OrgPermissions{},
|
||||
@@ -345,7 +348,7 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
||||
// Allow auditors to query deployment stats and insights.
|
||||
ResourceDeploymentStats.Type: {policy.ActionRead},
|
||||
ResourceDeploymentConfig.Type: {policy.ActionRead},
|
||||
// Allow auditors to query aibridge interceptions.
|
||||
// Allow auditors to query AI Bridge interceptions.
|
||||
ResourceAibridgeInterception.Type: {policy.ActionRead},
|
||||
}),
|
||||
User: []Permission{},
|
||||
@@ -998,6 +1001,7 @@ func OrgMemberPermissions(org OrgSettings) OrgRolePermissions {
|
||||
ResourcePrebuiltWorkspace,
|
||||
ResourceUser,
|
||||
ResourceOrganizationMember,
|
||||
ResourceAibridgeInterception,
|
||||
),
|
||||
Permissions(map[string][]policy.Action{
|
||||
// Reduced permission set on dormant workspaces. No build,
|
||||
@@ -1016,6 +1020,12 @@ func OrgMemberPermissions(org OrgSettings) OrgRolePermissions {
|
||||
ResourceOrganizationMember.Type: {
|
||||
policy.ActionRead,
|
||||
},
|
||||
// Members can create and update AI Bridge interceptions but
|
||||
// cannot read them back.
|
||||
ResourceAibridgeInterception.Type: {
|
||||
policy.ActionCreate,
|
||||
policy.ActionUpdate,
|
||||
},
|
||||
})...,
|
||||
)
|
||||
|
||||
@@ -1073,6 +1083,7 @@ func OrgServiceAccountPermissions(org OrgSettings) OrgRolePermissions {
|
||||
ResourcePrebuiltWorkspace,
|
||||
ResourceUser,
|
||||
ResourceOrganizationMember,
|
||||
ResourceAibridgeInterception,
|
||||
),
|
||||
Permissions(map[string][]policy.Action{
|
||||
// Reduced permission set on dormant workspaces. No build,
|
||||
@@ -1091,6 +1102,12 @@ func OrgServiceAccountPermissions(org OrgSettings) OrgRolePermissions {
|
||||
ResourceOrganizationMember.Type: {
|
||||
policy.ActionRead,
|
||||
},
|
||||
// Service accounts can create and update AI Bridge
|
||||
// interceptions but cannot read them back.
|
||||
ResourceAibridgeInterception.Type: {
|
||||
policy.ActionCreate,
|
||||
policy.ActionUpdate,
|
||||
},
|
||||
})...,
|
||||
)
|
||||
|
||||
|
||||
@@ -1023,8 +1023,9 @@ func TestRolePermissions(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "AIBridgeInterceptions",
|
||||
Actions: []policy.Action{policy.ActionCreate, policy.ActionRead, policy.ActionUpdate},
|
||||
// Members can create/update records but can't read them afterwards.
|
||||
Name: "AIBridgeInterceptionsCreateUpdate",
|
||||
Actions: []policy.Action{policy.ActionCreate, policy.ActionUpdate},
|
||||
Resource: rbac.ResourceAibridgeInterception.WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, memberMe},
|
||||
@@ -1036,6 +1037,22 @@ func TestRolePermissions(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
// Only owners and site-wide auditors can view interceptions and their sub-resources.
|
||||
Name: "AIBridgeInterceptionsRead",
|
||||
Actions: []policy.Action{policy.ActionRead},
|
||||
Resource: rbac.ResourceAibridgeInterception.WithOwner(currentUser.String()),
|
||||
AuthorizeMap: map[bool][]hasAuthSubjects{
|
||||
true: {owner, auditor},
|
||||
false: {
|
||||
memberMe,
|
||||
orgAdmin, otherOrgAdmin,
|
||||
orgAuditor, otherOrgAuditor,
|
||||
templateAdmin, orgTemplateAdmin, otherOrgTemplateAdmin,
|
||||
userAdmin, orgUserAdmin, otherOrgUserAdmin,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "BoundaryUsage",
|
||||
Actions: []policy.Action{policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
|
||||
@@ -3,7 +3,6 @@ package rbac
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -176,7 +175,7 @@ func CompositeScopeNames() []string {
|
||||
for k := range compositePerms {
|
||||
out = append(out, string(k))
|
||||
}
|
||||
sort.Strings(out)
|
||||
slices.Sort(out)
|
||||
return out
|
||||
}
|
||||
|
||||
|
||||
@@ -40,7 +40,8 @@ var externalLowLevel = map[ScopeName]struct{}{
|
||||
"file:create": {},
|
||||
"file:*": {},
|
||||
|
||||
// Users (personal profile only)
|
||||
// Users
|
||||
"user:read": {},
|
||||
"user:read_personal": {},
|
||||
"user:update_personal": {},
|
||||
"user.*": {},
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package rbac
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -16,7 +16,7 @@ func TestExternalScopeNames(t *testing.T) {
|
||||
|
||||
// Ensure sorted ascending
|
||||
sorted := append([]string(nil), names...)
|
||||
sort.Strings(sorted)
|
||||
slices.Sort(sorted)
|
||||
require.Equal(t, sorted, names)
|
||||
|
||||
// Ensure each entry expands to site-only
|
||||
@@ -62,6 +62,7 @@ func TestIsExternalScope(t *testing.T) {
|
||||
require.True(t, IsExternalScope("template:use"))
|
||||
require.True(t, IsExternalScope("workspace:*"))
|
||||
require.True(t, IsExternalScope("coder:workspaces.create"))
|
||||
require.True(t, IsExternalScope("user:read"))
|
||||
require.False(t, IsExternalScope("debug_info:read")) // internal-only
|
||||
require.False(t, IsExternalScope("unknown:read"))
|
||||
}
|
||||
|
||||
@@ -401,6 +401,49 @@ func AIBridgeInterceptions(ctx context.Context, db database.Store, query string,
|
||||
return filter, parser.Errors
|
||||
}
|
||||
|
||||
func AIBridgeSessions(ctx context.Context, db database.Store, query string, page codersdk.Pagination, actorID uuid.UUID, afterSessionID string) (database.ListAIBridgeSessionsParams, []codersdk.ValidationError) {
|
||||
// nolint:exhaustruct // Empty values just means "don't filter by that field".
|
||||
filter := database.ListAIBridgeSessionsParams{
|
||||
AfterSessionID: afterSessionID,
|
||||
// #nosec G115 - Safe conversion for pagination limit which is expected to be within int32 range
|
||||
Limit: int32(page.Limit),
|
||||
// #nosec G115 - Safe conversion for pagination offset which is expected to be within int32 range
|
||||
Offset: int32(page.Offset),
|
||||
}
|
||||
|
||||
if query == "" {
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
values, errors := searchTerms(query, func(string, url.Values) error {
|
||||
// Do not specify a default search key; let's be explicit to prevent user confusion.
|
||||
return xerrors.New("no search key specified")
|
||||
})
|
||||
if len(errors) > 0 {
|
||||
return filter, errors
|
||||
}
|
||||
|
||||
parser := httpapi.NewQueryParamParser()
|
||||
filter.InitiatorID = parseUser(ctx, db, parser, values, "initiator", actorID)
|
||||
filter.Provider = parser.String(values, "", "provider")
|
||||
filter.Model = parser.String(values, "", "model")
|
||||
filter.Client = parser.String(values, "", "client")
|
||||
filter.SessionID = parser.String(values, "", "session_id")
|
||||
|
||||
// Time must be between started_after and started_before.
|
||||
filter.StartedAfter = parser.Time3339Nano(values, time.Time{}, "started_after")
|
||||
filter.StartedBefore = parser.Time3339Nano(values, time.Time{}, "started_before")
|
||||
if !filter.StartedBefore.IsZero() && !filter.StartedAfter.IsZero() && !filter.StartedBefore.After(filter.StartedAfter) {
|
||||
parser.Errors = append(parser.Errors, codersdk.ValidationError{
|
||||
Field: "started_before",
|
||||
Detail: `Query param "started_before" has invalid value: "started_before" must be after "started_after" if set`,
|
||||
})
|
||||
}
|
||||
|
||||
parser.ErrorExcessParams(values)
|
||||
return filter, parser.Errors
|
||||
}
|
||||
|
||||
func AIBridgeModels(query string, page codersdk.Pagination) (database.ListAIBridgeModelsParams, []codersdk.ValidationError) {
|
||||
// nolint:exhaustruct // Empty values just means "don't filter by that field".
|
||||
filter := database.ListAIBridgeModelsParams{
|
||||
|
||||
@@ -1272,10 +1272,14 @@ func TestTemplateVersionsByTemplate(t *testing.T) {
|
||||
|
||||
func TestTemplateVersionByName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all sub-tests. Each sub-test
|
||||
// creates its own template version and template with unique
|
||||
// IDs so parallel execution is safe.
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
@@ -1290,8 +1294,6 @@ func TestTemplateVersionByName(t *testing.T) {
|
||||
|
||||
t.Run("Found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
@@ -1935,10 +1937,12 @@ func TestPaginatedTemplateVersions(t *testing.T) {
|
||||
|
||||
func TestTemplateVersionByOrganizationTemplateAndName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Shared instance — see TestTemplateVersionByName for rationale.
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
@@ -1953,8 +1957,6 @@ func TestTemplateVersionByOrganizationTemplateAndName(t *testing.T) {
|
||||
|
||||
t.Run("Found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
@@ -2204,10 +2206,14 @@ func TestTemplateVersionVariables(t *testing.T) {
|
||||
|
||||
func TestTemplateVersionPatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Single instance shared across all 9 sub-tests. Each sub-test
|
||||
// creates its own template version(s) and template(s) with
|
||||
// unique IDs so parallel execution is safe.
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
t.Run("Update the name", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
@@ -2226,8 +2232,6 @@ func TestTemplateVersionPatch(t *testing.T) {
|
||||
|
||||
t.Run("Update the message", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil, func(req *codersdk.CreateTemplateVersionRequest) {
|
||||
req.Message = "Example message"
|
||||
})
|
||||
@@ -2247,8 +2251,6 @@ func TestTemplateVersionPatch(t *testing.T) {
|
||||
|
||||
t.Run("Remove the message", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil, func(req *codersdk.CreateTemplateVersionRequest) {
|
||||
req.Message = "Example message"
|
||||
})
|
||||
@@ -2268,8 +2270,6 @@ func TestTemplateVersionPatch(t *testing.T) {
|
||||
|
||||
t.Run("Keep the message", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
wantMessage := "Example message"
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil, func(req *codersdk.CreateTemplateVersionRequest) {
|
||||
req.Message = wantMessage
|
||||
@@ -2291,8 +2291,6 @@ func TestTemplateVersionPatch(t *testing.T) {
|
||||
|
||||
t.Run("Use the same name if a new name is not passed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
@@ -2306,9 +2304,6 @@ func TestTemplateVersionPatch(t *testing.T) {
|
||||
|
||||
t.Run("Use the same name for two different templates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
version1 := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
coderdtest.CreateTemplate(t, client, user.OrganizationID, version1.ID)
|
||||
version2 := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
@@ -2334,8 +2329,6 @@ func TestTemplateVersionPatch(t *testing.T) {
|
||||
|
||||
t.Run("Use the same name for two versions for the same templates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version1 := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil, func(ctvr *codersdk.CreateTemplateVersionRequest) {
|
||||
ctvr.Name = "v1"
|
||||
})
|
||||
@@ -2356,8 +2349,6 @@ func TestTemplateVersionPatch(t *testing.T) {
|
||||
|
||||
t.Run("Rename the unassigned template", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version1 := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
@@ -2373,8 +2364,6 @@ func TestTemplateVersionPatch(t *testing.T) {
|
||||
|
||||
t.Run("Use incorrect template version name", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version1 := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
|
||||
+3
-40
@@ -7,7 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"sort"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -744,43 +744,6 @@ func (api *API) postLogout(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Set session token cookie
|
||||
// @Description Converts the current session token into a Set-Cookie response.
|
||||
// @Description This is used by embedded iframes (e.g. VS Code chat) that
|
||||
// @Description receive a session token out-of-band via postMessage but need
|
||||
// @Description cookie-based auth for WebSocket connections.
|
||||
// @ID set-session-token-cookie
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Authorization
|
||||
// @Success 204
|
||||
// @Router /users/me/session/token-to-cookie [post]
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) postSessionTokenCookie(rw http.ResponseWriter, r *http.Request) {
|
||||
// Only accept the token from the Coder-Session-Token header.
|
||||
// Other sources (query params, cookies) should not be allowed
|
||||
// to bootstrap a new cookie.
|
||||
token := r.Header.Get(codersdk.SessionTokenHeader)
|
||||
if token == "" {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Session token must be provided via the Coder-Session-Token header.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
cookie := api.DeploymentValues.HTTPCookies.Apply(&http.Cookie{
|
||||
Name: codersdk.SessionTokenCookie,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
// Expire the cookie when the underlying API key expires.
|
||||
Expires: apiKey.ExpiresAt,
|
||||
})
|
||||
http.SetCookie(rw, cookie)
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// GithubOAuth2Team represents a team scoped to an organization.
|
||||
type GithubOAuth2Team struct {
|
||||
Organization string
|
||||
@@ -1626,7 +1589,7 @@ func claimFields(claims map[string]interface{}) []string {
|
||||
for field := range claims {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
sort.Strings(fields)
|
||||
slices.Sort(fields)
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -1639,7 +1602,7 @@ func blankFields(claims map[string]interface{}) []string {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
}
|
||||
sort.Strings(fields)
|
||||
slices.Sort(fields)
|
||||
return fields
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/agentconnectionbatcher"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
@@ -258,7 +257,6 @@ func (api *API) startAgentYamuxMonitor(ctx context.Context,
|
||||
db: api.Database,
|
||||
replicaID: api.ID,
|
||||
updater: api,
|
||||
connectionBatcher: api.connectionBatcher,
|
||||
disconnectTimeout: api.AgentInactiveDisconnectTimeout,
|
||||
logger: api.Logger.With(
|
||||
slog.F("workspace_id", workspaceBuild.WorkspaceID),
|
||||
@@ -294,8 +292,6 @@ type agentConnectionMonitor struct {
|
||||
logger slog.Logger
|
||||
pingPeriod time.Duration
|
||||
|
||||
connectionBatcher *agentconnectionbatcher.Batcher
|
||||
|
||||
// state manipulated by both sendPings() and monitor() goroutines: needs to be threadsafe
|
||||
lastPing atomic.Pointer[time.Time]
|
||||
|
||||
@@ -458,32 +454,17 @@ func (m *agentConnectionMonitor) monitor(ctx context.Context) {
|
||||
Valid: true,
|
||||
}
|
||||
|
||||
if m.connectionBatcher != nil {
|
||||
m.connectionBatcher.Add(agentconnectionbatcher.Update{
|
||||
ID: m.workspaceAgent.ID,
|
||||
FirstConnectedAt: m.firstConnectedAt,
|
||||
LastConnectedAt: m.lastConnectedAt,
|
||||
DisconnectedAt: m.disconnectedAt,
|
||||
UpdatedAt: dbtime.Now(),
|
||||
LastConnectedReplicaID: uuid.NullUUID{
|
||||
UUID: m.replicaID,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
err = m.updateConnectionTimes(ctx)
|
||||
if err != nil {
|
||||
reason = err.Error()
|
||||
if !database.IsQueryCanceledError(err) {
|
||||
m.logger.Error(ctx, "failed to update agent connection times", slog.Error(err))
|
||||
}
|
||||
return
|
||||
err = m.updateConnectionTimes(ctx)
|
||||
if err != nil {
|
||||
reason = err.Error()
|
||||
if !database.IsQueryCanceledError(err) {
|
||||
m.logger.Error(ctx, "failed to update agent connection times", slog.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
// We don't need to publish a workspace update here because we
|
||||
// published an update when the workspace first connected. Since
|
||||
// all we've done is updated lastConnectedAt, the workspace is
|
||||
// still connected and hasn't changed status.
|
||||
// we don't need to publish a workspace update here because we published an update when the workspace first
|
||||
// connected. Since all we've done is updated lastConnectedAt, the workspace is still connected and hasn't
|
||||
// changed status. We don't expect to get updates just for the times changing.
|
||||
|
||||
ctx, err := dbauthz.WithWorkspaceRBAC(ctx, m.workspace.RBACObject())
|
||||
if err != nil {
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatcost"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatcost"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -20,12 +21,6 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatcost"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatloop"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/coderd/chatd/mcpclient"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
@@ -34,6 +29,12 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/coderd/webpush"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatcost"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/mcpclient"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/quartz"
|
||||
@@ -1199,6 +1200,7 @@ type chatMessage struct {
|
||||
contextLimit int64
|
||||
totalCostMicros int64
|
||||
runtimeMs int64
|
||||
providerResponseID string
|
||||
}
|
||||
|
||||
func newChatMessage(
|
||||
@@ -1255,6 +1257,101 @@ func (m chatMessage) withRuntimeMs(ms int64) chatMessage {
|
||||
return m
|
||||
}
|
||||
|
||||
func (m chatMessage) withProviderResponseID(id string) chatMessage {
|
||||
m.providerResponseID = id
|
||||
return m
|
||||
}
|
||||
|
||||
// chainModeInfo holds the information needed to determine whether
|
||||
// a follow-up turn can use OpenAI's previous_response_id chaining
|
||||
// instead of replaying full conversation history.
|
||||
type chainModeInfo struct {
|
||||
// previousResponseID is the provider response ID from the last
|
||||
// assistant message, if any.
|
||||
previousResponseID string
|
||||
// modelConfigID is the model configuration used to produce the
|
||||
// assistant message referenced by previousResponseID.
|
||||
modelConfigID uuid.UUID
|
||||
// trailingUserCount is the number of contiguous user messages
|
||||
// at the end of the conversation that form the current turn.
|
||||
trailingUserCount int
|
||||
}
|
||||
|
||||
// resolveChainMode scans DB messages from the end to count trailing user
|
||||
// messages for the current turn and detect whether the immediately
|
||||
// preceding assistant/tool block can chain from a provider response ID.
|
||||
func resolveChainMode(messages []database.ChatMessage) chainModeInfo {
|
||||
var info chainModeInfo
|
||||
i := len(messages) - 1
|
||||
for ; i >= 0; i-- {
|
||||
if messages[i].Role == database.ChatMessageRoleUser {
|
||||
info.trailingUserCount++
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
for ; i >= 0; i-- {
|
||||
switch messages[i].Role {
|
||||
case database.ChatMessageRoleAssistant:
|
||||
if messages[i].ProviderResponseID.Valid &&
|
||||
messages[i].ProviderResponseID.String != "" {
|
||||
info.previousResponseID = messages[i].ProviderResponseID.String
|
||||
if messages[i].ModelConfigID.Valid {
|
||||
info.modelConfigID = messages[i].ModelConfigID.UUID
|
||||
}
|
||||
return info
|
||||
}
|
||||
return info
|
||||
case database.ChatMessageRoleTool:
|
||||
continue
|
||||
default:
|
||||
return info
|
||||
}
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// filterPromptForChainMode keeps only system messages and the last
|
||||
// trailingUserCount user messages from the prompt. Assistant and tool
|
||||
// messages are dropped because the provider already has them via the
|
||||
// previous_response_id chain.
|
||||
func filterPromptForChainMode(
|
||||
prompt []fantasy.Message,
|
||||
trailingUserCount int,
|
||||
) []fantasy.Message {
|
||||
if trailingUserCount <= 0 {
|
||||
return prompt
|
||||
}
|
||||
|
||||
totalUsers := 0
|
||||
for _, msg := range prompt {
|
||||
if msg.Role == "user" {
|
||||
totalUsers++
|
||||
}
|
||||
}
|
||||
|
||||
usersToSkip := totalUsers - trailingUserCount
|
||||
if usersToSkip < 0 {
|
||||
usersToSkip = 0
|
||||
}
|
||||
|
||||
filtered := make([]fantasy.Message, 0, len(prompt))
|
||||
usersSeen := 0
|
||||
for _, msg := range prompt {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
filtered = append(filtered, msg)
|
||||
case "user":
|
||||
usersSeen++
|
||||
if usersSeen > usersToSkip {
|
||||
filtered = append(filtered, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
// appendChatMessage appends a single message to the batch insert params.
|
||||
func appendChatMessage(
|
||||
params *database.InsertChatMessagesParams,
|
||||
@@ -1276,6 +1373,7 @@ func appendChatMessage(
|
||||
params.Compressed = append(params.Compressed, msg.compressed)
|
||||
params.TotalCostMicros = append(params.TotalCostMicros, msg.totalCostMicros)
|
||||
params.RuntimeMs = append(params.RuntimeMs, msg.runtimeMs)
|
||||
params.ProviderResponseID = append(params.ProviderResponseID, msg.providerResponseID)
|
||||
}
|
||||
|
||||
func insertUserMessageAndSetPending(
|
||||
@@ -2823,6 +2921,7 @@ func (p *Server) runChat(
|
||||
if err := g.Wait(); err != nil {
|
||||
return result, err
|
||||
}
|
||||
chainInfo := resolveChainMode(messages)
|
||||
result.PushSummaryModel = model
|
||||
result.ProviderKeys = providerKeys
|
||||
// Fire title generation asynchronously so it doesn't block the
|
||||
@@ -3092,7 +3191,8 @@ func (p *Server) runChat(
|
||||
reasoningTokens, cacheCreationTokens, cacheReadTokens,
|
||||
).withContextLimit(contextLimit).
|
||||
withTotalCostMicros(totalCostVal).
|
||||
withRuntimeMs(runtimeMs))
|
||||
withRuntimeMs(runtimeMs).
|
||||
withProviderResponseID(step.ProviderResponseID))
|
||||
}
|
||||
|
||||
for _, resultContent := range toolResultContents {
|
||||
@@ -3150,8 +3250,14 @@ func (p *Server) runChat(
|
||||
// "Summarizing..." tool call with the "Summarized" tool
|
||||
// result.
|
||||
compactionToolCallID := "chat_summarized_" + uuid.NewString()
|
||||
effectiveThreshold := modelConfig.CompressionThreshold
|
||||
thresholdSource := "model_default"
|
||||
if override, ok := p.resolveUserCompactionThreshold(ctx, chat.OwnerID, modelConfig.ID); ok {
|
||||
effectiveThreshold = override
|
||||
thresholdSource = "user_override"
|
||||
}
|
||||
compactionOptions := &chatloop.CompactionOptions{
|
||||
ThresholdPercent: modelConfig.CompressionThreshold,
|
||||
ThresholdPercent: effectiveThreshold,
|
||||
ContextLimit: modelConfig.ContextLimit,
|
||||
Persist: func(
|
||||
persistCtx context.Context,
|
||||
@@ -3168,6 +3274,7 @@ func (p *Server) runChat(
|
||||
}
|
||||
logger.Info(persistCtx, "chat context summarized",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("threshold_source", thresholdSource),
|
||||
slog.F("threshold_percent", result.ThresholdPercent),
|
||||
slog.F("usage_percent", result.UsagePercent),
|
||||
slog.F("context_tokens", result.ContextTokens),
|
||||
@@ -3227,6 +3334,7 @@ func (p *Server) runChat(
|
||||
// create workspaces or spawn further subagents — they should
|
||||
// focus on completing their delegated task.
|
||||
if !chat.ParentChatID.Valid {
|
||||
// Workspace provisioning tools.
|
||||
tools = append(tools,
|
||||
chattool.ListTemplates(chattool.ListTemplatesOptions{
|
||||
DB: p.db,
|
||||
@@ -3254,6 +3362,37 @@ func (p *Server) runChat(
|
||||
WorkspaceMu: &workspaceMu,
|
||||
}),
|
||||
)
|
||||
// Plan presentation tool.
|
||||
tools = append(tools, chattool.ProposePlan(chattool.ProposePlanOptions{
|
||||
GetWorkspaceConn: workspaceCtx.getWorkspaceConn,
|
||||
StoreFile: func(ctx context.Context, name string, mediaType string, data []byte) (uuid.UUID, error) {
|
||||
workspaceCtx.chatStateMu.Lock()
|
||||
chatSnapshot := *workspaceCtx.currentChat
|
||||
workspaceCtx.chatStateMu.Unlock()
|
||||
|
||||
if !chatSnapshot.WorkspaceID.Valid {
|
||||
return uuid.Nil, xerrors.New("chat has no workspace")
|
||||
}
|
||||
|
||||
ws, err := p.db.GetWorkspaceByID(ctx, chatSnapshot.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("resolve workspace: %w", err)
|
||||
}
|
||||
|
||||
row, err := p.db.InsertChatFile(ctx, database.InsertChatFileParams{
|
||||
OwnerID: chatSnapshot.OwnerID,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
Name: name,
|
||||
Mimetype: mediaType,
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("insert chat file: %w", err)
|
||||
}
|
||||
|
||||
return row.ID, nil
|
||||
},
|
||||
}))
|
||||
tools = append(tools, p.subagentTools(ctx, func() database.Chat {
|
||||
return chat
|
||||
})...)
|
||||
@@ -3272,24 +3411,49 @@ func (p *Server) runChat(
|
||||
}
|
||||
|
||||
if isComputerUse {
|
||||
desktopGeometry := workspacesdk.DefaultDesktopGeometry()
|
||||
providerTools = append(providerTools, chatloop.ProviderTool{
|
||||
Definition: chattool.ComputerUseProviderTool(
|
||||
workspacesdk.DesktopDisplayWidth,
|
||||
workspacesdk.DesktopDisplayHeight),
|
||||
desktopGeometry.DeclaredWidth,
|
||||
desktopGeometry.DeclaredHeight,
|
||||
),
|
||||
Runner: chattool.NewComputerUseTool(
|
||||
workspacesdk.DesktopDisplayWidth,
|
||||
workspacesdk.DesktopDisplayHeight,
|
||||
workspaceCtx.getWorkspaceConn, quartz.NewReal(),
|
||||
desktopGeometry.DeclaredWidth,
|
||||
desktopGeometry.DeclaredHeight,
|
||||
workspaceCtx.getWorkspaceConn,
|
||||
quartz.NewReal(),
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
providerOptions := chatprovider.ProviderOptionsFromChatModelConfig(
|
||||
model,
|
||||
callConfig.ProviderOptions,
|
||||
)
|
||||
// When the OpenAI Responses API has store=true, the provider
|
||||
// retains conversation history server-side. For follow-up turns,
|
||||
// we set previous_response_id and send only system instructions
|
||||
// plus the new user input, avoiding redundant replay of prior
|
||||
// assistant and tool messages that the provider already has.
|
||||
chainModeActive := chatprovider.IsResponsesStoreEnabled(providerOptions) &&
|
||||
chainInfo.previousResponseID != "" &&
|
||||
chainInfo.trailingUserCount > 0 &&
|
||||
chainInfo.modelConfigID == modelConfig.ID
|
||||
if chainModeActive {
|
||||
providerOptions = chatprovider.CloneWithPreviousResponseID(
|
||||
providerOptions,
|
||||
chainInfo.previousResponseID,
|
||||
)
|
||||
prompt = filterPromptForChainMode(prompt, chainInfo.trailingUserCount)
|
||||
}
|
||||
|
||||
err = chatloop.Run(ctx, chatloop.RunOptions{
|
||||
Model: model,
|
||||
Messages: prompt,
|
||||
Tools: tools, MaxSteps: maxChatSteps,
|
||||
|
||||
ModelConfig: callConfig,
|
||||
ProviderOptions: chatprovider.ProviderOptionsFromChatModelConfig(model, callConfig.ProviderOptions),
|
||||
ProviderOptions: providerOptions,
|
||||
ProviderTools: providerTools,
|
||||
|
||||
ContextLimitFallback: modelConfigContextLimit,
|
||||
@@ -3337,8 +3501,17 @@ func (p *Server) runChat(
|
||||
if reloadUserPrompt != "" {
|
||||
reloadedPrompt = chatprompt.InsertSystem(reloadedPrompt, reloadUserPrompt)
|
||||
}
|
||||
if chainModeActive {
|
||||
reloadedPrompt = filterPromptForChainMode(
|
||||
reloadedPrompt,
|
||||
chainInfo.trailingUserCount,
|
||||
)
|
||||
}
|
||||
return reloadedPrompt, nil
|
||||
},
|
||||
DisableChainMode: func() {
|
||||
chainModeActive = false
|
||||
},
|
||||
|
||||
OnRetry: func(attempt int, retryErr error, delay time.Duration) {
|
||||
if val, ok := p.chatStreams.Load(chat.ID); ok {
|
||||
@@ -3715,6 +3888,34 @@ func (p *Server) resolveInstructions(
|
||||
return instruction
|
||||
}
|
||||
|
||||
// resolveUserCompactionThreshold looks up the user's per-model
|
||||
// compaction threshold override. Returns the override value and
|
||||
// true if one exists and is valid, or 0 and false otherwise.
|
||||
func (p *Server) resolveUserCompactionThreshold(ctx context.Context, userID uuid.UUID, modelConfigID uuid.UUID) (int32, bool) {
|
||||
raw, err := p.db.GetUserChatCompactionThreshold(ctx, database.GetUserChatCompactionThresholdParams{
|
||||
UserID: userID,
|
||||
Key: codersdk.CompactionThresholdKey(modelConfigID),
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, false
|
||||
}
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to fetch compaction threshold override",
|
||||
slog.F("user_id", userID),
|
||||
slog.F("model_config_id", modelConfigID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return 0, false
|
||||
}
|
||||
// Range 0..100 must stay in sync with handler validation in
|
||||
// coderd/chats.go.
|
||||
val, err := strconv.ParseInt(raw, 10, 32)
|
||||
if err != nil || val < 0 || val > 100 {
|
||||
return 0, false
|
||||
}
|
||||
return int32(val), true
|
||||
}
|
||||
|
||||
// resolveUserPrompt fetches the user's custom chat prompt from the
|
||||
// database and wraps it in <user-instructions> tags. Returns empty
|
||||
// string if no prompt is set.
|
||||
@@ -2,6 +2,7 @@ package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -606,6 +607,85 @@ func TestPublishToStream_DropWarnRateLimiting(t *testing.T) {
|
||||
requireFieldValue(t, subWarn[2], "dropped_count", int64(1))
|
||||
}
|
||||
|
||||
func TestResolveUserCompactionThreshold(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userID := uuid.New()
|
||||
modelConfigID := uuid.New()
|
||||
expectedKey := codersdk.CompactionThresholdKey(modelConfigID)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dbReturn string
|
||||
dbErr error
|
||||
wantVal int32
|
||||
wantOK bool
|
||||
wantWarnLog bool
|
||||
}{
|
||||
{
|
||||
name: "NoRowsReturnsDefault",
|
||||
dbErr: sql.ErrNoRows,
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "ValidOverride",
|
||||
dbReturn: "75",
|
||||
wantVal: 75,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "OutOfRangeValue",
|
||||
dbReturn: "101",
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "NonIntegerValue",
|
||||
dbReturn: "abc",
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "UnexpectedDBError",
|
||||
dbErr: xerrors.New("connection refused"),
|
||||
wantOK: false,
|
||||
wantWarnLog: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockDB := dbmock.NewMockStore(ctrl)
|
||||
sink := testutil.NewFakeSink(t)
|
||||
|
||||
srv := &Server{
|
||||
db: mockDB,
|
||||
logger: sink.Logger(),
|
||||
}
|
||||
|
||||
mockDB.EXPECT().GetUserChatCompactionThreshold(gomock.Any(), database.GetUserChatCompactionThresholdParams{
|
||||
UserID: userID,
|
||||
Key: expectedKey,
|
||||
}).Return(tc.dbReturn, tc.dbErr)
|
||||
|
||||
val, ok := srv.resolveUserCompactionThreshold(context.Background(), userID, modelConfigID)
|
||||
require.Equal(t, tc.wantVal, val)
|
||||
require.Equal(t, tc.wantOK, ok)
|
||||
|
||||
warns := sink.Entries(func(e slog.SinkEntry) bool {
|
||||
return e.Level == slog.LevelWarn
|
||||
})
|
||||
if tc.wantWarnLog {
|
||||
require.NotEmpty(t, warns, "expected a warning log entry")
|
||||
return
|
||||
}
|
||||
require.Empty(t, warns, "unexpected warning log entry")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// requireFieldValue asserts that a SinkEntry contains a field with
|
||||
// the given name and value.
|
||||
func requireFieldValue(t *testing.T, entry slog.SinkEntry, name string, expected interface{}) {
|
||||
@@ -26,10 +26,6 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattest"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
@@ -41,6 +37,10 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
@@ -111,6 +111,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
||||
IncludeProvisionerDaemon: true,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
agentToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
@@ -161,7 +162,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
||||
)
|
||||
})
|
||||
|
||||
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
APIKey: "test-api-key",
|
||||
BaseURL: openAIURL,
|
||||
@@ -170,7 +171,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
@@ -179,7 +180,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a root chat whose first model call will spawn a subagent.
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
@@ -193,7 +194,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
||||
// The root chat finishes first, then the chatd server
|
||||
// picks up and runs the child (subagent) chat.
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := client.GetChat(ctx, chat.ID)
|
||||
got, getErr := expClient.GetChat(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
@@ -217,7 +218,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
||||
require.GreaterOrEqual(t, len(recorded), 2,
|
||||
"expected at least 2 streamed LLM calls (root + subagent)")
|
||||
|
||||
workspaceTools := []string{"list_templates", "read_template", "create_workspace"}
|
||||
workspaceTools := []string{"propose_plan", "list_templates", "read_template", "create_workspace"}
|
||||
subagentTools := []string{"spawn_agent", "wait_agent", "message_agent", "close_agent"}
|
||||
|
||||
// Identify root and subagent calls. Root chat calls include
|
||||
@@ -901,15 +902,32 @@ func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
|
||||
acquireTrap := clock.Trap().NewTicker("chatd", "acquire")
|
||||
defer acquireTrap.Close()
|
||||
|
||||
assertPendingWithoutQueuedMessages := func(chatID uuid.UUID) {
|
||||
t.Helper()
|
||||
|
||||
queued, dbErr := db.GetChatQueuedMessages(ctx, chatID)
|
||||
require.NoError(t, dbErr)
|
||||
require.Empty(t, queued)
|
||||
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chatID)
|
||||
require.NoError(t, dbErr)
|
||||
require.Equal(t, database.ChatStatusPending, fromDB.Status)
|
||||
require.False(t, fromDB.WorkerID.Valid)
|
||||
}
|
||||
|
||||
streamStarted := make(chan struct{})
|
||||
interrupted := make(chan struct{})
|
||||
secondRequestStarted := make(chan struct{})
|
||||
thirdRequestStarted := make(chan struct{})
|
||||
allowFinish := make(chan struct{})
|
||||
var requestCount atomic.Int32
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("title")
|
||||
}
|
||||
if requestCount.Add(1) == 1 {
|
||||
|
||||
switch requestCount.Add(1) {
|
||||
case 1:
|
||||
chunks := make(chan chattest.OpenAIChunk, 1)
|
||||
go func() {
|
||||
defer close(chunks)
|
||||
@@ -928,7 +946,12 @@ func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
|
||||
<-allowFinish
|
||||
}()
|
||||
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
||||
case 2:
|
||||
close(secondRequestStarted)
|
||||
case 3:
|
||||
close(thirdRequestStarted)
|
||||
}
|
||||
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("done")...,
|
||||
)
|
||||
@@ -953,15 +976,7 @@ func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
clock.Advance(acquireInterval).MustWait(ctx)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-streamStarted:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
testutil.TryReceive(ctx, t, streamStarted)
|
||||
|
||||
queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
@@ -972,29 +987,11 @@ func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
|
||||
require.True(t, queuedResult.Queued)
|
||||
require.NotNil(t, queuedResult.QueuedMessage)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-interrupted:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
testutil.TryReceive(ctx, t, interrupted)
|
||||
|
||||
close(allowFinish)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
queued, dbErr := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
if dbErr != nil || len(queued) != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
return fromDB.Status == database.ChatStatusPending && !fromDB.WorkerID.Valid
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
chatd.WaitUntilIdleForTest(server)
|
||||
assertPendingWithoutQueuedMessages(chat.ID)
|
||||
|
||||
// Keep the acquire loop frozen here so "queued" stays pending.
|
||||
// That makes the later send queue because the chat is still busy,
|
||||
@@ -1045,63 +1042,41 @@ func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
clock.Advance(acquireInterval).MustWait(ctx)
|
||||
require.Eventually(t, func() bool {
|
||||
return requestCount.Load() >= 2
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
queued, dbErr := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
if dbErr != nil || len(queued) != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
return fromDB.Status == database.ChatStatusPending && !fromDB.WorkerID.Valid
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
testutil.TryReceive(ctx, t, secondRequestStarted)
|
||||
chatd.WaitUntilIdleForTest(server)
|
||||
assertPendingWithoutQueuedMessages(chat.ID)
|
||||
|
||||
clock.Advance(acquireInterval).MustWait(ctx)
|
||||
testutil.TryReceive(ctx, t, thirdRequestStarted)
|
||||
chatd.WaitUntilIdleForTest(server)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
queued, dbErr := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
if dbErr != nil || len(queued) != 0 {
|
||||
return false
|
||||
}
|
||||
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, queued)
|
||||
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil || fromDB.Status != database.ChatStatusWaiting {
|
||||
return false
|
||||
}
|
||||
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
|
||||
require.False(t, fromDB.WorkerID.Valid)
|
||||
|
||||
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
userTexts := make([]string, 0, 3)
|
||||
for _, message := range messages {
|
||||
if message.Role != database.ChatMessageRoleUser {
|
||||
continue
|
||||
}
|
||||
sdkMessage := db2sdk.ChatMessage(message)
|
||||
if len(sdkMessage.Content) != 1 {
|
||||
continue
|
||||
}
|
||||
userTexts = append(userTexts, sdkMessage.Content[0].Text)
|
||||
userTexts := make([]string, 0, 3)
|
||||
for _, message := range messages {
|
||||
if message.Role != database.ChatMessageRoleUser {
|
||||
continue
|
||||
}
|
||||
if len(userTexts) != 3 {
|
||||
return false
|
||||
sdkMessage := db2sdk.ChatMessage(message)
|
||||
if len(sdkMessage.Content) != 1 {
|
||||
continue
|
||||
}
|
||||
return requestCount.Load() >= 3 &&
|
||||
userTexts[0] == "hello" &&
|
||||
userTexts[1] == "queued" &&
|
||||
userTexts[2] == "later queued"
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
userTexts = append(userTexts, sdkMessage.Content[0].Text)
|
||||
}
|
||||
require.Equal(t, []string{"hello", "queued", "later queued"}, userTexts)
|
||||
}
|
||||
|
||||
func TestEditMessageRejectsWhenUsageLimitReached(t *testing.T) {
|
||||
@@ -1844,6 +1819,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
IncludeProvisionerDaemon: true,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
agentToken := uuid.NewString()
|
||||
// Add a startup script so the agent spends time in the
|
||||
@@ -1898,7 +1874,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
)
|
||||
})
|
||||
|
||||
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
APIKey: "test-api-key",
|
||||
BaseURL: openAIURL,
|
||||
@@ -1907,7 +1883,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
@@ -1915,7 +1891,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
@@ -1927,7 +1903,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
|
||||
var chatResult codersdk.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := client.GetChat(ctx, chat.ID)
|
||||
got, getErr := expClient.GetChat(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
@@ -1949,7 +1925,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, workspaceName, workspace.Name)
|
||||
|
||||
chatMsgs, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
chatMsgs, err := expClient.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
var foundCreateWorkspaceResult bool
|
||||
@@ -2023,6 +1999,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
IncludeProvisionerDaemon: true,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
@@ -2067,7 +2044,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
)
|
||||
})
|
||||
|
||||
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
APIKey: "test-api-key",
|
||||
BaseURL: openAIURL,
|
||||
@@ -2076,7 +2053,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
@@ -2085,7 +2062,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a chat with the stopped workspace pre-associated.
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
@@ -2098,7 +2075,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
|
||||
var chatResult codersdk.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := client.GetChat(ctx, chat.ID)
|
||||
got, getErr := expClient.GetChat(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
@@ -2120,7 +2097,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.WorkspaceTransitionStart, updatedWorkspace.LatestBuild.Transition)
|
||||
|
||||
chatMsgs, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
chatMsgs, err := expClient.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify start_workspace tool result exists in the chat messages.
|
||||
@@ -13,11 +13,12 @@ import (
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
fantasyopenai "charm.land/fantasy/providers/openai"
|
||||
"charm.land/fantasy/schema"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatretry"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -39,9 +40,10 @@ var ErrInterrupted = xerrors.New("chat interrupted")
|
||||
// persistence layer is responsible for splitting these into
|
||||
// separate database messages by role.
|
||||
type PersistedStep struct {
|
||||
Content []fantasy.Content
|
||||
Usage fantasy.Usage
|
||||
ContextLimit sql.NullInt64
|
||||
Content []fantasy.Content
|
||||
Usage fantasy.Usage
|
||||
ContextLimit sql.NullInt64
|
||||
ProviderResponseID string
|
||||
// Runtime is the wall-clock duration of this step,
|
||||
// covering LLM streaming, tool execution, and retries.
|
||||
// Zero indicates the duration was not measured (e.g.
|
||||
@@ -80,8 +82,9 @@ type RunOptions struct {
|
||||
role codersdk.ChatMessageRole,
|
||||
part codersdk.ChatMessagePart,
|
||||
)
|
||||
Compaction *CompactionOptions
|
||||
ReloadMessages func(context.Context) ([]fantasy.Message, error)
|
||||
Compaction *CompactionOptions
|
||||
ReloadMessages func(context.Context) ([]fantasy.Message, error)
|
||||
DisableChainMode func()
|
||||
|
||||
// OnRetry is called before each retry attempt when the LLM
|
||||
// stream fails with a retryable error. It provides the attempt
|
||||
@@ -245,6 +248,18 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
messages := opts.Messages
|
||||
var lastUsage fantasy.Usage
|
||||
var lastProviderMetadata fantasy.ProviderMetadata
|
||||
needsFullHistoryReload := false
|
||||
reloadFullHistory := func(stage string) error {
|
||||
if opts.ReloadMessages == nil {
|
||||
return nil
|
||||
}
|
||||
reloaded, err := opts.ReloadMessages(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("reload messages %s: %w", stage, err)
|
||||
}
|
||||
messages = reloaded
|
||||
return nil
|
||||
}
|
||||
|
||||
totalSteps := 0
|
||||
// When totalSteps reaches MaxSteps the inner loop exits immediately
|
||||
@@ -368,10 +383,11 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
// check and here, fall back to the interrupt-safe
|
||||
// path so partial content is not lost.
|
||||
if err := opts.PersistStep(ctx, PersistedStep{
|
||||
Content: result.content,
|
||||
Usage: result.usage,
|
||||
ContextLimit: contextLimit,
|
||||
Runtime: time.Since(stepStart),
|
||||
Content: result.content,
|
||||
Usage: result.usage,
|
||||
ContextLimit: contextLimit,
|
||||
ProviderResponseID: extractOpenAIResponseIDIfStored(opts.ProviderOptions, result.providerMetadata),
|
||||
Runtime: time.Since(stepStart),
|
||||
}); err != nil {
|
||||
if errors.Is(err, ErrInterrupted) {
|
||||
persistInterruptedStep(ctx, opts, &result)
|
||||
@@ -382,14 +398,41 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
lastUsage = result.usage
|
||||
lastProviderMetadata = result.providerMetadata
|
||||
|
||||
// Append the step's response messages so that both
|
||||
// inline and post-loop compaction see the full
|
||||
// conversation including the latest assistant reply.
|
||||
// When chain mode is active (PreviousResponseID set), exit
|
||||
// it after persisting the first chained step. Continuation
|
||||
// steps include tool-result messages, which fantasy rejects
|
||||
// when previous_response_id is set, so we must leave chain
|
||||
// mode and reload the full history before the next call.
|
||||
stepMessages := result.toResponseMessages()
|
||||
messages = append(messages, stepMessages...)
|
||||
if hasPreviousResponseID(opts.ProviderOptions) {
|
||||
clearPreviousResponseID(opts.ProviderOptions)
|
||||
if opts.DisableChainMode != nil {
|
||||
opts.DisableChainMode()
|
||||
}
|
||||
switch {
|
||||
case opts.ReloadMessages != nil:
|
||||
if err := reloadFullHistory("after chain mode exit"); err != nil {
|
||||
return err
|
||||
}
|
||||
needsFullHistoryReload = false
|
||||
default:
|
||||
messages = append(messages, stepMessages...)
|
||||
needsFullHistoryReload = false
|
||||
}
|
||||
} else {
|
||||
messages = append(messages, stepMessages...)
|
||||
}
|
||||
|
||||
if needsFullHistoryReload && !result.shouldContinue &&
|
||||
opts.ReloadMessages != nil {
|
||||
if err := reloadFullHistory("before final compaction after chain mode exit"); err != nil {
|
||||
return err
|
||||
}
|
||||
needsFullHistoryReload = false
|
||||
}
|
||||
|
||||
// Inline compaction.
|
||||
if opts.Compaction != nil && opts.ReloadMessages != nil {
|
||||
if !needsFullHistoryReload && opts.Compaction != nil && opts.ReloadMessages != nil {
|
||||
did, compactErr := tryCompact(
|
||||
ctx,
|
||||
opts.Model,
|
||||
@@ -405,14 +448,11 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
if did {
|
||||
alreadyCompacted = true
|
||||
compactedOnFinalStep = true
|
||||
reloaded, reloadErr := opts.ReloadMessages(ctx)
|
||||
if reloadErr != nil {
|
||||
return xerrors.Errorf("reload messages after compaction: %w", reloadErr)
|
||||
if err := reloadFullHistory("after compaction"); err != nil {
|
||||
return err
|
||||
}
|
||||
messages = reloaded
|
||||
}
|
||||
}
|
||||
|
||||
if !result.shouldContinue {
|
||||
stoppedByModel = true
|
||||
break
|
||||
@@ -423,9 +463,16 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
compactedOnFinalStep = false
|
||||
}
|
||||
|
||||
if needsFullHistoryReload && stoppedByModel && opts.ReloadMessages != nil {
|
||||
if err := reloadFullHistory("before post-run compaction after chain mode exit"); err != nil {
|
||||
return err
|
||||
}
|
||||
needsFullHistoryReload = false
|
||||
}
|
||||
|
||||
// Post-run compaction safety net: if we never compacted
|
||||
// during the loop, try once at the end.
|
||||
if !alreadyCompacted && opts.Compaction != nil && opts.ReloadMessages != nil {
|
||||
if !needsFullHistoryReload && !alreadyCompacted && opts.Compaction != nil && opts.ReloadMessages != nil {
|
||||
did, err := tryCompact(
|
||||
ctx,
|
||||
opts.Model,
|
||||
@@ -973,6 +1020,85 @@ func addAnthropicPromptCaching(messages []fantasy.Message) {
|
||||
}
|
||||
}
|
||||
|
||||
// hasPreviousResponseID checks whether the provider options contain
|
||||
// an OpenAI Responses entry with a non-empty PreviousResponseID.
|
||||
func hasPreviousResponseID(providerOptions fantasy.ProviderOptions) bool {
|
||||
if providerOptions == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, entry := range providerOptions {
|
||||
if options, ok := entry.(*fantasyopenai.ResponsesProviderOptions); ok {
|
||||
return options.PreviousResponseID != nil &&
|
||||
*options.PreviousResponseID != ""
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// clearPreviousResponseID removes PreviousResponseID from the OpenAI
|
||||
// Responses provider options entry, if present.
|
||||
func clearPreviousResponseID(providerOptions fantasy.ProviderOptions) {
|
||||
if providerOptions == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, entry := range providerOptions {
|
||||
if options, ok := entry.(*fantasyopenai.ResponsesProviderOptions); ok {
|
||||
options.PreviousResponseID = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractOpenAIResponseID extracts the OpenAI Responses API response
|
||||
// ID from provider metadata. Returns an empty string if no OpenAI
|
||||
// Responses metadata is present.
|
||||
func extractOpenAIResponseID(metadata fantasy.ProviderMetadata) string {
|
||||
if len(metadata) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, entry := range metadata {
|
||||
if providerMetadata, ok := entry.(*fantasyopenai.ResponsesProviderMetadata); ok && providerMetadata != nil {
|
||||
return providerMetadata.ResponseID
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractOpenAIResponseIDIfStored returns the OpenAI response ID
|
||||
// only when the provider options indicate store=true. Response IDs
|
||||
// from store=false turns are not persisted server-side and cannot
|
||||
// be used for chaining.
|
||||
func extractOpenAIResponseIDIfStored(
|
||||
providerOptions fantasy.ProviderOptions,
|
||||
metadata fantasy.ProviderMetadata,
|
||||
) string {
|
||||
if !isResponsesStoreEnabled(providerOptions) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return extractOpenAIResponseID(metadata)
|
||||
}
|
||||
|
||||
// isResponsesStoreEnabled checks whether the OpenAI Responses
|
||||
// provider options explicitly enable store=true.
|
||||
func isResponsesStoreEnabled(providerOptions fantasy.ProviderOptions) bool {
|
||||
if providerOptions == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, entry := range providerOptions {
|
||||
if options, ok := entry.(*fantasyopenai.ResponsesProviderOptions); ok {
|
||||
return options.Store != nil && *options.Store
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func extractContextLimit(metadata fantasy.ProviderMetadata) sql.NullInt64 {
|
||||
if len(metadata) == 0 {
|
||||
return sql.NullInt64{}
|
||||
+1
-1
@@ -14,11 +14,11 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
+40
@@ -1063,6 +1063,46 @@ func ProviderOptionsFromChatModelConfig(
|
||||
return result
|
||||
}
|
||||
|
||||
// IsResponsesStoreEnabled checks if the OpenAI Responses provider
|
||||
// options are present and have Store set to true. When true, the
|
||||
// provider stores conversation history server-side, enabling
|
||||
// follow-up chaining via PreviousResponseID.
|
||||
func IsResponsesStoreEnabled(opts fantasy.ProviderOptions) bool {
|
||||
if opts == nil {
|
||||
return false
|
||||
}
|
||||
raw, ok := opts[fantasyopenai.Name]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
respOpts, ok := raw.(*fantasyopenai.ResponsesProviderOptions)
|
||||
if !ok || respOpts == nil {
|
||||
return false
|
||||
}
|
||||
return respOpts.Store != nil && *respOpts.Store
|
||||
}
|
||||
|
||||
// CloneWithPreviousResponseID shallow-clones the provider options
|
||||
// map and the OpenAI Responses entry, setting PreviousResponseID
|
||||
// on the clone. The original map and entry are not mutated.
|
||||
func CloneWithPreviousResponseID(
|
||||
opts fantasy.ProviderOptions,
|
||||
previousResponseID string,
|
||||
) fantasy.ProviderOptions {
|
||||
cloned := make(fantasy.ProviderOptions, len(opts))
|
||||
for k, v := range opts {
|
||||
cloned[k] = v
|
||||
}
|
||||
if raw, ok := cloned[fantasyopenai.Name]; ok {
|
||||
if respOpts, ok := raw.(*fantasyopenai.ResponsesProviderOptions); ok && respOpts != nil {
|
||||
clone := *respOpts
|
||||
clone.PreviousResponseID = &previousResponseID
|
||||
cloned[fantasyopenai.Name] = &clone
|
||||
}
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func openAIProviderOptionsFromChatConfig(
|
||||
model fantasy.LanguageModel,
|
||||
options *codersdk.ChatModelOpenAIProviderOptions,
|
||||
+1
-1
@@ -9,8 +9,8 @@ import (
|
||||
fantasyvercel "charm.land/fantasy/providers/vercel"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
+2
-2
@@ -12,8 +12,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattest"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
)
|
||||
|
||||
func TestUserAgent(t *testing.T) {
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user