Compare commits

..

9 Commits

Author SHA1 Message Date
Paweł Banaszewski d15a6be1c7 review: documentation nit 2026-02-17 11:00:51 +00:00
Paweł Banaszewski fe9202ab6f docs 2026-02-17 09:43:08 +00:00
Paweł Banaszewski e36944f9cc feat: add --client flag to list interceptions command 2026-02-17 09:43:08 +00:00
Paweł Banaszewski 670801f310 review: metadata key 2026-02-17 09:43:08 +00:00
Paweł Banaszewski 3a826bcbb9 feat: record user agent in AI Bridge interceptions 2026-02-17 09:43:08 +00:00
Paweł Banaszewski 42b467acf7 migration number fix + telemetry test fix 2026-02-17 09:43:08 +00:00
Paweł Banaszewski 98bc5f04fb migration fix 2026-02-17 09:43:08 +00:00
Paweł Banaszewski 02aeef5a91 review: dafault + unknown -> Unknown 2026-02-17 09:43:08 +00:00
Paweł Banaszewski 1cbc8929ab feat: add client columns to aibridge_interceptions
Also adds interception filtering by client.
2026-02-17 09:43:08 +00:00
626 changed files with 6905 additions and 55324 deletions
+1 -1
View File
@@ -4,7 +4,7 @@ description: |
inputs:
version:
description: "The Go version to use."
default: "1.25.7"
default: "1.25.6"
use-preinstalled-go:
description: "Whether to use preinstalled Go."
default: "false"
+1 -1
View File
@@ -7,5 +7,5 @@ runs:
- name: Install Terraform
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # v3.1.2
with:
terraform_version: 1.14.5
terraform_version: 1.14.1
terraform_wrapper: false
-8
View File
@@ -489,14 +489,6 @@ jobs:
# macOS will output "The default interactive shell is now zsh" intermittently in CI.
touch ~/.bash_profile && echo "export BASH_SILENCE_DEPRECATION_WARNING=1" >> ~/.bash_profile
- name: Increase PTY limit (macOS)
if: runner.os == 'macOS'
shell: bash
run: |
# Increase PTY limit to avoid exhaustion during tests.
# Default is 511; 999 is the maximum value on CI runner.
sudo sysctl -w kern.tty.ptmx_max=999
- name: Test with PostgreSQL Database (Linux)
if: runner.os == 'Linux'
uses: ./.github/actions/test-go-pg
+1 -1
View File
@@ -146,7 +146,7 @@ jobs:
echo "image=$(cat "$image_job")" >> "$GITHUB_OUTPUT"
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@c1824fd6edce30d7ab345a9989de00bbd46ef284 # v0.34.0
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8
with:
image-ref: ${{ steps.build.outputs.image }}
format: sarif
+1 -1
View File
@@ -23,7 +23,7 @@ jobs:
egress-policy: audit
- name: stale
uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
with:
stale-issue-label: "stale"
stale-pr-label: "stale"
-3
View File
@@ -98,6 +98,3 @@ AGENTS.local.md
# Ignore plans written by AI agents.
PLAN.md
# Ignore any dev licenses
license.txt
+8 -16
View File
@@ -854,7 +854,7 @@ enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
# -C sets the directory for the go run command
go run -C ./scripts/apitypings main.go > $@
./scripts/biome_format.sh src/api/typesGenerated.ts
(cd site/ && pnpm exec biome format --write src/api/typesGenerated.ts)
touch "$@"
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
@@ -863,7 +863,7 @@ site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/prot
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*)
go run ./scripts/gensite/ -icons "$@"
./scripts/biome_format.sh src/theme/icons.json
(cd site/ && pnpm exec biome format --write src/theme/icons.json)
touch "$@"
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates)
@@ -901,12 +901,12 @@ codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scope
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
go run scripts/typegen/main.go rbac typescript > "$@"
./scripts/biome_format.sh src/api/rbacresourcesGenerated.ts
(cd site/ && pnpm exec biome format --write src/api/rbacresourcesGenerated.ts)
touch "$@"
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go
go run scripts/typegen/main.go countries > "$@"
./scripts/biome_format.sh src/api/countriesGenerated.ts
(cd site/ && pnpm exec biome format --write src/api/countriesGenerated.ts)
touch "$@"
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES)
@@ -950,11 +950,11 @@ coderd/apidoc/.gen: \
touch "$@"
docs/manifest.json: site/node_modules/.installed coderd/apidoc/.gen docs/reference/cli/index.md
./scripts/biome_format.sh ../docs/manifest.json
(cd site/ && pnpm exec biome format --write ../docs/manifest.json)
touch "$@"
coderd/apidoc/swagger.json: site/node_modules/.installed coderd/apidoc/.gen
./scripts/biome_format.sh ../coderd/apidoc/swagger.json
(cd site/ && pnpm exec biome format --write ../coderd/apidoc/swagger.json)
touch "$@"
update-golden-files:
@@ -999,19 +999,11 @@ enterprise/tailnet/testdata/.gen-golden: $(wildcard enterprise/tailnet/testdata/
touch "$@"
helm/coder/tests/testdata/.gen-golden: $(wildcard helm/coder/tests/testdata/*.yaml) $(wildcard helm/coder/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/coder/tests/*_test.go)
if command -v helm >/dev/null 2>&1; then
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
else
echo "WARNING: helm not found; skipping helm/coder golden generation" >&2
fi
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
touch "$@"
helm/provisioner/tests/testdata/.gen-golden: $(wildcard helm/provisioner/tests/testdata/*.yaml) $(wildcard helm/provisioner/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/provisioner/tests/*_test.go)
if command -v helm >/dev/null 2>&1; then
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
else
echo "WARNING: helm not found; skipping helm/provisioner golden generation" >&2
fi
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
touch "$@"
coderd/.gen-golden: $(wildcard coderd/testdata/*/*.golden) $(GO_SRC_FILES) $(wildcard coderd/*_test.go)
+2 -10
View File
@@ -111,12 +111,6 @@ type Client interface {
ConnectRPC28(ctx context.Context) (
proto.DRPCAgentClient28, tailnetproto.DRPCTailnetClient28, error,
)
// ConnectRPC28WithRole is like ConnectRPC28 but sends an explicit
// role query parameter to the server. The workspace agent should
// use role "agent" to enable connection monitoring.
ConnectRPC28WithRole(ctx context.Context, role string) (
proto.DRPCAgentClient28, tailnetproto.DRPCTailnetClient28, error,
)
tailnet.DERPMapRewriter
agentsdk.RefreshableSessionTokenProvider
}
@@ -1003,10 +997,8 @@ func (a *agent) run() (retErr error) {
return xerrors.Errorf("refresh token: %w", err)
}
// ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs.
// We pass role "agent" to enable connection monitoring on the server, which tracks
// the agent's connectivity state (first_connected_at, last_connected_at, disconnected_at).
aAPI, tAPI, err := a.client.ConnectRPC28WithRole(a.hardCtx, "agent")
// ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs
aAPI, tAPI, err := a.client.ConnectRPC28(a.hardCtx)
if err != nil {
return err
}
+103 -2
View File
@@ -1,22 +1,37 @@
package agentsocket_test
import (
"context"
"path/filepath"
"runtime"
"testing"
"github.com/google/uuid"
"github.com/spf13/afero"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agentsocket"
"github.com/coder/coder/v2/agent/agenttest"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
func TestServer(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("agentsocket is not supported on Windows")
}
t.Run("StartStop", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
server, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
require.NoError(t, err)
@@ -26,7 +41,7 @@ func TestServer(t *testing.T) {
t.Run("AlreadyStarted", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
server1, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
require.NoError(t, err)
@@ -34,4 +49,90 @@ func TestServer(t *testing.T) {
_, err = agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
require.ErrorContains(t, err, "create socket")
})
t.Run("AutoSocketPath", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
server, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
require.NoError(t, err)
require.NoError(t, server.Close())
})
}
func TestServerWindowsNotSupported(t *testing.T) {
t.Parallel()
if runtime.GOOS != "windows" {
t.Skip("this test only runs on Windows")
}
t.Run("NewServer", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
_, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
require.ErrorContains(t, err, "agentsocket is not supported on Windows")
})
t.Run("NewClient", func(t *testing.T) {
t.Parallel()
_, err := agentsocket.NewClient(context.Background(), agentsocket.WithPath("test.sock"))
require.ErrorContains(t, err, "agentsocket is not supported on Windows")
})
}
func TestAgentInitializesOnWindowsWithoutSocketServer(t *testing.T) {
t.Parallel()
if runtime.GOOS != "windows" {
t.Skip("this test only runs on Windows")
}
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t).Named("agent")
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
coordinator := tailnet.NewCoordinator(logger)
t.Cleanup(func() {
_ = coordinator.Close()
})
statsCh := make(chan *agentproto.Stats, 50)
agentID := uuid.New()
manifest := agentsdk.Manifest{
AgentID: agentID,
AgentName: "test-agent",
WorkspaceName: "test-workspace",
OwnerName: "test-user",
WorkspaceID: uuid.New(),
DERPMap: derpMap,
}
client := agenttest.NewClient(t, logger.Named("agenttest"), agentID, manifest, statsCh, coordinator)
t.Cleanup(client.Close)
options := agent.Options{
Client: client,
Filesystem: afero.NewMemMapFs(),
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: testutil.WaitShort,
EnvironmentVariables: map[string]string{},
SocketPath: "",
}
agnt := agent.New(options)
t.Cleanup(func() {
_ = agnt.Close()
})
startup := testutil.TryReceive(ctx, t, client.GetStartup())
require.NotNil(t, startup, "agent should send startup message")
err := agnt.Close()
require.NoError(t, err, "agent should close cleanly")
}
+17 -11
View File
@@ -2,6 +2,8 @@ package agentsocket_test
import (
"context"
"path/filepath"
"runtime"
"testing"
"github.com/stretchr/testify/require"
@@ -28,10 +30,14 @@ func newSocketClient(ctx context.Context, t *testing.T, socketPath string) *agen
func TestDRPCAgentSocketService(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("agentsocket is not supported on Windows")
}
t.Run("Ping", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -51,7 +57,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("NewUnit", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -73,7 +79,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnitAlreadyStarted", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -103,7 +109,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnitAlreadyCompleted", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -142,7 +148,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnitNotReady", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -172,7 +178,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("NewUnits", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -197,7 +203,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("DependencyAlreadyRegistered", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -232,7 +238,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("DependencyAddedAfterDependentStarted", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -274,7 +280,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnregisteredUnit", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -293,7 +299,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnitNotReady", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -317,7 +323,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnitReady", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
+6 -47
View File
@@ -4,60 +4,19 @@ package agentsocket
import (
"context"
"fmt"
"net"
"os"
"os/user"
"strings"
"github.com/Microsoft/go-winio"
"golang.org/x/xerrors"
)
const defaultSocketPath = `\\.\pipe\com.coder.agentsocket`
func createSocket(path string) (net.Listener, error) {
if path == "" {
path = defaultSocketPath
}
if !strings.HasPrefix(path, `\\.\pipe\`) {
return nil, xerrors.Errorf("%q is not a valid local socket path", path)
}
user, err := user.Current()
if err != nil {
return nil, fmt.Errorf("unable to look up current user: %w", err)
}
sid := user.Uid
// SecurityDescriptor is in SDDL format. c.f.
// https://learn.microsoft.com/en-us/windows/win32/secauthz/security-descriptor-string-format for full details.
// D: indicates this is a Discretionary Access Control List (DACL), which is Windows-speak for ACLs that allow or
// deny access (as opposed to SACL which controls audit logging).
// P indicates that this DACL is "protected" from being modified thru inheritance
// () delimit access control entries (ACEs), here we only have one, which, allows (A) generic all (GA) access to our
// specific user's security ID (SID).
//
// Note that although Microsoft docs at https://learn.microsoft.com/en-us/windows/win32/ipc/named-pipes warns that
// named pipes are accessible from remote machines in the general case, the `winio` package sets the flag
// windows.FILE_PIPE_REJECT_REMOTE_CLIENTS when creating pipes, so connections from remote machines are always
// denied. This is important because we sort of expect customers to run the Coder agent under a generic user
// account unless they are very sophisticated. We don't want this socket to cross the boundary of the local machine.
configuration := &winio.PipeConfig{
SecurityDescriptor: fmt.Sprintf("D:P(A;;GA;;;%s)", sid),
}
listener, err := winio.ListenPipe(path, configuration)
if err != nil {
return nil, xerrors.Errorf("failed to open named pipe: %w", err)
}
return listener, nil
func createSocket(_ string) (net.Listener, error) {
return nil, xerrors.New("agentsocket is not supported on Windows")
}
func cleanupSocket(path string) error {
return os.Remove(path)
func cleanupSocket(_ string) error {
return nil
}
func dialSocket(ctx context.Context, path string) (net.Conn, error) {
return winio.DialPipeContext(ctx, path)
func dialSocket(_ context.Context, _ string) (net.Conn, error) {
return nil, xerrors.New("agentsocket is not supported on Windows")
}
-10
View File
@@ -124,12 +124,6 @@ func (c *Client) Close() {
c.derpMapOnce.Do(func() { close(c.derpMapUpdates) })
}
func (c *Client) ConnectRPC28WithRole(ctx context.Context, _ string) (
agentproto.DRPCAgentClient28, proto.DRPCTailnetClient28, error,
) {
return c.ConnectRPC28(ctx)
}
func (c *Client) ConnectRPC28(ctx context.Context) (
agentproto.DRPCAgentClient28, proto.DRPCTailnetClient28, error,
) {
@@ -235,10 +229,6 @@ type FakeAgentAPI struct {
pushResourcesMonitoringUsageFunc func(*agentproto.PushResourcesMonitoringUsageRequest) (*agentproto.PushResourcesMonitoringUsageResponse, error)
}
func (*FakeAgentAPI) UpdateAppStatus(context.Context, *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
panic("unimplemented")
}
func (f *FakeAgentAPI) GetManifest(context.Context, *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
return f.manifest, nil
}
+330 -544
View File
File diff suppressed because it is too large Load Diff
+1 -20
View File
@@ -436,7 +436,7 @@ message CreateSubAgentRequest {
}
repeated DisplayApp display_apps = 6;
optional bytes id = 7;
}
@@ -494,24 +494,6 @@ message ReportBoundaryLogsRequest {
message ReportBoundaryLogsResponse {}
// UpdateAppStatusRequest updates the given Workspace App's status. c.f. agentsdk.PatchAppStatus
message UpdateAppStatusRequest {
string slug = 1;
enum AppStatusState {
WORKING = 0;
IDLE = 1;
COMPLETE = 2;
FAILURE = 3;
}
AppStatusState state = 2;
string message = 3;
string uri = 4;
}
message UpdateAppStatusResponse {}
service Agent {
rpc GetManifest(GetManifestRequest) returns (Manifest);
rpc GetServiceBanner(GetServiceBannerRequest) returns (ServiceBanner);
@@ -530,5 +512,4 @@ service Agent {
rpc DeleteSubAgent(DeleteSubAgentRequest) returns (DeleteSubAgentResponse);
rpc ListSubAgents(ListSubAgentsRequest) returns (ListSubAgentsResponse);
rpc ReportBoundaryLogs(ReportBoundaryLogsRequest) returns (ReportBoundaryLogsResponse);
rpc UpdateAppStatus(UpdateAppStatusRequest) returns (UpdateAppStatusResponse);
}
+1 -41
View File
@@ -56,7 +56,6 @@ type DRPCAgentClient interface {
DeleteSubAgent(ctx context.Context, in *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
ListSubAgents(ctx context.Context, in *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
type drpcAgentClient struct {
@@ -222,15 +221,6 @@ func (c *drpcAgentClient) ReportBoundaryLogs(ctx context.Context, in *ReportBoun
return out, nil
}
func (c *drpcAgentClient) UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
out := new(UpdateAppStatusResponse)
err := c.cc.Invoke(ctx, "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
type DRPCAgentServer interface {
GetManifest(context.Context, *GetManifestRequest) (*Manifest, error)
GetServiceBanner(context.Context, *GetServiceBannerRequest) (*ServiceBanner, error)
@@ -249,7 +239,6 @@ type DRPCAgentServer interface {
DeleteSubAgent(context.Context, *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
ListSubAgents(context.Context, *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
ReportBoundaryLogs(context.Context, *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
type DRPCAgentUnimplementedServer struct{}
@@ -322,13 +311,9 @@ func (s *DRPCAgentUnimplementedServer) ReportBoundaryLogs(context.Context, *Repo
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentUnimplementedServer) UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
type DRPCAgentDescription struct{}
func (DRPCAgentDescription) NumMethods() int { return 18 }
func (DRPCAgentDescription) NumMethods() int { return 17 }
func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
switch n {
@@ -485,15 +470,6 @@ func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver,
in1.(*ReportBoundaryLogsRequest),
)
}, DRPCAgentServer.ReportBoundaryLogs, true
case 17:
return "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentServer).
UpdateAppStatus(
ctx,
in1.(*UpdateAppStatusRequest),
)
}, DRPCAgentServer.UpdateAppStatus, true
default:
return "", nil, nil, nil, false
}
@@ -774,19 +750,3 @@ func (x *drpcAgent_ReportBoundaryLogsStream) SendAndClose(m *ReportBoundaryLogsR
}
return x.CloseSend()
}
type DRPCAgent_UpdateAppStatusStream interface {
drpc.Stream
SendAndClose(*UpdateAppStatusResponse) error
}
type drpcAgent_UpdateAppStatusStream struct {
drpc.Stream
}
func (x *drpcAgent_UpdateAppStatusStream) SendAndClose(m *UpdateAppStatusResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_proto_agent_proto{}); err != nil {
return err
}
return x.CloseSend()
}
+3 -7
View File
@@ -73,13 +73,9 @@ type DRPCAgentClient27 interface {
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
}
// DRPCAgentClient28 is the Agent API at v2.8. It adds
// - a SubagentId field to the WorkspaceAgentDevcontainer message
// - an Id field to the CreateSubAgentRequest message.
// - UpdateAppStatus RPC.
//
// Compatible with Coder v2.31+
// DRPCAgentClient28 is the Agent API at v2.8. It adds a SubagentId field to the
// WorkspaceAgentDevcontainer message, and a Id field to the CreateSubAgentRequest
// message. Compatible with Coder v2.31+
type DRPCAgentClient28 interface {
DRPCAgentClient27
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
+45 -50
View File
@@ -10,7 +10,6 @@ import (
"path/filepath"
"slices"
"strings"
"time"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
@@ -24,7 +23,6 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/toolsdk"
"github.com/coder/retry"
"github.com/coder/serpent"
)
@@ -541,6 +539,7 @@ func (r *RootCmd) mcpServer() *serpent.Command {
defer cancel()
defer srv.queue.Close()
cliui.Infof(inv.Stderr, "Failed to watch screen events")
// Start the reporter, watcher, and server. These are all tied to the
// lifetime of the MCP server, which is itself tied to the lifetime of the
// AI agent.
@@ -614,51 +613,48 @@ func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation)
}
func (s *mcpServer) startWatcher(ctx context.Context, inv *serpent.Invocation) {
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
return
}
go func() {
for retrier := retry.New(time.Second, 30*time.Second); retrier.Wait(ctx); {
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
if err == nil {
retrier.Reset()
loop:
for {
select {
case <-ctx.Done():
for {
select {
case <-ctx.Done():
return
case event := <-eventsCh:
switch ev := event.(type) {
case agentapi.EventStatusChange:
// If the screen is stable, report idle.
state := codersdk.WorkspaceAppStatusStateWorking
if ev.Status == agentapi.StatusStable {
state = codersdk.WorkspaceAppStatusStateIdle
}
err := s.queue.Push(taskReport{
state: state,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
case event := <-eventsCh:
switch ev := event.(type) {
case agentapi.EventStatusChange:
state := codersdk.WorkspaceAppStatusStateWorking
if ev.Status == agentapi.StatusStable {
state = codersdk.WorkspaceAppStatusStateIdle
}
err := s.queue.Push(taskReport{
state: state,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
}
case agentapi.EventMessageUpdate:
if ev.Role == agentapi.RoleUser {
err := s.queue.Push(taskReport{
messageID: &ev.Id,
state: codersdk.WorkspaceAppStatusStateWorking,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
}
}
}
case agentapi.EventMessageUpdate:
if ev.Role == agentapi.RoleUser {
err := s.queue.Push(taskReport{
messageID: &ev.Id,
state: codersdk.WorkspaceAppStatusStateWorking,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
}
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
}
break loop
}
}
} else {
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
}
return
}
}
}()
@@ -696,14 +692,13 @@ func (s *mcpServer) startServer(ctx context.Context, inv *serpent.Invocation, in
// Add tool dependencies.
toolOpts := []func(*toolsdk.Deps){
toolsdk.WithTaskReporter(func(args toolsdk.ReportTaskArgs) error {
state := codersdk.WorkspaceAppStatusState(args.State)
// The agent does not reliably report idle, so when AgentAPI is
// enabled we override idle to working and let the screen watcher
// detect the real idle via StatusStable. Final states (failure,
// complete) are trusted from the agent since the screen watcher
// cannot produce them.
if s.aiAgentAPIClient != nil && state == codersdk.WorkspaceAppStatusStateIdle {
state = codersdk.WorkspaceAppStatusStateWorking
// The agent does not reliably report its status correctly. If AgentAPI
// is enabled, we will always set the status to "working" when we get an
// MCP message, and rely on the screen watcher to eventually catch the
// idle state.
state := codersdk.WorkspaceAppStatusStateWorking
if s.aiAgentAPIClient == nil {
state = codersdk.WorkspaceAppStatusState(args.State)
}
return s.queue.Push(taskReport{
link: args.Link,
+1 -185
View File
@@ -921,7 +921,7 @@ func TestExpMcpReporter(t *testing.T) {
},
},
},
// We override idle from the agent to working, but trust final states.
// We ignore the state from the agent and assume "working".
{
name: "IgnoreAgentState",
// AI agent reports that it is finished but the summary says it is doing
@@ -953,46 +953,6 @@ func TestExpMcpReporter(t *testing.T) {
Message: "finished",
},
},
// Agent reports failure; trusted even with AgentAPI enabled.
{
state: codersdk.WorkspaceAppStatusStateFailure,
summary: "something broke",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateFailure,
Message: "something broke",
},
},
// After failure, watcher reports stable -> idle.
{
event: makeStatusEvent(agentapi.StatusStable),
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateIdle,
Message: "something broke",
},
},
},
},
// Final states pass through with AgentAPI enabled.
{
name: "AllowFinalStates",
tests: []test{
{
state: codersdk.WorkspaceAppStatusStateWorking,
summary: "doing work",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateWorking,
Message: "doing work",
},
},
// Agent reports complete; not overridden.
{
state: codersdk.WorkspaceAppStatusStateComplete,
summary: "all done",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateComplete,
Message: "all done",
},
},
},
},
// When AgentAPI is not being used, we accept agent state updates as-is.
@@ -1150,148 +1110,4 @@ func TestExpMcpReporter(t *testing.T) {
<-cmdDone
})
}
t.Run("Reconnect", func(t *testing.T) {
t.Parallel()
// Create a test deployment and workspace.
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user2.ID,
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
a[0].Apps = []*proto.App{
{
Slug: "vscode",
},
}
return a
}).Do()
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitLong))
// Watch the workspace for changes.
watcher, err := client.WatchWorkspace(ctx, r.Workspace.ID)
require.NoError(t, err)
var lastAppStatus codersdk.WorkspaceAppStatus
nextUpdate := func() codersdk.WorkspaceAppStatus {
for {
select {
case <-ctx.Done():
require.FailNow(t, "timed out waiting for status update")
case w, ok := <-watcher:
require.True(t, ok, "watch channel closed")
if w.LatestAppStatus != nil && w.LatestAppStatus.ID != lastAppStatus.ID {
t.Logf("Got status update: %s > %s", lastAppStatus.State, w.LatestAppStatus.State)
lastAppStatus = *w.LatestAppStatus
return lastAppStatus
}
}
}
}
// Mock AI AgentAPI server that supports disconnect/reconnect.
disconnect := make(chan struct{})
listening := make(chan func(sse codersdk.ServerSentEvent) error)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Create a cancelable context so we can stop the SSE sender
// goroutine on disconnect without waiting for the HTTP
// serve loop to cancel r.Context().
sseCtx, sseCancel := context.WithCancel(r.Context())
defer sseCancel()
r = r.WithContext(sseCtx)
send, closed, err := httpapi.ServerSentEventSender(w, r)
if err != nil {
httpapi.Write(sseCtx, w, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error setting up server-sent events.",
Detail: err.Error(),
})
return
}
// Send initial message so the watcher knows the agent is active.
send(*makeMessageEvent(0, agentapi.RoleAgent))
select {
case listening <- send:
case <-r.Context().Done():
return
}
select {
case <-closed:
case <-disconnect:
sseCancel()
<-closed
}
}))
t.Cleanup(srv.Close)
inv, _ := clitest.New(t,
"exp", "mcp", "server",
"--agent-url", client.URL.String(),
"--agent-token", r.AgentToken,
"--app-status-slug", "vscode",
"--allowed-tools=coder_report_task",
"--ai-agentapi-url", srv.URL,
)
inv = inv.WithContext(ctx)
pty := ptytest.New(t)
inv.Stdin = pty.Input()
inv.Stdout = pty.Output()
stderr := ptytest.New(t)
inv.Stderr = stderr.Output()
// Run the MCP server.
clitest.Start(t, inv)
// Initialize.
payload := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
pty.WriteLine(payload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore init response
// Get first sender from the initial SSE connection.
sender := testutil.RequireReceive(ctx, t, listening)
// Self-report a working status via tool call.
toolPayload := `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"doing work","link":""}}}`
pty.WriteLine(toolPayload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore response
got := nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
require.Equal(t, "doing work", got.Message)
// Watcher sends stable, verify idle is reported.
err = sender(*makeStatusEvent(agentapi.StatusStable))
require.NoError(t, err)
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
// Disconnect the SSE connection by signaling the handler to return.
testutil.RequireSend(ctx, t, disconnect, struct{}{})
// Wait for the watcher to reconnect and get the new sender.
sender = testutil.RequireReceive(ctx, t, listening)
// After reconnect, self-report a working status again.
toolPayload = `{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"reconnected","link":""}}}`
pty.WriteLine(toolPayload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore response
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
require.Equal(t, "reconnected", got.Message)
// Verify the watcher still processes events after reconnect.
err = sender(*makeStatusEvent(agentapi.StatusStable))
require.NoError(t, err)
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
cancel()
})
}
-12
View File
@@ -29,7 +29,6 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
templateVersionJobTimeout time.Duration
prebuildWorkspaceTimeout time.Duration
noCleanup bool
provisionerTags []string
tracingFlags = &scaletestTracingFlags{}
timeoutStrategy = &timeoutFlags{}
@@ -112,16 +111,10 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
tags, err := ParseProvisionerTags(provisionerTags)
if err != nil {
return err
}
for i := range numTemplates {
id := strconv.Itoa(int(i))
cfg := prebuilds.Config{
OrganizationID: me.OrganizationIDs[0],
ProvisionerTags: tags,
NumPresets: int(numPresets),
NumPresetPrebuilds: int(numPresetPrebuilds),
TemplateVersionJobTimeout: templateVersionJobTimeout,
@@ -290,11 +283,6 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
Description: "Skip cleanup (deletion test) and leave resources intact.",
Value: serpent.BoolOf(&noCleanup),
},
{
Flag: "provisioner-tag",
Description: "Specify a set of tags to target provisioner daemons.",
Value: serpent.StringArrayOf(&provisionerTags),
},
}
tracingFlags.attach(&cmd.Options)
+1 -45
View File
@@ -4,9 +4,6 @@ import (
"errors"
"fmt"
"net/http"
"os"
"os/exec"
"strings"
"time"
"golang.org/x/xerrors"
@@ -19,29 +16,6 @@ import (
"github.com/coder/serpent"
)
// detectGitRef attempts to resolve the current git branch and remote
// origin URL from the given working directory. These are sent to the
// control plane so it can look up PR/diff status via the GitHub API
// without SSHing into the workspace. Failures are silently ignored
// since this is best-effort.
func detectGitRef(workingDirectory string) (branch string, remoteOrigin string) {
run := func(args ...string) string {
//nolint:gosec
cmd := exec.Command(args[0], args[1:]...)
if workingDirectory != "" {
cmd.Dir = workingDirectory
}
out, err := cmd.Output()
if err != nil {
return ""
}
return strings.TrimSpace(string(out))
}
branch = run("git", "rev-parse", "--abbrev-ref", "HEAD")
remoteOrigin = run("git", "config", "--get", "remote.origin.url")
return branch, remoteOrigin
}
// gitAskpass is used by the Coder agent to automatically authenticate
// with Git providers based on a hostname.
func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
@@ -64,20 +38,8 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
return xerrors.Errorf("create agent client: %w", err)
}
workingDirectory, err := os.Getwd()
if err != nil {
workingDirectory = ""
}
// Detect the current git branch and remote origin so
// the control plane can resolve diffs without needing
// to SSH back into the workspace.
gitBranch, gitRemoteOrigin := detectGitRef(workingDirectory)
token, err := client.ExternalAuth(ctx, agentsdk.ExternalAuthRequest{
Match: host,
GitBranch: gitBranch,
GitRemoteOrigin: gitRemoteOrigin,
Match: host,
})
if err != nil {
var apiError *codersdk.Error
@@ -96,12 +58,6 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
return xerrors.Errorf("get git token: %w", err)
}
if token.URL != "" {
// This is to help the agent authenticate with Git.
if inv.Environ.Get("CODER_CHAT_AGENT") == "true" {
_, _ = fmt.Fprintf(inv.Stderr, `You must notify the user to authenticate with Git.\n\nThe URL is: %s\n`, token.URL)
return cliui.ErrCanceled
}
if err := openURL(inv, token.URL); err == nil {
cliui.Infof(inv.Stderr, "Your browser has been opened to authenticate with Git:\n%s", token.URL)
} else {
-57
View File
@@ -1,12 +1,9 @@
package cli_test
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
@@ -111,58 +108,4 @@ func TestGitAskpass(t *testing.T) {
})
stdout.ExpectMatch("username")
})
t.Run("ChatAgentAuthRequired", func(t *testing.T) {
t.Parallel()
listenCalls := atomic.Int64{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Has("listen") {
listenCalls.Add(1)
}
httpapi.Write(context.Background(), w, http.StatusOK, agentsdk.ExternalAuthResponse{
URL: "https://coder.example.com/external-auth/github",
Type: codersdk.EnhancedExternalAuthProviderGitHub.String(),
})
}))
t.Cleanup(srv.Close)
inv, _ := clitest.New(
t,
"--agent-url",
srv.URL,
"--no-open",
"Username for 'https://github.com':",
)
inv.Environ.Set("GIT_PREFIX", "/")
inv.Environ.Set("CODER_AGENT_TOKEN", "fake-token")
inv.Environ.Set("CODER_CHAT_AGENT", "true")
var stderr bytes.Buffer
inv.Stderr = &stderr
err := inv.Run()
require.Error(t, err)
require.Contains(t, err.Error(), "exit code")
require.Zero(t, listenCalls.Load())
output := stderr.String()
require.Contains(t, output, "CODER_GITAUTH_REQUIRED:")
require.NotContains(t, output, "Open the following URL to authenticate")
require.NotContains(t, output, "Your browser has been opened")
_, markerRaw, found := strings.Cut(output, "CODER_GITAUTH_REQUIRED:")
require.True(t, found)
var marker struct {
ProviderID string `json:"provider_id"`
ProviderType string `json:"provider_type"`
ProviderDisplayName string `json:"provider_display_name"`
AuthenticateURL string `json:"authenticate_url"`
}
require.NoError(t, json.Unmarshal([]byte(strings.TrimSpace(markerRaw)), &marker))
require.Equal(t, "github", marker.ProviderID)
require.Equal(t, codersdk.EnhancedExternalAuthProviderGitHub.String(), marker.ProviderType)
require.Equal(t, "GitHub", marker.ProviderDisplayName)
require.Equal(t, "https://coder.example.com/external-auth/github", marker.AuthenticateURL)
})
}
+5 -1
View File
@@ -106,7 +106,11 @@ func TestList(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
memberClient, member = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
sharedWorkspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
+1 -1
View File
@@ -297,7 +297,7 @@ func (pr *ParameterResolver) verifyConstraints(resolved []codersdk.WorkspaceBuil
return xerrors.Errorf("ephemeral parameter %q can be used only with --prompt-ephemeral-parameters or --ephemeral-parameter flag", r.Name)
}
if !tvp.Mutable && action != WorkspaceCreate && !pr.isFirstTimeUse(r.Name) {
if !tvp.Mutable && action != WorkspaceCreate {
return xerrors.Errorf("parameter %q is immutable and cannot be updated", r.Name)
}
}
+43 -120
View File
@@ -137,15 +137,6 @@ func createOIDCConfig(ctx context.Context, logger slog.Logger, vals *codersdk.De
if err != nil {
return nil, xerrors.Errorf("parse oidc oauth callback url: %w", err)
}
if vals.OIDC.RedirectURL.String() != "" {
redirectURL, err = vals.OIDC.RedirectURL.Value().Parse("/api/v2/users/oidc/callback")
if err != nil {
return nil, xerrors.Errorf("parse oidc redirect url %q", err)
}
logger.Warn(ctx, "custom OIDC redirect URL used instead of 'access_url', ensure this matches the value configured in your OIDC provider")
}
// If the scopes contain 'groups', we enable group support.
// Do not override any custom value set by the user.
if slice.Contains(vals.OIDC.Scopes, "groups") && vals.OIDC.GroupField == "" {
@@ -617,8 +608,28 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
}
}
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
if err != nil {
return xerrors.Errorf("read external auth providers from env: %w", err)
}
promRegistry := prometheus.NewRegistry()
oauthInstrument := promoauth.NewFactory(promRegistry)
vals.ExternalAuthConfigs.Value = append(vals.ExternalAuthConfigs.Value, extAuthEnv...)
externalAuthConfigs, err := externalauth.ConvertConfig(
oauthInstrument,
vals.ExternalAuthConfigs.Value,
vals.AccessURL.Value(),
)
if err != nil {
return xerrors.Errorf("convert external auth config: %w", err)
}
for _, c := range externalAuthConfigs {
logger.Debug(
ctx, "loaded external auth config",
slog.F("id", c.ID),
)
}
realIPConfig, err := httpmw.ParseRealIPConfig(vals.ProxyTrustedHeaders, vals.ProxyTrustedOrigins)
if err != nil {
@@ -649,7 +660,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
Pubsub: nil,
CacheDir: cacheDir,
GoogleTokenValidator: googleTokenValidator,
ExternalAuthConfigs: nil,
ExternalAuthConfigs: externalAuthConfigs,
RealIPConfig: realIPConfig,
SSHKeygenAlgorithm: sshKeygenAlgorithm,
TracerProvider: tracerProvider,
@@ -809,40 +820,6 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
return xerrors.Errorf("set deployment id: %w", err)
}
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
if err != nil {
return xerrors.Errorf("read external auth providers from env: %w", err)
}
mergedExternalAuthProviders := append([]codersdk.ExternalAuthConfig{}, vals.ExternalAuthConfigs.Value...)
mergedExternalAuthProviders = append(mergedExternalAuthProviders, extAuthEnv...)
vals.ExternalAuthConfigs.Value = mergedExternalAuthProviders
mergedExternalAuthProviders, err = maybeAppendDefaultGithubExternalAuthProvider(
ctx,
options.Logger,
options.Database,
vals,
mergedExternalAuthProviders,
)
if err != nil {
return xerrors.Errorf("maybe append default github external auth provider: %w", err)
}
options.ExternalAuthConfigs, err = externalauth.ConvertConfig(
oauthInstrument,
mergedExternalAuthProviders,
vals.AccessURL.Value(),
)
if err != nil {
return xerrors.Errorf("convert external auth config: %w", err)
}
for _, c := range options.ExternalAuthConfigs {
logger.Debug(
ctx, "loaded external auth config",
slog.F("id", c.ID),
)
}
// Manage push notifications.
experiments := coderd.ReadExperiments(options.Logger, options.DeploymentValues.Experiments.Value())
if experiments.Enabled(codersdk.ExperimentWebPush) {
@@ -1940,79 +1917,6 @@ type githubOAuth2ConfigParams struct {
enterpriseBaseURL string
}
func isDeploymentEligibleForGithubDefaultProvider(ctx context.Context, db database.Store) (bool, error) {
// We want to enable the default provider only for new deployments, and avoid
// enabling it if a deployment was upgraded from an older version.
// nolint:gocritic // Requires system privileges
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return false, xerrors.Errorf("get github default eligible: %w", err)
}
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
if defaultEligibleNotSet {
// nolint:gocritic // User count requires system privileges
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
if err != nil {
return false, xerrors.Errorf("get user count: %w", err)
}
// We check if a deployment is new by checking if it has any users.
defaultEligible = userCount == 0
// nolint:gocritic // Requires system privileges
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
return false, xerrors.Errorf("upsert github default eligible: %w", err)
}
}
return defaultEligible, nil
}
func maybeAppendDefaultGithubExternalAuthProvider(
ctx context.Context,
logger slog.Logger,
db database.Store,
vals *codersdk.DeploymentValues,
mergedExplicitProviders []codersdk.ExternalAuthConfig,
) ([]codersdk.ExternalAuthConfig, error) {
if !vals.ExternalAuthGithubDefaultProviderEnable.Value() {
logger.Info(ctx, "default github external auth provider suppressed",
slog.F("reason", "disabled by configuration"),
slog.F("flag", "external-auth-github-default-provider-enable"),
)
return mergedExplicitProviders, nil
}
if len(mergedExplicitProviders) > 0 {
logger.Info(ctx, "default github external auth provider suppressed",
slog.F("reason", "explicit external auth providers configured"),
slog.F("provider_count", len(mergedExplicitProviders)),
)
return mergedExplicitProviders, nil
}
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
if err != nil {
return nil, err
}
if !defaultEligible {
logger.Info(ctx, "default github external auth provider suppressed",
slog.F("reason", "deployment is not eligible"),
)
return mergedExplicitProviders, nil
}
logger.Info(ctx, "injecting default github external auth provider",
slog.F("type", codersdk.EnhancedExternalAuthProviderGitHub.String()),
slog.F("client_id", GithubOAuth2DefaultProviderClientID),
slog.F("device_flow", GithubOAuth2DefaultProviderDeviceFlow),
)
return append(mergedExplicitProviders, codersdk.ExternalAuthConfig{
Type: codersdk.EnhancedExternalAuthProviderGitHub.String(),
ClientID: GithubOAuth2DefaultProviderClientID,
DeviceFlow: GithubOAuth2DefaultProviderDeviceFlow,
}), nil
}
func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *codersdk.DeploymentValues) (*githubOAuth2ConfigParams, error) {
params := githubOAuth2ConfigParams{
accessURL: vals.AccessURL.Value(),
@@ -2037,9 +1941,28 @@ func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *c
return nil, nil //nolint:nilnil
}
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
if err != nil {
return nil, err
// Check if the deployment is eligible for the default GitHub OAuth2 provider.
// We want to enable it only for new deployments, and avoid enabling it
// if a deployment was upgraded from an older version.
// nolint:gocritic // Requires system privileges
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get github default eligible: %w", err)
}
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
if defaultEligibleNotSet {
// nolint:gocritic // User count requires system privileges
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
if err != nil {
return nil, xerrors.Errorf("get user count: %w", err)
}
// We check if a deployment is new by checking if it has any users.
defaultEligible = userCount == 0
// nolint:gocritic // Requires system privileges
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
return nil, xerrors.Errorf("upsert github default eligible: %w", err)
}
}
if !defaultEligible {
-162
View File
@@ -53,7 +53,6 @@ import (
"github.com/coder/coder/v2/coderd/database/migrations"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/coderd/userpassword"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/cryptorand"
"github.com/coder/coder/v2/pty/ptytest"
@@ -1741,18 +1740,6 @@ func TestServer(t *testing.T) {
// Next, we instruct the same server to display the YAML config
// and then save it.
// Because this is literally the same invocation, DefaultFn sets the
// value of 'Default'. Which triggers a mutually exclusive error
// on the next parse.
// Usually we only parse flags once, so this is not an issue
for _, c := range inv.Command.Children {
if c.Name() == "server" {
for i := range c.Options {
c.Options[i].DefaultFn = nil
}
break
}
}
inv = inv.WithContext(testutil.Context(t, testutil.WaitMedium))
//nolint:gocritic
inv.Args = append(args, "--write-config")
@@ -1806,155 +1793,6 @@ func TestServer(t *testing.T) {
})
}
//nolint:tparallel,paralleltest // This test sets environment variables.
func TestServer_ExternalAuthGitHubDefaultProvider(t *testing.T) {
type testCase struct {
name string
args []string
env map[string]string
createUserPreStart bool
expectedProviders []string
}
run := func(t *testing.T, tc testCase) {
ctx := testutil.Context(t, testutil.WaitLong)
unsetPrefixedEnv := func(prefix string) {
t.Helper()
for _, envVar := range os.Environ() {
envKey, _, found := strings.Cut(envVar, "=")
if !found || !strings.HasPrefix(envKey, prefix) {
continue
}
value, had := os.LookupEnv(envKey)
require.True(t, had)
require.NoError(t, os.Unsetenv(envKey))
keyCopy := envKey
valueCopy := value
t.Cleanup(func() {
// This is for setting/unsetting a number of prefixed env vars.
// t.Setenv doesn't cover this use case.
// nolint:usetesting
_ = os.Setenv(keyCopy, valueCopy)
})
}
}
unsetPrefixedEnv("CODER_EXTERNAL_AUTH_")
unsetPrefixedEnv("CODER_GITAUTH_")
dbURL, err := dbtestutil.Open(t)
require.NoError(t, err)
db, _ := dbtestutil.NewDB(t, dbtestutil.WithURL(dbURL))
const (
existingUserEmail = "existing-user@coder.com"
existingUserUsername = "existing-user"
existingUserPassword = "SomeSecurePassword!"
)
if tc.createUserPreStart {
hashedPassword, err := userpassword.Hash(existingUserPassword)
require.NoError(t, err)
_ = dbgen.User(t, db, database.User{
Email: existingUserEmail,
Username: existingUserUsername,
HashedPassword: []byte(hashedPassword),
})
}
args := []string{
"server",
"--postgres-url", dbURL,
"--http-address", ":0",
"--access-url", "https://example.com",
}
args = append(args, tc.args...)
inv, cfg := clitest.New(t, args...)
for envKey, value := range tc.env {
t.Setenv(envKey, value)
}
clitest.Start(t, inv)
accessURL := waitAccessURL(t, cfg)
client := codersdk.New(accessURL)
if tc.createUserPreStart {
loginResp, err := client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{
Email: existingUserEmail,
Password: existingUserPassword,
})
require.NoError(t, err)
client.SetSessionToken(loginResp.SessionToken)
} else {
_ = coderdtest.CreateFirstUser(t, client)
}
externalAuthResp, err := client.ListExternalAuths(ctx)
require.NoError(t, err)
gotProviders := map[string]codersdk.ExternalAuthLinkProvider{}
for _, provider := range externalAuthResp.Providers {
gotProviders[provider.ID] = provider
}
require.Len(t, gotProviders, len(tc.expectedProviders))
for _, providerID := range tc.expectedProviders {
provider, ok := gotProviders[providerID]
require.Truef(t, ok, "expected provider %q to be configured", providerID)
if providerID == codersdk.EnhancedExternalAuthProviderGitHub.String() {
require.Equal(t, codersdk.EnhancedExternalAuthProviderGitHub.String(), provider.Type)
require.True(t, provider.Device)
}
}
}
for _, tc := range []testCase{
{
name: "NewDeployment_NoExplicitProviders_InjectsDefaultGithub",
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitHub.String()},
},
{
name: "ExistingDeployment_DoesNotInjectDefaultGithub",
createUserPreStart: true,
expectedProviders: nil,
},
{
name: "DefaultProviderDisabled_DoesNotInjectDefaultGithub",
args: []string{
"--external-auth-github-default-provider-enable=false",
},
expectedProviders: nil,
},
{
name: "ExplicitProviderViaConfig_DoesNotInjectDefaultGithub",
args: []string{
`--external-auth-providers=[{"type":"gitlab","client_id":"config-client-id"}]`,
},
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
},
{
name: "ExplicitProviderViaEnv_DoesNotInjectDefaultGithub",
env: map[string]string{
"CODER_EXTERNAL_AUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
"CODER_EXTERNAL_AUTH_0_CLIENT_ID": "env-client-id",
},
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
},
{
name: "ExplicitProviderViaLegacyEnv_DoesNotInjectDefaultGithub",
env: map[string]string{
"CODER_GITAUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
"CODER_GITAUTH_0_CLIENT_ID": "legacy-env-client-id",
},
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
},
} {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
//nolint:tparallel,paralleltest // This test sets environment variables.
func TestServer_Logging_NoParallel(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+31 -7
View File
@@ -25,7 +25,11 @@ func TestSharingShare(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -64,8 +68,12 @@ func TestSharingShare(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
orgOwner = coderdtest.CreateFirstUser(t, client)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -119,7 +127,11 @@ func TestSharingShare(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -170,7 +182,11 @@ func TestSharingStatus(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -214,7 +230,11 @@ func TestSharingRemove(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -271,7 +291,11 @@ func TestSharingRemove(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
+1 -1
View File
@@ -120,7 +120,7 @@ func (r *RootCmd) start() *serpent.Command {
func buildWorkspaceStartRequest(inv *serpent.Invocation, client *codersdk.Client, workspace codersdk.Workspace, parameterFlags workspaceParameterFlags, buildFlags buildFlags, action WorkspaceCLIAction) (codersdk.CreateWorkspaceBuildRequest, error) {
version := workspace.LatestBuild.TemplateVersionID
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || workspace.TemplateRequireActiveVersion || action == WorkspaceUpdate {
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || action == WorkspaceUpdate {
version = workspace.TemplateActiveVersionID
if version != workspace.LatestBuild.TemplateVersionID {
action = WorkspaceUpdate
+4 -4
View File
@@ -33,7 +33,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Do()
statefilePath := filepath.Join(t.TempDir(), "state")
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name, statefilePath)
@@ -54,7 +54,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Do()
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name)
var gotState bytes.Buffer
@@ -74,7 +74,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Do()
inv, root := clitest.New(t, "state", "pull", taUser.Username+"/"+r.Workspace.Name,
"--build", fmt.Sprintf("%d", r.Build.BuildNumber))
@@ -170,7 +170,7 @@ func TestStatePush(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(initialState).
Seed(database.WorkspaceBuild{ProvisionerState: initialState}).
Do()
wantState := []byte("updated state")
stateFile, err := os.CreateTemp(t.TempDir(), "")
+7 -9
View File
@@ -1,3 +1,5 @@
//go:build !windows
package cli_test
import (
@@ -5,7 +7,6 @@ import (
"context"
"os"
"path/filepath"
"runtime"
"testing"
"time"
@@ -24,15 +25,12 @@ func setupSocketServer(t *testing.T) (path string, cleanup func()) {
t.Helper()
// Use a temporary socket path for each test
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
// Create parent directory if needed. Not necessary on Windows because named pipes live in an abstract namespace
// not tied to any real files.
if runtime.GOOS != "windows" {
parentDir := filepath.Dir(socketPath)
err := os.MkdirAll(parentDir, 0o700)
require.NoError(t, err, "create socket directory")
}
// Create parent directory if needed
parentDir := filepath.Dir(socketPath)
err := os.MkdirAll(parentDir, 0o700)
require.NoError(t, err, "create socket directory")
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
-1
View File
@@ -18,7 +18,6 @@ func (r *RootCmd) tasksCommand() *serpent.Command {
r.taskList(),
r.taskLogs(),
r.taskPause(),
r.taskResume(),
r.taskSend(),
r.taskStatus(),
},
-95
View File
@@ -1,95 +0,0 @@
package cli
import (
"fmt"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/pretty"
"github.com/coder/serpent"
)
func (r *RootCmd) taskResume() *serpent.Command {
var noWait bool
cmd := &serpent.Command{
Use: "resume <task>",
Short: "Resume a task",
Long: FormatExamples(
Example{
Description: "Resume a task by name",
Command: "coder task resume my-task",
},
Example{
Description: "Resume another user's task",
Command: "coder task resume alice/my-task",
},
Example{
Description: "Resume a task without confirmation",
Command: "coder task resume my-task --yes",
},
),
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
Options: serpent.OptionSet{
{
Flag: "no-wait",
Description: "Return immediately after resuming the task.",
Value: serpent.BoolOf(&noWait),
},
cliui.SkipPromptOption(),
},
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
client, err := r.InitClient(inv)
if err != nil {
return err
}
task, err := client.TaskByIdentifier(ctx, inv.Args[0])
if err != nil {
return xerrors.Errorf("resolve task %q: %w", inv.Args[0], err)
}
display := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
if task.Status == codersdk.TaskStatusError || task.Status == codersdk.TaskStatusUnknown {
return xerrors.Errorf("task %q is in %s state and cannot be resumed; check the workspace build logs and agent status for details", display, task.Status)
} else if task.Status != codersdk.TaskStatusPaused {
return xerrors.Errorf("task %q cannot be resumed (current status: %s)", display, task.Status)
}
_, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: fmt.Sprintf("Resume task %s?", pretty.Sprint(cliui.DefaultStyles.Code, display)),
IsConfirm: true,
Default: cliui.ConfirmNo,
})
if err != nil {
return err
}
resp, err := client.ResumeTask(ctx, task.OwnerName, task.ID)
if err != nil {
return xerrors.Errorf("resume task %q: %w", display, err)
} else if resp.WorkspaceBuild == nil {
return xerrors.Errorf("resume task %q: no workspace build returned", display)
}
if noWait {
_, _ = fmt.Fprintf(inv.Stdout, "Resuming task %q in the background.\n", cliui.Keyword(display))
return nil
}
if err = cliui.WorkspaceBuild(ctx, inv.Stdout, client, resp.WorkspaceBuild.ID); err != nil {
return xerrors.Errorf("watch resume build for task %q: %w", display, err)
}
_, _ = fmt.Fprintf(inv.Stdout, "\nThe %s task has been resumed.\n", cliui.Keyword(display))
return nil
},
}
return cmd
}
-183
View File
@@ -1,183 +0,0 @@
package cli_test
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
func TestExpTaskResume(t *testing.T) {
t.Parallel()
// pauseTask is a helper that pauses a task and waits for the stop
// build to complete.
pauseTask := func(ctx context.Context, t *testing.T, client *codersdk.Client, task codersdk.Task) {
t.Helper()
pauseResp, err := client.PauseTask(ctx, task.OwnerName, task.ID)
require.NoError(t, err)
require.NotNil(t, pauseResp.WorkspaceBuild)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, pauseResp.WorkspaceBuild.ID)
}
t.Run("WithYesFlag", func(t *testing.T) {
t.Parallel()
// Given: A paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume the task
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
output := clitest.Capture(inv)
clitest.SetupConfig(t, userClient, root)
// Then: We expect the task to be resumed
ctx := testutil.Context(t, testutil.WaitMedium)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
require.Contains(t, output.Stdout(), "has been resumed")
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
})
// OtherUserTask verifies that an admin can resume a task owned by
// another user using the "owner/name" identifier format.
t.Run("OtherUserTask", func(t *testing.T) {
t.Parallel()
// Given: A different user's paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
adminClient, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume their task
identifier := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
inv, root := clitest.New(t, "task", "resume", identifier, "--yes")
output := clitest.Capture(inv)
clitest.SetupConfig(t, adminClient, root)
// Then: We expect the task to be resumed
ctx := testutil.Context(t, testutil.WaitMedium)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
require.Contains(t, output.Stdout(), "has been resumed")
updated, err := adminClient.TaskByIdentifier(ctx, identifier)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
})
t.Run("NoWait", func(t *testing.T) {
t.Parallel()
// Given: A paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume the task (and specify no wait)
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes", "--no-wait")
output := clitest.Capture(inv)
clitest.SetupConfig(t, userClient, root)
// Then: We expect the task to be resumed in the background
ctx := testutil.Context(t, testutil.WaitMedium)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
require.Contains(t, output.Stdout(), "in the background")
// And: The task to eventually be resumed
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
ws := coderdtest.MustWorkspace(t, userClient, task.WorkspaceID.UUID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, ws.LatestBuild.ID)
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
})
t.Run("PromptConfirm", func(t *testing.T) {
t.Parallel()
// Given: A paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume the task
inv, root := clitest.New(t, "task", "resume", task.Name)
clitest.SetupConfig(t, userClient, root)
// And: We confirm we want to resume the task
ctx := testutil.Context(t, testutil.WaitMedium)
inv = inv.WithContext(ctx)
pty := ptytest.New(t).Attach(inv)
w := clitest.StartWithWaiter(t, inv)
pty.ExpectMatchContext(ctx, "Resume task")
pty.WriteLine("yes")
// Then: We expect the task to be resumed
pty.ExpectMatchContext(ctx, "has been resumed")
require.NoError(t, w.Wait())
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
})
t.Run("PromptDecline", func(t *testing.T) {
t.Parallel()
// Given: A paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume the task
inv, root := clitest.New(t, "task", "resume", task.Name)
clitest.SetupConfig(t, userClient, root)
// But: Say no at the confirmation screen
ctx := testutil.Context(t, testutil.WaitMedium)
inv = inv.WithContext(ctx)
pty := ptytest.New(t).Attach(inv)
w := clitest.StartWithWaiter(t, inv)
pty.ExpectMatchContext(ctx, "Resume task")
pty.WriteLine("no")
require.Error(t, w.Wait())
// Then: We expect the task to still be paused
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
})
t.Run("TaskNotPaused", func(t *testing.T) {
t.Parallel()
// Given: A running task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
// When: We attempt to resume the task that is not paused
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
clitest.SetupConfig(t, userClient, root)
// Then: We expect to get an error that the task is not paused
ctx := testutil.Context(t, testutil.WaitMedium)
err := inv.WithContext(ctx).Run()
require.ErrorContains(t, err, "cannot be resumed")
})
}
-17
View File
@@ -137,23 +137,6 @@ func Test_Tasks(t *testing.T) {
require.Equal(t, codersdk.TaskStatusPaused, task.Status, "task should be paused")
},
},
{
name: "resume task",
cmdArgs: []string{"task", "resume", taskName, "--yes"},
assertFn: func(stdout string, userClient *codersdk.Client) {
require.Contains(t, stdout, "has been resumed", "resume output should confirm task was resumed")
},
},
{
name: "get task status after resume",
cmdArgs: []string{"task", "status", taskName, "--output", "json"},
assertFn: func(stdout string, userClient *codersdk.Client) {
var task codersdk.Task
require.NoError(t, json.NewDecoder(strings.NewReader(stdout)).Decode(&task), "should unmarshal task status")
require.Equal(t, taskName, task.Name, "task name should match")
require.Equal(t, codersdk.TaskStatusInitializing, task.Status, "task should be initializing after resume")
},
},
{
name: "delete task",
cmdArgs: []string{"task", "delete", taskName, "--yes"},
-4
View File
@@ -139,10 +139,8 @@ func (r *RootCmd) templateVersionsList() *serpent.Command {
type templateVersionRow struct {
// For json format:
TemplateVersion codersdk.TemplateVersion `table:"-"`
ActiveJSON bool `json:"active" table:"-"`
// For table format:
ID string `json:"-" table:"id"`
Name string `json:"-" table:"name,default_sort"`
CreatedAt time.Time `json:"-" table:"created at"`
CreatedBy string `json:"-" table:"created by"`
@@ -168,8 +166,6 @@ func templateVersionsToRows(activeVersionID uuid.UUID, templateVersions ...coder
rows[i] = templateVersionRow{
TemplateVersion: templateVersion,
ActiveJSON: templateVersion.ID == activeVersionID,
ID: templateVersion.ID.String(),
Name: templateVersion.Name,
CreatedAt: templateVersion.CreatedAt,
CreatedBy: templateVersion.CreatedBy.Username,
-29
View File
@@ -1,9 +1,7 @@
package cli_test
import (
"bytes"
"context"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
@@ -42,33 +40,6 @@ func TestTemplateVersions(t *testing.T) {
pty.ExpectMatch(version.CreatedBy.Username)
pty.ExpectMatch("Active")
})
t.Run("ListVersionsJSON", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
_ = coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
inv, root := clitest.New(t, "templates", "versions", "list", template.Name, "--output", "json")
clitest.SetupConfig(t, member, root)
var stdout bytes.Buffer
inv.Stdout = &stdout
require.NoError(t, inv.Run())
var rows []struct {
TemplateVersion codersdk.TemplateVersion `json:"TemplateVersion"`
Active bool `json:"active"`
}
require.NoError(t, json.Unmarshal(stdout.Bytes(), &rows))
require.Len(t, rows, 1)
assert.Equal(t, version.ID, rows[0].TemplateVersion.ID)
assert.True(t, rows[0].Active)
})
}
func TestTemplateVersionsPromote(t *testing.T) {
+5 -11
View File
@@ -49,9 +49,10 @@ OPTIONS:
security purposes if a --wildcard-access-url is configured.
--disable-workspace-sharing bool, $CODER_DISABLE_WORKSPACE_SHARING
Disable workspace sharing. Workspace ACL checking is disabled and only
owners can have ssh, apps and terminal access to workspaces. Access
based on the 'owner' role is also allowed unless disabled via
Disable workspace sharing (requires the "workspace-sharing" experiment
to be enabled). Workspace ACL checking is disabled and only owners can
have ssh, apps and terminal access to workspaces. Access based on the
'owner' role is also allowed unless disabled via
--disable-owner-workspace-access.
--swagger-enable bool, $CODER_SWAGGER_ENABLE
@@ -62,9 +63,6 @@ OPTIONS:
Separate multiple experiments with commas, or enter '*' to opt-in to
all available experiments.
--external-auth-github-default-provider-enable bool, $CODER_EXTERNAL_AUTH_GITHUB_DEFAULT_PROVIDER_ENABLE (default: true)
Enable the default GitHub external auth provider managed by Coder.
--postgres-auth password|awsiamrds, $CODER_PG_AUTH (default: password)
Type of auth to use when connecting to postgres. For AWS RDS, using
IAM authentication (awsiamrds) is recommended.
@@ -385,17 +383,13 @@ NETWORKING OPTIONS:
--samesite-auth-cookie lax|none, $CODER_SAMESITE_AUTH_COOKIE (default: lax)
Controls the 'SameSite' property is set on browser session cookies.
--secure-auth-cookie bool, $CODER_SECURE_AUTH_COOKIE (default: false)
--secure-auth-cookie bool, $CODER_SECURE_AUTH_COOKIE
Controls if the 'Secure' property is set on browser session cookies.
--wildcard-access-url string, $CODER_WILDCARD_ACCESS_URL
Specifies the wildcard hostname to use for workspace applications in
the form "*.example.com".
--host-prefix-cookie bool, $CODER_HOST_PREFIX_COOKIE (default: false)
Recommended to be enabled. Enables `__Host-` prefix for cookies to
guarantee they are only set by the right domain.
NETWORKING / DERP OPTIONS:
Most Coder deployments never have to think about DERP because all connections
between workspaces and users are peer-to-peer. However, when Coder cannot
-1
View File
@@ -13,7 +13,6 @@ SUBCOMMANDS:
list List tasks
logs Show a task's logs
pause Pause a task
resume Resume a task
send Send input to a task
status Show the status of a task.
-28
View File
@@ -1,28 +0,0 @@
coder v0.0.0-devel
USAGE:
coder task resume [flags] <task>
Resume a task
- Resume a task by name:
$ coder task resume my-task
- Resume another user's task:
$ coder task resume alice/my-task
- Resume a task without confirmation:
$ coder task resume my-task --yes
OPTIONS:
--no-wait bool
Return immediately after resuming the task.
-y, --yes bool
Bypass confirmation prompts.
———
Run `coder --help` for a list of global options.
+1 -1
View File
@@ -9,7 +9,7 @@ OPTIONS:
-O, --org string, $CODER_ORGANIZATION
Select which organization (uuid or name) to use.
-c, --column [id|name|created at|created by|status|active|archived] (default: name,created at,created by,status,active)
-c, --column [name|created at|created by|status|active|archived] (default: name,created at,created by,status,active)
Columns to display in table output.
--include-archived bool
+1 -1
View File
@@ -27,7 +27,7 @@ USAGE:
SUBCOMMANDS:
create Create a token
list List tokens
remove Expire or delete a token
remove Delete a token
view Display detailed information about a token
———
-4
View File
@@ -15,10 +15,6 @@ OPTIONS:
-c, --column [id|name|scopes|allow list|last used|expires at|created at|owner] (default: id,name,scopes,allow list,last used,expires at,created at)
Columns to display in table output.
--include-expired bool
Include expired tokens in the output. By default, expired tokens are
hidden.
-o, --output table|json (default: table)
Output format.
+2 -10
View File
@@ -1,19 +1,11 @@
coder v0.0.0-devel
USAGE:
coder tokens remove [flags] <name|id|token>
coder tokens remove <name|id|token>
Expire or delete a token
Delete a token
Aliases: delete, rm
Remove a token by expiring it. Use --delete to permanently hard-delete the
token instead.
OPTIONS:
--delete bool
Permanently delete the token instead of expiring it. This removes the
audit trail.
———
Run `coder --help` for a list of global options.
+5 -17
View File
@@ -176,15 +176,11 @@ networking:
# (default: <unset>, type: string-array)
proxyTrustedOrigins: []
# Controls if the 'Secure' property is set on browser session cookies.
# (default: false, type: bool)
# (default: <unset>, type: bool)
secureAuthCookie: false
# Controls the 'SameSite' property is set on browser session cookies.
# (default: lax, type: enum[lax\|none])
sameSiteAuthCookie: lax
# Recommended to be enabled. Enables `__Host-` prefix for cookies to guarantee
# they are only set by the right domain.
# (default: false, type: bool)
hostPrefixCookie: false
# Whether Coder only allows connections to workspaces via the browser.
# (default: <unset>, type: bool)
browserOnly: false
@@ -421,11 +417,6 @@ oidc:
# an insecure OIDC configuration. It is not recommended to use this flag.
# (default: <unset>, type: bool)
dangerousSkipIssuerChecks: false
# Optional override of the default redirect url which uses the deployment's access
# url. Useful in situations where a deployment has more than 1 domain. Using this
# setting can also break OIDC, so use with caution.
# (default: <unset>, type: url)
oidc-redirect-url:
# Telemetry is critical to our ability to improve Coder. We strip all personal
# information before sending data to our servers. Please only disable telemetry
# when required by your organization's security policy.
@@ -523,10 +514,10 @@ disablePathApps: false
# workspaces.
# (default: <unset>, type: bool)
disableOwnerWorkspaceAccess: false
# Disable workspace sharing. Workspace ACL checking is disabled and only owners
# can have ssh, apps and terminal access to workspaces. Access based on the
# 'owner' role is also allowed unless disabled via
# --disable-owner-workspace-access.
# Disable workspace sharing (requires the "workspace-sharing" experiment to be
# enabled). Workspace ACL checking is disabled and only owners can have ssh, apps
# and terminal access to workspaces. Access based on the 'owner' role is also
# allowed unless disabled via --disable-owner-workspace-access.
# (default: <unset>, type: bool)
disableWorkspaceSharing: false
# These options change the behavior of how clients interact with the Coder.
@@ -563,9 +554,6 @@ supportLinks: []
# External Authentication providers.
# (default: <unset>, type: struct[[]codersdk.ExternalAuthConfig])
externalAuthProviders: []
# Enable the default GitHub external auth provider managed by Coder.
# (default: true, type: bool)
externalAuthGithubDefaultProviderEnable: true
# Hostname of HTTPS server that runs https://github.com/coder/wgtunnel. By
# default, this will pick the best available wgtunnel server hosted by Coder. e.g.
# "tunnel.example.com".
+14 -37
View File
@@ -218,10 +218,9 @@ func (r *RootCmd) listTokens() *serpent.Command {
}
var (
all bool
includeExpired bool
displayTokens []tokenListRow
formatter = cliui.NewOutputFormatter(
all bool
displayTokens []tokenListRow
formatter = cliui.NewOutputFormatter(
cliui.TableFormat([]tokenListRow{}, defaultCols),
cliui.JSONFormat(),
)
@@ -241,8 +240,7 @@ func (r *RootCmd) listTokens() *serpent.Command {
}
tokens, err := client.Tokens(inv.Context(), codersdk.Me, codersdk.TokensFilter{
IncludeAll: all,
IncludeExpired: includeExpired,
IncludeAll: all,
})
if err != nil {
return xerrors.Errorf("list tokens: %w", err)
@@ -276,12 +274,6 @@ func (r *RootCmd) listTokens() *serpent.Command {
Description: "Specifies whether all users' tokens will be listed or not (must have Owner role to see all tokens).",
Value: serpent.BoolOf(&all),
},
{
Name: "include-expired",
Flag: "include-expired",
Description: "Include expired tokens in the output. By default, expired tokens are hidden.",
Value: serpent.BoolOf(&includeExpired),
},
}
formatter.AttachOptions(&cmd.Options)
@@ -331,13 +323,10 @@ func (r *RootCmd) viewToken() *serpent.Command {
}
func (r *RootCmd) removeToken() *serpent.Command {
var deleteToken bool
cmd := &serpent.Command{
Use: "remove <name|id|token>",
Aliases: []string{"delete"},
Short: "Expire or delete a token",
Long: "Remove a token by expiring it. Use --delete to permanently hard-" +
"delete the token instead.",
Short: "Delete a token",
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
@@ -349,7 +338,7 @@ func (r *RootCmd) removeToken() *serpent.Command {
token, err := client.APIKeyByName(inv.Context(), codersdk.Me, inv.Args[0])
if err != nil {
// If it's a token, we need to extract the ID.
// If it's a token, we need to extract the ID
maybeID := strings.Split(inv.Args[0], "-")[0]
token, err = client.APIKeyByID(inv.Context(), codersdk.Me, maybeID)
if err != nil {
@@ -357,29 +346,17 @@ func (r *RootCmd) removeToken() *serpent.Command {
}
}
if deleteToken {
err = client.DeleteAPIKey(inv.Context(), codersdk.Me, token.ID)
if err != nil {
return xerrors.Errorf("delete api key: %w", err)
}
cliui.Infof(inv.Stdout, "Token has been deleted.")
return nil
}
err = client.ExpireAPIKey(inv.Context(), codersdk.Me, token.ID)
err = client.DeleteAPIKey(inv.Context(), codersdk.Me, token.ID)
if err != nil {
return xerrors.Errorf("expire api key: %w", err)
return xerrors.Errorf("delete api key: %w", err)
}
cliui.Infof(inv.Stdout, "Token has been expired.")
return nil
},
}
cmd.Options = serpent.OptionSet{
{
Flag: "delete",
Description: "Permanently delete the token instead of expiring it. This removes the audit trail.",
Value: serpent.BoolOf(&deleteToken),
cliui.Infof(
inv.Stdout,
"Token has been deleted.",
)
return nil
},
}
+17 -153
View File
@@ -6,16 +6,12 @@ import (
"encoding/json"
"fmt"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
@@ -26,7 +22,7 @@ func TestTokens(t *testing.T) {
adminUser := coderdtest.CreateFirstUser(t, client)
secondUserClient, secondUser := coderdtest.CreateAnotherUser(t, client, adminUser.OrganizationID)
thirdUserClient, thirdUser := coderdtest.CreateAnotherUser(t, client, adminUser.OrganizationID)
_, thirdUser := coderdtest.CreateAnotherUser(t, client, adminUser.OrganizationID)
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancelFunc()
@@ -159,7 +155,7 @@ func TestTokens(t *testing.T) {
require.Len(t, scopedToken.AllowList, 1)
require.Equal(t, allowSpec, scopedToken.AllowList[0].String())
// Delete by name (default behavior is now expire)
// Delete by name
inv, root = clitest.New(t, "tokens", "rm", "token-one")
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
@@ -168,42 +164,21 @@ func TestTokens(t *testing.T) {
require.NoError(t, err)
res = buf.String()
require.NotEmpty(t, res)
require.Contains(t, res, "expired")
// Regular users cannot expire other users' tokens (expire is default now).
inv, root = clitest.New(t, "tokens", "rm", secondTokenID)
clitest.SetupConfig(t, thirdUserClient, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
// Only admin users can expire other users' tokens (expire is default now).
inv, root = clitest.New(t, "tokens", "rm", secondTokenID)
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
// Validate that token was expired
if token, err := client.APIKeyByName(ctx, secondUser.ID.String(), "token-two"); assert.NoError(t, err) {
require.True(t, token.ExpiresAt.Before(time.Now()))
}
// Delete by ID (explicit delete flag)
inv, root = clitest.New(t, "tokens", "rm", "--delete", secondTokenID)
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
res = buf.String()
require.NotEmpty(t, res)
require.Contains(t, res, "deleted")
// Delete scoped token by ID (explicit delete flag)
inv, root = clitest.New(t, "tokens", "rm", "--delete", scopedTokenID)
// Delete by ID
inv, root = clitest.New(t, "tokens", "rm", secondTokenID)
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
res = buf.String()
require.NotEmpty(t, res)
require.Contains(t, res, "deleted")
// Delete scoped token by ID
inv, root = clitest.New(t, "tokens", "rm", scopedTokenID)
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
@@ -224,8 +199,8 @@ func TestTokens(t *testing.T) {
require.NotEmpty(t, res)
fourthToken := res
// Delete by token (explicit delete flag)
inv, root = clitest.New(t, "tokens", "rm", "--delete", fourthToken)
// Delete by token
inv, root = clitest.New(t, "tokens", "rm", fourthToken)
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
@@ -235,114 +210,3 @@ func TestTokens(t *testing.T) {
require.NotEmpty(t, res)
require.Contains(t, res, "deleted")
}
func TestTokensListExpiredFiltering(t *testing.T) {
t.Parallel()
client, _, api := coderdtest.NewWithAPI(t, nil)
owner := coderdtest.CreateFirstUser(t, client)
// Create a valid (non-expired) token
validToken, _ := dbgen.APIKey(t, api.Database, database.APIKey{
UserID: owner.UserID,
ExpiresAt: time.Now().Add(24 * time.Hour),
LoginType: database.LoginTypeToken,
TokenName: "valid-token",
})
// Create an expired token
expiredToken, _ := dbgen.APIKey(t, api.Database, database.APIKey{
UserID: owner.UserID,
ExpiresAt: time.Now().Add(-24 * time.Hour),
LoginType: database.LoginTypeToken,
TokenName: "expired-token",
})
t.Run("HidesExpiredByDefault", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
inv, root := clitest.New(t, "tokens", "ls")
clitest.SetupConfig(t, client, root)
buf := new(bytes.Buffer)
inv.Stdout = buf
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
res := buf.String()
require.Contains(t, res, validToken.ID)
require.Contains(t, res, "valid-token")
require.NotContains(t, res, expiredToken.ID)
require.NotContains(t, res, "expired-token")
})
t.Run("ShowsExpiredWithFlag", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
inv, root := clitest.New(t, "tokens", "ls", "--include-expired")
clitest.SetupConfig(t, client, root)
buf := new(bytes.Buffer)
inv.Stdout = buf
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
res := buf.String()
require.Contains(t, res, validToken.ID)
require.Contains(t, res, "valid-token")
require.Contains(t, res, expiredToken.ID)
require.Contains(t, res, "expired-token")
})
t.Run("JSONOutputRespectsFilter", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// Default (no expired)
inv, root := clitest.New(t, "tokens", "ls", "--output=json")
clitest.SetupConfig(t, client, root)
buf := new(bytes.Buffer)
inv.Stdout = buf
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
res := buf.String()
require.Contains(t, res, "valid-token")
require.NotContains(t, res, "expired-token")
// With --include-expired
inv, root = clitest.New(t, "tokens", "ls", "--output=json", "--include-expired")
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
res = buf.String()
require.Contains(t, res, "valid-token")
require.Contains(t, res, "expired-token")
})
t.Run("AllUsersWithIncludeExpired", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
inv, root := clitest.New(t, "tokens", "ls", "--all", "--include-expired")
clitest.SetupConfig(t, client, root)
buf := new(bytes.Buffer)
inv.Stdout = buf
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
res := buf.String()
// Should show both valid and expired tokens
require.Contains(t, res, validToken.ID)
require.Contains(t, res, "valid-token")
require.Contains(t, res, expiredToken.ID)
require.Contains(t, res, "expired-token")
})
}
-70
View File
@@ -990,74 +990,4 @@ func TestUpdateValidateRichParameters(t *testing.T) {
_ = testutil.TryReceive(ctx, t, doneChan)
})
t.Run("NewImmutableParameterViaFlag", func(t *testing.T) {
t.Parallel()
// Create template and workspace with only a mutable parameter.
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
templateParameters := []*proto.RichParameter{
{Name: stringParameterName, Type: "string", Mutable: true, Required: true, Options: []*proto.RichParameterOption{
{Name: "First option", Description: "This is first option", Value: "1st"},
{Name: "Second option", Description: "This is second option", Value: "2nd"},
}},
}
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses(templateParameters))
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
inv, root := clitest.New(t, "create", "my-workspace", "--yes", "--template", template.Name, "--parameter", fmt.Sprintf("%s=%s", stringParameterName, "1st"))
clitest.SetupConfig(t, member, root)
err := inv.Run()
require.NoError(t, err)
// Update template: add a new immutable parameter.
updatedTemplateParameters := []*proto.RichParameter{
templateParameters[0],
{Name: immutableParameterName, Type: "string", Mutable: false, Required: true, Options: []*proto.RichParameterOption{
{Name: "fir", Description: "First option for immutable parameter", Value: "I"},
{Name: "sec", Description: "Second option for immutable parameter", Value: "II"},
}},
}
updatedVersion := coderdtest.UpdateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses(updatedTemplateParameters), template.ID)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, updatedVersion.ID)
err = client.UpdateActiveTemplateVersion(context.Background(), template.ID, codersdk.UpdateActiveTemplateVersion{
ID: updatedVersion.ID,
})
require.NoError(t, err)
// Update workspace, supplying the new immutable parameter via
// the --parameter flag. This should succeed because it's the
// first time this parameter is being set.
inv, root = clitest.New(t, "update", "my-workspace",
"--parameter", fmt.Sprintf("%s=%s", immutableParameterName, "II"))
clitest.SetupConfig(t, member, root)
pty := ptytest.New(t).Attach(inv)
doneChan := make(chan struct{})
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
pty.ExpectMatch("Planning workspace")
ctx := testutil.Context(t, testutil.WaitLong)
_ = testutil.TryReceive(ctx, t, doneChan)
// Verify the immutable parameter was set correctly.
workspace, err := client.WorkspaceByOwnerAndName(ctx, memberUser.ID.String(), "my-workspace", codersdk.WorkspaceOptions{})
require.NoError(t, err)
actualParameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID)
require.NoError(t, err)
require.Contains(t, actualParameters, codersdk.WorkspaceBuildParameter{
Name: immutableParameterName,
Value: "II",
})
})
}
-2
View File
@@ -179,8 +179,6 @@ func New(opts Options, workspace database.Workspace) *API {
Database: opts.Database,
Log: opts.Log,
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
Clock: opts.Clock,
NotificationsEnqueuer: opts.NotificationsEnqueuer,
}
api.MetadataAPI = &MetadataAPI{
-240
View File
@@ -2,10 +2,6 @@ package agentapi
import (
"context"
"database/sql"
"fmt"
"net/http"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
@@ -13,14 +9,7 @@ import (
"cdr.dev/slog/v3"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/notifications"
strutil "github.com/coder/coder/v2/coderd/util/strings"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
type AppsAPI struct {
@@ -28,8 +17,6 @@ type AppsAPI struct {
Database database.Store
Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
NotificationsEnqueuer notifications.Enqueuer
Clock quartz.Clock
}
func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) {
@@ -117,230 +104,3 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
}
return &agentproto.BatchUpdateAppHealthResponse{}, nil
}
func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
if len(req.Message) > 160 {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Message is too long.",
Detail: "Message must be less than 160 characters.",
Validations: []codersdk.ValidationError{
{Field: "message", Detail: "Message must be less than 160 characters."},
},
})
}
var dbState database.WorkspaceAppStatusState
switch req.State {
case agentproto.UpdateAppStatusRequest_COMPLETE:
dbState = database.WorkspaceAppStatusStateComplete
case agentproto.UpdateAppStatusRequest_FAILURE:
dbState = database.WorkspaceAppStatusStateFailure
case agentproto.UpdateAppStatusRequest_WORKING:
dbState = database.WorkspaceAppStatusStateWorking
case agentproto.UpdateAppStatusRequest_IDLE:
dbState = database.WorkspaceAppStatusStateIdle
default:
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Invalid state provided.",
Detail: fmt.Sprintf("invalid state: %q", req.State),
Validations: []codersdk.ValidationError{
{Field: "state", Detail: "State must be one of: complete, failure, working, idle."},
},
})
}
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, err
}
app, err := a.Database.GetWorkspaceAppByAgentIDAndSlug(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{
AgentID: workspaceAgent.ID,
Slug: req.Slug,
})
if err != nil {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace app.",
Detail: fmt.Sprintf("No app found with slug %q", req.Slug),
})
}
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
if err != nil {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace.",
Detail: err.Error(),
})
}
// Treat the message as untrusted input.
cleaned := strutil.UISanitize(req.Message)
// Get the latest status for the workspace app to detect no-op updates
// nolint:gocritic // This is a system restricted operation.
latestAppStatus, err := a.Database.GetLatestWorkspaceAppStatusByAppID(dbauthz.AsSystemRestricted(ctx), app.ID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get latest workspace app status.",
Detail: err.Error(),
})
}
// If no rows found, latestAppStatus will be a zero-value struct (ID == uuid.Nil)
// nolint:gocritic // This is a system restricted operation.
_, err = a.Database.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
ID: uuid.New(),
CreatedAt: dbtime.Now(),
WorkspaceID: workspace.ID,
AgentID: workspaceAgent.ID,
AppID: app.ID,
State: dbState,
Message: cleaned,
Uri: sql.NullString{
String: req.Uri,
Valid: req.Uri != "",
},
})
if err != nil {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to insert workspace app status.",
Detail: err.Error(),
})
}
if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentAppStatusUpdate)
if err != nil {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to publish workspace update.",
Detail: err.Error(),
})
}
}
// Notify on state change to Working/Idle for AI tasks
a.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, dbState, workspace, workspaceAgent)
if shouldBump(dbState, latestAppStatus) {
// We pass time.Time{} for nextAutostart since we don't have access to
// TemplateScheduleStore here. The activity bump logic handles this by
// defaulting to the template's activity_bump duration (typically 1 hour).
workspacestats.ActivityBumpWorkspace(ctx, a.Log, a.Database, workspace.ID, time.Time{})
}
// just return a blank response because it doesn't contain any settable fields at present.
return new(agentproto.UpdateAppStatusResponse), nil
}
func shouldBump(dbState database.WorkspaceAppStatusState, latestAppStatus database.WorkspaceAppStatus) bool {
// Bump deadline when agent reports working or transitions away from working.
// This prevents auto-pause during active work and gives users time to interact
// after work completes.
// Bump if reporting working state.
if dbState == database.WorkspaceAppStatusStateWorking {
return true
}
// Bump if transitioning away from working state.
if latestAppStatus.ID != uuid.Nil {
prevState := latestAppStatus.State
if prevState == database.WorkspaceAppStatusStateWorking {
return true
}
}
return false
}
// enqueueAITaskStateNotification enqueues a notification when an AI task's app
// transitions to Working or Idle.
// No-op if:
// - the workspace agent app isn't configured as an AI task,
// - the new state equals the latest persisted state,
// - the workspace agent is not ready (still starting up).
func (a *AppsAPI) enqueueAITaskStateNotification(
ctx context.Context,
appID uuid.UUID,
latestAppStatus database.WorkspaceAppStatus,
newAppStatus database.WorkspaceAppStatusState,
workspace database.Workspace,
agent database.WorkspaceAgent,
) {
var notificationTemplate uuid.UUID
switch newAppStatus {
case database.WorkspaceAppStatusStateWorking:
notificationTemplate = notifications.TemplateTaskWorking
case database.WorkspaceAppStatusStateIdle:
notificationTemplate = notifications.TemplateTaskIdle
case database.WorkspaceAppStatusStateComplete:
notificationTemplate = notifications.TemplateTaskCompleted
case database.WorkspaceAppStatusStateFailure:
notificationTemplate = notifications.TemplateTaskFailed
default:
// Not a notifiable state, do nothing
return
}
if !workspace.TaskID.Valid {
// Workspace has no task ID, do nothing.
return
}
// Only send notifications when the agent is ready. We want to skip
// any state transitions that occur whilst the workspace is starting
// up as it doesn't make sense to receive them.
if agent.LifecycleState != database.WorkspaceAgentLifecycleStateReady {
a.Log.Debug(ctx, "skipping AI task notification because agent is not ready",
slog.F("agent_id", agent.ID),
slog.F("lifecycle_state", agent.LifecycleState),
slog.F("new_app_status", newAppStatus),
)
return
}
task, err := a.Database.GetTaskByID(ctx, workspace.TaskID.UUID)
if err != nil {
a.Log.Warn(ctx, "failed to get task", slog.Error(err))
return
}
if !task.WorkspaceAppID.Valid || task.WorkspaceAppID.UUID != appID {
// Non-task app, do nothing.
return
}
// Skip if the latest persisted state equals the new state (no new transition)
// Note: uuid.Nil check is valid here. If no previous status exists,
// GetLatestWorkspaceAppStatusByAppID returns sql.ErrNoRows and we get a zero-value struct.
if latestAppStatus.ID != uuid.Nil && latestAppStatus.State == newAppStatus {
return
}
// Skip the initial "Working" notification when the task first starts.
// This is obvious to the user since they just created the task.
// We still notify on the first "Idle" status and all subsequent transitions.
if latestAppStatus.ID == uuid.Nil && newAppStatus == database.WorkspaceAppStatusStateWorking {
return
}
if _, err := a.NotificationsEnqueuer.EnqueueWithData(
// nolint:gocritic // Need notifier actor to enqueue notifications
dbauthz.AsNotifier(ctx),
workspace.OwnerID,
notificationTemplate,
map[string]string{
"task": task.Name,
"workspace": workspace.Name,
},
map[string]any{
// Use a 1-minute bucketed timestamp to bypass per-day dedupe,
// allowing identical content to resend within the same day
// (but not more than once every 10s).
"dedupe_bypass_ts": a.Clock.Now().UTC().Truncate(time.Minute),
},
"api-workspace-agent-app-status",
// Associate this notification with related entities
workspace.ID, workspace.OwnerID, workspace.OrganizationID, appID,
); err != nil {
a.Log.Warn(ctx, "failed to notify of task state", slog.Error(err))
return
}
}
-115
View File
@@ -1,115 +0,0 @@
package agentapi
import (
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/util/ptr"
)
func TestShouldBump(t *testing.T) {
t.Parallel()
tests := []struct {
name string
prevState *database.WorkspaceAppStatusState // nil means no previous state
newState database.WorkspaceAppStatusState
shouldBump bool
}{
{
name: "FirstStatusBumps",
prevState: nil,
newState: database.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
{
name: "WorkingToIdleBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: true,
},
{
name: "WorkingToCompleteBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
newState: database.WorkspaceAppStatusStateComplete,
shouldBump: true,
},
{
name: "CompleteToIdleNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "CompleteToCompleteNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
newState: database.WorkspaceAppStatusStateComplete,
shouldBump: false,
},
{
name: "FailureToIdleNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "FailureToFailureNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
newState: database.WorkspaceAppStatusStateFailure,
shouldBump: false,
},
{
name: "CompleteToWorkingBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
newState: database.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
{
name: "FailureToCompleteNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
newState: database.WorkspaceAppStatusStateComplete,
shouldBump: false,
},
{
name: "WorkingToFailureBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
newState: database.WorkspaceAppStatusStateFailure,
shouldBump: true,
},
{
name: "IdleToIdleNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateIdle),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "IdleToWorkingBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateIdle),
newState: database.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var prevAppStatus database.WorkspaceAppStatus
// If there's a previous state, report it first.
if tt.prevState != nil {
prevAppStatus.ID = uuid.UUID{1}
prevAppStatus.State = *tt.prevState
}
didBump := shouldBump(tt.newState, prevAppStatus)
if tt.shouldBump {
require.True(t, didBump, "wanted deadline to bump but it didn't")
} else {
require.False(t, didBump, "wanted deadline not to bump but it did")
}
})
}
}
-188
View File
@@ -2,13 +2,9 @@ package agentapi_test
import (
"context"
"database/sql"
"net/http"
"strings"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
@@ -16,12 +12,8 @@ import (
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestBatchUpdateAppHealths(t *testing.T) {
@@ -261,183 +253,3 @@ func TestBatchUpdateAppHealths(t *testing.T) {
require.Nil(t, resp)
})
}
func TestWorkspaceAgentAppStatus(t *testing.T) {
t.Parallel()
t.Run("Success", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
fEnq := &notificationstest.FakeEnqueuer{}
mClock := quartz.NewMock(t)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
workspaceUpdates := make(chan wspubsub.WorkspaceEventKind, 100)
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(_ context.Context, agnt *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
assert.Equal(t, *agnt, agent)
testutil.AssertSend(ctx, t, workspaceUpdates, kind)
return nil
},
NotificationsEnqueuer: fEnq,
Clock: mClock,
}
app := database.WorkspaceApp{
ID: uuid.UUID{8},
}
mDB.EXPECT().GetWorkspaceAppByAgentIDAndSlug(gomock.Any(), database.GetWorkspaceAppByAgentIDAndSlugParams{
AgentID: agent.ID,
Slug: "vscode",
}).Times(1).Return(app, nil)
task := database.Task{
ID: uuid.UUID{7},
WorkspaceAppID: uuid.NullUUID{
Valid: true,
UUID: app.ID,
},
}
mDB.EXPECT().GetTaskByID(gomock.Any(), task.ID).Times(1).Return(task, nil)
workspace := database.Workspace{
ID: uuid.UUID{9},
TaskID: uuid.NullUUID{
Valid: true,
UUID: task.ID,
},
}
mDB.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Times(1).Return(workspace, nil)
appStatus := database.WorkspaceAppStatus{
ID: uuid.UUID{6},
}
mDB.EXPECT().GetLatestWorkspaceAppStatusByAppID(gomock.Any(), app.ID).Times(1).Return(appStatus, nil)
mDB.EXPECT().InsertWorkspaceAppStatus(
gomock.Any(),
gomock.Cond(func(params database.InsertWorkspaceAppStatusParams) bool {
if params.AgentID == agent.ID && params.AppID == app.ID {
assert.Equal(t, "testing", params.Message)
assert.Equal(t, database.WorkspaceAppStatusStateComplete, params.State)
assert.True(t, params.Uri.Valid)
assert.Equal(t, "https://example.com", params.Uri.String)
return true
}
return false
})).Times(1).Return(database.WorkspaceAppStatus{}, nil)
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "vscode",
Message: "testing",
Uri: "https://example.com",
State: agentproto.UpdateAppStatusRequest_COMPLETE,
})
require.NoError(t, err)
kind := testutil.RequireReceive(ctx, t, workspaceUpdates)
require.Equal(t, wspubsub.WorkspaceEventKindAgentAppStatusUpdate, kind)
sent := fEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTaskCompleted))
require.Len(t, sent, 1)
})
t.Run("FailUnknownApp", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
mDB.EXPECT().GetWorkspaceAppByAgentIDAndSlug(gomock.Any(), gomock.Any()).
Times(1).
Return(database.WorkspaceApp{}, sql.ErrNoRows)
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "unknown",
Message: "testing",
Uri: "https://example.com",
State: agentproto.UpdateAppStatusRequest_COMPLETE,
})
require.ErrorContains(t, err, "No app found with slug")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("FailUnknownState", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "vscode",
Message: "testing",
Uri: "https://example.com",
State: 77,
})
require.ErrorContains(t, err, "Invalid state")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("FailTooLong", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "vscode",
Message: strings.Repeat("a", 161),
Uri: "https://example.com",
State: agentproto.UpdateAppStatusRequest_COMPLETE,
})
require.ErrorContains(t, err, "Message is too long")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
}
+1 -1
View File
@@ -128,7 +128,7 @@ func (a *SubAgentAPI) CreateSubAgent(ctx context.Context, req *agentproto.Create
Name: agentName,
ResourceID: parentAgent.ResourceID,
AuthToken: uuid.New(),
AuthInstanceID: sql.NullString{},
AuthInstanceID: parentAgent.AuthInstanceID,
Architecture: req.Architecture,
EnvironmentVariables: pqtype.NullRawMessage{},
OperatingSystem: req.OperatingSystem,
+1 -46
View File
@@ -175,52 +175,6 @@ func TestSubAgentAPI(t *testing.T) {
}
})
// Context: https://github.com/coder/coder/pull/22196
t.Run("CreateSubAgentDoesNotInheritAuthInstanceID", func(t *testing.T) {
t.Parallel()
var (
log = testutil.Logger(t)
clock = quartz.NewMock(t)
db, org = newDatabaseWithOrg(t)
user, agent = newUserWithWorkspaceAgent(t, db, org)
)
// Given: The parent agent has an AuthInstanceID set
ctx := testutil.Context(t, testutil.WaitShort)
parentAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agent.ID)
require.NoError(t, err)
require.True(t, parentAgent.AuthInstanceID.Valid, "parent agent should have an AuthInstanceID")
require.NotEmpty(t, parentAgent.AuthInstanceID.String)
api := newAgentAPI(t, log, db, clock, user, org, agent)
// When: We create a sub agent
createResp, err := api.CreateSubAgent(ctx, &proto.CreateSubAgentRequest{
Name: "sub-agent",
Directory: "/workspaces/test",
Architecture: "amd64",
OperatingSystem: "linux",
})
require.NoError(t, err)
subAgentID, err := uuid.FromBytes(createResp.Agent.Id)
require.NoError(t, err)
// Then: The sub-agent must NOT re-use the parent's AuthInstanceID.
subAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), subAgentID)
require.NoError(t, err)
assert.False(t, subAgent.AuthInstanceID.Valid, "sub-agent should not have an AuthInstanceID")
assert.Empty(t, subAgent.AuthInstanceID.String, "sub-agent AuthInstanceID string should be empty")
// Double-check: looking up by the parent's instance ID must
// still return the parent, not the sub-agent.
lookedUp, err := db.GetWorkspaceAgentByInstanceID(dbauthz.AsSystemRestricted(ctx), parentAgent.AuthInstanceID.String)
require.NoError(t, err)
assert.Equal(t, parentAgent.ID, lookedUp.ID, "instance ID lookup should still return the parent agent")
})
type expectedAppError struct {
index int32
field string
@@ -1366,6 +1320,7 @@ func TestSubAgentAPI(t *testing.T) {
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
+3 -37
View File
@@ -21,12 +21,10 @@ import (
agentapisdk "github.com/coder/agentapi-sdk-go"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpapi/httperror"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/searchquery"
@@ -192,8 +190,7 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
})
defer commitAuditWS()
workspace, err := createWorkspace(ctx, aReqWS, apiKey.UserID, api, owner, createReq, &createWorkspaceOptions{
remoteAddr: r.RemoteAddr,
workspace, err := createWorkspace(ctx, aReqWS, apiKey.UserID, api, owner, createReq, r, &createWorkspaceOptions{
// Before creating the workspace, ensure that this task can be created.
preCreateInTX: func(ctx context.Context, tx database.Store) error {
// Create task record in the database before creating the workspace so that
@@ -467,6 +464,7 @@ func (api *API) convertTasks(ctx context.Context, requesterID uuid.UUID, dbTasks
apiWorkspaces, err := convertWorkspaces(
ctx,
api.Experiments,
api.Logger,
requesterID,
workspaces,
@@ -546,6 +544,7 @@ func (api *API) taskGet(rw http.ResponseWriter, r *http.Request) {
ws, err := convertWorkspace(
ctx,
api.Experiments,
api.Logger,
apiKey.UserID,
workspace,
@@ -1301,23 +1300,6 @@ func (api *API) pauseTask(rw http.ResponseWriter, r *http.Request) {
return
}
if _, err := api.NotificationsEnqueuer.Enqueue(
// nolint:gocritic // Need notifier actor to enqueue notifications.
dbauthz.AsNotifier(ctx),
workspace.OwnerID,
notifications.TemplateTaskPaused,
map[string]string{
"task": task.Name,
"task_id": task.ID.String(),
"workspace": workspace.Name,
"pause_reason": "manual",
},
"api-task-pause",
workspace.ID, workspace.OwnerID, workspace.OrganizationID,
); err != nil {
api.Logger.Warn(ctx, "failed to notify of task paused", slog.Error(err), slog.F("task_id", task.ID), slog.F("workspace_id", workspace.ID))
}
httpapi.Write(ctx, rw, http.StatusAccepted, codersdk.PauseTaskResponse{
WorkspaceBuild: &build,
})
@@ -1405,22 +1387,6 @@ func (api *API) resumeTask(rw http.ResponseWriter, r *http.Request) {
httperror.WriteWorkspaceBuildError(ctx, rw, err)
return
}
if _, err := api.NotificationsEnqueuer.Enqueue(
// nolint:gocritic // Need notifier actor to enqueue notifications.
dbauthz.AsNotifier(ctx),
workspace.OwnerID,
notifications.TemplateTaskResumed,
map[string]string{
"task": task.Name,
"task_id": task.ID.String(),
"workspace": workspace.Name,
},
"api-task-resume",
workspace.ID, workspace.OwnerID, workspace.OrganizationID,
); err != nil {
api.Logger.Warn(ctx, "failed to notify of task resumed", slog.Error(err), slog.F("task_id", task.ID), slog.F("workspace_id", workspace.ID))
}
httpapi.Write(ctx, rw, http.StatusAccepted, codersdk.ResumeTaskResponse{
WorkspaceBuild: &build,
})
+39 -111
View File
@@ -45,10 +45,10 @@ import (
)
// createTaskInState is a helper to create a task in the desired state.
// It returns a function that takes context, test, and status, and returns the task.
// It returns a function that takes context, test, and status, and returns the task ID.
// The caller is responsible for setting up the database, owner, and user.
func createTaskInState(db database.Store, ownerSubject rbac.Subject, ownerOrgID, userID uuid.UUID) func(context.Context, *testing.T, database.TaskStatus) database.Task {
return func(ctx context.Context, t *testing.T, status database.TaskStatus) database.Task {
func createTaskInState(db database.Store, ownerSubject rbac.Subject, ownerOrgID, userID uuid.UUID) func(context.Context, *testing.T, database.TaskStatus) uuid.UUID {
return func(ctx context.Context, t *testing.T, status database.TaskStatus) uuid.UUID {
ctx = dbauthz.As(ctx, ownerSubject)
builder := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -65,9 +65,6 @@ func createTaskInState(db database.Store, ownerSubject rbac.Subject, ownerOrgID,
builder = builder.Pending()
case database.TaskStatusInitializing:
builder = builder.Starting()
case database.TaskStatusActive:
// Default builder produces a succeeded start build.
// Post-processing below sets agent and app to active.
case database.TaskStatusPaused:
builder = builder.Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStop,
@@ -79,32 +76,31 @@ func createTaskInState(db database.Store, ownerSubject rbac.Subject, ownerOrgID,
}
resp := builder.Do()
taskID := resp.Task.ID
// Post-process by manipulating agent and app state.
if status == database.TaskStatusActive || status == database.TaskStatusError {
// Set agent to ready state so agent_status returns 'active'.
if status == database.TaskStatusError {
// First, set agent to ready state so agent_status returns 'active'.
// This ensures the cascade reaches app_status.
err := db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: resp.Agents[0].ID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
})
require.NoError(t, err)
// Then set workspace app health to unhealthy to trigger error state.
apps, err := db.GetWorkspaceAppsByAgentID(ctx, resp.Agents[0].ID)
require.NoError(t, err)
require.Len(t, apps, 1, "expected exactly one app for task")
appHealth := database.WorkspaceAppHealthHealthy
if status == database.TaskStatusError {
appHealth = database.WorkspaceAppHealthUnhealthy
}
err = db.UpdateWorkspaceAppHealthByID(ctx, database.UpdateWorkspaceAppHealthByIDParams{
ID: apps[0].ID,
Health: appHealth,
Health: database.WorkspaceAppHealthUnhealthy,
})
require.NoError(t, err)
}
return resp.Task
return taskID
}
}
@@ -832,7 +828,7 @@ func TestTasks(t *testing.T) {
t.Run("SendToNonActiveStates", func(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{})
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitMedium)
@@ -849,9 +845,9 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusPaused)
taskID := createTask(ctx, t, database.TaskStatusPaused)
err := client.TaskSend(ctx, "me", task.ID, codersdk.TaskSendRequest{
err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
Input: "Hello",
})
@@ -867,9 +863,9 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusInitializing)
taskID := createTask(ctx, t, database.TaskStatusInitializing)
err := client.TaskSend(ctx, "me", task.ID, codersdk.TaskSendRequest{
err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
Input: "Hello",
})
@@ -885,9 +881,9 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusPending)
taskID := createTask(ctx, t, database.TaskStatusPending)
err := client.TaskSend(ctx, "me", task.ID, codersdk.TaskSendRequest{
err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
Input: "Hello",
})
@@ -903,9 +899,9 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusError)
taskID := createTask(ctx, t, database.TaskStatusError)
err := client.TaskSend(ctx, "me", task.ID, codersdk.TaskSendRequest{
err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
Input: "Hello",
})
@@ -1124,16 +1120,16 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusPending)
taskID := createTask(ctx, t, database.TaskStatusPending)
err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: task.ID,
TaskID: taskID,
LogSnapshot: json.RawMessage(snapshotJSON),
LogSnapshotCreatedAt: snapshotTime,
})
require.NoError(t, err, "upserting task snapshot")
logsResp, err := client.TaskLogs(ctx, "me", task.ID)
logsResp, err := client.TaskLogs(ctx, "me", taskID)
require.NoError(t, err, "fetching task logs")
verifySnapshotLogs(t, logsResp)
})
@@ -1142,16 +1138,16 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusInitializing)
taskID := createTask(ctx, t, database.TaskStatusInitializing)
err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: task.ID,
TaskID: taskID,
LogSnapshot: json.RawMessage(snapshotJSON),
LogSnapshotCreatedAt: snapshotTime,
})
require.NoError(t, err, "upserting task snapshot")
logsResp, err := client.TaskLogs(ctx, "me", task.ID)
logsResp, err := client.TaskLogs(ctx, "me", taskID)
require.NoError(t, err, "fetching task logs")
verifySnapshotLogs(t, logsResp)
})
@@ -1160,16 +1156,16 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusPaused)
taskID := createTask(ctx, t, database.TaskStatusPaused)
err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: task.ID,
TaskID: taskID,
LogSnapshot: json.RawMessage(snapshotJSON),
LogSnapshotCreatedAt: snapshotTime,
})
require.NoError(t, err, "upserting task snapshot")
logsResp, err := client.TaskLogs(ctx, "me", task.ID)
logsResp, err := client.TaskLogs(ctx, "me", taskID)
require.NoError(t, err, "fetching task logs")
verifySnapshotLogs(t, logsResp)
})
@@ -1178,9 +1174,9 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusPending)
taskID := createTask(ctx, t, database.TaskStatusPending)
logsResp, err := client.TaskLogs(ctx, "me", task.ID)
logsResp, err := client.TaskLogs(ctx, "me", taskID)
require.NoError(t, err)
assert.True(t, logsResp.Snapshot)
@@ -1192,7 +1188,7 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusPending)
taskID := createTask(ctx, t, database.TaskStatusPending)
invalidEnvelope := coderd.TaskLogSnapshotEnvelope{
Format: "unknown-format",
@@ -1202,13 +1198,13 @@ func TestTasks(t *testing.T) {
require.NoError(t, err)
err = db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: task.ID,
TaskID: taskID,
LogSnapshot: json.RawMessage(invalidJSON),
LogSnapshotCreatedAt: snapshotTime,
})
require.NoError(t, err)
_, err = client.TaskLogs(ctx, "me", task.ID)
_, err = client.TaskLogs(ctx, "me", taskID)
require.Error(t, err)
var sdkErr *codersdk.Error
@@ -1221,16 +1217,16 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusPending)
taskID := createTask(ctx, t, database.TaskStatusPending)
err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: task.ID,
TaskID: taskID,
LogSnapshot: json.RawMessage(`{"format":"agentapi","data":"not an object"}`),
LogSnapshotCreatedAt: snapshotTime,
})
require.NoError(t, err)
_, err = client.TaskLogs(ctx, "me", task.ID)
_, err = client.TaskLogs(ctx, "me", taskID)
require.Error(t, err)
var sdkErr *codersdk.Error
@@ -1242,9 +1238,9 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusError)
taskID := createTask(ctx, t, database.TaskStatusError)
_, err := client.TaskLogs(ctx, "me", task.ID)
_, err := client.TaskLogs(ctx, "me", taskID)
require.Error(t, err)
var sdkErr *codersdk.Error
@@ -2567,6 +2563,7 @@ func TestPauseTask(t *testing.T) {
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
task, _ := setupWorkspaceTask(t, db, owner)
userClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID, tc.roles...)
@@ -2790,41 +2787,6 @@ func TestPauseTask(t *testing.T) {
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusInternalServerError, apiErr.StatusCode())
})
t.Run("Notification", func(t *testing.T) {
t.Parallel()
var (
notifyEnq = &notificationstest.FakeEnqueuer{}
ownerClient, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{NotificationsEnqueuer: notifyEnq})
owner = coderdtest.CreateFirstUser(t, ownerClient)
)
ctx := testutil.Context(t, testutil.WaitMedium)
ownerUser, err := ownerClient.User(ctx, owner.UserID.String())
require.NoError(t, err)
createTask := createTaskInState(db, coderdtest.AuthzUserSubject(ownerUser), owner.OrganizationID, owner.UserID)
// Given: A task in an active state
task := createTask(ctx, t, database.TaskStatusActive)
workspace, err := ownerClient.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
// When: We pause the task
_, err = ownerClient.PauseTask(ctx, codersdk.Me, task.ID)
require.NoError(t, err)
// Then: A notification should be sent
sent := notifyEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTaskPaused))
require.Len(t, sent, 1)
require.Equal(t, owner.UserID, sent[0].UserID)
require.Equal(t, task.Name, sent[0].Labels["task"])
require.Equal(t, task.ID.String(), sent[0].Labels["task_id"])
require.Equal(t, workspace.Name, sent[0].Labels["workspace"])
require.Equal(t, "manual", sent[0].Labels["pause_reason"])
})
}
func TestResumeTask(t *testing.T) {
@@ -3154,38 +3116,4 @@ func TestResumeTask(t *testing.T) {
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusInternalServerError, apiErr.StatusCode())
})
t.Run("Notification", func(t *testing.T) {
t.Parallel()
var (
notifyEnq = &notificationstest.FakeEnqueuer{}
ownerClient, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{NotificationsEnqueuer: notifyEnq})
owner = coderdtest.CreateFirstUser(t, ownerClient)
)
ctx := testutil.Context(t, testutil.WaitMedium)
ownerUser, err := ownerClient.User(ctx, owner.UserID.String())
require.NoError(t, err)
createTask := createTaskInState(db, coderdtest.AuthzUserSubject(ownerUser), owner.OrganizationID, owner.UserID)
// Given: A task in a paused state
task := createTask(ctx, t, database.TaskStatusPaused)
workspace, err := ownerClient.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
// When: We resume the task
_, err = ownerClient.ResumeTask(ctx, codersdk.Me, task.ID)
require.NoError(t, err)
// Then: A notification should be sent
sent := notifyEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTaskResumed))
require.Len(t, sent, 1)
require.Equal(t, owner.UserID, sent[0].UserID)
require.Equal(t, task.Name, sent[0].Labels["task"])
require.Equal(t, task.ID.String(), sent[0].Labels["task_id"])
require.Equal(t, workspace.Name, sent[0].Labels["workspace"])
})
}
+21 -1069
View File
File diff suppressed because it is too large Load Diff
+21 -993
View File
File diff suppressed because it is too large Load Diff
+8 -77
View File
@@ -307,26 +307,20 @@ func (api *API) apiKeyByName(rw http.ResponseWriter, r *http.Request) {
// @Tags Users
// @Param user path string true "User ID, name, or me"
// @Success 200 {array} codersdk.APIKey
// @Param include_expired query bool false "Include expired tokens in the list"
// @Router /users/{user}/keys/tokens [get]
func (api *API) tokens(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
user = httpmw.UserParam(r)
keys []database.APIKey
err error
queryStr = r.URL.Query().Get("include_all")
includeAll, _ = strconv.ParseBool(queryStr)
expiredStr = r.URL.Query().Get("include_expired")
includeExpired, _ = strconv.ParseBool(expiredStr)
ctx = r.Context()
user = httpmw.UserParam(r)
keys []database.APIKey
err error
queryStr = r.URL.Query().Get("include_all")
includeAll, _ = strconv.ParseBool(queryStr)
)
if includeAll {
// get tokens for all users
keys, err = api.Database.GetAPIKeysByLoginType(ctx, database.GetAPIKeysByLoginTypeParams{
LoginType: database.LoginTypeToken,
IncludeExpired: includeExpired,
})
keys, err = api.Database.GetAPIKeysByLoginType(ctx, database.LoginTypeToken)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching API keys.",
@@ -336,7 +330,7 @@ func (api *API) tokens(rw http.ResponseWriter, r *http.Request) {
}
} else {
// get user's tokens only
keys, err = api.Database.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: user.ID, IncludeExpired: includeExpired})
keys, err = api.Database.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: user.ID})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching API keys.",
@@ -427,69 +421,6 @@ func (api *API) deleteAPIKey(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusNoContent)
}
// @Summary Expire API key
// @ID expire-api-key
// @Security CoderSessionToken
// @Tags Users
// @Param user path string true "User ID, name, or me"
// @Param keyid path string true "Key ID" format(string)
// @Success 204
// @Failure 404 {object} codersdk.Response
// @Failure 500 {object} codersdk.Response
// @Router /users/{user}/keys/{keyid}/expire [put]
func (api *API) expireAPIKey(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
keyID = chi.URLParam(r, "keyid")
auditor = api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.APIKey](rw, &audit.RequestParams{
Audit: *auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()
if err := api.Database.InTx(func(db database.Store) error {
key, err := db.GetAPIKeyByID(ctx, keyID)
if err != nil {
return xerrors.Errorf("fetch API key: %w", err)
}
if !key.ExpiresAt.After(api.Clock.Now()) {
return nil // Already expired
}
aReq.Old = key
if err := db.UpdateAPIKeyByID(ctx, database.UpdateAPIKeyByIDParams{
ID: key.ID,
LastUsed: key.LastUsed,
ExpiresAt: dbtime.Now(),
IPAddress: key.IPAddress,
}); err != nil {
return xerrors.Errorf("expire API key: %w", err)
}
// Fetch the updated key for audit log.
newKey, err := db.GetAPIKeyByID(ctx, keyID)
if err != nil {
api.Logger.Warn(ctx, "failed to fetch updated API key for audit log", slog.Error(err))
} else {
aReq.New = newKey
}
return nil
}, nil); httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
} else if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error expiring API key.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
// @Summary Get token config
// @ID get-token-config
// @Security CoderSessionToken
+3 -196
View File
@@ -69,44 +69,6 @@ func TestTokenCRUD(t *testing.T) {
require.Equal(t, database.AuditActionDelete, auditor.AuditLogs()[numLogs-1].Action)
}
func TestTokensFilterExpired(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
adminClient := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, adminClient)
// Create a token.
res, err := adminClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
Lifetime: time.Hour * 24 * 7,
})
require.NoError(t, err)
keyID := strings.Split(res.Key, "-")[0]
// List tokens without including expired - should see the token.
keys, err := adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
require.NoError(t, err)
require.Len(t, keys, 1)
// Expire the token.
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
require.NoError(t, err)
// List tokens without including expired - should NOT see expired token.
keys, err = adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
require.NoError(t, err)
require.Empty(t, keys)
// List tokens WITH including expired - should see expired token.
keys, err = adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{
IncludeExpired: true,
})
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, keyID, keys[0].ID)
}
func TestTokenScoped(t *testing.T) {
t.Parallel()
@@ -477,7 +439,7 @@ func TestAPIKey_PrebuildsNotAllowed(t *testing.T) {
DeploymentValues: dc,
})
setupCtx := testutil.Context(t, testutil.WaitLong)
ctx := testutil.Context(t, testutil.WaitLong)
// Given: an existing api token for the prebuilds user
_, prebuildsToken := dbgen.APIKey(t, db, database.APIKey{
@@ -486,167 +448,12 @@ func TestAPIKey_PrebuildsNotAllowed(t *testing.T) {
client.SetSessionToken(prebuildsToken)
// When: the prebuilds user tries to create an API key
_, err := client.CreateAPIKey(setupCtx, database.PrebuildsSystemUserID.String())
_, err := client.CreateAPIKey(ctx, database.PrebuildsSystemUserID.String())
// Then: denied.
require.ErrorContains(t, err, httpapi.ResourceForbiddenResponse.Message)
// When: the prebuilds user tries to create a token
_, err = client.CreateToken(setupCtx, database.PrebuildsSystemUserID.String(), codersdk.CreateTokenRequest{})
_, err = client.CreateToken(ctx, database.PrebuildsSystemUserID.String(), codersdk.CreateTokenRequest{})
// Then: also denied.
require.ErrorContains(t, err, httpapi.ResourceForbiddenResponse.Message)
}
//nolint:tparallel,paralleltest // Subtests share the same coderdtest instance and auditor.
func TestExpireAPIKey(t *testing.T) {
t.Parallel()
auditor := audit.NewMock()
adminClient := coderdtest.New(t, &coderdtest.Options{Auditor: auditor})
admin := coderdtest.CreateFirstUser(t, adminClient)
memberClient, member := coderdtest.CreateAnotherUser(t, adminClient, admin.OrganizationID)
t.Run("OwnerCanExpireOwnToken", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
// Create a token.
res, err := adminClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
Lifetime: time.Hour * 24 * 7,
})
require.NoError(t, err)
keyID := strings.Split(res.Key, "-")[0]
// Verify the token is not expired.
key, err := adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
require.NoError(t, err)
require.True(t, key.ExpiresAt.After(time.Now()))
auditor.ResetLogs()
// Expire the token.
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
require.NoError(t, err)
// Verify the token is expired.
key, err = adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
require.NoError(t, err)
require.True(t, key.ExpiresAt.Before(time.Now()))
// Verify audit log.
als := auditor.AuditLogs()
require.Len(t, als, 1)
require.Equal(t, database.AuditActionWrite, als[0].Action)
require.Equal(t, database.ResourceTypeApiKey, als[0].ResourceType)
require.Equal(t, admin.UserID.String(), als[0].UserID.String())
})
t.Run("AdminCanExpireOtherUsersToken", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
// Create a token for the member.
res, err := memberClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
Lifetime: time.Hour * 24 * 7,
})
require.NoError(t, err)
keyID := strings.Split(res.Key, "-")[0]
// Admin expires the member's token.
err = adminClient.ExpireAPIKey(ctx, member.ID.String(), keyID)
require.NoError(t, err)
// Verify the token is expired.
key, err := memberClient.APIKeyByID(ctx, codersdk.Me, keyID)
require.NoError(t, err)
require.True(t, key.ExpiresAt.Before(time.Now()))
})
t.Run("MemberCannotExpireOtherUsersToken", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
// Create a token for the admin.
res, err := adminClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
Lifetime: time.Hour * 24 * 7,
})
require.NoError(t, err)
keyID := strings.Split(res.Key, "-")[0]
// Member attempts to expire admin's token.
err = memberClient.ExpireAPIKey(ctx, admin.UserID.String(), keyID)
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
// Members cannot read other users, so they get a 404 Not Found
// from the authorization layer.
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
})
t.Run("NotFound", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
// Try to expire a non-existent token.
err := adminClient.ExpireAPIKey(ctx, codersdk.Me, "nonexistent")
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
})
t.Run("ExpiringAlreadyExpiredTokenSucceeds", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
// Create and expire a token.
res, err := adminClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
Lifetime: time.Hour * 24 * 7,
})
require.NoError(t, err)
keyID := strings.Split(res.Key, "-")[0]
// Expire it once.
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
require.NoError(t, err)
// Invariant: make sure it's actually expired
key, err := adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
require.NoError(t, err)
require.LessOrEqual(t, key.ExpiresAt, time.Now(), "key should be expired")
// Expire it again - should succeed (idempotent).
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
require.NoError(t, err)
// Token should still be just as expired as before. No more, no less.
keyAgain, err := adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
require.NoError(t, err)
require.Equal(t, key.ExpiresAt, keyAgain.ExpiresAt, "expiration should be idempotent")
})
t.Run("DeletingExpiredTokenSucceeds", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
// Create a token.
res, err := adminClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
Lifetime: time.Hour * 24 * 7,
})
require.NoError(t, err)
keyID := strings.Split(res.Key, "-")[0]
// Expire it first.
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
require.NoError(t, err)
// Verify it's expired.
key, err := adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
require.NoError(t, err)
require.True(t, key.ExpiresAt.Before(time.Now()))
// Delete the expired token - should succeed.
err = adminClient.DeleteAPIKey(ctx, codersdk.Me, keyID)
require.NoError(t, err)
// Verify it's gone.
_, err = adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
})
}
-48
View File
@@ -1,7 +1,6 @@
package coderd
import (
"context"
"fmt"
"net/http"
@@ -9,7 +8,6 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/rbac"
@@ -93,36 +91,6 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action policy.Action, object
return true
}
// AuthorizeContext checks whether the RBAC subject on the context
// is authorized to perform the given action. The subject must have
// been set via dbauthz.As or the ExtractAPIKey middleware. Returns
// false if the subject is missing or unauthorized.
func (h *HTTPAuthorizer) AuthorizeContext(ctx context.Context, action policy.Action, object rbac.Objecter) bool {
roles, ok := dbauthz.ActorFromContext(ctx)
if !ok {
h.Logger.Error(ctx, "no authorization actor in context")
return false
}
err := h.Authorizer.Authorize(ctx, roles, action, object.RBACObject())
if err != nil {
internalError := new(rbac.UnauthorizedError)
logger := h.Logger
if xerrors.As(err, internalError) {
logger = h.Logger.With(slog.F("internal_error", internalError.Internal()))
}
logger.Warn(ctx, "requester is not authorized to access the object",
slog.F("roles", roles.SafeRoleNames()),
slog.F("actor_id", roles.ID),
slog.F("actor_name", roles),
slog.F("scope", roles.SafeScopeName()),
slog.F("action", action),
slog.F("object", object),
)
return false
}
return true
}
// AuthorizeSQLFilter returns an authorization filter that can used in a
// SQL 'WHERE' clause. If the filter is used, the resulting rows returned
// from postgres are already authorized, and the caller does not need to
@@ -138,22 +106,6 @@ func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action policy.Actio
return prepared, nil
}
// AuthorizeSQLFilterContext is like AuthorizeSQLFilter but reads the
// RBAC subject from the context directly rather than from an
// *http.Request. The subject must have been set via dbauthz.As.
func (h *HTTPAuthorizer) AuthorizeSQLFilterContext(ctx context.Context, action policy.Action, objectType string) (rbac.PreparedAuthorized, error) {
roles, ok := dbauthz.ActorFromContext(ctx)
if !ok {
return nil, xerrors.New("no authorization actor in context")
}
prepared, err := h.Authorizer.Prepare(ctx, roles, action, objectType)
if err != nil {
return nil, xerrors.Errorf("prepare filter: %w", err)
}
return prepared, nil
}
// checkAuthorization returns if the current API key can use the given
// permissions, factoring in the current user's roles and the API key scopes.
//
-35
View File
@@ -231,7 +231,6 @@ func (e *Executor) runOnce(t time.Time) Stats {
job *database.ProvisionerJob
auditLog *auditParams
shouldNotifyDormancy bool
shouldNotifyTaskPause bool
nextBuild *database.WorkspaceBuild
activeTemplateVersion database.TemplateVersion
ws database.Workspace
@@ -317,10 +316,6 @@ func (e *Executor) runOnce(t time.Time) Stats {
return nil
}
if reason == database.BuildReasonTaskAutoPause {
shouldNotifyTaskPause = true
}
// Get the template version job to access tags
templateVersionJob, err := tx.GetProvisionerJobByID(e.ctx, activeTemplateVersion.JobID)
if err != nil {
@@ -487,28 +482,6 @@ func (e *Executor) runOnce(t time.Time) Stats {
log.Warn(e.ctx, "failed to notify of workspace marked as dormant", slog.Error(err), slog.F("workspace_id", ws.ID))
}
}
if shouldNotifyTaskPause {
task, err := e.db.GetTaskByID(e.ctx, ws.TaskID.UUID)
if err != nil {
log.Warn(e.ctx, "failed to get task for pause notification", slog.Error(err), slog.F("task_id", ws.TaskID.UUID), slog.F("workspace_id", ws.ID))
} else {
if _, err := e.notificationsEnqueuer.Enqueue(
e.ctx,
ws.OwnerID,
notifications.TemplateTaskPaused,
map[string]string{
"task": task.Name,
"task_id": task.ID.String(),
"workspace": ws.Name,
"pause_reason": "idle timeout",
},
"lifecycle_executor",
ws.ID, ws.OwnerID, ws.OrganizationID,
); err != nil {
log.Warn(e.ctx, "failed to notify of task paused", slog.Error(err), slog.F("task_id", ws.TaskID.UUID), slog.F("workspace_id", ws.ID))
}
}
}
return nil
}()
if err != nil && !xerrors.Is(err, context.Canceled) {
@@ -552,18 +525,10 @@ func getNextTransition(
) {
switch {
case isEligibleForAutostop(user, ws, latestBuild, latestJob, currentTick):
// Use task-specific reason for AI task workspaces.
if ws.TaskID.Valid {
return database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil
}
return database.WorkspaceTransitionStop, database.BuildReasonAutostop, nil
case isEligibleForAutostart(user, ws, latestBuild, latestJob, templateSchedule, currentTick):
return database.WorkspaceTransitionStart, database.BuildReasonAutostart, nil
case isEligibleForFailedStop(latestBuild, latestJob, templateSchedule, currentTick):
// Use task-specific reason for AI task workspaces.
if ws.TaskID.Valid {
return database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil
}
return database.WorkspaceTransitionStop, database.BuildReasonAutostop, nil
case isEligibleForDormantStop(ws, templateSchedule, currentTick):
// Only stop started workspaces.
@@ -5,113 +5,12 @@ import (
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/schedule"
)
func Test_getNextTransition_TaskAutoPause(t *testing.T) {
t.Parallel()
// Set up a workspace that is eligible for autostop (past deadline).
now := time.Now()
pastDeadline := now.Add(-time.Hour)
okUser := database.User{Status: database.UserStatusActive}
okBuild := database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStart,
Deadline: pastDeadline,
}
okJob := database.ProvisionerJob{
JobStatus: database.ProvisionerJobStatusSucceeded,
}
okTemplateSchedule := schedule.TemplateScheduleOptions{}
// Failed build setup for failedstop tests.
failedBuild := database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStart,
}
failedJob := database.ProvisionerJob{
JobStatus: database.ProvisionerJobStatusFailed,
CompletedAt: sql.NullTime{Time: now.Add(-time.Hour), Valid: true},
}
failedTemplateSchedule := schedule.TemplateScheduleOptions{
FailureTTL: time.Minute, // TTL already elapsed since job completed an hour ago.
}
testCases := []struct {
Name string
Workspace database.Workspace
Build database.WorkspaceBuild
Job database.ProvisionerJob
TemplateSchedule schedule.TemplateScheduleOptions
ExpectedReason database.BuildReason
}{
{
Name: "RegularWorkspace_Autostop",
Workspace: database.Workspace{
DormantAt: sql.NullTime{Valid: false},
},
Build: okBuild,
Job: okJob,
TemplateSchedule: okTemplateSchedule,
ExpectedReason: database.BuildReasonAutostop,
},
{
Name: "TaskWorkspace_Autostop_UsesTaskAutoPause",
Workspace: database.Workspace{
DormantAt: sql.NullTime{Valid: false},
TaskID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
},
Build: okBuild,
Job: okJob,
TemplateSchedule: okTemplateSchedule,
ExpectedReason: database.BuildReasonTaskAutoPause,
},
{
Name: "RegularWorkspace_FailedStop",
Workspace: database.Workspace{
DormantAt: sql.NullTime{Valid: false},
},
Build: failedBuild,
Job: failedJob,
TemplateSchedule: failedTemplateSchedule,
ExpectedReason: database.BuildReasonAutostop,
},
{
Name: "TaskWorkspace_FailedStop_UsesTaskAutoPause",
Workspace: database.Workspace{
DormantAt: sql.NullTime{Valid: false},
TaskID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
},
Build: failedBuild,
Job: failedJob,
TemplateSchedule: failedTemplateSchedule,
ExpectedReason: database.BuildReasonTaskAutoPause,
},
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
transition, reason, err := getNextTransition(
okUser,
tc.Workspace,
tc.Build,
tc.Job,
tc.TemplateSchedule,
now,
)
require.NoError(t, err)
require.Equal(t, database.WorkspaceTransitionStop, transition)
require.Equal(t, tc.ExpectedReason, reason)
})
}
}
func Test_isEligibleForAutostart(t *testing.T) {
t.Parallel()
@@ -2019,69 +2019,5 @@ func TestExecutorTaskWorkspace(t *testing.T) {
assert.Contains(t, stats.Transitions, workspace.ID, "task workspace should be in transitions")
assert.Equal(t, database.WorkspaceTransitionStop, stats.Transitions[workspace.ID], "should autostop the workspace")
require.Empty(t, stats.Errors, "should have no errors when managing task workspaces")
// Then: The build reason should be TaskAutoPause (not regular Autostop)
workspace = coderdtest.MustWorkspace(t, client, workspace.ID)
_ = coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
workspace = coderdtest.MustWorkspace(t, client, workspace.ID)
assert.Equal(t, codersdk.BuildReasonTaskAutoPause, workspace.LatestBuild.Reason, "task workspace should use TaskAutoPause build reason")
})
t.Run("AutostopNotification", func(t *testing.T) {
t.Parallel()
var (
tickCh = make(chan time.Time)
statsCh = make(chan autobuild.Stats)
notifyEnq = notificationstest.FakeEnqueuer{}
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
AutobuildTicker: tickCh,
IncludeProvisionerDaemon: true,
AutobuildStats: statsCh,
NotificationsEnqueuer: &notifyEnq,
})
admin = coderdtest.CreateFirstUser(t, client)
)
// Given: A task workspace with an 8 hour deadline
ctx := testutil.Context(t, testutil.WaitShort)
template := createTaskTemplate(t, client, admin.OrganizationID, ctx, 8*time.Hour)
workspace := createTaskWorkspace(t, client, template, ctx, "test task for autostop notification")
// Given: The workspace is currently running
workspace = coderdtest.MustWorkspace(t, client, workspace.ID)
require.Equal(t, codersdk.WorkspaceTransitionStart, workspace.LatestBuild.Transition)
require.NotZero(t, workspace.LatestBuild.Deadline, "workspace should have a deadline for autostop")
p, err := coderdtest.GetProvisionerForTags(db, time.Now(), workspace.OrganizationID, map[string]string{})
require.NoError(t, err)
// When: the autobuild executor ticks after the deadline
go func() {
tickTime := workspace.LatestBuild.Deadline.Time.Add(time.Minute)
coderdtest.UpdateProvisionerLastSeenAt(t, db, p.ID, tickTime)
tickCh <- tickTime
close(tickCh)
}()
// Then: We expect to see a stop transition
stats := <-statsCh
require.Len(t, stats.Transitions, 1, "lifecycle executor should transition the task workspace")
assert.Contains(t, stats.Transitions, workspace.ID, "task workspace should be in transitions")
assert.Equal(t, database.WorkspaceTransitionStop, stats.Transitions[workspace.ID], "should autostop the workspace")
require.Empty(t, stats.Errors, "should have no errors when managing task workspaces")
// Then: A task paused notification was sent with "idle timeout" reason
require.True(t, workspace.TaskID.Valid, "workspace should have a task ID")
task, err := db.GetTaskByID(dbauthz.AsSystemRestricted(ctx), workspace.TaskID.UUID)
require.NoError(t, err)
sent := notifyEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTaskPaused))
require.Len(t, sent, 1)
require.Equal(t, workspace.OwnerID, sent[0].UserID)
require.Equal(t, task.Name, sent[0].Labels["task"])
require.Equal(t, task.ID.String(), sent[0].Labels["task_id"])
require.Equal(t, workspace.Name, sent[0].Labels["workspace"])
require.Equal(t, "idle timeout", sent[0].Labels["pause_reason"])
})
}
File diff suppressed because it is too large Load Diff
-305
View File
@@ -1,305 +0,0 @@
package chatd_test
import (
"context"
"database/sql"
"encoding/json"
"testing"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/chatd"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
func TestInterruptChatBroadcastsStatusAcrossInstances(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replicaA := newTestServer(t, db, ps, uuid.New())
replicaB := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat, err := replicaA.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "interrupt-me",
ModelConfigID: model.ID,
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
})
require.NoError(t, err)
runningWorker := uuid.New()
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: runningWorker, Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
_, events, cancel, ok := replicaB.Subscribe(ctx, chat.ID, nil)
require.True(t, ok)
t.Cleanup(cancel)
updated := replicaA.InterruptChat(ctx, chat)
require.Equal(t, database.ChatStatusWaiting, updated.Status)
require.False(t, updated.WorkerID.Valid)
require.Eventually(t, func() bool {
select {
case event := <-events:
if event.Type != codersdk.ChatStreamEventTypeStatus || event.Status == nil {
return false
}
return event.Status.Status == codersdk.ChatStatusWaiting
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
}
func TestInterruptChatClearsWorkerInDatabase(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "db-transition",
ModelConfigID: model.ID,
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
})
require.NoError(t, err)
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
updated := replica.InterruptChat(ctx, chat)
require.Equal(t, database.ChatStatusWaiting, updated.Status)
require.False(t, updated.WorkerID.Valid)
fromDB, err := db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
require.False(t, fromDB.WorkerID.Valid)
}
func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "heartbeat-ownership",
ModelConfigID: model.ID,
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
})
require.NoError(t, err)
workerID := uuid.New()
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
rows, err := db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
ID: chat.ID,
WorkerID: uuid.New(),
})
require.NoError(t, err)
require.Equal(t, int64(0), rows)
rows, err = db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
ID: chat.ID,
WorkerID: workerID,
})
require.NoError(t, err)
require.Equal(t, int64(1), rows)
}
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "queue-when-busy",
ModelConfigID: model.ID,
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
})
require.NoError(t, err)
workerID := uuid.New()
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []fantasy.Content{fantasy.TextContent{Text: "queued"}},
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
})
require.NoError(t, err)
require.True(t, result.Queued)
require.NotNil(t, result.QueuedMessage)
require.Equal(t, database.ChatStatusRunning, result.Chat.Status)
require.Equal(t, workerID, result.Chat.WorkerID.UUID)
require.True(t, result.Chat.WorkerID.Valid)
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, queued, 1)
messages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, messages, 1)
}
func TestSendMessageInterruptBehaviorSendsImmediatelyWhenBusy(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "interrupt-when-busy",
ModelConfigID: model.ID,
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
})
require.NoError(t, err)
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []fantasy.Content{fantasy.TextContent{Text: "interrupt"}},
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
})
require.NoError(t, err)
require.False(t, result.Queued)
require.Equal(t, database.ChatStatusPending, result.Chat.Status)
require.False(t, result.Chat.WorkerID.Valid)
fromDB, err := db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusPending, fromDB.Status)
require.False(t, fromDB.WorkerID.Valid)
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, queued, 0)
messages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, messages, 2)
require.Equal(t, int64(messages[len(messages)-1].ID), result.Message.ID)
}
func newTestServer(
t *testing.T,
db database.Store,
ps dbpubsub.Pubsub,
replicaID uuid.UUID,
) *chatd.Server {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: replicaID,
Pubsub: ps,
PendingChatAcquireInterval: testutil.WaitSuperLong,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
return server
}
func seedChatDependencies(
ctx context.Context,
t *testing.T,
db database.Store,
) (database.User, database.ChatModelConfig) {
t.Helper()
user := dbgen.User(t, db, database.User{})
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
BaseUrl: "",
ApiKeyKeyID: sql.NullString{},
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
})
require.NoError(t, err)
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "gpt-4o-mini",
DisplayName: "Test Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 70,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
return user, model
}
-676
View File
@@ -1,676 +0,0 @@
package chatloop
import (
"context"
"database/sql"
"encoding/json"
"errors"
"strconv"
"strings"
"sync"
"charm.land/fantasy"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/codersdk"
)
const (
interruptedToolResultErrorMessage = "tool call was interrupted before it produced a result"
)
var ErrInterrupted = xerrors.New("chat interrupted")
// PersistedStep contains the full content of a completed or
// interrupted agent step. Content includes both assistant blocks
// (text, reasoning, tool calls) and tool result blocks, mirroring
// what fantasy provides in StepResult.Content. The persistence
// layer is responsible for splitting these into separate database
// messages by role.
type PersistedStep struct {
Content []fantasy.Content
Usage fantasy.Usage
ContextLimit sql.NullInt64
}
// RunOptions configures a single streaming chat loop run.
type RunOptions struct {
Model fantasy.LanguageModel
Messages []fantasy.Message
Tools []fantasy.AgentTool
StreamCall fantasy.AgentStreamCall
MaxSteps int
ActiveTools []string
ContextLimitFallback int64
PersistStep func(context.Context, PersistedStep) error
PublishMessagePart func(
role fantasy.MessageRole,
part codersdk.ChatMessagePart,
)
Compaction *CompactionOptions
OnInterruptedPersistError func(error)
}
// Run executes the chat step-stream loop and delegates persistence/publishing to callbacks.
func Run(ctx context.Context, opts RunOptions) (*fantasy.AgentResult, error) {
if opts.Model == nil {
return nil, xerrors.New("chat model is required")
}
if opts.PersistStep == nil {
return nil, xerrors.New("persist step callback is required")
}
if opts.MaxSteps <= 0 {
opts.MaxSteps = 1
}
publishMessagePart := func(role fantasy.MessageRole, part codersdk.ChatMessagePart) {
if opts.PublishMessagePart == nil {
return
}
opts.PublishMessagePart(role, part)
}
var (
stepStateMu sync.Mutex
streamToolNames map[string]string
streamReasoningTitles map[string]string
streamReasoningText map[string]string
// stepToolResultContents tracks tool results received during
// streaming. These are needed for the interrupted-step path
// where OnStepFinish never fires.
stepToolResultContents []fantasy.ToolResultContent
stepAssistantDraft []fantasy.Content
stepToolCallIndexByID map[string]int
)
resetStepState := func() {
stepStateMu.Lock()
streamToolNames = make(map[string]string)
streamReasoningTitles = make(map[string]string)
streamReasoningText = make(map[string]string)
stepToolResultContents = nil
stepAssistantDraft = nil
stepToolCallIndexByID = make(map[string]int)
stepStateMu.Unlock()
}
setReasoningTitleFromText := func(id string, text string) {
if id == "" || strings.TrimSpace(text) == "" {
return
}
stepStateMu.Lock()
defer stepStateMu.Unlock()
if streamReasoningTitles[id] != "" {
return
}
streamReasoningText[id] += text
if !strings.ContainsAny(streamReasoningText[id], "\r\n") {
return
}
title := chatprompt.ReasoningTitleFromFirstLine(streamReasoningText[id])
if title == "" {
return
}
streamReasoningTitles[id] = title
}
appendDraftText := func(text string) {
if text == "" {
return
}
stepStateMu.Lock()
defer stepStateMu.Unlock()
if len(stepAssistantDraft) > 0 {
lastIndex := len(stepAssistantDraft) - 1
switch last := stepAssistantDraft[lastIndex].(type) {
case fantasy.TextContent:
last.Text += text
stepAssistantDraft[lastIndex] = last
return
case *fantasy.TextContent:
last.Text += text
stepAssistantDraft[lastIndex] = fantasy.TextContent{Text: last.Text}
return
}
}
stepAssistantDraft = append(stepAssistantDraft, fantasy.TextContent{Text: text})
}
appendDraftReasoning := func(text string) {
if text == "" {
return
}
stepStateMu.Lock()
defer stepStateMu.Unlock()
if len(stepAssistantDraft) > 0 {
lastIndex := len(stepAssistantDraft) - 1
switch last := stepAssistantDraft[lastIndex].(type) {
case fantasy.ReasoningContent:
last.Text += text
stepAssistantDraft[lastIndex] = last
return
case *fantasy.ReasoningContent:
last.Text += text
stepAssistantDraft[lastIndex] = fantasy.ReasoningContent{Text: last.Text}
return
}
}
stepAssistantDraft = append(stepAssistantDraft, fantasy.ReasoningContent{Text: text})
}
upsertDraftToolCall := func(toolCallID, toolName, input string, appendInput bool) {
if toolCallID == "" {
return
}
stepStateMu.Lock()
defer stepStateMu.Unlock()
if strings.TrimSpace(toolName) != "" {
streamToolNames[toolCallID] = toolName
}
index, exists := stepToolCallIndexByID[toolCallID]
if !exists {
stepToolCallIndexByID[toolCallID] = len(stepAssistantDraft)
stepAssistantDraft = append(stepAssistantDraft, fantasy.ToolCallContent{
ToolCallID: toolCallID,
ToolName: toolName,
Input: input,
})
return
}
if index < 0 || index >= len(stepAssistantDraft) {
stepToolCallIndexByID[toolCallID] = len(stepAssistantDraft)
stepAssistantDraft = append(stepAssistantDraft, fantasy.ToolCallContent{
ToolCallID: toolCallID,
ToolName: toolName,
Input: input,
})
return
}
existingCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](stepAssistantDraft[index])
if !ok {
if ptrCall, ptrOK := fantasy.AsContentType[*fantasy.ToolCallContent](stepAssistantDraft[index]); ptrOK && ptrCall != nil {
existingCall = *ptrCall
ok = true
}
}
if !ok {
stepToolCallIndexByID[toolCallID] = len(stepAssistantDraft)
stepAssistantDraft = append(stepAssistantDraft, fantasy.ToolCallContent{
ToolCallID: toolCallID,
ToolName: toolName,
Input: input,
})
return
}
if strings.TrimSpace(toolName) != "" {
existingCall.ToolName = toolName
}
if appendInput {
existingCall.Input += input
} else if input != "" || existingCall.Input == "" {
existingCall.Input = input
}
stepAssistantDraft[index] = existingCall
}
appendDraftSource := func(source fantasy.SourceContent) {
stepStateMu.Lock()
stepAssistantDraft = append(stepAssistantDraft, source)
stepStateMu.Unlock()
}
persistInterruptedStep := func() error {
stepStateMu.Lock()
draft := append([]fantasy.Content(nil), stepAssistantDraft...)
toolResults := append([]fantasy.ToolResultContent(nil), stepToolResultContents...)
toolNameByCallID := make(map[string]string, len(streamToolNames))
for id, name := range streamToolNames {
toolNameByCallID[id] = name
}
stepStateMu.Unlock()
if len(draft) == 0 && len(toolResults) == 0 {
return nil
}
// Track which tool calls already have results.
answeredToolCalls := make(map[string]struct{}, len(toolResults))
for _, tr := range toolResults {
if tr.ToolCallID != "" {
answeredToolCalls[tr.ToolCallID] = struct{}{}
}
}
// Build the combined content: draft + received tool results
// + synthetic interrupted results for unanswered tool calls.
content := make([]fantasy.Content, 0, len(draft)+len(toolResults))
content = append(content, draft...)
for _, tr := range toolResults {
content = append(content, tr)
}
for _, block := range draft {
toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](block)
if !ok {
if ptrCall, ptrOK := fantasy.AsContentType[*fantasy.ToolCallContent](block); ptrOK && ptrCall != nil {
toolCall = *ptrCall
ok = true
}
}
if !ok || toolCall.ToolCallID == "" {
continue
}
if _, exists := answeredToolCalls[toolCall.ToolCallID]; exists {
continue
}
toolName := strings.TrimSpace(toolCall.ToolName)
if toolName == "" {
toolName = strings.TrimSpace(toolNameByCallID[toolCall.ToolCallID])
}
content = append(content, fantasy.ToolResultContent{
ToolCallID: toolCall.ToolCallID,
ToolName: toolName,
Result: fantasy.ToolResultOutputContentError{
Error: xerrors.New(interruptedToolResultErrorMessage),
},
})
answeredToolCalls[toolCall.ToolCallID] = struct{}{}
}
persistCtx := context.WithoutCancel(ctx)
return opts.PersistStep(persistCtx, PersistedStep{
Content: content,
})
}
resetStepState()
agent := fantasy.NewAgent(
opts.Model,
fantasy.WithTools(opts.Tools...),
fantasy.WithStopConditions(fantasy.StepCountIs(opts.MaxSteps)),
)
applyAnthropicCaching := shouldApplyAnthropicPromptCaching(opts.Model)
// Fantasy's AgentStreamCall currently requires a non-empty Prompt and always
// appends it as a user message. chatd already supplies the full history in
// Messages, so we pass and then strip a sentinel user message in PrepareStep.
sentinelPrompt := "__chatd_agent_prompt_sentinel_" + uuid.NewString()
streamCall := opts.StreamCall
streamCall.Prompt = sentinelPrompt
streamCall.Messages = opts.Messages
streamCall.PrepareStep = func(
stepCtx context.Context,
options fantasy.PrepareStepFunctionOptions,
) (context.Context, fantasy.PrepareStepResult, error) {
return stepCtx, prepareStepResult(
options.Messages,
sentinelPrompt,
opts.ActiveTools,
applyAnthropicCaching,
), nil
}
streamCall.OnStepStart = func(_ int) error {
resetStepState()
return nil
}
streamCall.OnTextDelta = func(_ string, text string) error {
appendDraftText(text)
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: text,
})
return nil
}
streamCall.OnReasoningDelta = func(id string, text string) error {
appendDraftReasoning(text)
setReasoningTitleFromText(id, text)
stepStateMu.Lock()
title := streamReasoningTitles[id]
stepStateMu.Unlock()
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Text: text,
Title: title,
})
return nil
}
streamCall.OnReasoningEnd = func(id string, _ fantasy.ReasoningContent) error {
stepStateMu.Lock()
if streamReasoningTitles[id] == "" {
// At the end of reasoning we have the full text, so we can
// safely evaluate first-line title format even if no newline
// ever arrived in deltas.
streamReasoningTitles[id] = chatprompt.ReasoningTitleFromFirstLine(
streamReasoningText[id],
)
}
title := streamReasoningTitles[id]
stepStateMu.Unlock()
if title != "" {
// Publish a title-only reasoning part so clients can update the
// reasoning header when metadata arrives at the end of streaming.
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Title: title,
})
}
return nil
}
streamCall.OnToolInputStart = func(id, toolName string) error {
upsertDraftToolCall(id, toolName, "", false)
return nil
}
streamCall.OnToolInputDelta = func(id, delta string) error {
stepStateMu.Lock()
toolName := streamToolNames[id]
stepStateMu.Unlock()
upsertDraftToolCall(id, toolName, delta, true)
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: id,
ToolName: toolName,
ArgsDelta: delta,
})
return nil
}
streamCall.OnToolCall = func(toolCall fantasy.ToolCallContent) error {
upsertDraftToolCall(toolCall.ToolCallID, toolCall.ToolName, toolCall.Input, false)
publishMessagePart(
fantasy.MessageRoleAssistant,
chatprompt.PartFromContent(toolCall),
)
return nil
}
streamCall.OnSource = func(source fantasy.SourceContent) error {
appendDraftSource(source)
publishMessagePart(
fantasy.MessageRoleAssistant,
chatprompt.PartFromContent(source),
)
return nil
}
streamCall.OnToolResult = func(result fantasy.ToolResultContent) error {
publishMessagePart(
fantasy.MessageRoleTool,
chatprompt.PartFromContent(result),
)
stepStateMu.Lock()
if result.ToolCallID != "" && strings.TrimSpace(result.ToolName) != "" {
streamToolNames[result.ToolCallID] = result.ToolName
}
stepToolResultContents = append(stepToolResultContents, result)
stepStateMu.Unlock()
return nil
}
streamCall.OnStepFinish = func(stepResult fantasy.StepResult) error {
contextLimit := extractContextLimit(stepResult.ProviderMetadata)
if !contextLimit.Valid && opts.ContextLimitFallback > 0 {
contextLimit = sql.NullInt64{
Int64: opts.ContextLimitFallback,
Valid: true,
}
}
return opts.PersistStep(ctx, PersistedStep{
Content: stepResult.Content,
Usage: stepResult.Usage,
ContextLimit: contextLimit,
})
}
result, err := agent.Stream(ctx, streamCall)
if err != nil {
if errors.Is(err, context.Canceled) &&
errors.Is(context.Cause(ctx), ErrInterrupted) {
if persistErr := persistInterruptedStep(); persistErr != nil {
if opts.OnInterruptedPersistError != nil {
opts.OnInterruptedPersistError(persistErr)
}
}
return nil, ErrInterrupted
}
return nil, xerrors.Errorf("stream response: %w", err)
}
if opts.Compaction != nil {
if err := maybeCompact(ctx, opts, result); err != nil {
if opts.Compaction.OnError != nil {
opts.Compaction.OnError(err)
}
}
}
return result, nil
}
//nolint:revive // Boolean controls Anthropic-specific caching behavior.
func prepareStepResult(
messages []fantasy.Message,
sentinel string,
activeTools []string,
anthropicCaching bool,
) fantasy.PrepareStepResult {
filtered := make([]fantasy.Message, 0, len(messages))
removed := false
for _, message := range messages {
if !removed &&
message.Role == fantasy.MessageRoleUser &&
len(message.Content) == 1 {
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](message.Content[0])
if ok && textPart.Text == sentinel {
removed = true
continue
}
}
filtered = append(filtered, message)
}
result := fantasy.PrepareStepResult{
Messages: filtered,
}
if anthropicCaching {
result.Messages = addAnthropicPromptCaching(result.Messages)
}
if len(activeTools) > 0 {
result.ActiveTools = append([]string(nil), activeTools...)
}
return result
}
func shouldApplyAnthropicPromptCaching(model fantasy.LanguageModel) bool {
if model == nil {
return false
}
return model.Provider() == fantasyanthropic.Name
}
func addAnthropicPromptCaching(messages []fantasy.Message) []fantasy.Message {
for i := range messages {
messages[i].ProviderOptions = nil
}
providerOption := fantasy.ProviderOptions{
fantasyanthropic.Name: &fantasyanthropic.ProviderCacheControlOptions{
CacheControl: fantasyanthropic.CacheControl{Type: "ephemeral"},
},
}
lastSystemRoleIdx := -1
systemMessageUpdated := false
for i, msg := range messages {
if msg.Role == fantasy.MessageRoleSystem {
lastSystemRoleIdx = i
} else if !systemMessageUpdated && lastSystemRoleIdx >= 0 {
messages[lastSystemRoleIdx].ProviderOptions = providerOption
systemMessageUpdated = true
}
if i > len(messages)-3 {
messages[i].ProviderOptions = providerOption
}
}
return messages
}
func extractContextLimit(metadata fantasy.ProviderMetadata) sql.NullInt64 {
if len(metadata) == 0 {
return sql.NullInt64{}
}
encoded, err := json.Marshal(metadata)
if err != nil || len(encoded) == 0 {
return sql.NullInt64{}
}
var payload any
if err := json.Unmarshal(encoded, &payload); err != nil {
return sql.NullInt64{}
}
limit, ok := findContextLimitValue(payload)
if !ok {
return sql.NullInt64{}
}
return sql.NullInt64{
Int64: limit,
Valid: true,
}
}
func findContextLimitValue(value any) (int64, bool) {
var (
limit int64
found bool
)
collectContextLimitValues(value, func(candidate int64) {
if !found || candidate > limit {
limit = candidate
found = true
}
})
return limit, found
}
func collectContextLimitValues(value any, onValue func(int64)) {
switch typed := value.(type) {
case map[string]any:
for key, child := range typed {
if isContextLimitKey(key) {
if numeric, ok := numericContextLimitValue(child); ok {
onValue(numeric)
}
}
collectContextLimitValues(child, onValue)
}
case []any:
for _, child := range typed {
collectContextLimitValues(child, onValue)
}
}
}
func isContextLimitKey(key string) bool {
normalized := normalizeMetadataKey(key)
if normalized == "" {
return false
}
switch normalized {
case
"contextlimit",
"contextwindow",
"contextlength",
"maxcontext",
"maxcontexttokens",
"maxinputtokens",
"maxinputtoken",
"inputtokenlimit":
return true
}
return strings.Contains(normalized, "context") &&
(strings.Contains(normalized, "limit") ||
strings.Contains(normalized, "window") ||
strings.Contains(normalized, "length") ||
strings.HasPrefix(normalized, "max"))
}
func normalizeMetadataKey(key string) string {
var b strings.Builder
b.Grow(len(key))
for _, r := range key {
switch {
case r >= 'a' && r <= 'z':
_, _ = b.WriteRune(r)
case r >= 'A' && r <= 'Z':
_, _ = b.WriteRune(r + ('a' - 'A'))
case r >= '0' && r <= '9':
_, _ = b.WriteRune(r)
}
}
return b.String()
}
func numericContextLimitValue(value any) (int64, bool) {
switch typed := value.(type) {
case int64:
return positiveInt64(typed)
case int32:
return positiveInt64(int64(typed))
case int:
return positiveInt64(int64(typed))
case float64:
casted := int64(typed)
if typed > 0 && float64(casted) == typed {
return casted, true
}
case string:
parsed, err := strconv.ParseInt(strings.TrimSpace(typed), 10, 64)
if err == nil {
return positiveInt64(parsed)
}
case json.Number:
parsed, err := typed.Int64()
if err == nil {
return positiveInt64(parsed)
}
}
return 0, false
}
func positiveInt64(value int64) (int64, bool) {
if value <= 0 {
return 0, false
}
return value, true
}
-289
View File
@@ -1,289 +0,0 @@
package chatloop //nolint:testpackage // Uses internal symbols.
import (
"context"
"iter"
"strings"
"testing"
"charm.land/fantasy"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
)
const activeToolName = "read_file"
func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
t.Parallel()
var capturedCall fantasy.Call
model := &loopTestModel{
provider: fantasyanthropic.Name,
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
capturedCall = call
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}), nil
},
}
persistStepCalls := 0
var persistedStep PersistedStep
_, err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleSystem, "sys-1"),
textMessage(fantasy.MessageRoleSystem, "sys-2"),
textMessage(fantasy.MessageRoleUser, "hello"),
textMessage(fantasy.MessageRoleAssistant, "working"),
textMessage(fantasy.MessageRoleUser, "continue"),
},
Tools: []fantasy.AgentTool{
newNoopTool(activeToolName),
newNoopTool("write_file"),
},
MaxSteps: 3,
ActiveTools: []string{activeToolName},
ContextLimitFallback: 4096,
PersistStep: func(_ context.Context, step PersistedStep) error {
persistStepCalls++
persistedStep = step
return nil
},
})
require.NoError(t, err)
require.Equal(t, 1, persistStepCalls)
require.True(t, persistedStep.ContextLimit.Valid)
require.Equal(t, int64(4096), persistedStep.ContextLimit.Int64)
require.NotEmpty(t, capturedCall.Prompt)
require.False(t, containsPromptSentinel(capturedCall.Prompt))
require.Len(t, capturedCall.Tools, 1)
require.Equal(t, activeToolName, capturedCall.Tools[0].GetName())
require.Len(t, capturedCall.Prompt, 5)
require.False(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[0]))
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[1]))
require.False(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[2]))
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[3]))
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[4]))
}
func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
t.Parallel()
started := make(chan struct{})
model := &loopTestModel{
provider: "fake",
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
parts := []fantasy.StreamPart{
{
Type: fantasy.StreamPartTypeToolInputStart,
ID: "interrupt-tool-1",
ToolCallName: "read_file",
},
{
Type: fantasy.StreamPartTypeToolInputDelta,
ID: "interrupt-tool-1",
ToolCallName: "read_file",
Delta: `{"path":"main.go"`,
},
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "partial assistant output"},
}
for _, part := range parts {
if !yield(part) {
return
}
}
select {
case <-started:
default:
close(started)
}
<-ctx.Done()
_ = yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeError,
Error: ctx.Err(),
})
}), nil
},
}
ctx, cancel := context.WithCancelCause(context.Background())
defer cancel(nil)
go func() {
<-started
cancel(ErrInterrupted)
}()
persistedAssistantCtxErr := xerrors.New("unset")
var persistedContent []fantasy.Content
_, err := Run(ctx, RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
Tools: []fantasy.AgentTool{
newNoopTool("read_file"),
},
MaxSteps: 3,
PersistStep: func(persistCtx context.Context, step PersistedStep) error {
persistedAssistantCtxErr = persistCtx.Err()
persistedContent = append([]fantasy.Content(nil), step.Content...)
return nil
},
})
require.ErrorIs(t, err, ErrInterrupted)
require.NoError(t, persistedAssistantCtxErr)
require.NotEmpty(t, persistedContent)
var (
foundText bool
foundToolCall bool
foundToolResult bool
)
for _, block := range persistedContent {
if text, ok := fantasy.AsContentType[fantasy.TextContent](block); ok {
if strings.Contains(text.Text, "partial assistant output") {
foundText = true
}
continue
}
if toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](block); ok {
if toolCall.ToolCallID == "interrupt-tool-1" &&
toolCall.ToolName == "read_file" &&
strings.Contains(toolCall.Input, `"path":"main.go"`) {
foundToolCall = true
}
continue
}
if toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block); ok {
if toolResult.ToolCallID == "interrupt-tool-1" &&
toolResult.ToolName == "read_file" {
_, isErr := toolResult.Result.(fantasy.ToolResultOutputContentError)
require.True(t, isErr, "interrupted tool result should be an error")
foundToolResult = true
}
}
}
require.True(t, foundText)
require.True(t, foundToolCall)
require.True(t, foundToolResult)
}
type loopTestModel struct {
provider string
model string
generateFn func(context.Context, fantasy.Call) (*fantasy.Response, error)
streamFn func(context.Context, fantasy.Call) (fantasy.StreamResponse, error)
}
func (m *loopTestModel) Provider() string {
if m.provider != "" {
return m.provider
}
return "fake"
}
func (m *loopTestModel) Model() string {
if m.model != "" {
return m.model
}
return "fake"
}
func (m *loopTestModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
if m.generateFn != nil {
return m.generateFn(ctx, call)
}
return &fantasy.Response{}, nil
}
func (m *loopTestModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
if m.streamFn != nil {
return m.streamFn(ctx, call)
}
return streamFromParts([]fantasy.StreamPart{{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
}}), nil
}
func (*loopTestModel) GenerateObject(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
return nil, xerrors.New("not implemented")
}
func (*loopTestModel) StreamObject(context.Context, fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
return nil, xerrors.New("not implemented")
}
func streamFromParts(parts []fantasy.StreamPart) fantasy.StreamResponse {
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
for _, part := range parts {
if !yield(part) {
return
}
}
})
}
func newNoopTool(name string) fantasy.AgentTool {
return fantasy.NewAgentTool(
name,
"test noop tool",
func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) {
return fantasy.ToolResponse{}, nil
},
)
}
func textMessage(role fantasy.MessageRole, text string) fantasy.Message {
return fantasy.Message{
Role: role,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: text},
},
}
}
func containsPromptSentinel(prompt []fantasy.Message) bool {
for _, message := range prompt {
if message.Role != fantasy.MessageRoleUser || len(message.Content) != 1 {
continue
}
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](message.Content[0])
if !ok {
continue
}
if strings.HasPrefix(textPart.Text, "__chatd_agent_prompt_sentinel_") {
return true
}
}
return false
}
func hasAnthropicEphemeralCacheControl(message fantasy.Message) bool {
if len(message.ProviderOptions) == 0 {
return false
}
options, ok := message.ProviderOptions[fantasyanthropic.Name]
if !ok {
return false
}
cacheOptions, ok := options.(*fantasyanthropic.ProviderCacheControlOptions)
return ok && cacheOptions.CacheControl.Type == "ephemeral"
}
-209
View File
@@ -1,209 +0,0 @@
package chatloop
import (
"context"
"strings"
"time"
"charm.land/fantasy"
"golang.org/x/xerrors"
)
const (
defaultCompactionThresholdPercent = int32(70)
minCompactionThresholdPercent = int32(0)
maxCompactionThresholdPercent = int32(100)
defaultCompactionSummaryPrompt = "Summarize the current chat so a " +
"new assistant can continue seamlessly. Include the user's goals, " +
"decisions made, concrete technical details (files, commands, APIs), " +
"errors encountered and fixes, and open questions. Be dense and factual. " +
"Omit pleasantries and next-step suggestions."
defaultCompactionSystemSummaryPrefix = "Summary of earlier chat context:"
defaultCompactionTimeout = 90 * time.Second
)
type CompactionOptions struct {
ThresholdPercent int32
ContextLimit int64
SummaryPrompt string
SystemSummaryPrefix string
Timeout time.Duration
Persist func(context.Context, CompactionResult) error
OnError func(error)
}
type CompactionResult struct {
SystemSummary string
SummaryReport string
ThresholdPercent int32
UsagePercent float64
ContextTokens int64
ContextLimit int64
}
func maybeCompact(
ctx context.Context,
runOpts RunOptions,
runResult *fantasy.AgentResult,
) error {
if runResult == nil || runOpts.Compaction == nil {
return nil
}
config := *runOpts.Compaction
if config.Persist == nil {
return xerrors.New("compaction persist callback is required")
}
if strings.TrimSpace(config.SummaryPrompt) == "" {
config.SummaryPrompt = defaultCompactionSummaryPrompt
}
if strings.TrimSpace(config.SystemSummaryPrefix) == "" {
config.SystemSummaryPrefix = defaultCompactionSystemSummaryPrefix
}
if config.Timeout <= 0 {
config.Timeout = defaultCompactionTimeout
}
if config.ThresholdPercent < minCompactionThresholdPercent ||
config.ThresholdPercent > maxCompactionThresholdPercent {
config.ThresholdPercent = defaultCompactionThresholdPercent
}
if config.ThresholdPercent >= maxCompactionThresholdPercent {
return nil
}
if runOpts.MaxSteps > 0 && len(runResult.Steps) >= runOpts.MaxSteps {
lastStep := runResult.Steps[len(runResult.Steps)-1]
if lastStep.FinishReason == fantasy.FinishReasonToolCalls &&
len(lastStep.Content.ToolCalls()) > 0 {
return nil
}
}
contextTokens := int64(0)
contextLimitFromMetadata := int64(0)
for i := len(runResult.Steps) - 1; i >= 0; i-- {
usage := runResult.Steps[i].Usage
total := int64(0)
hasContextTokens := false
if usage.InputTokens > 0 {
total += usage.InputTokens
hasContextTokens = true
}
if usage.CacheReadTokens > 0 {
total += usage.CacheReadTokens
hasContextTokens = true
}
if usage.CacheCreationTokens > 0 {
total += usage.CacheCreationTokens
hasContextTokens = true
}
if !hasContextTokens && usage.TotalTokens > 0 {
total = usage.TotalTokens
hasContextTokens = true
}
if !hasContextTokens || total <= 0 {
continue
}
contextTokens = total
metadataLimit := extractContextLimit(runResult.Steps[i].ProviderMetadata)
if metadataLimit.Valid && metadataLimit.Int64 > 0 {
contextLimitFromMetadata = metadataLimit.Int64
}
break
}
if contextTokens <= 0 {
return nil
}
contextLimit := contextLimitFromMetadata
if contextLimit <= 0 && config.ContextLimit > 0 {
contextLimit = config.ContextLimit
}
if contextLimit <= 0 && runOpts.ContextLimitFallback > 0 {
contextLimit = runOpts.ContextLimitFallback
}
if contextLimit <= 0 {
return nil
}
usagePercent := (float64(contextTokens) / float64(contextLimit)) * 100
if usagePercent < float64(config.ThresholdPercent) {
return nil
}
summary, err := generateCompactionSummary(
ctx,
runOpts.Model,
runOpts.Messages,
runResult.Steps,
config,
)
if err != nil {
return err
}
if summary == "" {
return nil
}
systemSummary := strings.TrimSpace(
config.SystemSummaryPrefix + "\n\n" + summary,
)
return config.Persist(ctx, CompactionResult{
SystemSummary: systemSummary,
SummaryReport: summary,
ThresholdPercent: config.ThresholdPercent,
UsagePercent: usagePercent,
ContextTokens: contextTokens,
ContextLimit: contextLimit,
})
}
func generateCompactionSummary(
ctx context.Context,
model fantasy.LanguageModel,
messages []fantasy.Message,
steps []fantasy.StepResult,
options CompactionOptions,
) (string, error) {
summaryPrompt := make([]fantasy.Message, 0, len(messages)+len(steps)+1)
summaryPrompt = append(summaryPrompt, messages...)
for _, step := range steps {
summaryPrompt = append(summaryPrompt, step.Messages...)
}
summaryPrompt = append(summaryPrompt, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: options.SummaryPrompt},
},
})
toolChoice := fantasy.ToolChoiceNone
summaryCtx, cancel := context.WithTimeout(ctx, options.Timeout)
defer cancel()
response, err := model.Generate(summaryCtx, fantasy.Call{
Prompt: summaryPrompt,
ToolChoice: &toolChoice,
})
if err != nil {
return "", xerrors.Errorf("generate summary text: %w", err)
}
parts := make([]string, 0, len(response.Content))
for _, block := range response.Content {
textBlock, ok := fantasy.AsContentType[fantasy.TextContent](block)
if !ok {
continue
}
text := strings.TrimSpace(textBlock.Text)
if text == "" {
continue
}
parts = append(parts, text)
}
return strings.TrimSpace(strings.Join(parts, " ")), nil
}
-132
View File
@@ -1,132 +0,0 @@
package chatloop //nolint:testpackage // Uses internal symbols.
import (
"context"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
)
func TestRun_Compaction(t *testing.T) {
t.Parallel()
t.Run("PersistsWhenThresholdReached", func(t *testing.T) {
t.Parallel()
persistCompactionCalls := 0
var persistedCompaction CompactionResult
const summaryText = "summary text for compaction"
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
Usage: fantasy.Usage{
InputTokens: 80,
TotalTokens: 85,
},
},
}), nil
},
generateFn: func(_ context.Context, call fantasy.Call) (*fantasy.Response, error) {
require.NotEmpty(t, call.Prompt)
lastPrompt := call.Prompt[len(call.Prompt)-1]
require.Equal(t, fantasy.MessageRoleUser, lastPrompt.Role)
require.Len(t, lastPrompt.Content, 1)
instruction, ok := fantasy.AsMessagePart[fantasy.TextPart](lastPrompt.Content[0])
require.True(t, ok)
require.Equal(t, "summarize now", instruction.Text)
return &fantasy.Response{
Content: []fantasy.Content{
fantasy.TextContent{Text: summaryText},
},
}, nil
},
}
_, err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
MaxSteps: 1,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
ContextLimitFallback: 100,
Compaction: &CompactionOptions{
ThresholdPercent: 70,
SummaryPrompt: "summarize now",
Persist: func(_ context.Context, result CompactionResult) error {
persistCompactionCalls++
persistedCompaction = result
return nil
},
},
})
require.NoError(t, err)
require.Equal(t, 1, persistCompactionCalls)
require.Contains(t, persistedCompaction.SystemSummary, summaryText)
require.Equal(t, summaryText, persistedCompaction.SummaryReport)
require.Equal(t, int64(80), persistedCompaction.ContextTokens)
require.Equal(t, int64(100), persistedCompaction.ContextLimit)
require.InDelta(t, 80.0, persistedCompaction.UsagePercent, 0.0001)
})
t.Run("ErrorsAreReported", func(t *testing.T) {
t.Parallel()
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
Usage: fantasy.Usage{
InputTokens: 80,
},
},
}), nil
},
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return nil, xerrors.New("generate failed")
},
}
compactionErr := xerrors.New("unset")
_, err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
MaxSteps: 1,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
ContextLimitFallback: 100,
Compaction: &CompactionOptions{
ThresholdPercent: 70,
Persist: func(_ context.Context, _ CompactionResult) error {
return nil
},
OnError: func(err error) {
compactionErr = err
},
},
})
require.NoError(t, err)
require.Error(t, compactionErr)
require.ErrorContains(t, compactionErr, "generate summary text")
})
}
-982
View File
@@ -1,982 +0,0 @@
package chatprompt
import (
"encoding/json"
"regexp"
"strings"
"charm.land/fantasy"
fantasyopenai "charm.land/fantasy/providers/openai"
"github.com/sqlc-dev/pqtype"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/codersdk"
)
var toolCallIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
func ConvertMessages(
messages []database.ChatMessage,
) ([]fantasy.Message, error) {
prompt := make([]fantasy.Message, 0, len(messages))
toolNameByCallID := make(map[string]string)
for _, message := range messages {
visibility := message.Visibility
if visibility == "" {
visibility = database.ChatMessageVisibilityBoth
}
if visibility != database.ChatMessageVisibilityModel &&
visibility != database.ChatMessageVisibilityBoth {
continue
}
switch message.Role {
case string(fantasy.MessageRoleSystem):
content, err := parseSystemContent(message.Content)
if err != nil {
return nil, err
}
if strings.TrimSpace(content) == "" {
continue
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: content},
},
})
case string(fantasy.MessageRoleUser):
content, err := ParseContent(string(fantasy.MessageRoleUser), message.Content)
if err != nil {
return nil, err
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: ToMessageParts(content),
})
case string(fantasy.MessageRoleAssistant):
content, err := ParseContent(string(fantasy.MessageRoleAssistant), message.Content)
if err != nil {
return nil, err
}
parts := normalizeAssistantToolCallInputs(ToMessageParts(content))
for _, toolCall := range ExtractToolCalls(parts) {
if toolCall.ToolCallID == "" || strings.TrimSpace(toolCall.ToolName) == "" {
continue
}
toolNameByCallID[sanitizeToolCallID(toolCall.ToolCallID)] = toolCall.ToolName
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleAssistant,
Content: parts,
})
case string(fantasy.MessageRoleTool):
rows, err := parseToolResultRows(message.Content)
if err != nil {
return nil, err
}
parts := make([]fantasy.MessagePart, 0, len(rows))
for _, row := range rows {
if row.ToolCallID != "" && row.ToolName != "" {
toolNameByCallID[sanitizeToolCallID(row.ToolCallID)] = row.ToolName
}
parts = append(parts, row.toToolResultPart())
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleTool,
Content: parts,
})
default:
return nil, xerrors.Errorf("unsupported chat message role %q", message.Role)
}
}
prompt = injectMissingToolResults(prompt)
prompt = injectMissingToolUses(
prompt,
toolNameByCallID,
)
return prompt, nil
}
// PrependSystem prepends a system message unless an existing system
// message already mentions create_workspace guidance.
func PrependSystem(prompt []fantasy.Message, instruction string) []fantasy.Message {
instruction = strings.TrimSpace(instruction)
if instruction == "" {
return prompt
}
for _, message := range prompt {
if message.Role != fantasy.MessageRoleSystem {
continue
}
for _, part := range message.Content {
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
if !ok {
continue
}
if strings.Contains(strings.ToLower(textPart.Text), "create_workspace") {
return prompt
}
}
}
out := make([]fantasy.Message, 0, len(prompt)+1)
out = append(out, fantasy.Message{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: instruction},
},
})
out = append(out, prompt...)
return out
}
// InsertSystem inserts a system message after the existing system
// block and before the first non-system message.
func InsertSystem(prompt []fantasy.Message, instruction string) []fantasy.Message {
instruction = strings.TrimSpace(instruction)
if instruction == "" {
return prompt
}
systemMessage := fantasy.Message{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: instruction},
},
}
out := make([]fantasy.Message, 0, len(prompt)+1)
inserted := false
for _, message := range prompt {
if !inserted && message.Role != fantasy.MessageRoleSystem {
out = append(out, systemMessage)
inserted = true
}
out = append(out, message)
}
if !inserted {
out = append(out, systemMessage)
}
return out
}
// AppendUser appends an instruction as a user message at the end of
// the prompt.
func AppendUser(prompt []fantasy.Message, instruction string) []fantasy.Message {
instruction = strings.TrimSpace(instruction)
if instruction == "" {
return prompt
}
out := make([]fantasy.Message, 0, len(prompt)+1)
out = append(out, prompt...)
out = append(out, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: instruction},
},
})
return out
}
// ParseContent decodes persisted chat message content blocks.
func ParseContent(role string, raw pqtype.NullRawMessage) ([]fantasy.Content, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return nil, nil
}
var text string
if err := json.Unmarshal(raw.RawMessage, &text); err == nil {
return []fantasy.Content{fantasy.TextContent{Text: text}}, nil
}
var rawBlocks []json.RawMessage
if err := json.Unmarshal(raw.RawMessage, &rawBlocks); err != nil {
return nil, xerrors.Errorf("parse %s content: %w", role, err)
}
content := make([]fantasy.Content, 0, len(rawBlocks))
for i, rawBlock := range rawBlocks {
block, err := fantasy.UnmarshalContent(rawBlock)
if err != nil {
return nil, xerrors.Errorf("parse %s content block %d: %w", role, i, err)
}
content = append(content, block)
}
return content, nil
}
// toolResultRaw is an untyped representation of a persisted tool
// result row. We intentionally avoid a strict Go struct so that
// historical shapes are never rejected.
type toolResultRaw struct {
ToolCallID string `json:"tool_call_id"`
ToolName string `json:"tool_name"`
Result json.RawMessage `json:"result"`
IsError bool `json:"is_error,omitempty"`
}
// parseToolResultRows decodes persisted tool result rows.
func parseToolResultRows(raw pqtype.NullRawMessage) ([]toolResultRaw, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return nil, nil
}
var rows []toolResultRaw
if err := json.Unmarshal(raw.RawMessage, &rows); err != nil {
return nil, xerrors.Errorf("parse tool content: %w", err)
}
return rows, nil
}
func (r toolResultRaw) toToolResultPart() fantasy.ToolResultPart {
toolCallID := sanitizeToolCallID(r.ToolCallID)
resultText := string(r.Result)
if resultText == "" || resultText == "null" {
resultText = "{}"
}
if r.IsError {
message := strings.TrimSpace(resultText)
if extracted := extractErrorString(r.Result); extracted != "" {
message = extracted
}
return fantasy.ToolResultPart{
ToolCallID: toolCallID,
Output: fantasy.ToolResultOutputContentError{
Error: xerrors.New(message),
},
}
}
return fantasy.ToolResultPart{
ToolCallID: toolCallID,
Output: fantasy.ToolResultOutputContentText{
Text: resultText,
},
}
}
// extractErrorString pulls the "error" field from a JSON object if
// present, returning it as a string. Returns "" if the field is
// missing or the input is not an object.
func extractErrorString(raw json.RawMessage) string {
var fields map[string]json.RawMessage
if err := json.Unmarshal(raw, &fields); err != nil {
return ""
}
errField, ok := fields["error"]
if !ok {
return ""
}
var s string
if err := json.Unmarshal(errField, &s); err != nil {
return ""
}
return strings.TrimSpace(s)
}
// ToMessageParts converts fantasy content blocks into message parts.
func ToMessageParts(content []fantasy.Content) []fantasy.MessagePart {
parts := make([]fantasy.MessagePart, 0, len(content))
for _, block := range content {
switch value := block.(type) {
case fantasy.TextContent:
parts = append(parts, fantasy.TextPart{
Text: value.Text,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case *fantasy.TextContent:
parts = append(parts, fantasy.TextPart{
Text: value.Text,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case fantasy.ReasoningContent:
parts = append(parts, fantasy.ReasoningPart{
Text: value.Text,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case *fantasy.ReasoningContent:
parts = append(parts, fantasy.ReasoningPart{
Text: value.Text,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case fantasy.ToolCallContent:
parts = append(parts, fantasy.ToolCallPart{
ToolCallID: sanitizeToolCallID(value.ToolCallID),
ToolName: value.ToolName,
Input: value.Input,
ProviderExecuted: value.ProviderExecuted,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case *fantasy.ToolCallContent:
parts = append(parts, fantasy.ToolCallPart{
ToolCallID: sanitizeToolCallID(value.ToolCallID),
ToolName: value.ToolName,
Input: value.Input,
ProviderExecuted: value.ProviderExecuted,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case fantasy.FileContent:
parts = append(parts, fantasy.FilePart{
Data: value.Data,
MediaType: value.MediaType,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case *fantasy.FileContent:
parts = append(parts, fantasy.FilePart{
Data: value.Data,
MediaType: value.MediaType,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case fantasy.ToolResultContent:
parts = append(parts, fantasy.ToolResultPart{
ToolCallID: sanitizeToolCallID(value.ToolCallID),
Output: value.Result,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
case *fantasy.ToolResultContent:
parts = append(parts, fantasy.ToolResultPart{
ToolCallID: sanitizeToolCallID(value.ToolCallID),
Output: value.Result,
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
})
}
}
return parts
}
func normalizeAssistantToolCallInputs(
parts []fantasy.MessagePart,
) []fantasy.MessagePart {
normalized := make([]fantasy.MessagePart, 0, len(parts))
for _, part := range parts {
toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
if !ok {
normalized = append(normalized, part)
continue
}
toolCall.Input = normalizeToolCallInput(toolCall.Input)
normalized = append(normalized, toolCall)
}
return normalized
}
// normalizeToolCallInput guarantees tool call input is a JSON object string.
// Anthropic drops assistant tool calls with malformed input, which can leave
// following tool results orphaned.
func normalizeToolCallInput(input string) string {
input = strings.TrimSpace(input)
if input == "" {
return "{}"
}
var object map[string]any
if err := json.Unmarshal([]byte(input), &object); err != nil || object == nil {
return "{}"
}
return input
}
// ExtractToolCalls returns all tool call parts as content blocks.
func ExtractToolCalls(parts []fantasy.MessagePart) []fantasy.ToolCallContent {
toolCalls := make([]fantasy.ToolCallContent, 0, len(parts))
for _, part := range parts {
toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
if !ok {
continue
}
toolCalls = append(toolCalls, fantasy.ToolCallContent{
ToolCallID: toolCall.ToolCallID,
ToolName: toolCall.ToolName,
Input: toolCall.Input,
ProviderExecuted: toolCall.ProviderExecuted,
})
}
return toolCalls
}
// MarshalContent encodes message content blocks for persistence.
func MarshalContent(blocks []fantasy.Content) (pqtype.NullRawMessage, error) {
if len(blocks) == 0 {
return pqtype.NullRawMessage{}, nil
}
encodedBlocks := make([]json.RawMessage, 0, len(blocks))
for i, block := range blocks {
encoded, err := marshalContentBlock(block)
if err != nil {
return pqtype.NullRawMessage{}, xerrors.Errorf(
"encode content block %d: %w",
i,
err,
)
}
encodedBlocks = append(encodedBlocks, encoded)
}
data, err := json.Marshal(encodedBlocks)
if err != nil {
return pqtype.NullRawMessage{}, xerrors.Errorf("encode content blocks: %w", err)
}
return pqtype.NullRawMessage{RawMessage: data, Valid: true}, nil
}
// MarshalToolResult encodes a single tool result for persistence as
// an opaque JSON blob. The stored shape is
// [{"tool_call_id":…,"tool_name":…,"result":…,"is_error":…}].
func MarshalToolResult(toolCallID, toolName string, result json.RawMessage, isError bool) (pqtype.NullRawMessage, error) {
row := toolResultRaw{
ToolCallID: toolCallID,
ToolName: toolName,
Result: result,
IsError: isError,
}
data, err := json.Marshal([]toolResultRaw{row})
if err != nil {
return pqtype.NullRawMessage{}, xerrors.Errorf("encode tool result: %w", err)
}
return pqtype.NullRawMessage{RawMessage: data, Valid: true}, nil
}
// MarshalToolResultContent encodes a fantasy tool result content
// block for persistence. It extracts the raw fields and delegates
// to MarshalToolResult.
func MarshalToolResultContent(content fantasy.ToolResultContent) (pqtype.NullRawMessage, error) {
var result json.RawMessage
var isError bool
switch output := content.Result.(type) {
case fantasy.ToolResultOutputContentError:
isError = true
if output.Error != nil {
result, _ = json.Marshal(map[string]any{"error": output.Error.Error()})
} else {
result = []byte(`{"error":""}`)
}
case fantasy.ToolResultOutputContentText:
result = json.RawMessage(output.Text)
if !json.Valid(result) {
result, _ = json.Marshal(map[string]any{"output": output.Text})
}
case fantasy.ToolResultOutputContentMedia:
result, _ = json.Marshal(map[string]any{
"data": output.Data,
"mime_type": output.MediaType,
"text": output.Text,
})
default:
result = []byte(`{}`)
}
return MarshalToolResult(content.ToolCallID, content.ToolName, result, isError)
}
// PartFromContent converts fantasy content into a SDK chat message part.
func PartFromContent(block fantasy.Content) codersdk.ChatMessagePart {
switch value := block.(type) {
case fantasy.TextContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: value.Text,
}
case *fantasy.TextContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: value.Text,
}
case fantasy.ReasoningContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Text: value.Text,
Title: reasoningSummaryTitle(value.ProviderMetadata),
}
case *fantasy.ReasoningContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Text: value.Text,
Title: reasoningSummaryTitle(value.ProviderMetadata),
}
case fantasy.ToolCallContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: value.ToolCallID,
ToolName: value.ToolName,
Args: []byte(value.Input),
}
case *fantasy.ToolCallContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: value.ToolCallID,
ToolName: value.ToolName,
Args: []byte(value.Input),
}
case fantasy.SourceContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeSource,
SourceID: value.ID,
URL: value.URL,
Title: value.Title,
}
case *fantasy.SourceContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeSource,
SourceID: value.ID,
URL: value.URL,
Title: value.Title,
}
case fantasy.FileContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeFile,
MediaType: value.MediaType,
Data: value.Data,
}
case *fantasy.FileContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeFile,
MediaType: value.MediaType,
Data: value.Data,
}
case fantasy.ToolResultContent:
return toolResultContentToPart(value)
case *fantasy.ToolResultContent:
return toolResultContentToPart(*value)
default:
return codersdk.ChatMessagePart{}
}
}
// ToolResultToPart converts a tool call ID, raw result, and error
// flag into a ChatMessagePart. This is the minimal conversion used
// both during streaming and when reading from the database.
func ToolResultToPart(toolCallID, toolName string, result json.RawMessage, isError bool) codersdk.ChatMessagePart {
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolResult,
ToolCallID: toolCallID,
ToolName: toolName,
Result: result,
IsError: isError,
}
}
// toolResultContentToPart converts a fantasy ToolResultContent
// directly into a ChatMessagePart without an intermediate struct.
func toolResultContentToPart(content fantasy.ToolResultContent) codersdk.ChatMessagePart {
var result json.RawMessage
var isError bool
switch output := content.Result.(type) {
case fantasy.ToolResultOutputContentError:
isError = true
if output.Error != nil {
result, _ = json.Marshal(map[string]any{"error": output.Error.Error()})
} else {
result = []byte(`{"error":""}`)
}
case fantasy.ToolResultOutputContentText:
result = json.RawMessage(output.Text)
// Ensure valid JSON; wrap in an object if not.
if !json.Valid(result) {
result, _ = json.Marshal(map[string]any{"output": output.Text})
}
case fantasy.ToolResultOutputContentMedia:
result, _ = json.Marshal(map[string]any{
"data": output.Data,
"mime_type": output.MediaType,
"text": output.Text,
})
default:
result = []byte(`{}`)
}
return ToolResultToPart(content.ToolCallID, content.ToolName, result, isError)
}
// ReasoningTitleFromFirstLine extracts a compact markdown title.
func ReasoningTitleFromFirstLine(text string) string {
text = strings.TrimSpace(text)
if text == "" {
return ""
}
firstLine := text
if idx := strings.IndexAny(firstLine, "\r\n"); idx >= 0 {
firstLine = firstLine[:idx]
}
firstLine = strings.TrimSpace(firstLine)
if firstLine == "" || !strings.HasPrefix(firstLine, "**") {
return ""
}
rest := firstLine[2:]
end := strings.Index(rest, "**")
if end < 0 {
return ""
}
title := strings.TrimSpace(rest[:end])
if title == "" {
return ""
}
// Require the first line to be exactly "**title**" (ignoring
// surrounding whitespace) so providers without this format don't
// accidentally emit a title.
if strings.TrimSpace(rest[end+2:]) != "" {
return ""
}
return compactReasoningSummaryTitle(title)
}
func injectMissingToolResults(prompt []fantasy.Message) []fantasy.Message {
result := make([]fantasy.Message, 0, len(prompt))
for i := 0; i < len(prompt); i++ {
msg := prompt[i]
result = append(result, msg)
if msg.Role != fantasy.MessageRoleAssistant {
continue
}
toolCalls := ExtractToolCalls(msg.Content)
if len(toolCalls) == 0 {
continue
}
// Collect the tool call IDs that have results in the
// following tool message(s).
answered := make(map[string]struct{})
j := i + 1
for ; j < len(prompt); j++ {
if prompt[j].Role != fantasy.MessageRoleTool {
break
}
for _, part := range prompt[j].Content {
tr, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
if !ok {
continue
}
answered[tr.ToolCallID] = struct{}{}
}
}
if i+1 < j {
// Preserve persisted tool result ordering and inject any
// synthetic results after the existing contiguous tool messages.
result = append(result, prompt[i+1:j]...)
i = j - 1
}
// Build synthetic results for any unanswered tool calls.
var missing []fantasy.MessagePart
for _, tc := range toolCalls {
if _, ok := answered[tc.ToolCallID]; !ok {
missing = append(missing, fantasy.ToolResultPart{
ToolCallID: tc.ToolCallID,
Output: fantasy.ToolResultOutputContentError{
Error: xerrors.New("tool call was interrupted and did not receive a result"),
},
})
}
}
if len(missing) > 0 {
result = append(result, fantasy.Message{
Role: fantasy.MessageRoleTool,
Content: missing,
})
}
}
return result
}
func injectMissingToolUses(
prompt []fantasy.Message,
toolNameByCallID map[string]string,
) []fantasy.Message {
result := make([]fantasy.Message, 0, len(prompt))
for _, msg := range prompt {
if msg.Role != fantasy.MessageRoleTool {
result = append(result, msg)
continue
}
toolResults := make([]fantasy.ToolResultPart, 0, len(msg.Content))
for _, part := range msg.Content {
toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
if !ok {
continue
}
toolResults = append(toolResults, toolResult)
}
if len(toolResults) == 0 {
result = append(result, msg)
continue
}
// Walk backwards through the result to find the nearest
// preceding assistant message (skipping over other tool
// messages that belong to the same batch of results).
answeredByPrevious := make(map[string]struct{})
for k := len(result) - 1; k >= 0; k-- {
if result[k].Role == fantasy.MessageRoleAssistant {
for _, toolCall := range ExtractToolCalls(result[k].Content) {
toolCallID := sanitizeToolCallID(toolCall.ToolCallID)
if toolCallID == "" {
continue
}
answeredByPrevious[toolCallID] = struct{}{}
}
break
}
if result[k].Role != fantasy.MessageRoleTool {
break
}
}
matchingResults := make([]fantasy.ToolResultPart, 0, len(toolResults))
orphanResults := make([]fantasy.ToolResultPart, 0, len(toolResults))
for _, toolResult := range toolResults {
toolCallID := sanitizeToolCallID(toolResult.ToolCallID)
if _, ok := answeredByPrevious[toolCallID]; ok {
matchingResults = append(matchingResults, toolResult)
continue
}
orphanResults = append(orphanResults, toolResult)
}
if len(orphanResults) == 0 {
result = append(result, msg)
continue
}
syntheticToolUse := syntheticToolUseMessage(
orphanResults,
toolNameByCallID,
)
if len(syntheticToolUse.Content) == 0 {
result = append(result, msg)
continue
}
if len(matchingResults) > 0 {
result = append(result, toolMessageFromToolResultParts(matchingResults))
}
result = append(result, syntheticToolUse)
result = append(result, toolMessageFromToolResultParts(orphanResults))
}
return result
}
func toolMessageFromToolResultParts(results []fantasy.ToolResultPart) fantasy.Message {
parts := make([]fantasy.MessagePart, 0, len(results))
for _, result := range results {
parts = append(parts, result)
}
return fantasy.Message{
Role: fantasy.MessageRoleTool,
Content: parts,
}
}
func syntheticToolUseMessage(
toolResults []fantasy.ToolResultPart,
toolNameByCallID map[string]string,
) fantasy.Message {
parts := make([]fantasy.MessagePart, 0, len(toolResults))
seen := make(map[string]struct{}, len(toolResults))
for _, toolResult := range toolResults {
toolCallID := sanitizeToolCallID(toolResult.ToolCallID)
if toolCallID == "" {
continue
}
if _, ok := seen[toolCallID]; ok {
continue
}
toolName := strings.TrimSpace(toolNameByCallID[toolCallID])
if toolName == "" {
continue
}
seen[toolCallID] = struct{}{}
parts = append(parts, fantasy.ToolCallPart{
ToolCallID: toolCallID,
ToolName: toolName,
Input: "{}",
})
}
return fantasy.Message{
Role: fantasy.MessageRoleAssistant,
Content: parts,
}
}
func parseSystemContent(raw pqtype.NullRawMessage) (string, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return "", nil
}
var content string
if err := json.Unmarshal(raw.RawMessage, &content); err != nil {
return "", xerrors.Errorf("parse system message content: %w", err)
}
return content, nil
}
func sanitizeToolCallID(id string) string {
if id == "" {
return ""
}
return toolCallIDSanitizer.ReplaceAllString(id, "_")
}
func marshalContentBlock(block fantasy.Content) (json.RawMessage, error) {
encoded, err := json.Marshal(block)
if err != nil {
return nil, err
}
title, ok := reasoningTitleFromContent(block)
if !ok || title == "" {
return encoded, nil
}
var envelope struct {
Type string `json:"type"`
Data map[string]any `json:"data"`
}
if err := json.Unmarshal(encoded, &envelope); err != nil {
return nil, err
}
if !strings.EqualFold(envelope.Type, string(fantasy.ContentTypeReasoning)) {
return encoded, nil
}
if envelope.Data == nil {
envelope.Data = map[string]any{}
}
envelope.Data["title"] = title
encodedWithTitle, err := json.Marshal(envelope)
if err != nil {
return nil, err
}
return encodedWithTitle, nil
}
func reasoningTitleFromContent(block fantasy.Content) (string, bool) {
switch value := block.(type) {
case fantasy.ReasoningContent:
return ReasoningTitleFromFirstLine(value.Text), true
case *fantasy.ReasoningContent:
if value == nil {
return "", false
}
return ReasoningTitleFromFirstLine(value.Text), true
default:
return "", false
}
}
func reasoningSummaryTitle(metadata fantasy.ProviderMetadata) string {
if len(metadata) == 0 {
return ""
}
reasoningMetadata := fantasyopenai.GetReasoningMetadata(
fantasy.ProviderOptions(metadata),
)
if reasoningMetadata == nil {
return ""
}
for _, summary := range reasoningMetadata.Summary {
if title := compactReasoningSummaryTitle(summary); title != "" {
return title
}
}
return ""
}
func compactReasoningSummaryTitle(summary string) string {
const maxWords = 8
const maxRunes = 80
summary = strings.TrimSpace(summary)
if summary == "" {
return ""
}
summary = strings.Trim(summary, "\"'`")
summary = reasoningSummaryHeadline(summary)
words := strings.Fields(summary)
if len(words) == 0 {
return ""
}
truncated := false
if len(words) > maxWords {
words = words[:maxWords]
truncated = true
}
title := strings.Join(words, " ")
if truncated {
title += "…"
}
return truncateRunes(title, maxRunes)
}
func reasoningSummaryHeadline(summary string) string {
summary = strings.TrimSpace(summary)
if summary == "" {
return ""
}
// OpenAI summary_text may be markdown like:
// "**Title**\n\nLonger explanation ...".
// Keep only the heading segment for UI titles.
if idx := strings.Index(summary, "\n\n"); idx >= 0 {
summary = summary[:idx]
}
if idx := strings.IndexAny(summary, "\r\n"); idx >= 0 {
summary = summary[:idx]
}
summary = strings.TrimSpace(summary)
if summary == "" {
return ""
}
if strings.HasPrefix(summary, "**") {
rest := summary[2:]
if end := strings.Index(rest, "**"); end >= 0 {
bold := strings.TrimSpace(rest[:end])
if bold != "" {
summary = bold
}
}
}
return strings.TrimSpace(strings.Trim(summary, "\"'`"))
}
func truncateRunes(value string, maxLen int) string {
if maxLen <= 0 {
return ""
}
runes := []rune(value)
if len(runes) <= maxLen {
return value
}
return string(runes[:maxLen])
}
@@ -1,90 +0,0 @@
package chatprompt
import (
"encoding/json"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
)
func TestConvertMessages_NormalizesAssistantToolCallInput(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input string
expected string
}{
{
name: "empty input",
input: "",
expected: "{}",
},
{
name: "invalid json",
input: "{\"command\":",
expected: "{}",
},
{
name: "non-object json",
input: "[]",
expected: "{}",
},
{
name: "valid object json",
input: "{\"command\":\"ls\"}",
expected: "{\"command\":\"ls\"}",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
assistantContent, err := MarshalContent([]fantasy.Content{
fantasy.ToolCallContent{
ToolCallID: "toolu_01C4PqN6F2493pi7Ebag8Vg7",
ToolName: "execute",
Input: tc.input,
},
})
require.NoError(t, err)
toolContent, err := MarshalToolResult(
"toolu_01C4PqN6F2493pi7Ebag8Vg7",
"execute",
json.RawMessage(`{"error":"tool call was interrupted before it produced a result"}`),
true,
)
require.NoError(t, err)
prompt, err := ConvertMessages([]database.ChatMessage{
{
Role: string(fantasy.MessageRoleAssistant),
Visibility: database.ChatMessageVisibilityBoth,
Content: assistantContent,
},
{
Role: string(fantasy.MessageRoleTool),
Visibility: database.ChatMessageVisibilityBoth,
Content: toolContent,
},
})
require.NoError(t, err)
require.Len(t, prompt, 2)
require.Equal(t, fantasy.MessageRoleAssistant, prompt[0].Role)
toolCalls := ExtractToolCalls(prompt[0].Content)
require.Len(t, toolCalls, 1)
require.Equal(t, tc.expected, toolCalls[0].Input)
require.Equal(t, "execute", toolCalls[0].ToolName)
require.Equal(t, "toolu_01C4PqN6F2493pi7Ebag8Vg7", toolCalls[0].ToolCallID)
require.Equal(t, fantasy.MessageRoleTool, prompt[1].Role)
})
}
}
File diff suppressed because it is too large Load Diff
@@ -1,191 +0,0 @@
package chatprovider_test
import (
"testing"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
fantasyopenai "charm.land/fantasy/providers/openai"
fantasyopenrouter "charm.land/fantasy/providers/openrouter"
fantasyvercel "charm.land/fantasy/providers/vercel"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
"github.com/coder/coder/v2/codersdk"
)
func TestReasoningEffortFromChat(t *testing.T) {
t.Parallel()
tests := []struct {
name string
provider string
input *string
want *string
}{
{
name: "OpenAICaseInsensitive",
provider: "openai",
input: stringPtr(" HIGH "),
want: stringPtr(string(fantasyopenai.ReasoningEffortHigh)),
},
{
name: "AnthropicEffort",
provider: "anthropic",
input: stringPtr("max"),
want: stringPtr(string(fantasyanthropic.EffortMax)),
},
{
name: "OpenRouterEffort",
provider: "openrouter",
input: stringPtr("medium"),
want: stringPtr(string(fantasyopenrouter.ReasoningEffortMedium)),
},
{
name: "VercelEffort",
provider: "vercel",
input: stringPtr("xhigh"),
want: stringPtr(string(fantasyvercel.ReasoningEffortXHigh)),
},
{
name: "InvalidEffortReturnsNil",
provider: "openai",
input: stringPtr("unknown"),
want: nil,
},
{
name: "UnsupportedProviderReturnsNil",
provider: "bedrock",
input: stringPtr("high"),
want: nil,
},
{
name: "NilInputReturnsNil",
provider: "openai",
input: nil,
want: nil,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := chatprovider.ReasoningEffortFromChat(tt.provider, tt.input)
require.Equal(t, tt.want, got)
})
}
}
func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
t.Parallel()
options := &codersdk.ChatModelProviderOptions{
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
Enabled: boolPtr(true),
},
Provider: &codersdk.ChatModelOpenRouterProvider{
Order: []string{"openai"},
},
},
}
defaults := &codersdk.ChatModelProviderOptions{
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
Enabled: boolPtr(false),
Exclude: boolPtr(true),
MaxTokens: int64Ptr(123),
Effort: stringPtr("high"),
},
IncludeUsage: boolPtr(true),
Provider: &codersdk.ChatModelOpenRouterProvider{
Order: []string{"anthropic"},
AllowFallbacks: boolPtr(true),
RequireParameters: boolPtr(false),
DataCollection: stringPtr("allow"),
Only: []string{"openai"},
Ignore: []string{"foo"},
Quantizations: []string{"int8"},
Sort: stringPtr("latency"),
},
},
}
chatprovider.MergeMissingProviderOptions(&options, defaults)
require.NotNil(t, options)
require.NotNil(t, options.OpenRouter)
require.NotNil(t, options.OpenRouter.Reasoning)
require.True(t, *options.OpenRouter.Reasoning.Enabled)
require.Equal(t, true, *options.OpenRouter.Reasoning.Exclude)
require.EqualValues(t, 123, *options.OpenRouter.Reasoning.MaxTokens)
require.Equal(t, "high", *options.OpenRouter.Reasoning.Effort)
require.NotNil(t, options.OpenRouter.IncludeUsage)
require.True(t, *options.OpenRouter.IncludeUsage)
require.NotNil(t, options.OpenRouter.Provider)
require.Equal(t, []string{"openai"}, options.OpenRouter.Provider.Order)
require.NotNil(t, options.OpenRouter.Provider.AllowFallbacks)
require.True(t, *options.OpenRouter.Provider.AllowFallbacks)
require.NotNil(t, options.OpenRouter.Provider.RequireParameters)
require.False(t, *options.OpenRouter.Provider.RequireParameters)
require.Equal(t, "allow", *options.OpenRouter.Provider.DataCollection)
require.Equal(t, []string{"openai"}, options.OpenRouter.Provider.Only)
require.Equal(t, []string{"foo"}, options.OpenRouter.Provider.Ignore)
require.Equal(t, []string{"int8"}, options.OpenRouter.Provider.Quantizations)
require.Equal(t, "latency", *options.OpenRouter.Provider.Sort)
}
func TestMergeMissingCallConfig_FillsUnsetFields(t *testing.T) {
t.Parallel()
dst := codersdk.ChatModelCallConfig{
Temperature: float64Ptr(0.2),
ProviderOptions: &codersdk.ChatModelProviderOptions{
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
User: stringPtr("alice"),
},
},
}
defaults := codersdk.ChatModelCallConfig{
MaxOutputTokens: int64Ptr(512),
Temperature: float64Ptr(0.9),
TopP: float64Ptr(0.8),
ProviderOptions: &codersdk.ChatModelProviderOptions{
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
User: stringPtr("bob"),
ReasoningEffort: stringPtr("medium"),
},
},
}
chatprovider.MergeMissingCallConfig(&dst, defaults)
require.NotNil(t, dst.MaxOutputTokens)
require.EqualValues(t, 512, *dst.MaxOutputTokens)
require.NotNil(t, dst.Temperature)
require.Equal(t, 0.2, *dst.Temperature)
require.NotNil(t, dst.TopP)
require.Equal(t, 0.8, *dst.TopP)
require.NotNil(t, dst.ProviderOptions)
require.NotNil(t, dst.ProviderOptions.OpenAI)
require.Equal(t, "alice", *dst.ProviderOptions.OpenAI.User)
require.Equal(t, "medium", *dst.ProviderOptions.OpenAI.ReasoningEffort)
}
func stringPtr(value string) *string {
return &value
}
func boolPtr(value bool) *bool {
return &value
}
func int64Ptr(value int64) *int64 {
return &value
}
func float64Ptr(value float64) *float64 {
return &value
}
-402
View File
@@ -1,402 +0,0 @@
package chattest
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/google/uuid"
)
// AnthropicHandler handles Anthropic API requests and returns a response.
type AnthropicHandler func(req *AnthropicRequest) AnthropicResponse
// AnthropicResponse represents a response to an Anthropic request.
// Either StreamingChunks or Response should be set, not both.
type AnthropicResponse struct {
StreamingChunks <-chan AnthropicChunk
Response *AnthropicMessage
}
// AnthropicRequest represents an Anthropic messages request.
type AnthropicRequest struct {
*http.Request // Embed http.Request
Model string `json:"model"`
Messages []AnthropicRequestMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Options map[string]interface{} `json:",inline"`
}
// AnthropicRequestMessage represents a message in an Anthropic request.
// Content may be either a string or a structured content array.
type AnthropicRequestMessage struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
}
// AnthropicMessage represents a message in an Anthropic response.
type AnthropicMessage struct {
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Role string `json:"role"`
Content string `json:"content,omitempty"`
Model string `json:"model,omitempty"`
StopReason string `json:"stop_reason,omitempty"`
Usage AnthropicUsage `json:"usage,omitempty"`
}
// AnthropicUsage represents usage information in an Anthropic response.
type AnthropicUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// AnthropicChunk represents a streaming chunk from Anthropic.
type AnthropicChunk struct {
Type string `json:"type"`
Index int `json:"index,omitempty"`
Message AnthropicChunkMessage `json:"message,omitempty"`
ContentBlock AnthropicContentBlock `json:"content_block,omitempty"`
Delta AnthropicDeltaBlock `json:"delta,omitempty"`
StopReason string `json:"stop_reason,omitempty"`
StopSequence *string `json:"stop_sequence,omitempty"`
Usage AnthropicUsage `json:"usage,omitempty"`
}
// AnthropicChunkMessage represents message metadata in a chunk.
type AnthropicChunkMessage struct {
ID string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Model string `json:"model"`
}
// AnthropicContentBlock represents a content block in a chunk.
type AnthropicContentBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input json.RawMessage `json:"input,omitempty"`
}
// AnthropicDeltaBlock represents a delta block in a chunk.
type AnthropicDeltaBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
}
// anthropicServer is a test server that mocks the Anthropic API.
type anthropicServer struct {
mu sync.Mutex
server *httptest.Server
handler AnthropicHandler
request *AnthropicRequest
}
// NewAnthropic creates a new Anthropic test server with a handler function.
// The handler is called for each request and should return either a streaming
// response (via channel) or a non-streaming response.
// Returns the base URL of the server.
func NewAnthropic(t testing.TB, handler AnthropicHandler) string {
t.Helper()
s := &anthropicServer{
handler: handler,
}
mux := http.NewServeMux()
mux.HandleFunc("POST /v1/messages", s.handleMessages)
s.server = httptest.NewServer(mux)
t.Cleanup(func() {
s.server.Close()
})
return s.server.URL
}
func (s *anthropicServer) handleMessages(w http.ResponseWriter, r *http.Request) {
var req AnthropicRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
// Return a more detailed error for debugging
http.Error(w, fmt.Sprintf("decode request: %v", err), http.StatusBadRequest)
return
}
req.Request = r // Embed the original http.Request
s.mu.Lock()
s.request = &req
s.mu.Unlock()
resp := s.handler(&req)
s.writeResponse(w, &req, resp)
}
func (s *anthropicServer) writeResponse(w http.ResponseWriter, req *AnthropicRequest, resp AnthropicResponse) {
hasStreaming := resp.StreamingChunks != nil
hasNonStreaming := resp.Response != nil
switch {
case hasStreaming && hasNonStreaming:
http.Error(w, "handler returned both streaming and non-streaming responses", http.StatusInternalServerError)
return
case !hasStreaming && !hasNonStreaming:
http.Error(w, "handler returned empty response", http.StatusInternalServerError)
return
case req.Stream && !hasStreaming:
http.Error(w, "handler returned non-streaming response for streaming request", http.StatusInternalServerError)
return
case !req.Stream && !hasNonStreaming:
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
return
case hasStreaming:
s.writeStreamingResponse(w, resp.StreamingChunks)
default:
s.writeNonStreamingResponse(w, resp.Response)
}
}
func (s *anthropicServer) writeStreamingResponse(w http.ResponseWriter, chunks <-chan AnthropicChunk) {
_ = s // receiver unused but kept for consistency
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("anthropic-version", "2023-06-01")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
for chunk := range chunks {
chunkData := make(map[string]interface{})
chunkData["type"] = chunk.Type
switch chunk.Type {
case "message_start":
chunkData["message"] = chunk.Message
case "content_block_start":
chunkData["index"] = chunk.Index
chunkData["content_block"] = chunk.ContentBlock
case "content_block_delta":
chunkData["index"] = chunk.Index
chunkData["delta"] = chunk.Delta
case "content_block_stop":
chunkData["index"] = chunk.Index
case "message_delta":
chunkData["delta"] = map[string]interface{}{
"stop_reason": chunk.StopReason,
"stop_sequence": chunk.StopSequence,
}
chunkData["usage"] = chunk.Usage
case "message_stop":
// No additional fields
}
chunkBytes, err := json.Marshal(chunkData)
if err != nil {
return
}
// Send both event and data lines to match Anthropic API format
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", chunk.Type, chunkBytes); err != nil {
return
}
flusher.Flush()
}
}
func (s *anthropicServer) writeNonStreamingResponse(w http.ResponseWriter, resp *AnthropicMessage) {
_ = s // receiver unused but kept for consistency
response := map[string]interface{}{
"id": resp.ID,
"type": resp.Type,
"role": resp.Role,
"model": resp.Model,
"content": []map[string]interface{}{
{
"type": "text",
"text": resp.Content,
},
},
"stop_reason": resp.StopReason,
"usage": resp.Usage,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("anthropic-version", "2023-06-01")
_ = json.NewEncoder(w).Encode(response)
}
// AnthropicStreamingResponse creates a streaming response from chunks.
func AnthropicStreamingResponse(chunks ...AnthropicChunk) AnthropicResponse {
ch := make(chan AnthropicChunk, len(chunks))
go func() {
for _, chunk := range chunks {
ch <- chunk
}
close(ch)
}()
return AnthropicResponse{StreamingChunks: ch}
}
// AnthropicNonStreamingResponse creates a non-streaming response with the given text.
func AnthropicNonStreamingResponse(text string) AnthropicResponse {
return AnthropicResponse{
Response: &AnthropicMessage{
ID: fmt.Sprintf("msg-%s", uuid.New().String()[:8]),
Type: "message",
Role: "assistant",
Content: text,
Model: "claude-3-opus-20240229",
StopReason: "end_turn",
Usage: AnthropicUsage{
InputTokens: 10,
OutputTokens: 5,
},
},
}
}
// AnthropicTextChunks creates a complete streaming response with text deltas.
// Takes text deltas and creates all required chunks (message_start,
// content_block_start, content_block_delta for each delta,
// content_block_stop, message_delta, message_stop).
func AnthropicTextChunks(deltas ...string) []AnthropicChunk {
if len(deltas) == 0 {
return nil
}
messageID := fmt.Sprintf("msg-%s", uuid.New().String()[:8])
model := "claude-3-opus-20240229"
chunks := []AnthropicChunk{
{
Type: "message_start",
Message: AnthropicChunkMessage{
ID: messageID,
Type: "message",
Role: "assistant",
Model: model,
},
},
{
Type: "content_block_start",
Index: 0,
ContentBlock: AnthropicContentBlock{
Type: "text",
Text: "", // According to Anthropic API spec, text should be empty in content_block_start
},
},
}
// Add a delta chunk for each delta
for _, delta := range deltas {
chunks = append(chunks, AnthropicChunk{
Type: "content_block_delta",
Index: 0,
Delta: AnthropicDeltaBlock{
Type: "text_delta",
Text: delta,
},
})
}
chunks = append(chunks,
AnthropicChunk{
Type: "content_block_stop",
Index: 0,
},
AnthropicChunk{
Type: "message_delta",
StopReason: "end_turn",
Usage: AnthropicUsage{
InputTokens: 10,
OutputTokens: 5,
},
},
AnthropicChunk{
Type: "message_stop",
},
)
return chunks
}
// AnthropicToolCallChunks creates a complete streaming response for a tool call.
// Input JSON can be split across multiple deltas, matching Anthropic's
// input_json_delta streaming behavior.
func AnthropicToolCallChunks(toolName string, inputJSONDeltas ...string) []AnthropicChunk {
if len(inputJSONDeltas) == 0 {
return nil
}
if toolName == "" {
toolName = "tool"
}
messageID := fmt.Sprintf("msg-%s", uuid.New().String()[:8])
model := "claude-3-opus-20240229"
toolCallID := fmt.Sprintf("toolu_%s", uuid.New().String()[:8])
chunks := []AnthropicChunk{
{
Type: "message_start",
Message: AnthropicChunkMessage{
ID: messageID,
Type: "message",
Role: "assistant",
Model: model,
},
},
{
Type: "content_block_start",
Index: 0,
ContentBlock: AnthropicContentBlock{
Type: "tool_use",
ID: toolCallID,
Name: toolName,
Input: json.RawMessage("{}"),
},
},
}
for _, delta := range inputJSONDeltas {
chunks = append(chunks, AnthropicChunk{
Type: "content_block_delta",
Index: 0,
Delta: AnthropicDeltaBlock{
Type: "input_json_delta",
PartialJSON: delta,
},
})
}
chunks = append(chunks,
AnthropicChunk{
Type: "content_block_stop",
Index: 0,
},
AnthropicChunk{
Type: "message_delta",
StopReason: "tool_use",
Usage: AnthropicUsage{
InputTokens: 10,
OutputTokens: 5,
},
},
AnthropicChunk{
Type: "message_stop",
},
)
return chunks
}
-221
View File
@@ -1,221 +0,0 @@
package chattest_test
import (
"context"
"sync/atomic"
"testing"
"charm.land/fantasy"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chattest"
)
func TestAnthropic_Streaming(t *testing.T) {
t.Parallel()
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
return chattest.AnthropicStreamingResponse(
chattest.AnthropicTextChunks("Hello", " world", "!")...,
)
})
// Create fantasy client pointing to our test server
client, err := fantasyanthropic.New(
fantasyanthropic.WithAPIKey("test-key"),
fantasyanthropic.WithBaseURL(serverURL),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "claude-3-opus-20240229")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Say hello"},
},
},
},
}
stream, err := model.Stream(ctx, call)
require.NoError(t, err)
expectedDeltas := []string{"Hello", " world", "!"}
deltaIndex := 0
var allParts []fantasy.StreamPart
for part := range stream {
allParts = append(allParts, part)
if part.Type == fantasy.StreamPartTypeTextDelta {
require.Less(t, deltaIndex, len(expectedDeltas), "Received more deltas than expected")
require.Equal(t, expectedDeltas[deltaIndex], part.Delta,
"Delta at index %d should be %q, got %q", deltaIndex, expectedDeltas[deltaIndex], part.Delta)
deltaIndex++
}
}
require.Equal(t, len(expectedDeltas), deltaIndex, "Expected %d deltas, got %d. Total parts received: %d", len(expectedDeltas), deltaIndex, len(allParts))
}
func TestAnthropic_ToolCalls(t *testing.T) {
t.Parallel()
var requestCount atomic.Int32
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
switch requestCount.Add(1) {
case 1:
return chattest.AnthropicStreamingResponse(
chattest.AnthropicToolCallChunks("get_weather", `{"location":"San Francisco"}`)...,
)
default:
return chattest.AnthropicStreamingResponse(
chattest.AnthropicTextChunks("The weather in San Francisco is 72F.")...,
)
}
})
client, err := fantasyanthropic.New(
fantasyanthropic.WithAPIKey("test-key"),
fantasyanthropic.WithBaseURL(serverURL),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
require.NoError(t, err)
type weatherInput struct {
Location string `json:"location"`
}
var toolCallCount atomic.Int32
weatherTool := fantasy.NewAgentTool(
"get_weather",
"Get weather for a location.",
func(ctx context.Context, input weatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
toolCallCount.Add(1)
require.Equal(t, "San Francisco", input.Location)
return fantasy.NewTextResponse("72F"), nil
},
)
agent := fantasy.NewAgent(
model,
fantasy.WithSystemPrompt("You are a helpful assistant."),
fantasy.WithTools(weatherTool),
)
result, err := agent.Stream(context.Background(), fantasy.AgentStreamCall{
Prompt: "What's the weather in San Francisco?",
})
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, int32(1), toolCallCount.Load(), "expected exactly one tool execution")
require.GreaterOrEqual(t, requestCount.Load(), int32(2), "expected follow-up model call after tool execution")
}
func TestAnthropic_NonStreaming(t *testing.T) {
t.Parallel()
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
return chattest.AnthropicNonStreamingResponse("Response text")
})
// Create fantasy client pointing to our test server
client, err := fantasyanthropic.New(
fantasyanthropic.WithAPIKey("test-key"),
fantasyanthropic.WithBaseURL(serverURL),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "claude-3-opus-20240229")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Test message"},
},
},
},
}
response, err := model.Generate(ctx, call)
require.NoError(t, err)
require.NotNil(t, response)
}
func TestAnthropic_Streaming_MismatchReturnsErrorPart(t *testing.T) {
t.Parallel()
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
return chattest.AnthropicNonStreamingResponse("wrong response type")
})
client, err := fantasyanthropic.New(
fantasyanthropic.WithAPIKey("test-key"),
fantasyanthropic.WithBaseURL(serverURL),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
require.NoError(t, err)
stream, err := model.Stream(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
},
},
})
require.NoError(t, err)
var streamErr error
for part := range stream {
if part.Type == fantasy.StreamPartTypeError {
streamErr = part.Error
break
}
}
require.Error(t, streamErr)
require.Contains(t, streamErr.Error(), "500 Internal Server Error")
}
func TestAnthropic_NonStreaming_MismatchReturnsError(t *testing.T) {
t.Parallel()
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
return chattest.AnthropicStreamingResponse(
chattest.AnthropicTextChunks("wrong", " response")...,
)
})
client, err := fantasyanthropic.New(
fantasyanthropic.WithAPIKey("test-key"),
fantasyanthropic.WithBaseURL(serverURL),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
require.NoError(t, err)
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
},
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "500 Internal Server Error")
}
-457
View File
@@ -1,457 +0,0 @@
package chattest
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/google/uuid"
)
// OpenAIHandler handles OpenAI API requests and returns a response.
type OpenAIHandler func(req *OpenAIRequest) OpenAIResponse
// OpenAIResponse represents a response to an OpenAI request.
// Either StreamingChunks or Response should be set, not both.
type OpenAIResponse struct {
StreamingChunks <-chan OpenAIChunk
Response *OpenAICompletion
}
// OpenAIRequest represents an OpenAI chat completion request.
type OpenAIRequest struct {
*http.Request
Model string `json:"model"`
Messages []OpenAIMessage `json:"messages"`
Stream bool `json:"stream,omitempty"`
Prompt []interface{} `json:"prompt,omitempty"` // For responses API
Options map[string]interface{} `json:",inline"`
}
// OpenAIMessage represents a message in an OpenAI request.
type OpenAIMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// OpenAIToolCallFunction represents the function details in a tool call.
type OpenAIToolCallFunction struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
}
// OpenAIToolCall represents a tool call in a streaming chunk or completion.
type OpenAIToolCall struct {
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Function OpenAIToolCallFunction `json:"function,omitempty"`
Index int `json:"index,omitempty"` // For streaming deltas
}
// OpenAIChunkChoice represents a choice in a streaming chunk.
type OpenAIChunkChoice struct {
Index int `json:"index"`
Delta string `json:"delta,omitempty"`
ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
}
// OpenAIChunk represents a streaming chunk from OpenAI.
type OpenAIChunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []OpenAIChunkChoice `json:"choices"`
}
// OpenAICompletionChoice represents a choice in a completion response.
type OpenAICompletionChoice struct {
Index int `json:"index"`
Message OpenAIMessage `json:"message"`
ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
FinishReason string `json:"finish_reason"`
}
// OpenAICompletionUsage represents usage information in a completion response.
type OpenAICompletionUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// OpenAICompletion represents a non-streaming OpenAI completion response.
type OpenAICompletion struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []OpenAICompletionChoice `json:"choices"`
Usage OpenAICompletionUsage `json:"usage"`
}
// openAIServer is a test server that mocks the OpenAI API.
type openAIServer struct {
mu sync.Mutex
server *httptest.Server
handler OpenAIHandler
request *OpenAIRequest
}
// NewOpenAI creates a new OpenAI test server with a handler function.
// The handler is called for each request and should return either a streaming
// response (via channel) or a non-streaming response.
// Returns the base URL of the server.
func NewOpenAI(t testing.TB, handler OpenAIHandler) string {
t.Helper()
s := &openAIServer{
handler: handler,
}
mux := http.NewServeMux()
mux.HandleFunc("POST /chat/completions", s.handleChatCompletions)
mux.HandleFunc("POST /responses", s.handleResponses)
s.server = httptest.NewServer(mux)
t.Cleanup(func() {
s.server.Close()
})
return s.server.URL
}
func (s *openAIServer) handleChatCompletions(w http.ResponseWriter, r *http.Request) {
var req OpenAIRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
req.Request = r
s.mu.Lock()
s.request = &req
s.mu.Unlock()
resp := s.handler(&req)
s.writeChatCompletionsResponse(w, &req, resp)
}
func (s *openAIServer) handleResponses(w http.ResponseWriter, r *http.Request) {
var req OpenAIRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
req.Request = r
s.mu.Lock()
s.request = &req
s.mu.Unlock()
resp := s.handler(&req)
s.writeResponsesAPIResponse(w, &req, resp)
}
func (s *openAIServer) writeChatCompletionsResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
hasStreaming := resp.StreamingChunks != nil
hasNonStreaming := resp.Response != nil
switch {
case hasStreaming && hasNonStreaming:
http.Error(w, "handler returned both streaming and non-streaming responses", http.StatusInternalServerError)
return
case !hasStreaming && !hasNonStreaming:
http.Error(w, "handler returned empty response", http.StatusInternalServerError)
return
case req.Stream && !hasStreaming:
http.Error(w, "handler returned non-streaming response for streaming request", http.StatusInternalServerError)
return
case !req.Stream && !hasNonStreaming:
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
return
case hasStreaming:
s.writeChatCompletionsStreaming(w, resp.StreamingChunks)
default:
s.writeChatCompletionsNonStreaming(w, resp.Response)
}
}
func (s *openAIServer) writeResponsesAPIResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
hasStreaming := resp.StreamingChunks != nil
hasNonStreaming := resp.Response != nil
switch {
case hasStreaming && hasNonStreaming:
http.Error(w, "handler returned both streaming and non-streaming responses", http.StatusInternalServerError)
return
case !hasStreaming && !hasNonStreaming:
http.Error(w, "handler returned empty response", http.StatusInternalServerError)
return
case req.Stream && !hasStreaming:
http.Error(w, "handler returned non-streaming response for streaming request", http.StatusInternalServerError)
return
case !req.Stream && !hasNonStreaming:
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
return
case hasStreaming:
s.writeResponsesAPIStreaming(w, resp.StreamingChunks)
default:
s.writeResponsesAPINonStreaming(w, resp.Response)
}
}
func (s *openAIServer) writeChatCompletionsStreaming(w http.ResponseWriter, chunks <-chan OpenAIChunk) {
_ = s // receiver unused but kept for consistency
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
for chunk := range chunks {
choicesData := make([]map[string]interface{}, len(chunk.Choices))
for i, choice := range chunk.Choices {
choiceData := map[string]interface{}{
"index": choice.Index,
}
if choice.Delta != "" {
choiceData["delta"] = map[string]interface{}{
"content": choice.Delta,
}
}
if len(choice.ToolCalls) > 0 {
// Tool calls come in the delta
if choiceData["delta"] == nil {
choiceData["delta"] = make(map[string]interface{})
}
delta, ok := choiceData["delta"].(map[string]interface{})
if !ok {
delta = make(map[string]interface{})
choiceData["delta"] = delta
}
delta["tool_calls"] = choice.ToolCalls
}
if choice.FinishReason != "" {
choiceData["finish_reason"] = choice.FinishReason
}
choicesData[i] = choiceData
}
chunkData := map[string]interface{}{
"id": chunk.ID,
"object": chunk.Object,
"created": chunk.Created,
"model": chunk.Model,
"choices": choicesData,
}
chunkBytes, err := json.Marshal(chunkData)
if err != nil {
return
}
if _, err := fmt.Fprintf(w, "data: %s\n\n", chunkBytes); err != nil {
return
}
flusher.Flush()
}
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
}
func (s *openAIServer) writeResponsesAPIStreaming(w http.ResponseWriter, chunks <-chan OpenAIChunk) {
_ = s // receiver unused but kept for consistency
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
itemIDs := make(map[int]string)
for chunk := range chunks {
// Responses API sends one event per choice
for outputIndex, choice := range chunk.Choices {
if choice.Index != 0 {
outputIndex = choice.Index
}
itemID, found := itemIDs[outputIndex]
if !found {
itemID = fmt.Sprintf("msg_%s", uuid.New().String()[:8])
itemIDs[outputIndex] = itemID
}
chunkData := map[string]interface{}{
"type": "response.output_text.delta",
"item_id": itemID,
"output_index": outputIndex,
"created": chunk.Created,
"model": chunk.Model,
"content_index": 0,
"delta": choice.Delta,
}
chunkBytes, err := json.Marshal(chunkData)
if err != nil {
return
}
if _, err := fmt.Fprintf(w, "data: %s\n\n", chunkBytes); err != nil {
return
}
flusher.Flush()
}
}
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
}
func (s *openAIServer) writeChatCompletionsNonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {
_ = s // receiver unused but kept for consistency
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}
func (s *openAIServer) writeResponsesAPINonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {
_ = s // receiver unused but kept for consistency
// Convert all choices to output format
outputs := make([]map[string]interface{}, len(resp.Choices))
for i, choice := range resp.Choices {
outputs[i] = map[string]interface{}{
"id": uuid.New().String(),
"type": "message",
"role": "assistant",
"content": []map[string]interface{}{
{
"type": "output_text",
"text": choice.Message.Content,
},
},
}
}
response := map[string]interface{}{
"id": resp.ID,
"object": "response",
"created": resp.Created,
"model": resp.Model,
"output": outputs,
"usage": resp.Usage,
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(response)
}
// OpenAIStreamingResponse creates a streaming response from chunks.
func OpenAIStreamingResponse(chunks ...OpenAIChunk) OpenAIResponse {
ch := make(chan OpenAIChunk, len(chunks))
go func() {
for _, chunk := range chunks {
ch <- chunk
}
close(ch)
}()
return OpenAIResponse{StreamingChunks: ch}
}
// OpenAINonStreamingResponse creates a non-streaming response with the given text.
func OpenAINonStreamingResponse(text string) OpenAIResponse {
return OpenAIResponse{
Response: &OpenAICompletion{
ID: fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8]),
Object: "chat.completion",
Created: time.Now().Unix(),
Model: "gpt-4",
Choices: []OpenAICompletionChoice{
{
Index: 0,
Message: OpenAIMessage{
Role: "assistant",
Content: text,
},
FinishReason: "stop",
},
},
Usage: OpenAICompletionUsage{
PromptTokens: 10,
CompletionTokens: 5,
TotalTokens: 15,
},
},
}
}
// OpenAITextChunks creates streaming chunks with text deltas.
// Each delta string becomes a separate chunk with a single choice.
// Returns a slice of chunks, one per delta, with each choice having its index (0, 1, 2, ...).
func OpenAITextChunks(deltas ...string) []OpenAIChunk {
if len(deltas) == 0 {
return nil
}
chunkID := fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8])
now := time.Now().Unix()
chunks := make([]OpenAIChunk, len(deltas))
for i, delta := range deltas {
chunks[i] = OpenAIChunk{
ID: chunkID,
Object: "chat.completion.chunk",
Created: now,
Model: "gpt-4",
Choices: []OpenAIChunkChoice{
{
Index: i,
Delta: delta,
},
},
}
}
return chunks
}
// OpenAIToolCallChunk creates a streaming chunk with a tool call.
// Takes the tool name and arguments JSON string, creates a tool call for choice index 0.
func OpenAIToolCallChunk(toolName, arguments string) OpenAIChunk {
return OpenAIChunk{
ID: fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8]),
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: "gpt-4",
Choices: []OpenAIChunkChoice{
{
Index: 0,
ToolCalls: []OpenAIToolCall{
{
Index: 0,
ID: fmt.Sprintf("call_%s", uuid.New().String()[:8]),
Type: "function",
Function: OpenAIToolCallFunction{
Name: toolName,
Arguments: arguments,
},
},
},
},
},
}
}
-367
View File
@@ -1,367 +0,0 @@
package chattest_test
import (
"context"
"sync/atomic"
"testing"
"charm.land/fantasy"
fantasyopenai "charm.land/fantasy/providers/openai"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chattest"
)
func TestOpenAI_Streaming(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAIStreamingResponse(
append(
append(
chattest.OpenAITextChunks("Hello", "Hi"),
chattest.OpenAITextChunks(" world", " there")...,
),
chattest.OpenAITextChunks("!", "!")...,
)...,
)
})
// Create fantasy client pointing to our test server
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "gpt-4")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Say hello"},
},
},
},
}
stream, err := model.Stream(ctx, call)
require.NoError(t, err)
// We expect chunks in order: one choice per chunk
// So we get: "Hello" (choice 0), "Hi" (choice 1), " world" (choice 0), " there" (choice 1), "!" (choice 0), "!" (choice 1)
expectedDeltas := []string{"Hello", "Hi", " world", " there", "!", "!"}
deltaIndex := 0
for part := range stream {
if part.Type == fantasy.StreamPartTypeTextDelta {
// Verify we're getting deltas in the expected order
require.Less(t, deltaIndex, len(expectedDeltas), "Received more deltas than expected")
require.Equal(t, expectedDeltas[deltaIndex], part.Delta,
"Delta at index %d should be %q, got %q", deltaIndex, expectedDeltas[deltaIndex], part.Delta)
deltaIndex++
}
}
// Verify we received all expected deltas
require.Equal(t, len(expectedDeltas), deltaIndex, "Expected %d deltas, got %d", len(expectedDeltas), deltaIndex)
}
func TestOpenAI_Streaming_ResponsesAPI(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAIStreamingResponse(
append(
append(
chattest.OpenAITextChunks("First", "Second"),
chattest.OpenAITextChunks(" output", " output")...,
),
chattest.OpenAITextChunks("!", "!")...,
)...,
)
})
// Create fantasy client pointing to our test server (responses API)
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
fantasyopenai.WithUseResponsesAPI(),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "gpt-4")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Say hello"},
},
},
},
}
stream, err := model.Stream(ctx, call)
require.NoError(t, err)
var parts []fantasy.StreamPart
for part := range stream {
parts = append(parts, part)
}
// Verify we received the chunks in order
require.Greater(t, len(parts), 0)
// Extract text deltas from parts and verify they match expected chunks in order
// We expect: "First", " output", "!" for choice 0, and "Second", " output", "!" for choice 1
var allDeltas []string
for _, part := range parts {
if part.Type == fantasy.StreamPartTypeTextDelta {
allDeltas = append(allDeltas, part.Delta)
}
}
// Verify we received deltas (responses API may handle multiple choices differently)
// If we got text deltas, verify the content
if len(allDeltas) > 0 {
allText := ""
for _, delta := range allDeltas {
allText += delta
}
require.Contains(t, allText, "First")
require.Contains(t, allText, "Second")
require.Contains(t, allText, "output")
require.Contains(t, allText, "!")
} else {
// If no text deltas, at least verify we got some parts (may be different format)
require.Greater(t, len(parts), 0, "Expected at least one stream part")
}
}
func TestOpenAI_NonStreaming_CompletionsAPI(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAINonStreamingResponse("First response")
})
// Create fantasy client pointing to our test server (completions API)
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "gpt-4")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Test message"},
},
},
},
}
response, err := model.Generate(ctx, call)
require.NoError(t, err)
require.NotNil(t, response)
}
func TestOpenAI_ToolCalls(t *testing.T) {
t.Parallel()
var requestCount atomic.Int32
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
switch requestCount.Add(1) {
case 1:
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk("get_weather", `{"location":"San Francisco"}`),
)
default:
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("The weather in San Francisco is 72F.")...,
)
}
})
// Create fantasy client pointing to our test server
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "gpt-4")
require.NoError(t, err)
type weatherInput struct {
Location string `json:"location"`
}
var toolCallCount atomic.Int32
weatherTool := fantasy.NewAgentTool(
"get_weather",
"Get weather for a location.",
func(ctx context.Context, input weatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
toolCallCount.Add(1)
require.Equal(t, "San Francisco", input.Location)
return fantasy.NewTextResponse("72F"), nil
},
)
agent := fantasy.NewAgent(
model,
fantasy.WithSystemPrompt("You are a helpful assistant."),
fantasy.WithTools(weatherTool),
)
result, err := agent.Stream(ctx, fantasy.AgentStreamCall{
Prompt: "What's the weather in San Francisco?",
})
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, int32(1), toolCallCount.Load(), "expected exactly one tool execution")
require.GreaterOrEqual(t, requestCount.Load(), int32(2), "expected follow-up model call after tool execution")
}
func TestOpenAI_NonStreaming_ResponsesAPI(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAINonStreamingResponse("First output")
})
// Create fantasy client pointing to our test server (responses API)
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
fantasyopenai.WithUseResponsesAPI(),
)
require.NoError(t, err)
ctx := context.Background()
model, err := client.LanguageModel(ctx, "gpt-4")
require.NoError(t, err)
call := fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "Test message"},
},
},
},
}
response, err := model.Generate(ctx, call)
require.NoError(t, err)
require.NotNil(t, response)
}
func TestOpenAI_Streaming_MismatchReturnsErrorPart(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAINonStreamingResponse("wrong response type")
})
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "gpt-4")
require.NoError(t, err)
stream, err := model.Stream(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
},
},
})
require.NoError(t, err)
var streamErr error
for part := range stream {
if part.Type == fantasy.StreamPartTypeError {
streamErr = part.Error
break
}
}
require.Error(t, streamErr)
require.Contains(t, streamErr.Error(), "non-streaming response for streaming request")
}
func TestOpenAI_NonStreaming_MismatchReturnsError_CompletionsAPI(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("wrong response type")...)
})
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "gpt-4")
require.NoError(t, err)
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
},
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "streaming response for non-streaming request")
}
func TestOpenAI_NonStreaming_MismatchReturnsError_ResponsesAPI(t *testing.T) {
t.Parallel()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("wrong response type")...)
})
client, err := fantasyopenai.New(
fantasyopenai.WithAPIKey("test-key"),
fantasyopenai.WithBaseURL(serverURL),
fantasyopenai.WithUseResponsesAPI(),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "gpt-4")
require.NoError(t, err)
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
},
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "streaming response for non-streaming request")
}
-50
View File
@@ -1,50 +0,0 @@
package chattool
import (
"encoding/json"
"strings"
"unicode/utf8"
"charm.land/fantasy"
"github.com/google/uuid"
"golang.org/x/xerrors"
)
// toolResponse builds a fantasy.ToolResponse from a JSON-serializable
// result payload.
func toolResponse(result map[string]any) fantasy.ToolResponse {
data, err := json.Marshal(result)
if err != nil {
return fantasy.NewTextResponse("{}")
}
return fantasy.NewTextResponse(string(data))
}
// parseOwnerID parses a UUID string into a uuid.UUID, returning
// an error if the string is empty or not a valid UUID.
func parseOwnerID(raw string) (uuid.UUID, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return uuid.Nil, xerrors.New("owner ID is empty")
}
id, err := uuid.Parse(raw)
if err != nil {
return uuid.Nil, xerrors.Errorf("invalid owner ID %q: %w", raw, err)
}
return id, nil
}
func truncateRunes(value string, maxLen int) string {
if maxLen <= 0 || value == "" {
return ""
}
if utf8.RuneCountInString(value) <= maxLen {
return value
}
runes := []rune(value)
if maxLen > len(runes) {
maxLen = len(runes)
}
return string(runes[:maxLen])
}
-426
View File
@@ -1,426 +0,0 @@
package chattool
import (
"context"
"database/sql"
"fmt"
"strings"
"sync"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/util/namesgenerator"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
// buildPollInterval is how often we check if the workspace
// build has completed.
buildPollInterval = 2 * time.Second
// buildTimeout is the maximum time to wait for a workspace
// build to complete before giving up.
buildTimeout = 10 * time.Minute
// agentConnectTimeout is the maximum time to wait for the
// workspace agent to become reachable after a successful build.
agentConnectTimeout = 2 * time.Minute
// agentRetryInterval is how often we retry connecting to the
// workspace agent.
agentRetryInterval = 2 * time.Second
// agentAttemptTimeout is the timeout for a single connection
// attempt to the workspace agent during the retry loop.
agentAttemptTimeout = 5 * time.Second
// agentPingTimeout is the timeout for a single agent ping
// when checking whether an existing workspace is alive.
agentPingTimeout = 5 * time.Second
)
// CreateWorkspaceFn creates a workspace for the given owner.
type CreateWorkspaceFn func(
ctx context.Context,
ownerID uuid.UUID,
req codersdk.CreateWorkspaceRequest,
) (codersdk.Workspace, error)
// AgentConnFunc provides access to workspace agent connections.
type AgentConnFunc func(
ctx context.Context,
agentID uuid.UUID,
) (workspacesdk.AgentConn, func(), error)
// CreateWorkspaceOptions configures the create_workspace tool.
type CreateWorkspaceOptions struct {
DB database.Store
OwnerID uuid.UUID
ChatID uuid.UUID
CreateFn CreateWorkspaceFn
AgentConnFn AgentConnFunc
WorkspaceMu *sync.Mutex
}
type createWorkspaceArgs struct {
TemplateID string `json:"template_id"`
Name string `json:"name,omitempty"`
Parameters map[string]string `json:"parameters,omitempty"`
}
// CreateWorkspace returns a tool that creates a new workspace from a
// template. The tool is idempotent: if the chat already has a
// workspace that is building or running, it returns the existing
// workspace instead of creating a new one. A mutex prevents parallel
// calls from creating duplicate workspaces.
func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"create_workspace",
"Create a new workspace from a template. Requires a "+
"template_id (from list_templates). Optionally provide "+
"a name and parameter values (from read_template). "+
"If no name is given, one will be generated. "+
"This tool is idempotent — if the chat already has a "+
"workspace that is building or running, the existing "+
"workspace is returned.",
func(ctx context.Context, args createWorkspaceArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.CreateFn == nil {
return fantasy.NewTextErrorResponse("workspace creator is not configured"), nil
}
templateIDStr := strings.TrimSpace(args.TemplateID)
if templateIDStr == "" {
return fantasy.NewTextErrorResponse("template_id is required; use list_templates to find one"), nil
}
templateID, err := uuid.Parse(templateIDStr)
if err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("invalid template_id: %w", err).Error(),
), nil
}
// Serialize workspace creation to prevent parallel
// tool calls from creating duplicate workspaces.
if options.WorkspaceMu != nil {
options.WorkspaceMu.Lock()
defer options.WorkspaceMu.Unlock()
}
// Check for an existing workspace on the chat.
if options.DB != nil && options.ChatID != uuid.Nil {
existing, done, existErr := checkExistingWorkspace(
ctx, options.DB, options.ChatID,
options.AgentConnFn,
)
if existErr != nil {
return fantasy.NewTextErrorResponse(existErr.Error()), nil
}
if done {
return toolResponse(existing), nil
}
}
ownerID := options.OwnerID
// Set up dbauthz context for DB lookups.
if options.DB != nil {
ownerCtx, ownerErr := asOwner(ctx, options.DB, ownerID)
if ownerErr != nil {
return fantasy.NewTextErrorResponse(ownerErr.Error()), nil
}
ctx = ownerCtx
}
createReq := codersdk.CreateWorkspaceRequest{
TemplateID: templateID,
}
// Resolve workspace name.
name := strings.TrimSpace(args.Name)
if name == "" {
seed := "workspace"
if options.DB != nil {
if t, lookupErr := options.DB.GetTemplateByID(ctx, templateID); lookupErr == nil {
seed = t.Name
}
}
name = generatedWorkspaceName(seed)
} else if err := codersdk.NameValid(name); err != nil {
name = generatedWorkspaceName(name)
}
createReq.Name = name
// Map parameters.
for k, v := range args.Parameters {
createReq.RichParameterValues = append(
createReq.RichParameterValues,
codersdk.WorkspaceBuildParameter{Name: k, Value: v},
)
}
workspace, err := options.CreateFn(ctx, ownerID, createReq)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
// Wait for the build to complete and the agent to
// come online so subsequent tools can use the
// workspace immediately.
if options.DB != nil {
if err := waitForBuild(ctx, options.DB, workspace.ID); err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("workspace build failed: %w", err).Error(),
), nil
}
}
// Look up the first agent so we can link it to the chat.
workspaceAgentID := uuid.Nil
if options.DB != nil {
agents, agentErr := options.DB.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspace.ID)
if agentErr == nil && len(agents) > 0 {
workspaceAgentID = agents[0].ID
}
}
// Persist workspace + agent association on the chat.
if options.DB != nil && options.ChatID != uuid.Nil {
_, _ = options.DB.UpdateChatWorkspace(ctx, database.UpdateChatWorkspaceParams{
ID: options.ChatID,
WorkspaceID: uuid.NullUUID{
UUID: workspace.ID,
Valid: true,
},
WorkspaceAgentID: uuid.NullUUID{
UUID: workspaceAgentID,
Valid: workspaceAgentID != uuid.Nil,
},
})
}
// Wait for the agent to come online.
if workspaceAgentID != uuid.Nil && options.AgentConnFn != nil {
if err := waitForAgent(ctx, options.AgentConnFn, workspaceAgentID); err != nil {
// Non-fatal: the workspace was created
// successfully, the agent just isn't ready
// yet. The model can retry.
return toolResponse(map[string]any{
"created": true,
"workspace_name": workspace.FullName(),
"agent_status": "not_ready",
"agent_error": err.Error(),
}), nil
}
}
return toolResponse(map[string]any{
"created": true,
"workspace_name": workspace.FullName(),
}), nil
},
)
}
// checkExistingWorkspace checks whether the chat already has a usable
// workspace. Returns the result map and true if the caller should
// return early (workspace exists and is alive or building). Returns
// false if the caller should proceed with creation (workspace is dead
// or missing).
func checkExistingWorkspace(
ctx context.Context,
db database.Store,
chatID uuid.UUID,
agentConnFn AgentConnFunc,
) (map[string]any, bool, error) {
chat, err := db.GetChatByID(ctx, chatID)
if err != nil {
return nil, false, xerrors.Errorf("load chat: %w", err)
}
if !chat.WorkspaceID.Valid {
return nil, false, nil
}
// Check if workspace still exists.
ws, err := db.GetWorkspaceByID(ctx, chat.WorkspaceID.UUID)
if err != nil {
if xerrors.Is(err, sql.ErrNoRows) {
// Workspace was deleted — allow creation.
return nil, false, nil
}
return nil, false, xerrors.Errorf("load workspace: %w", err)
}
// Check the latest build status.
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID)
if err != nil {
// Can't determine status — allow creation.
return nil, false, nil
}
job, err := db.GetProvisionerJobByID(ctx, build.JobID)
if err != nil {
return nil, false, nil
}
switch job.JobStatus {
case database.ProvisionerJobStatusPending,
database.ProvisionerJobStatusRunning:
// Build is in progress — wait for it instead of
// creating a new workspace.
if err := waitForBuild(ctx, db, ws.ID); err != nil {
return nil, false, xerrors.Errorf(
"existing workspace build failed: %w", err,
)
}
return map[string]any{
"created": false,
"workspace_name": ws.Name,
"status": "already_exists",
"message": "workspace was already being built and is now ready",
}, true, nil
case database.ProvisionerJobStatusSucceeded:
// Build succeeded — check if agent is reachable.
if chat.WorkspaceAgentID.Valid && agentConnFn != nil {
pingCtx, cancel := context.WithTimeout(
ctx, agentPingTimeout,
)
defer cancel()
conn, release, connErr := agentConnFn(
pingCtx, chat.WorkspaceAgentID.UUID,
)
if connErr == nil {
release()
_ = conn
return map[string]any{
"created": false,
"workspace_name": ws.Name,
"status": "already_exists",
"message": "workspace is already running and reachable",
}, true, nil
}
// Agent unreachable — workspace is dead, allow
// creation.
}
// No agent ID or no conn func — allow creation.
return nil, false, nil
default:
// Failed, canceled, etc — allow creation.
return nil, false, nil
}
}
// waitForBuild polls the workspace's latest build until it
// completes or the context expires.
func waitForBuild(
ctx context.Context,
db database.Store,
workspaceID uuid.UUID,
) error {
buildCtx, cancel := context.WithTimeout(ctx, buildTimeout)
defer cancel()
ticker := time.NewTicker(buildPollInterval)
defer ticker.Stop()
for {
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(
buildCtx, workspaceID,
)
if err != nil {
return xerrors.Errorf("get latest build: %w", err)
}
job, err := db.GetProvisionerJobByID(buildCtx, build.JobID)
if err != nil {
return xerrors.Errorf("get provisioner job: %w", err)
}
switch job.JobStatus {
case database.ProvisionerJobStatusSucceeded:
return nil
case database.ProvisionerJobStatusFailed:
errMsg := "build failed"
if job.Error.Valid {
errMsg = job.Error.String
}
return xerrors.New(errMsg)
case database.ProvisionerJobStatusCanceled:
return xerrors.New("build was canceled")
case database.ProvisionerJobStatusPending,
database.ProvisionerJobStatusRunning,
database.ProvisionerJobStatusCanceling:
// Still in progress — keep waiting.
default:
return xerrors.Errorf("unexpected job status: %s", job.JobStatus)
}
select {
case <-buildCtx.Done():
return xerrors.Errorf(
"timed out waiting for workspace build: %w",
buildCtx.Err(),
)
case <-ticker.C:
}
}
}
// waitForAgent retries connecting to the workspace agent until it
// succeeds or the timeout expires.
func waitForAgent(
ctx context.Context,
agentConnFn AgentConnFunc,
agentID uuid.UUID,
) error {
agentCtx, cancel := context.WithTimeout(ctx, agentConnectTimeout)
defer cancel()
ticker := time.NewTicker(agentRetryInterval)
defer ticker.Stop()
var lastErr error
for {
attemptCtx, attemptCancel := context.WithTimeout(agentCtx, agentAttemptTimeout)
conn, release, err := agentConnFn(attemptCtx, agentID)
attemptCancel()
if err == nil {
release()
_ = conn
return nil
}
lastErr = err
select {
case <-agentCtx.Done():
return xerrors.Errorf(
"timed out waiting for workspace agent: %w",
lastErr,
)
case <-ticker.C:
}
}
}
func generatedWorkspaceName(seed string) string {
base := codersdk.UsernameFrom(strings.TrimSpace(strings.ToLower(seed)))
if strings.TrimSpace(base) == "" {
base = "workspace"
}
suffix := strings.ReplaceAll(uuid.NewString(), "-", "")[:4]
if len(base) > 27 {
base = strings.Trim(base[:27], "-")
}
if base == "" {
base = "workspace"
}
name := fmt.Sprintf("%s-%s", base, suffix)
if err := codersdk.NameValid(name); err == nil {
return name
}
return namesgenerator.NameDigitWith("-")
}
-50
View File
@@ -1,50 +0,0 @@
package chattool
import (
"context"
"charm.land/fantasy"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
type EditFilesOptions struct {
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
}
type EditFilesArgs struct {
Files []workspacesdk.FileEdits `json:"files"`
}
func EditFiles(options EditFilesOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"edit_files",
"Perform search-and-replace edits on one or more files in the workspace."+
" Each file can have multiple edits applied atomically.",
func(ctx context.Context, args EditFilesArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
}
conn, err := options.GetWorkspaceConn(ctx)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return executeEditFilesTool(ctx, conn, args)
},
)
}
func executeEditFilesTool(
ctx context.Context,
conn workspacesdk.AgentConn,
args EditFilesArgs,
) (fantasy.ToolResponse, error) {
if len(args.Files) == 0 {
return fantasy.NewTextErrorResponse("files is required"), nil
}
if err := conn.EditFiles(ctx, workspacesdk.FileEditRequest{Files: args.Files}); err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return toolResponse(map[string]any{"ok": true}), nil
}
-133
View File
@@ -1,133 +0,0 @@
package chattool
import (
"context"
"time"
"charm.land/fantasy"
"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
defaultExecuteTimeout = 60 * time.Second
chatAgentEnvVar = "CODER_CHAT_AGENT"
gitAuthRequiredPrefix = "CODER_GITAUTH_REQUIRED:"
authRequiredResultReason = "authentication_required"
)
type ExecuteOptions struct {
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
DefaultTimeout time.Duration
}
type ExecuteArgs struct {
Command string `json:"command"`
TimeoutSeconds *int `json:"timeout_seconds,omitempty"`
}
func Execute(options ExecuteOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"execute",
"Execute a shell command in the workspace.",
func(ctx context.Context, args ExecuteArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
}
conn, err := options.GetWorkspaceConn(ctx)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return executeTool(ctx, conn, args, options.DefaultTimeout), nil
},
)
}
func executeTool(
ctx context.Context,
conn workspacesdk.AgentConn,
args ExecuteArgs,
defaultTimeout time.Duration,
) fantasy.ToolResponse {
if args.Command == "" {
return fantasy.NewTextErrorResponse("command is required")
}
timeout := defaultTimeout
if timeout <= 0 {
timeout = defaultExecuteTimeout
}
if args.TimeoutSeconds != nil {
timeout = time.Duration(*args.TimeoutSeconds) * time.Second
}
cmdCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
output, exitCode, err := runCommand(cmdCtx, conn, args.Command)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error())
}
return toolResponse(map[string]any{
"output": output,
"exit_code": exitCode,
})
}
func runCommand(
ctx context.Context,
conn workspacesdk.AgentConn,
command string,
) (string, int, error) {
sshClient, err := conn.SSHClient(ctx)
if err != nil {
return "", 0, err
}
defer sshClient.Close()
session, err := sshClient.NewSession()
if err != nil {
return "", 0, err
}
defer session.Close()
if err := session.Setenv(chatAgentEnvVar, "true"); err != nil {
return "", 0, xerrors.Errorf("set %s: %w", chatAgentEnvVar, err)
}
resultCh := make(chan struct {
output string
exitCode int
err error
}, 1)
go func() {
output, err := session.CombinedOutput(command)
exitCode := 0
if err != nil {
var exitErr *ssh.ExitError
if xerrors.As(err, &exitErr) {
exitCode = exitErr.ExitStatus()
} else {
exitCode = 1
}
}
resultCh <- struct {
output string
exitCode int
err error
}{
output: string(output),
exitCode: exitCode,
err: err,
}
}()
select {
case <-ctx.Done():
_ = session.Close()
return "", 0, ctx.Err()
case result := <-resultCh:
return result.output, result.exitCode, result.err
}
}
-94
View File
@@ -1,94 +0,0 @@
package chattool
import (
"context"
"database/sql"
"strings"
"charm.land/fantasy"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/rbac"
)
// ListTemplatesOptions configures the list_templates tool.
type ListTemplatesOptions struct {
DB database.Store
OwnerID uuid.UUID
}
type listTemplatesArgs struct {
Query string `json:"query,omitempty"`
}
// ListTemplates returns a tool that lists available workspace templates.
// The agent uses this to discover templates before creating a workspace.
func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"list_templates",
"List available workspace templates. Optionally filter by a "+
"search query matching template name or description. "+
"Use this to find a template before creating a workspace.",
func(ctx context.Context, args listTemplatesArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.DB == nil {
return fantasy.NewTextErrorResponse("database is not configured"), nil
}
ctx, err := asOwner(ctx, options.DB, options.OwnerID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
filterParams := database.GetTemplatesWithFilterParams{
Deleted: false,
Deprecated: sql.NullBool{
Bool: false,
Valid: true,
},
}
query := strings.TrimSpace(args.Query)
if query != "" {
filterParams.FuzzyName = query
}
templates, err := options.DB.GetTemplatesWithFilter(ctx, filterParams)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
items := make([]map[string]any, 0, len(templates))
for _, t := range templates {
item := map[string]any{
"id": t.ID.String(),
"name": t.Name,
}
if display := strings.TrimSpace(t.DisplayName); display != "" {
item["display_name"] = display
}
if desc := strings.TrimSpace(t.Description); desc != "" {
item["description"] = truncateRunes(desc, 200)
}
items = append(items, item)
}
return toolResponse(map[string]any{
"templates": items,
"count": len(items),
}), nil
},
)
}
// asOwner sets up a dbauthz context for the given owner so that
// subsequent database calls are scoped to what that user can access.
func asOwner(ctx context.Context, db database.Store, ownerID uuid.UUID) (context.Context, error) {
actor, _, err := httpmw.UserRBACSubject(ctx, db, ownerID, rbac.ScopeAll)
if err != nil {
return ctx, xerrors.Errorf("load user authorization: %w", err)
}
return dbauthz.As(ctx, actor), nil
}
-72
View File
@@ -1,72 +0,0 @@
package chattool
import (
"context"
"io"
"charm.land/fantasy"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
type ReadFileOptions struct {
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
}
type ReadFileArgs struct {
Path string `json:"path"`
Offset *int64 `json:"offset,omitempty"`
Limit *int64 `json:"limit,omitempty"`
}
func ReadFile(options ReadFileOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"read_file",
"Read a file from the workspace.",
func(ctx context.Context, args ReadFileArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
}
conn, err := options.GetWorkspaceConn(ctx)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return executeReadFileTool(ctx, conn, args)
},
)
}
func executeReadFileTool(
ctx context.Context,
conn workspacesdk.AgentConn,
args ReadFileArgs,
) (fantasy.ToolResponse, error) {
if args.Path == "" {
return fantasy.NewTextErrorResponse("path is required"), nil
}
offset := int64(0)
limit := int64(0)
if args.Offset != nil {
offset = *args.Offset
}
if args.Limit != nil {
limit = *args.Limit
}
reader, mimeType, err := conn.ReadFile(ctx, args.Path, offset, limit)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
defer reader.Close()
data, err := io.ReadAll(reader)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return toolResponse(map[string]any{
"content": string(data),
"mime_type": mimeType,
}), nil
}
-130
View File
@@ -1,130 +0,0 @@
package chattool
import (
"context"
"encoding/json"
"strings"
"charm.land/fantasy"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
)
// ReadTemplateOptions configures the read_template tool.
type ReadTemplateOptions struct {
DB database.Store
OwnerID uuid.UUID
}
type readTemplateArgs struct {
TemplateID string `json:"template_id"`
}
// ReadTemplate returns a tool that retrieves details about a specific
// template, including its configurable rich parameters. The agent
// uses this after list_templates and before create_workspace.
func ReadTemplate(options ReadTemplateOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"read_template",
"Get details about a workspace template, including its "+
"configurable parameters. Use this after finding a "+
"template with list_templates and before creating a "+
"workspace with create_workspace.",
func(ctx context.Context, args readTemplateArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.DB == nil {
return fantasy.NewTextErrorResponse("database is not configured"), nil
}
templateIDStr := strings.TrimSpace(args.TemplateID)
if templateIDStr == "" {
return fantasy.NewTextErrorResponse("template_id is required"), nil
}
templateID, err := uuid.Parse(templateIDStr)
if err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("invalid template_id: %w", err).Error(),
), nil
}
ctx, err = asOwner(ctx, options.DB, options.OwnerID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
template, err := options.DB.GetTemplateByID(ctx, templateID)
if err != nil {
return fantasy.NewTextErrorResponse("template not found"), nil
}
params, err := options.DB.GetTemplateVersionParameters(ctx, template.ActiveVersionID)
if err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("failed to get template parameters: %w", err).Error(),
), nil
}
templateInfo := map[string]any{
"id": template.ID.String(),
"name": template.Name,
"active_version_id": template.ActiveVersionID.String(),
}
if display := strings.TrimSpace(template.DisplayName); display != "" {
templateInfo["display_name"] = display
}
if desc := strings.TrimSpace(template.Description); desc != "" {
templateInfo["description"] = desc
}
paramList := make([]map[string]any, 0, len(params))
for _, p := range params {
param := map[string]any{
"name": p.Name,
"type": p.Type,
"required": p.Required,
}
if display := strings.TrimSpace(p.DisplayName); display != "" {
param["display_name"] = display
}
if desc := strings.TrimSpace(p.Description); desc != "" {
param["description"] = truncateRunes(desc, 300)
}
if p.DefaultValue != "" {
param["default"] = p.DefaultValue
}
if p.Mutable {
param["mutable"] = true
}
if p.Ephemeral {
param["ephemeral"] = true
}
if p.FormType != "" {
param["form_type"] = string(p.FormType)
}
if len(p.Options) > 0 && string(p.Options) != "null" && string(p.Options) != "[]" {
var opts []map[string]any
if err := json.Unmarshal(p.Options, &opts); err == nil && len(opts) > 0 {
param["options"] = opts
}
}
if p.ValidationRegex != "" {
param["validation_regex"] = p.ValidationRegex
}
if p.ValidationMin.Valid {
param["validation_min"] = p.ValidationMin.Int32
}
if p.ValidationMax.Valid {
param["validation_max"] = p.ValidationMax.Int32
}
paramList = append(paramList, param)
}
return toolResponse(map[string]any{
"template": templateInfo,
"parameters": paramList,
}), nil
},
)
}
-51
View File
@@ -1,51 +0,0 @@
package chattool
import (
"context"
"strings"
"charm.land/fantasy"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
type WriteFileOptions struct {
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
}
type WriteFileArgs struct {
Path string `json:"path"`
Content string `json:"content"`
}
func WriteFile(options WriteFileOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"write_file",
"Write a file to the workspace.",
func(ctx context.Context, args WriteFileArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
}
conn, err := options.GetWorkspaceConn(ctx)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return executeWriteFileTool(ctx, conn, args)
},
)
}
func executeWriteFileTool(
ctx context.Context,
conn workspacesdk.AgentConn,
args WriteFileArgs,
) (fantasy.ToolResponse, error) {
if args.Path == "" {
return fantasy.NewTextErrorResponse("path is required"), nil
}
if err := conn.WriteFile(ctx, args.Path, strings.NewReader(args.Content)); err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return toolResponse(map[string]any{"ok": true}), nil
}
-126
View File
@@ -1,126 +0,0 @@
package chatd
import (
"context"
"io"
"net/http"
"regexp"
"strings"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
coderHomeInstructionDir = ".coder"
coderHomeInstructionFile = "AGENTS.md"
maxInstructionFileBytes = 64 * 1024
)
var markdownCommentPattern = regexp.MustCompile(`<!--[\s\S]*?-->`)
func readHomeInstructionFile(
ctx context.Context,
conn workspacesdk.AgentConn,
) (content string, sourcePath string, truncated bool, err error) {
if conn == nil {
return "", "", false, nil
}
coderDir, err := conn.LS(ctx, "", workspacesdk.LSRequest{
Path: []string{coderHomeInstructionDir},
Relativity: workspacesdk.LSRelativityHome,
})
if err != nil {
if isCodersdkStatusCode(err, http.StatusNotFound) {
return "", "", false, nil
}
return "", "", false, xerrors.Errorf("list home instruction directory: %w", err)
}
var filePath string
for _, entry := range coderDir.Contents {
if entry.IsDir {
continue
}
if strings.EqualFold(strings.TrimSpace(entry.Name), coderHomeInstructionFile) {
filePath = strings.TrimSpace(entry.AbsolutePathString)
break
}
}
if filePath == "" {
return "", "", false, nil
}
reader, _, err := conn.ReadFile(
ctx,
filePath,
0,
maxInstructionFileBytes+1,
)
if err != nil {
if isCodersdkStatusCode(err, http.StatusNotFound) {
return "", "", false, nil
}
return "", "", false, xerrors.Errorf("read home instruction file: %w", err)
}
defer reader.Close()
raw, err := io.ReadAll(reader)
if err != nil {
return "", "", false, xerrors.Errorf("read home instruction bytes: %w", err)
}
truncated = int64(len(raw)) > maxInstructionFileBytes
if truncated {
raw = raw[:maxInstructionFileBytes]
}
content = sanitizeInstructionMarkdown(string(raw))
if content == "" {
return "", "", truncated, nil
}
return content, filePath, truncated, nil
}
func sanitizeInstructionMarkdown(content string) string {
content = strings.ReplaceAll(content, "\r\n", "\n")
content = strings.ReplaceAll(content, "\r", "\n")
content = markdownCommentPattern.ReplaceAllString(content, "")
return strings.TrimSpace(content)
}
//nolint:revive // Boolean indicates content was truncated.
func formatHomeInstruction(content string, sourcePath string, truncated bool) string {
content = strings.TrimSpace(content)
if content == "" {
return ""
}
sourcePath = strings.TrimSpace(sourcePath)
if sourcePath == "" {
sourcePath = "~/.coder/AGENTS.md"
}
var b strings.Builder
_, _ = b.WriteString("<coder-home-instructions>\n")
_, _ = b.WriteString("Source: ")
_, _ = b.WriteString(sourcePath)
if truncated {
_, _ = b.WriteString(" (truncated to 64KiB)")
}
_, _ = b.WriteString("\n\n")
_, _ = b.WriteString(content)
_, _ = b.WriteString("\n</coder-home-instructions>")
return b.String()
}
func isCodersdkStatusCode(err error, statusCode int) bool {
var sdkErr *codersdk.Error
if !xerrors.As(err, &sdkErr) {
return false
}
return sdkErr.StatusCode() == statusCode
}
-134
View File
@@ -1,134 +0,0 @@
package chatd //nolint:testpackage // Uses internal symbols.
import (
"context"
"io"
"strings"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
)
func TestSanitizeInstructionMarkdown(t *testing.T) {
t.Parallel()
input := "line 1\r\n<!-- hidden -->\r\nline 2\r\n"
require.Equal(t, "line 1\n\nline 2", sanitizeInstructionMarkdown(input))
}
func TestReadHomeInstructionFileNotFound(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).DoAndReturn(
func(context.Context, string, workspacesdk.LSRequest) (workspacesdk.LSResponse, error) {
return workspacesdk.LSResponse{}, codersdk.NewTestError(404, "POST", "/api/v0/list-directory")
},
)
content, sourcePath, truncated, err := readHomeInstructionFile(context.Background(), conn)
require.NoError(t, err)
require.Empty(t, content)
require.Empty(t, sourcePath)
require.False(t, truncated)
}
func TestReadHomeInstructionFileSuccess(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).DoAndReturn(
func(context.Context, string, workspacesdk.LSRequest) (workspacesdk.LSResponse, error) {
return workspacesdk.LSResponse{
Contents: []workspacesdk.LSFile{{
Name: "AGENTS.md",
AbsolutePathString: "/home/coder/.coder/AGENTS.md",
}},
}, nil
},
)
conn.EXPECT().ReadFile(
gomock.Any(),
"/home/coder/.coder/AGENTS.md",
int64(0),
int64(maxInstructionFileBytes+1),
).Return(
io.NopCloser(strings.NewReader("base\n<!-- hidden -->\nlocal")),
"text/markdown",
nil,
)
content, sourcePath, truncated, err := readHomeInstructionFile(context.Background(), conn)
require.NoError(t, err)
require.Equal(t, "base\n\nlocal", content)
require.Equal(t, "/home/coder/.coder/AGENTS.md", sourcePath)
require.False(t, truncated)
}
func TestReadHomeInstructionFileTruncates(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
conn := agentconnmock.NewMockAgentConn(ctrl)
content := strings.Repeat("a", maxInstructionFileBytes+8)
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).Return(
workspacesdk.LSResponse{
Contents: []workspacesdk.LSFile{{
Name: "AGENTS.md",
AbsolutePathString: "/home/coder/.coder/AGENTS.md",
}},
},
nil,
)
conn.EXPECT().ReadFile(
gomock.Any(),
"/home/coder/.coder/AGENTS.md",
int64(0),
int64(maxInstructionFileBytes+1),
).Return(io.NopCloser(strings.NewReader(content)), "text/markdown", nil)
got, _, truncated, err := readHomeInstructionFile(context.Background(), conn)
require.NoError(t, err)
require.True(t, truncated)
require.Len(t, got, maxInstructionFileBytes)
}
func TestInsertSystemInstructionAfterSystemMessages(t *testing.T) {
t.Parallel()
prompt := []fantasy.Message{
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "base"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "hello"},
},
},
}
got := chatprompt.InsertSystem(prompt, "project rules")
require.Len(t, got, 3)
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
require.Equal(t, fantasy.MessageRoleSystem, got[1].Role)
require.Equal(t, fantasy.MessageRoleUser, got[2].Role)
part, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
require.True(t, ok)
require.Equal(t, "project rules", part.Text)
}
-73
View File
@@ -1,73 +0,0 @@
package chatd
// DefaultSystemPrompt is used for new chats when no deployment override is
// configured.
const DefaultSystemPrompt = `You are the Coder agent an interactive chat tool that helps users with software-engineering tasks inside of the Coder product.
Use the instructions below and the tools available to you to assist User.
IMPORTANT obey every rule in this prompt before anything else.
Do EXACTLY what the User asked, never more, never less.
<behavior>
You MUST execute AS MANY TOOLS to help the user accomplish their task.
You are COMFORTABLE with vague tasks - using your tools to collect the most relevant answer possible.
If a user asks how something works, no matter how vague, you MUST use your tools to collect the most relevant answer possible.
DO NOT ask the user for clarification - just use your tools.
</behavior>
<personality>
Analytical You break problems into measurable steps, relying on tool output and data rather than intuition.
Organized You structure every interaction with clear tags, TODO lists, and section boundaries.
Precision-Oriented You insist on exact formatting, package-manager choice, and rule adherence.
Efficiency-Focused You minimize chatter, run tasks in parallel, and favor small, complete answers.
Clarity-Seeking You ask for missing details instead of guessing, avoiding any ambiguity.
</personality>
<communication>
Be concise, direct, and to the point.
NO emojis unless the User explicitly asks for them.
If a task appears incomplete or ambiguous, **pause and ask the User** rather than guessing or marking "done".
Prefer accuracy over reassurance; confirm facts with tool calls instead of assuming the User is right.
If you face an architectural, tooling, or package-manager choice, **ask the User's preference first**.
Default to the project's existing package manager / tooling; never substitute without confirmation.
You MUST avoid text before/after your response, such as "The answer is" or "Short answer:", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
Mimic the style of the User's messages.
Do not remind the User you are happy to help.
Do not inherently assume the User is correct; they may be making assumptions.
If you are not confident in your answer, DO NOT provide an answer. Use your tools to collect more information, or ask the User for help.
Do not act with sycophantic flattery or over-the-top enthusiasm.
Here are examples to demonstrate appropriate communication style and level of verbosity:
<example>
user: find me a good issue to work on
assistant: Issue [#1234](https://example) indicates a bug in the frontend, which you've contributed to in the past.
</example>
<example>
user: work on this issue <url>
...assistant does work...
assistant: I've put up this pull request: https://github.com/example/example/pull/1824. Please let me know your thoughts!
</example>
<example>
user: what is 2+2?
assistant: 4
</example>
<example>
user: how does X work in <popular-repository-name>?
assistant: Let me take a look at the code...
[tool calls to investigate the repository]
</example>
</communication>
<collaboration>
When a user asks for help with a task or there is ambiguity on the objective, always start by asking clarifying questions to understand:
- What specific aspect they want to focus on
- Their goals and vision for the changes
- Their preferences for approach or style
- What problems they're trying to solve
Don't assume what needs to be done - collaborate to define the scope together.
</collaboration>`
-513
View File
@@ -1,513 +0,0 @@
package chatd
import (
"context"
"encoding/json"
"sort"
"strings"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/database"
)
var ErrSubagentNotDescendant = xerrors.New("target chat is not a descendant of current chat")
const (
subagentAwaitPollInterval = 200 * time.Millisecond
defaultSubagentWaitTimeout = 5 * time.Minute
)
type spawnAgentArgs struct {
Prompt string `json:"prompt"`
Title string `json:"title,omitempty"`
}
type waitAgentArgs struct {
ChatID string `json:"chat_id"`
TimeoutSeconds *int `json:"timeout_seconds,omitempty"`
}
type messageAgentArgs struct {
ChatID string `json:"chat_id"`
Message string `json:"message"`
Interrupt bool `json:"interrupt,omitempty"`
}
type closeAgentArgs struct {
ChatID string `json:"chat_id"`
}
func (p *Server) subagentTools(currentChat func() database.Chat) []fantasy.AgentTool {
return []fantasy.AgentTool{
fantasy.NewAgentTool(
"spawn_agent",
"Spawn a delegated child agent chat from the root chat.",
func(ctx context.Context, args spawnAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if currentChat == nil {
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
}
parent := currentChat()
if parent.ParentChatID.Valid {
return fantasy.NewTextErrorResponse("delegated chats cannot create child subagents"), nil
}
parent, err := p.db.GetChatByID(ctx, parent.ID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
childChat, err := p.createChildSubagentChat(
ctx,
parent,
args.Prompt,
args.Title,
)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return toolJSONResponse(map[string]any{
"chat_id": childChat.ID.String(),
"title": childChat.Title,
"status": string(childChat.Status),
}), nil
},
),
fantasy.NewAgentTool(
"wait_agent",
"Wait until a delegated descendant agent reaches a non-streaming status.",
func(ctx context.Context, args waitAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if currentChat == nil {
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
}
targetChatID, err := parseSubagentToolChatID(args.ChatID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
timeout := defaultSubagentWaitTimeout
if args.TimeoutSeconds != nil {
timeout = time.Duration(*args.TimeoutSeconds) * time.Second
}
parent := currentChat()
targetChat, report, err := p.awaitSubagentCompletion(
ctx,
parent.ID,
targetChatID,
timeout,
)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return toolJSONResponse(map[string]any{
"chat_id": targetChatID.String(),
"title": targetChat.Title,
"report": report,
"status": string(targetChat.Status),
}), nil
},
),
fantasy.NewAgentTool(
"message_agent",
"Send a message to a delegated descendant agent. Use wait_agent to collect a response.",
func(ctx context.Context, args messageAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if currentChat == nil {
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
}
targetChatID, err := parseSubagentToolChatID(args.ChatID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
parent := currentChat()
targetChat, err := p.sendSubagentMessage(
ctx,
parent.ID,
targetChatID,
args.Message,
args.Interrupt,
)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return toolJSONResponse(map[string]any{
"chat_id": targetChatID.String(),
"title": targetChat.Title,
"status": string(targetChat.Status),
"interrupted": args.Interrupt,
}), nil
},
),
fantasy.NewAgentTool(
"close_agent",
"Interrupt a delegated descendant agent immediately.",
func(ctx context.Context, args closeAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if currentChat == nil {
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
}
targetChatID, err := parseSubagentToolChatID(args.ChatID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
parent := currentChat()
targetChat, err := p.closeSubagent(
ctx,
parent.ID,
targetChatID,
)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return toolJSONResponse(map[string]any{
"chat_id": targetChatID.String(),
"title": targetChat.Title,
"terminated": true,
"status": string(targetChat.Status),
}), nil
},
),
}
}
func parseSubagentToolChatID(raw string) (uuid.UUID, error) {
chatID, err := uuid.Parse(strings.TrimSpace(raw))
if err != nil {
return uuid.Nil, xerrors.New("chat_id must be a valid UUID")
}
return chatID, nil
}
func (p *Server) createChildSubagentChat(
ctx context.Context,
parent database.Chat,
prompt string,
title string,
) (database.Chat, error) {
if parent.ParentChatID.Valid {
return database.Chat{}, xerrors.New("delegated chats cannot create child subagents")
}
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return database.Chat{}, xerrors.New("prompt is required")
}
title = strings.TrimSpace(title)
if title == "" {
title = subagentFallbackChatTitle(prompt)
}
rootChatID := parent.ID
if parent.RootChatID.Valid {
rootChatID = parent.RootChatID.UUID
}
if parent.LastModelConfigID == uuid.Nil {
return database.Chat{}, xerrors.New("parent chat model config id is required")
}
child, err := p.CreateChat(ctx, CreateOptions{
OwnerID: parent.OwnerID,
WorkspaceID: parent.WorkspaceID,
WorkspaceAgentID: parent.WorkspaceAgentID,
ParentChatID: uuid.NullUUID{
UUID: parent.ID,
Valid: true,
},
RootChatID: uuid.NullUUID{
UUID: rootChatID,
Valid: true,
},
ModelConfigID: parent.LastModelConfigID,
Title: title,
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: prompt}},
})
if err != nil {
return database.Chat{}, xerrors.Errorf("create child chat: %w", err)
}
return child, nil
}
func (p *Server) sendSubagentMessage(
ctx context.Context,
parentChatID uuid.UUID,
targetChatID uuid.UUID,
message string,
interrupt bool,
) (database.Chat, error) {
message = strings.TrimSpace(message)
if message == "" {
return database.Chat{}, xerrors.New("message is required")
}
isDescendant, err := isSubagentDescendant(ctx, p.db, parentChatID, targetChatID)
if err != nil {
return database.Chat{}, err
}
if !isDescendant {
return database.Chat{}, ErrSubagentNotDescendant
}
busyBehavior := SendMessageBusyBehaviorQueue
if interrupt {
busyBehavior = SendMessageBusyBehaviorInterrupt
}
sendResult, err := p.SendMessage(ctx, SendMessageOptions{
ChatID: targetChatID,
Content: []fantasy.Content{fantasy.TextContent{Text: message}},
BusyBehavior: busyBehavior,
})
if err != nil {
return database.Chat{}, err
}
return sendResult.Chat, nil
}
func (p *Server) awaitSubagentCompletion(
ctx context.Context,
parentChatID uuid.UUID,
targetChatID uuid.UUID,
timeout time.Duration,
) (database.Chat, string, error) {
isDescendant, err := isSubagentDescendant(ctx, p.db, parentChatID, targetChatID)
if err != nil {
return database.Chat{}, "", err
}
if !isDescendant {
return database.Chat{}, "", ErrSubagentNotDescendant
}
if timeout <= 0 {
timeout = defaultSubagentWaitTimeout
}
timer := time.NewTimer(timeout)
defer timer.Stop()
ticker := time.NewTicker(subagentAwaitPollInterval)
defer ticker.Stop()
for {
targetChat, report, done, checkErr := p.checkSubagentCompletion(ctx, targetChatID)
if checkErr != nil {
return database.Chat{}, "", checkErr
}
if done {
if targetChat.Status == database.ChatStatusError {
reason := strings.TrimSpace(report)
if reason == "" {
reason = "agent reached error status"
}
return database.Chat{}, "", xerrors.New(reason)
}
return targetChat, report, nil
}
select {
case <-ticker.C:
case <-timer.C:
return database.Chat{}, "", xerrors.New("timed out waiting for delegated subagent completion")
case <-ctx.Done():
return database.Chat{}, "", ctx.Err()
}
}
}
func (p *Server) closeSubagent(
ctx context.Context,
parentChatID uuid.UUID,
targetChatID uuid.UUID,
) (database.Chat, error) {
isDescendant, err := isSubagentDescendant(ctx, p.db, parentChatID, targetChatID)
if err != nil {
return database.Chat{}, err
}
if !isDescendant {
return database.Chat{}, ErrSubagentNotDescendant
}
targetChat, err := p.db.GetChatByID(ctx, targetChatID)
if err != nil {
return database.Chat{}, xerrors.Errorf("get target chat: %w", err)
}
if targetChat.Status == database.ChatStatusWaiting {
return targetChat, nil
}
updatedChat := p.InterruptChat(ctx, targetChat)
if updatedChat.Status != database.ChatStatusWaiting {
return database.Chat{}, xerrors.New("set target chat waiting")
}
return updatedChat, nil
}
func (p *Server) checkSubagentCompletion(
ctx context.Context,
chatID uuid.UUID,
) (database.Chat, string, bool, error) {
chat, err := p.db.GetChatByID(ctx, chatID)
if err != nil {
return database.Chat{}, "", false, xerrors.Errorf("get chat: %w", err)
}
if chat.Status == database.ChatStatusPending || chat.Status == database.ChatStatusRunning {
return database.Chat{}, "", false, nil
}
report, err := latestSubagentAssistantMessage(ctx, p.db, chatID)
if err != nil {
return database.Chat{}, "", false, err
}
return chat, report, true, nil
}
func latestSubagentAssistantMessage(
ctx context.Context,
store database.Store,
chatID uuid.UUID,
) (string, error) {
messages, err := store.GetChatMessagesByChatID(ctx, chatID)
if err != nil {
return "", xerrors.Errorf("get chat messages: %w", err)
}
sort.Slice(messages, func(i, j int) bool {
if messages[i].CreatedAt.Equal(messages[j].CreatedAt) {
return messages[i].ID < messages[j].ID
}
return messages[i].CreatedAt.Before(messages[j].CreatedAt)
})
for i := len(messages) - 1; i >= 0; i-- {
message := messages[i]
if message.Role != string(fantasy.MessageRoleAssistant) ||
message.Visibility == database.ChatMessageVisibilityModel {
continue
}
content, parseErr := chatprompt.ParseContent(message.Role, message.Content)
if parseErr != nil {
continue
}
text := strings.TrimSpace(contentBlocksToText(content))
if text == "" {
continue
}
return text, nil
}
return "", nil
}
func isSubagentDescendant(
ctx context.Context,
store database.Store,
ancestorChatID uuid.UUID,
targetChatID uuid.UUID,
) (bool, error) {
if ancestorChatID == targetChatID {
return false, nil
}
descendants, err := listSubagentDescendants(ctx, store, ancestorChatID)
if err != nil {
return false, err
}
for _, descendant := range descendants {
if descendant.ID == targetChatID {
return true, nil
}
}
return false, nil
}
func listSubagentDescendants(
ctx context.Context,
store database.Store,
chatID uuid.UUID,
) ([]database.Chat, error) {
queue := []uuid.UUID{chatID}
visited := map[uuid.UUID]struct{}{chatID: {}}
out := make([]database.Chat, 0)
for len(queue) > 0 {
parentChatID := queue[0]
queue = queue[1:]
children, err := store.ListChildChatsByParentID(ctx, parentChatID)
if err != nil {
return nil, xerrors.Errorf("list child chats for %s: %w", parentChatID, err)
}
for _, child := range children {
if _, ok := visited[child.ID]; ok {
continue
}
visited[child.ID] = struct{}{}
out = append(out, child)
queue = append(queue, child.ID)
}
}
return out, nil
}
func subagentFallbackChatTitle(message string) string {
const maxWords = 6
const maxRunes = 80
words := strings.Fields(message)
if len(words) == 0 {
return "New Chat"
}
truncated := false
if len(words) > maxWords {
words = words[:maxWords]
truncated = true
}
title := strings.Join(words, " ")
if truncated {
title += "..."
}
return subagentTruncateRunes(title, maxRunes)
}
func subagentTruncateRunes(value string, max int) string {
if max <= 0 {
return ""
}
runes := []rune(value)
if len(runes) <= max {
return value
}
return string(runes[:max])
}
func toolJSONResponse(result map[string]any) fantasy.ToolResponse {
data, err := json.Marshal(result)
if err != nil {
return fantasy.NewTextResponse("{}")
}
return fantasy.NewTextResponse(string(data))
}
-216
View File
@@ -1,216 +0,0 @@
package chatd
import (
"context"
"strings"
"time"
"charm.land/fantasy"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/database"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
)
const titleGenerationPrompt = "Generate a concise title (max 8 words, under 128 characters) for " +
"the user's first message. Return plain text only — no quotes, no emoji, " +
"no markdown, no special characters."
// maybeGenerateChatTitle generates an AI title for the chat when
// appropriate (first user message, no assistant reply yet, and the
// current title is either empty or still the fallback truncation).
// It is a best-effort operation that logs and swallows errors.
func (p *Server) maybeGenerateChatTitle(
ctx context.Context,
chat database.Chat,
messages []database.ChatMessage,
model fantasy.LanguageModel,
logger slog.Logger,
) {
input, ok := titleInput(chat, messages)
if !ok {
return
}
titleCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
title, err := generateTitle(titleCtx, model, input)
if err != nil {
logger.Debug(ctx, "failed to generate chat title",
slog.F("chat_id", chat.ID),
slog.Error(err),
)
return
}
if title == "" || title == chat.Title {
return
}
_, err = p.db.UpdateChatByID(ctx, database.UpdateChatByIDParams{
ID: chat.ID,
Title: title,
})
if err != nil {
logger.Warn(ctx, "failed to update generated chat title",
slog.F("chat_id", chat.ID),
slog.Error(err),
)
return
}
chat.Title = title
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange)
}
// generateTitle calls the model with a title-generation system prompt
// and returns the normalized result.
func generateTitle(
ctx context.Context,
model fantasy.LanguageModel,
input string,
) (string, error) {
prompt := []fantasy.Message{
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: titleGenerationPrompt},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: input},
},
},
}
toolChoice := fantasy.ToolChoiceNone
response, err := model.Generate(ctx, fantasy.Call{
Prompt: prompt,
ToolChoice: &toolChoice,
})
if err != nil {
return "", xerrors.Errorf("generate title text: %w", err)
}
title := normalizeTitleOutput(contentBlocksToText(response.Content))
if title == "" {
return "", xerrors.New("generated title was empty")
}
return title, nil
}
// titleInput returns the first user message text and whether title
// generation should proceed. It returns false when the chat already
// has assistant/tool replies, has more than one visible user message,
// or the current title doesn't look like a candidate for replacement.
func titleInput(
chat database.Chat,
messages []database.ChatMessage,
) (string, bool) {
userCount := 0
firstUserText := ""
for _, message := range messages {
if message.Visibility == database.ChatMessageVisibilityModel {
continue
}
switch message.Role {
case string(fantasy.MessageRoleAssistant), string(fantasy.MessageRoleTool):
return "", false
case string(fantasy.MessageRoleUser):
userCount++
if firstUserText == "" {
parsed, err := chatprompt.ParseContent(
string(fantasy.MessageRoleUser), message.Content,
)
if err != nil {
return "", false
}
firstUserText = strings.TrimSpace(
contentBlocksToText(parsed),
)
}
}
}
if userCount != 1 || firstUserText == "" {
return "", false
}
currentTitle := strings.TrimSpace(chat.Title)
if currentTitle == "" {
return firstUserText, true
}
if currentTitle != fallbackChatTitle(firstUserText) {
return "", false
}
return firstUserText, true
}
func normalizeTitleOutput(title string) string {
title = strings.TrimSpace(title)
if title == "" {
return ""
}
title = strings.Trim(title, "\"'`")
title = strings.Join(strings.Fields(title), " ")
return truncateRunes(title, 80)
}
func fallbackChatTitle(message string) string {
const maxWords = 6
const maxRunes = 80
words := strings.Fields(message)
if len(words) == 0 {
return "New Chat"
}
truncated := false
if len(words) > maxWords {
words = words[:maxWords]
truncated = true
}
title := strings.Join(words, " ")
if truncated {
title += "…"
}
return truncateRunes(title, maxRunes)
}
// contentBlocksToText concatenates the text parts of content blocks
// into a single space-separated string.
func contentBlocksToText(content []fantasy.Content) string {
parts := make([]string, 0, len(content))
for _, block := range content {
textBlock, ok := fantasy.AsContentType[fantasy.TextContent](block)
if !ok {
continue
}
text := strings.TrimSpace(textBlock.Text)
if text == "" {
continue
}
parts = append(parts, text)
}
return strings.Join(parts, " ")
}
func truncateRunes(value string, maxLen int) string {
if maxLen <= 0 {
return ""
}
runes := []rune(value)
if len(runes) <= maxLen {
return value
}
return string(runes[:maxLen])
}
-3229
View File
File diff suppressed because it is too large Load Diff
-336
View File
@@ -1,336 +0,0 @@
package coderd_test
import (
"net/http"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
func TestChats(t *testing.T) {
t.Parallel()
t.Run("PostChats", func(t *testing.T) {
t.Parallel()
t.Run("Success", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
modelConfig := createChatModelConfig(t, client)
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "hello from chats route tests",
},
},
})
require.NoError(t, err)
require.NotEqual(t, uuid.Nil, chat.ID)
require.Equal(t, user.UserID, chat.OwnerID)
require.Equal(t, modelConfig.ID, chat.LastModelConfigID)
require.Equal(t, "hello from chats route tests", chat.Title)
require.Equal(t, codersdk.ChatStatusPending, chat.Status)
require.NotZero(t, chat.CreatedAt)
require.NotZero(t, chat.UpdatedAt)
require.Nil(t, chat.WorkspaceID)
require.Nil(t, chat.WorkspaceAgentID)
require.NotNil(t, chat.RootChatID)
require.Equal(t, chat.ID, *chat.RootChatID)
chatWithMessages, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, chat.ID, chatWithMessages.Chat.ID)
foundUserMessage := false
for _, message := range chatWithMessages.Messages {
if message.Role != "user" {
continue
}
for _, part := range message.Content {
if part.Type == codersdk.ChatMessagePartTypeText &&
part.Text == "hello from chats route tests" {
foundUserMessage = true
break
}
}
}
require.True(t, foundUserMessage)
})
t.Run("HidesSystemPromptMessages", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
_ = createChatModelConfig(t, client)
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "verify hidden system prompt",
},
},
})
require.NoError(t, err)
chatWithMessages, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
for _, message := range chatWithMessages.Messages {
require.NotEqual(t, "system", message.Role)
}
})
t.Run("WorkspaceNotAccessible", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient, db := coderdtest.NewWithDatabase(t, nil)
firstUser := coderdtest.CreateFirstUser(t, adminClient)
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: firstUser.OrganizationID,
OwnerID: firstUser.UserID,
}).WithAgent().Do()
_, err := memberClient.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "hello",
},
},
WorkspaceID: &workspaceBuild.Workspace.ID,
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Workspace not found or you do not have access to this resource", sdkErr.Message)
})
t.Run("WorkspaceNotFound", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
workspaceID := uuid.New()
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "hello",
},
},
WorkspaceID: &workspaceID,
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Workspace not found or you do not have access to this resource", sdkErr.Message)
})
t.Run("WorkspaceSelectsFirstAgent", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
modelConfig := createChatModelConfig(t, client)
workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "hello",
},
},
WorkspaceID: &workspaceBuild.Workspace.ID,
})
require.NoError(t, err)
require.NotNil(t, chat.WorkspaceID)
require.Equal(t, workspaceBuild.Workspace.ID, *chat.WorkspaceID)
require.NotNil(t, chat.WorkspaceAgentID)
require.Equal(t, workspaceBuild.Agents[0].ID, *chat.WorkspaceAgentID)
require.Equal(t, modelConfig.ID, chat.LastModelConfigID)
})
t.Run("MissingDefaultModelConfig", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "hello",
},
},
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "No default chat model config is configured.", sdkErr.Message)
})
t.Run("EmptyContent", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: nil,
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Content is required.", sdkErr.Message)
require.Equal(t, "Content cannot be empty.", sdkErr.Detail)
})
t.Run("EmptyText", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: " ",
},
},
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Invalid input part.", sdkErr.Message)
require.Equal(t, "content[0].text cannot be empty.", sdkErr.Detail)
})
t.Run("UnsupportedPartType", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartType("image"),
Text: "hello",
},
},
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Invalid input part.", sdkErr.Message)
require.Equal(t, `content[0].type "image" is not supported.`, sdkErr.Detail)
})
})
t.Run("ListChatModels", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
_ = createChatModelConfig(t, client)
models, err := client.ListChatModels(ctx)
require.NoError(t, err)
var openAIProvider *codersdk.ChatModelProvider
for i := range models.Providers {
if models.Providers[i].Provider == "openai" {
openAIProvider = &models.Providers[i]
break
}
}
require.NotNil(t, openAIProvider)
require.True(t, openAIProvider.Available)
foundModel := false
for _, model := range openAIProvider.Models {
if model.Provider == "openai" && model.Model == "gpt-4o-mini" {
foundModel = true
break
}
}
require.True(t, foundModel)
})
t.Run("ListChatProviders", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
_ = createChatModelConfig(t, client)
providers, err := client.ListChatProviders(ctx)
require.NoError(t, err)
var openAIProvider *codersdk.ChatProviderConfig
for i := range providers {
if providers[i].Provider == "openai" {
openAIProvider = &providers[i]
break
}
}
require.NotNil(t, openAIProvider)
require.Equal(t, codersdk.ChatProviderConfigSourceDatabase, openAIProvider.Source)
require.True(t, openAIProvider.Enabled)
require.True(t, openAIProvider.HasAPIKey)
})
}
func createChatModelConfig(t *testing.T, client *codersdk.Client) codersdk.ChatModelConfig {
t.Helper()
ctx := testutil.Context(t, testutil.WaitLong)
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: "test-api-key",
})
require.NoError(t, err)
contextLimit := int64(4096)
isDefault := true
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai",
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
})
require.NoError(t, err)
return modelConfig
}
func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error {
t.Helper()
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, expectedStatus, sdkErr.StatusCode())
return sdkErr
}
+7 -63
View File
@@ -49,7 +49,6 @@ import (
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/awsidentity"
"github.com/coder/coder/v2/coderd/boundaryusage"
"github.com/coder/coder/v2/coderd/chatd"
"github.com/coder/coder/v2/coderd/connectionlog"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/database"
@@ -239,9 +238,6 @@ type Options struct {
SSHConfig codersdk.SSHConfigResponse
HTTPClient *http.Client
// ChatRemotePartsProvider provides cross-replica message_part streaming.
// Set by enterprise for HA deployments. Nil in AGPL single-replica.
ChatRemotePartsProvider chatd.RemotePartsProvider
UpdateAgentMetrics func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric)
StatsBatcher workspacestats.Batcher
@@ -592,6 +588,7 @@ func New(options *Options) *API {
var buildUsageChecker atomic.Pointer[wsbuilder.UsageChecker]
var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{}
buildUsageChecker.Store(&noopUsageChecker)
api := &API{
ctx: ctx,
cancel: cancel,
@@ -757,17 +754,6 @@ func New(options *Options) *API {
panic("failed to setup server tailnet: " + err.Error())
}
api.agentProvider = stn
api.chatDaemon = chatd.New(chatd.Config{
Logger: options.Logger.Named("chats"),
Database: options.Database,
ReplicaID: api.ID,
RemotePartsProvider: options.ChatRemotePartsProvider,
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
AgentConn: api.agentProvider.AgentConn,
CreateWorkspace: api.chatCreateWorkspace,
Pubsub: options.Pubsub,
})
if options.DeploymentValues.Prometheus.Enable {
options.PrometheusRegistry.MustRegister(stn)
api.lifecycleMetrics = agentapi.NewLifecycleMetrics(options.PrometheusRegistry)
@@ -914,7 +900,6 @@ func New(options *Options) *API {
sharedhttpmw.Recover(api.Logger),
httpmw.WithProfilingLabels,
tracing.StatusWriterMiddleware,
options.DeploymentValues.HTTPCookies.Middleware,
tracing.Middleware(api.TracerProvider),
httpmw.AttachRequestID,
httpmw.ExtractRealIP(api.RealIPConfig),
@@ -1174,43 +1159,6 @@ func New(options *Options) *API {
r.Get("/", api.auditLogs)
r.Post("/testgenerate", api.generateFakeAuditLog)
})
r.Route("/chats", func(r chi.Router) {
r.Use(apiKeyMiddleware)
r.Get("/", api.listChats)
r.Post("/", api.postChats)
r.Get("/models", api.listChatModels)
r.Get("/watch", api.watchChats)
r.Route("/providers", func(r chi.Router) {
r.Get("/", api.listChatProviders)
r.Post("/", api.createChatProvider)
r.Route("/{providerConfig}", func(r chi.Router) {
r.Patch("/", api.updateChatProvider)
r.Delete("/", api.deleteChatProvider)
})
})
r.Route("/model-configs", func(r chi.Router) {
r.Get("/", api.listChatModelConfigs)
r.Post("/", api.createChatModelConfig)
r.Route("/{modelConfig}", func(r chi.Router) {
r.Patch("/", api.updateChatModelConfig)
r.Delete("/", api.deleteChatModelConfig)
})
})
r.Route("/{chat}", func(r chi.Router) {
r.Use(httpmw.ExtractChatParam(options.Database))
r.Get("/", api.getChat)
r.Delete("/", api.deleteChat)
r.Post("/messages", api.postChatMessages)
r.Get("/stream", api.streamChat)
r.Post("/interrupt", api.interruptChat)
r.Get("/diff-status", api.getChatDiffStatus)
r.Get("/diff", api.getChatDiffContents)
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
r.Delete("/", api.deleteChatQueuedMessage)
r.Post("/promote", api.promoteChatQueuedMessage)
})
})
})
r.Route("/files", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
@@ -1284,10 +1232,7 @@ func New(options *Options) *API {
r.Get("/", api.organizationMember)
r.Delete("/", api.deleteOrganizationMember)
r.Put("/roles", api.putMemberRoles)
r.Route("/workspaces", func(r chi.Router) {
r.Post("/", api.postWorkspacesByOrganization)
r.Get("/available-users", api.workspaceAvailableUsers)
})
r.Post("/workspaces", api.postWorkspacesByOrganization)
})
})
})
@@ -1454,7 +1399,6 @@ func New(options *Options) *API {
r.Route("/{keyid}", func(r chi.Router) {
r.Get("/", api.apiKeyByID)
r.Delete("/", api.deleteAPIKey)
r.Put("/expire", api.expireAPIKey)
})
})
@@ -1577,6 +1521,10 @@ func New(options *Options) *API {
})
r.Get("/timings", api.workspaceTimings)
r.Route("/acl", func(r chi.Router) {
r.Use(
httpmw.RequireExperiment(api.Experiments, codersdk.ExperimentWorkspaceSharing),
)
r.Get("/", api.workspaceACL)
r.Patch("/", api.patchWorkspaceACL)
r.Delete("/", api.deleteWorkspaceACL)
@@ -1953,8 +1901,6 @@ type API struct {
// dbRolluper rolls up template usage stats from raw agent and app
// stats. This is used to provide insights in the WebUI.
dbRolluper *dbrollup.Rolluper
// chatDaemon handles background processing of pending chats.
chatDaemon *chatd.Server
}
// Close waits for all WebSocket connections to drain before returning.
@@ -1983,10 +1929,8 @@ func (api *API) Close() error {
case <-timer.C:
api.Logger.Warn(api.ctx, "websocket shutdown timed out after 10 seconds")
}
api.dbRolluper.Close()
if err := api.chatDaemon.Close(); err != nil {
api.Logger.Warn(api.ctx, "close chat processor", slog.Error(err))
}
api.metricsCache.Close()
if api.updateChecker != nil {
api.updateChecker.Close()
+13 -16
View File
@@ -6,20 +6,17 @@ type CheckConstraint string
// CheckConstraint enums.
const (
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
)
-332
View File
@@ -2,7 +2,6 @@
package db2sdk
import (
"database/sql"
"encoding/json"
"fmt"
"net/url"
@@ -12,7 +11,6 @@ import (
"strings"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/hashicorp/hcl/v2"
"github.com/sqlc-dev/pqtype"
@@ -20,7 +18,6 @@ import (
"tailscale.com/tailcfg"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
@@ -1053,332 +1050,3 @@ func jsonOrEmptyMap(rawMessage pqtype.NullRawMessage) map[string]any {
}
return m
}
func ChatMessage(m database.ChatMessage) codersdk.ChatMessage {
modelConfigID := &m.ModelConfigID.UUID
if !m.ModelConfigID.Valid {
modelConfigID = nil
}
msg := codersdk.ChatMessage{
ID: m.ID,
ChatID: m.ChatID,
ModelConfigID: modelConfigID,
CreatedAt: m.CreatedAt,
Role: m.Role,
}
if m.Content.Valid {
parts, err := chatMessageParts(m.Role, m.Content)
if err == nil {
msg.Content = parts
}
}
usage := chatMessageUsage(m)
if usage != nil {
msg.Usage = usage
}
return msg
}
// chatMessageUsage builds a ChatMessageUsage from the database row,
// returning nil when no token fields are populated.
func chatMessageUsage(m database.ChatMessage) *codersdk.ChatMessageUsage {
inputTokens := nullInt64Ptr(m.InputTokens)
outputTokens := nullInt64Ptr(m.OutputTokens)
totalTokens := nullInt64Ptr(m.TotalTokens)
reasoningTokens := nullInt64Ptr(m.ReasoningTokens)
cacheCreationTokens := nullInt64Ptr(m.CacheCreationTokens)
cacheReadTokens := nullInt64Ptr(m.CacheReadTokens)
contextLimit := nullInt64Ptr(m.ContextLimit)
if inputTokens == nil && outputTokens == nil && totalTokens == nil &&
reasoningTokens == nil && cacheCreationTokens == nil &&
cacheReadTokens == nil && contextLimit == nil {
return nil
}
return &codersdk.ChatMessageUsage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
TotalTokens: totalTokens,
ReasoningTokens: reasoningTokens,
CacheCreationTokens: cacheCreationTokens,
CacheReadTokens: cacheReadTokens,
ContextLimit: contextLimit,
}
}
// ChatQueuedMessages converts a slice of database queued messages
// to their SDK representation.
func ChatQueuedMessages(messages []database.ChatQueuedMessage) []codersdk.ChatQueuedMessage {
out := make([]codersdk.ChatQueuedMessage, 0, len(messages))
for _, message := range messages {
out = append(out, codersdk.ChatQueuedMessage{
ID: message.ID,
ChatID: message.ChatID,
Content: message.Content,
CreatedAt: message.CreatedAt,
})
}
return out
}
func chatMessageParts(role string, raw pqtype.NullRawMessage) ([]codersdk.ChatMessagePart, error) {
switch role {
case string(fantasy.MessageRoleSystem):
content, err := parseSystemContent(raw)
if err != nil {
return nil, err
}
if strings.TrimSpace(content) == "" {
return nil, nil
}
return []codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: content,
}}, nil
case string(fantasy.MessageRoleUser), string(fantasy.MessageRoleAssistant):
content, err := parseContentBlocks(role, raw)
if err != nil {
return nil, err
}
var rawBlocks []json.RawMessage
if role == string(fantasy.MessageRoleAssistant) {
_ = json.Unmarshal(raw.RawMessage, &rawBlocks)
}
parts := make([]codersdk.ChatMessagePart, 0, len(content))
for i, block := range content {
part := contentBlockToPart(block)
if part.Type == "" {
continue
}
if part.Type == codersdk.ChatMessagePartTypeReasoning {
part.Title = ""
if i < len(rawBlocks) {
part.Title = reasoningStoredTitle(rawBlocks[i])
}
}
parts = append(parts, part)
}
return parts, nil
case string(fantasy.MessageRoleTool):
results, err := parseToolResults(raw)
if err != nil {
return nil, err
}
parts := make([]codersdk.ChatMessagePart, 0, len(results))
for _, result := range results {
parts = append(parts, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolResult,
ToolCallID: result.ToolCallID,
ToolName: result.ToolName,
Result: result.Result,
IsError: result.IsError,
})
}
return parts, nil
default:
return nil, nil
}
}
func parseSystemContent(raw pqtype.NullRawMessage) (string, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return "", nil
}
var content string
if err := json.Unmarshal(raw.RawMessage, &content); err != nil {
return "", xerrors.Errorf("parse system content: %w", err)
}
return content, nil
}
func parseContentBlocks(role string, raw pqtype.NullRawMessage) ([]fantasy.Content, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return nil, nil
}
if role == string(fantasy.MessageRoleUser) {
var text string
if err := json.Unmarshal(raw.RawMessage, &text); err == nil {
return []fantasy.Content{
fantasy.TextContent{Text: text},
}, nil
}
}
var blocks []json.RawMessage
if err := json.Unmarshal(raw.RawMessage, &blocks); err != nil {
return nil, xerrors.Errorf("parse content blocks: %w", err)
}
content := make([]fantasy.Content, 0, len(blocks))
for _, block := range blocks {
decoded, err := fantasy.UnmarshalContent(block)
if err != nil {
return nil, xerrors.Errorf("parse content block: %w", err)
}
content = append(content, decoded)
}
return content, nil
}
// toolResultRow is used only for extracting top-level fields from
// persisted tool result JSON. The result payload is kept as raw JSON.
type toolResultRow struct {
ToolCallID string `json:"tool_call_id"`
ToolName string `json:"tool_name"`
Result json.RawMessage `json:"result"`
IsError bool `json:"is_error,omitempty"`
}
func parseToolResults(raw pqtype.NullRawMessage) ([]toolResultRow, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return nil, nil
}
var results []toolResultRow
if err := json.Unmarshal(raw.RawMessage, &results); err != nil {
return nil, xerrors.Errorf("parse tool results: %w", err)
}
return results, nil
}
func reasoningStoredTitle(raw json.RawMessage) string {
var envelope struct {
Type string `json:"type"`
Data struct {
Title string `json:"title"`
} `json:"data"`
}
if err := json.Unmarshal(raw, &envelope); err != nil {
return ""
}
if !strings.EqualFold(envelope.Type, string(fantasy.ContentTypeReasoning)) {
return ""
}
return strings.TrimSpace(envelope.Data.Title)
}
func contentBlockToPart(block fantasy.Content) codersdk.ChatMessagePart {
switch value := block.(type) {
case fantasy.TextContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: value.Text,
}
case *fantasy.TextContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: value.Text,
}
case fantasy.ReasoningContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Text: value.Text,
}
case *fantasy.ReasoningContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Text: value.Text,
}
case fantasy.ToolCallContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: value.ToolCallID,
ToolName: value.ToolName,
Args: []byte(value.Input),
}
case *fantasy.ToolCallContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: value.ToolCallID,
ToolName: value.ToolName,
Args: []byte(value.Input),
}
case fantasy.SourceContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeSource,
SourceID: value.ID,
URL: value.URL,
Title: value.Title,
}
case *fantasy.SourceContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeSource,
SourceID: value.ID,
URL: value.URL,
Title: value.Title,
}
case fantasy.FileContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeFile,
MediaType: value.MediaType,
Data: value.Data,
}
case *fantasy.FileContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeFile,
MediaType: value.MediaType,
Data: value.Data,
}
case fantasy.ToolResultContent:
return chatprompt.ToolResultToPart(
value.ToolCallID,
value.ToolName,
toolResultOutputToRawJSON(value.Result),
toolResultOutputIsError(value.Result),
)
case *fantasy.ToolResultContent:
return chatprompt.ToolResultToPart(
value.ToolCallID,
value.ToolName,
toolResultOutputToRawJSON(value.Result),
toolResultOutputIsError(value.Result),
)
default:
return codersdk.ChatMessagePart{}
}
}
func toolResultOutputToRawJSON(output fantasy.ToolResultOutputContent) json.RawMessage {
switch v := output.(type) {
case fantasy.ToolResultOutputContentError:
if v.Error != nil {
data, _ := json.Marshal(map[string]any{"error": v.Error.Error()})
return data
}
return json.RawMessage(`{"error":""}`)
case fantasy.ToolResultOutputContentText:
raw := json.RawMessage(v.Text)
if json.Valid(raw) {
return raw
}
data, _ := json.Marshal(map[string]any{"output": v.Text})
return data
case fantasy.ToolResultOutputContentMedia:
data, _ := json.Marshal(map[string]any{
"data": v.Data,
"mime_type": v.MediaType,
"text": v.Text,
})
return data
default:
return json.RawMessage(`{}`)
}
}
func toolResultOutputIsError(output fantasy.ToolResultOutputContent) bool {
_, ok := output.(fantasy.ToolResultOutputContentError)
return ok
}
func nullInt64Ptr(v sql.NullInt64) *int64 {
if !v.Valid {
return nil
}
value := v.Int64
return &value
}
-78
View File
@@ -8,8 +8,6 @@ import (
"testing"
"time"
"charm.land/fantasy"
fantasyopenai "charm.land/fantasy/providers/openai"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
@@ -437,79 +435,3 @@ func TestAIBridgeInterception(t *testing.T) {
})
}
}
func TestChatMessage_ReasoningPartWithoutPersistedTitleIsEmpty(t *testing.T) {
t.Parallel()
assistantContent, err := json.Marshal([]fantasy.Content{
fantasy.ReasoningContent{
Text: "Plan migration",
ProviderMetadata: fantasy.ProviderMetadata{
fantasyopenai.Name: &fantasyopenai.ResponsesReasoningMetadata{
ItemID: "reasoning-1",
Summary: []string{"Plan migration"},
},
},
},
})
require.NoError(t, err)
message := db2sdk.ChatMessage(database.ChatMessage{
ID: 1,
ChatID: uuid.New(),
CreatedAt: time.Now(),
Role: string(fantasy.MessageRoleAssistant),
Content: pqtype.NullRawMessage{
RawMessage: assistantContent,
Valid: true,
},
})
require.Len(t, message.Content, 1)
require.Equal(t, codersdk.ChatMessagePartTypeReasoning, message.Content[0].Type)
require.Equal(t, "Plan migration", message.Content[0].Text)
require.Empty(t, message.Content[0].Title)
}
func TestChatMessage_ReasoningPartPrefersPersistedTitle(t *testing.T) {
t.Parallel()
reasoningContent, err := json.Marshal(fantasy.ReasoningContent{
Text: "Verify schema updates, then apply changes in order.",
ProviderMetadata: fantasy.ProviderMetadata{
fantasyopenai.Name: &fantasyopenai.ResponsesReasoningMetadata{
ItemID: "reasoning-1",
Summary: []string{
"**Metadata-derived title**\n\nLonger explanation.",
},
},
},
})
require.NoError(t, err)
var envelope map[string]any
require.NoError(t, json.Unmarshal(reasoningContent, &envelope))
dataValue, ok := envelope["data"].(map[string]any)
require.True(t, ok)
dataValue["title"] = "Persisted stream title"
encodedReasoning, err := json.Marshal(envelope)
require.NoError(t, err)
assistantContent, err := json.Marshal([]json.RawMessage{encodedReasoning})
require.NoError(t, err)
message := db2sdk.ChatMessage(database.ChatMessage{
ID: 1,
ChatID: uuid.New(),
CreatedAt: time.Now(),
Role: string(fantasy.MessageRoleAssistant),
Content: pqtype.NullRawMessage{
RawMessage: assistantContent,
Valid: true,
},
})
require.Len(t, message.Content, 1)
require.Equal(t, codersdk.ChatMessagePartTypeReasoning, message.Content[0].Type)
require.Equal(t, "Persisted stream title", message.Content[0].Title)
}
+13 -450
View File
@@ -453,7 +453,6 @@ var (
rbac.ResourceProvisionerJobs.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreate},
rbac.ResourceOauth2App.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
rbac.ResourceOauth2AppSecret.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
rbac.ResourceChat.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
}),
User: []rbac.Permission{},
ByOrgID: map[string]rbac.OrgPermissions{},
@@ -669,31 +668,6 @@ var (
}),
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
subjectWorkspaceBuilder = rbac.Subject{
Type: rbac.SubjectTypeWorkspaceBuilder,
FriendlyName: "Workspace Builder",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
{
Identifier: rbac.RoleIdentifier{Name: "workspace-builder"},
DisplayName: "Workspace Builder",
Site: rbac.Permissions(map[string][]policy.Action{
// Reading provisioner daemons to check eligibility.
rbac.ResourceProvisionerDaemon.Type: {policy.ActionRead},
// Updating provisioner jobs (e.g. marking prebuild
// jobs complete).
rbac.ResourceProvisionerJobs.Type: {policy.ActionUpdate},
// Reading provisioner state requires template update
// permission.
rbac.ResourceTemplate.Type: {policy.ActionUpdate},
}),
User: []rbac.Permission{},
ByOrgID: map[string]rbac.OrgPermissions{},
},
}),
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
)
// AsProvisionerd returns a context with an actor that has permissions required
@@ -800,14 +774,6 @@ func AsBoundaryUsageTracker(ctx context.Context) context.Context {
return As(ctx, subjectBoundaryUsageTracker)
}
// AsWorkspaceBuilder returns a context with an actor that has permissions
// required for the workspace builder to prepare workspace builds. This
// includes reading provisioner daemons, updating provisioner jobs, and
// reading provisioner state (which requires template update permission).
func AsWorkspaceBuilder(ctx context.Context) context.Context {
return As(ctx, subjectWorkspaceBuilder)
}
var AsRemoveActor = rbac.Subject{
ID: "remove-actor",
}
@@ -1485,15 +1451,6 @@ func (q *querier) authorizeProvisionerJob(ctx context.Context, job database.Prov
return nil
}
func (q *querier) AcquireChat(ctx context.Context, arg database.AcquireChatParams) (database.Chat, error) {
// AcquireChat is a system-level operation used by the chat processor.
// Authorization is done at the system level, not per-user.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return database.Chat{}, err
}
return q.db.AcquireChat(ctx, arg)
}
func (q *querier) AcquireLock(ctx context.Context, id int64) error {
return q.db.AcquireLock(ctx, id)
}
@@ -1722,17 +1679,6 @@ func (q *querier) DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) e
return q.db.DeleteAPIKeysByUserID(ctx, userID)
}
func (q *querier) DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.UUID) error {
chat, err := q.db.GetChatByID(ctx, chatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.DeleteAllChatQueuedMessages(ctx, chatID)
}
func (q *querier) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) error {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil {
return err
@@ -1757,54 +1703,6 @@ func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, u
return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID)
}
func (q *querier) DeleteChatByID(ctx context.Context, id uuid.UUID) error {
chat, err := q.db.GetChatByID(ctx, id)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionDelete, chat); err != nil {
return err
}
return q.db.DeleteChatByID(ctx, id)
}
func (q *querier) DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error {
// Authorize delete on the parent chat.
chat, err := q.db.GetChatByID(ctx, chatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionDelete, chat); err != nil {
return err
}
return q.db.DeleteChatMessagesByChatID(ctx, chatID)
}
func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteChatModelConfigByID(ctx, id)
}
func (q *querier) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteChatProviderByID(ctx, id)
}
func (q *querier) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.DeleteChatQueuedMessage(ctx, arg)
}
func (q *querier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceCryptoKey); err != nil {
return database.CryptoKey{}, err
@@ -2263,12 +2161,12 @@ func (q *querier) GetAPIKeyByName(ctx context.Context, arg database.GetAPIKeyByN
return fetch(q.log, q.auth, q.db.GetAPIKeyByName)(ctx, arg)
}
func (q *querier) GetAPIKeysByLoginType(ctx context.Context, loginType database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
func (q *querier) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByLoginType)(ctx, loginType)
}
func (q *querier) GetAPIKeysByUserID(ctx context.Context, params database.GetAPIKeysByUserIDParams) ([]database.APIKey, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByUserID)(ctx, params)
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByUserID)(ctx, database.GetAPIKeysByUserIDParams{LoginType: params.LoginType, UserID: params.UserID})
}
func (q *querier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]database.APIKey, error) {
@@ -2359,7 +2257,7 @@ func (q *querier) GetAuditLogsOffset(ctx context.Context, arg database.GetAuditL
}
func (q *querier) GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx context.Context, authToken uuid.UUID) (database.GetAuthenticatedWorkspaceAgentAndBuildByAuthTokenRow, error) {
// This is a system function.
// This is a system function
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return database.GetAuthenticatedWorkspaceAgentAndBuildByAuthTokenRow{}, err
}
@@ -2373,131 +2271,6 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI
return q.db.GetAuthorizationUserRoles(ctx, userID)
}
func (q *querier) GetChatByID(ctx context.Context, id uuid.UUID) (database.Chat, error) {
return fetch(q.log, q.auth, q.db.GetChatByID)(ctx, id)
}
func (q *querier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (database.Chat, error) {
return fetch(q.log, q.auth, q.db.GetChatByIDForUpdate)(ctx, id)
}
func (q *querier) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, chatID)
if err != nil {
return database.ChatDiffStatus{}, err
}
return q.db.GetChatDiffStatusByChatID(ctx, chatID)
}
func (q *querier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIDs []uuid.UUID) ([]database.ChatDiffStatus, error) {
if len(chatIDs) == 0 {
return []database.ChatDiffStatus{}, nil
}
actor, ok := ActorFromContext(ctx)
if ok && actor.Type == rbac.SubjectTypeSystemRestricted {
return q.db.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
}
for _, chatID := range chatIDs {
// Authorize read on each parent chat.
_, err := q.GetChatByID(ctx, chatID)
if err != nil {
return nil, err
}
}
return q.db.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
}
func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
// ChatMessages are authorized through their parent Chat.
// We need to fetch the message first to get its chat_id.
msg, err := q.db.GetChatMessageByID(ctx, id)
if err != nil {
return database.ChatMessage{}, err
}
// Authorize read on the parent chat.
_, err = q.GetChatByID(ctx, msg.ChatID)
if err != nil {
return database.ChatMessage{}, err
}
return msg, nil
}
func (q *querier) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, chatID)
if err != nil {
return nil, err
}
return q.db.GetChatMessagesByChatID(ctx, chatID)
}
func (q *querier) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, chatID)
if err != nil {
return nil, err
}
return q.db.GetChatMessagesForPromptByChatID(ctx, chatID)
}
func (q *querier) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatModelConfig{}, err
}
return q.db.GetChatModelConfigByID(ctx, id)
}
func (q *querier) GetChatModelConfigByProviderAndModel(ctx context.Context, arg database.GetChatModelConfigByProviderAndModelParams) (database.ChatModelConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatModelConfig{}, err
}
return q.db.GetChatModelConfigByProviderAndModel(ctx, arg)
}
func (q *querier) GetChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetChatModelConfigs(ctx)
}
func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatProvider{}, err
}
return q.db.GetChatProviderByID(ctx, id)
}
func (q *querier) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatProvider{}, err
}
return q.db.GetChatProviderByProvider(ctx, provider)
}
func (q *querier) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetChatProviders(ctx)
}
func (q *querier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
_, err := q.GetChatByID(ctx, chatID)
if err != nil {
return nil, err
}
return q.db.GetChatQueuedMessages(ctx, chatID)
}
func (q *querier) GetChatsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.Chat, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByOwnerID)(ctx, ownerID)
}
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
// Just like with the audit logs query, shortcut if the user is an owner.
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
@@ -2555,13 +2328,6 @@ func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) {
return q.db.GetDERPMeshKey(ctx)
}
func (q *querier) GetDefaultChatModelConfig(ctx context.Context) (database.ChatModelConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatModelConfig{}, err
}
return q.db.GetDefaultChatModelConfig(ctx)
}
func (q *querier) GetDefaultOrganization(ctx context.Context) (database.Organization, error) {
return fetch(q.log, q.auth, func(ctx context.Context, _ any) (database.Organization, error) {
return q.db.GetDefaultOrganization(ctx)
@@ -2602,20 +2368,6 @@ func (q *querier) GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx context.C
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetEligibleProvisionerDaemonsByProvisionerJobIDs)(ctx, provisionerJobIDs)
}
func (q *querier) GetEnabledChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetEnabledChatModelConfigs(ctx)
}
func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetEnabledChatProviders(ctx)
}
func (q *querier) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetExternalAuthLink)(ctx, arg)
}
@@ -3315,14 +3067,6 @@ func (q *querier) GetRuntimeConfig(ctx context.Context, key string) (string, err
return q.db.GetRuntimeConfig(ctx, key)
}
func (q *querier) GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]database.Chat, error) {
// GetStaleChats is a system-level operation used by the chat processor for recovery.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat); err != nil {
return nil, err
}
return q.db.GetStaleChats(ctx, staleThreshold)
}
func (q *querier) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]database.TailnetPeer, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return nil, err
@@ -3389,13 +3133,6 @@ func (q *querier) GetTelemetryItems(ctx context.Context) ([]database.TelemetryIt
return q.db.GetTelemetryItems(ctx)
}
func (q *querier) GetTelemetryTaskEvents(ctx context.Context, arg database.GetTelemetryTaskEventsParams) ([]database.GetTelemetryTaskEventsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTask.All()); err != nil {
return nil, err
}
return q.db.GetTelemetryTaskEvents(ctx, arg)
}
func (q *querier) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) {
if err := q.authorizeTemplateInsights(ctx, arg.TemplateIDs); err != nil {
return nil, err
@@ -4177,11 +3914,6 @@ func (q *querier) GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, wor
return q.db.GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx, workspaceBuildIDs, prep)
}
func (q *querier) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, buildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
// Fetching the provisioner state requires Update permission on the template.
return fetchWithAction(q.log, q.auth, policy.ActionUpdate, q.db.GetWorkspaceBuildProvisionerStateByID)(ctx, buildID)
}
func (q *querier) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
@@ -4446,47 +4178,6 @@ func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLo
return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg)
}
func (q *querier) InsertChat(ctx context.Context, arg database.InsertChatParams) (database.Chat, error) {
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()), q.db.InsertChat)(ctx, arg)
}
func (q *querier) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
// Authorize create on the parent chat (using update permission).
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return database.ChatMessage{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatMessage{}, err
}
return q.db.InsertChatMessage(ctx, arg)
}
func (q *querier) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatModelConfig{}, err
}
return q.db.InsertChatModelConfig(ctx, arg)
}
func (q *querier) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatProvider{}, err
}
return q.db.InsertChatProvider(ctx, arg)
}
func (q *querier) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return database.ChatQueuedMessage{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatQueuedMessage{}, err
}
return q.db.InsertChatQueuedMessage(ctx, arg)
}
func (q *querier) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceCryptoKey); err != nil {
return database.CryptoKey{}, err
@@ -5055,14 +4746,6 @@ func (q *querier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Contex
return q.db.ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg)
}
func (q *querier) ListAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams) ([]string, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
if err != nil {
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
}
return q.db.ListAuthorizedAIBridgeModels(ctx, arg, prep)
}
func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
// This function is a system function until we implement a join for aibridge interceptions.
// Matches the behavior of the workspaces listing endpoint.
@@ -5093,14 +4776,6 @@ func (q *querier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context,
return q.db.ListAIBridgeUserPromptsByInterceptionIDs(ctx, interceptionIDs)
}
func (q *querier) ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]database.Chat, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListChatsByRootID)(ctx, rootChatID)
}
func (q *querier) ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]database.Chat, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListChildChatsByParentID)(ctx, parentChatID)
}
func (q *querier) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListProvisionerKeysByOrganization)(ctx, organizationID)
}
@@ -5181,17 +4856,6 @@ func (q *querier) PaginatedOrganizationMembers(ctx context.Context, arg database
return q.db.PaginatedOrganizationMembers(ctx, arg)
}
func (q *querier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (database.ChatQueuedMessage, error) {
chat, err := q.db.GetChatByID(ctx, chatID)
if err != nil {
return database.ChatQueuedMessage{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatQueuedMessage{}, err
}
return q.db.PopNextQueuedMessage(ctx, chatID)
}
func (q *querier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
template, err := q.db.GetTemplateByID(ctx, templateID)
if err != nil {
@@ -5270,13 +4934,6 @@ func (q *querier) UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error {
return update(q.log, q.auth, fetch, q.db.UnfavoriteWorkspace)(ctx, id)
}
func (q *querier) UnsetDefaultChatModelConfigs(ctx context.Context) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.UnsetDefaultChatModelConfigs(ctx)
}
func (q *querier) UpdateAIBridgeInterceptionEnded(ctx context.Context, params database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, params.ID); err != nil {
return database.AIBridgeInterception{}, err
@@ -5291,75 +4948,6 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe
return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg)
}
func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.Chat{}, err
}
return q.db.UpdateChatByID(ctx, arg)
}
func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return 0, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return 0, err
}
return q.db.UpdateChatHeartbeat(ctx, arg)
}
func (q *querier) UpdateChatModelConfig(ctx context.Context, arg database.UpdateChatModelConfigParams) (database.ChatModelConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatModelConfig{}, err
}
return q.db.UpdateChatModelConfig(ctx, arg)
}
func (q *querier) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatProvider{}, err
}
return q.db.UpdateChatProvider(ctx, arg)
}
func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
// UpdateChatStatus is used by the chat processor to change chat status.
// It should be called with system context.
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.Chat{}, err
}
return q.db.UpdateChatStatus(ctx, arg)
}
func (q *querier) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.Chat{}, err
}
// UpdateChatWorkspace is manually implemented for chat tables and may not be
// present on every wrapped store interface yet.
chatWorkspaceUpdater, ok := q.db.(interface {
UpdateChatWorkspace(context.Context, database.UpdateChatWorkspaceParams) (database.Chat, error)
})
if !ok {
return database.Chat{}, xerrors.New("update chat workspace is not implemented by wrapped store")
}
return chatWorkspaceUpdater.UpdateChatWorkspace(ctx, arg)
}
func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceCryptoKey); err != nil {
return database.CryptoKey{}, err
@@ -6415,30 +6003,6 @@ func (q *querier) UpsertBoundaryUsageStats(ctx context.Context, arg database.Ups
return q.db.UpsertBoundaryUsageStats(ctx, arg)
}
func (q *querier) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
// Authorize update on the parent chat.
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return database.ChatDiffStatus{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatDiffStatus{}, err
}
return q.db.UpsertChatDiffStatus(ctx, arg)
}
func (q *querier) UpsertChatDiffStatusReference(ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
// Authorize update on the parent chat.
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return database.ChatDiffStatus{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatDiffStatus{}, err
}
return q.db.UpsertChatDiffStatusReference(ctx, arg)
}
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return database.ConnectionLog{}, err
@@ -6730,17 +6294,16 @@ func (q *querier) CountAuthorizedConnectionLogs(ctx context.Context, arg databas
return q.CountConnectionLogs(ctx, arg)
}
func (q *querier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeInterceptionsRow, error) {
return q.db.ListAuthorizedAIBridgeInterceptions(ctx, arg, prepared)
}
func (q *querier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
return q.db.CountAuthorizedAIBridgeInterceptions(ctx, arg, prepared)
}
func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams, _ rbac.PreparedAuthorized) ([]string, error) {
// TODO: Delete this function, all ListAIBridgeModels should be authorized. For now just call ListAIBridgeModels on the authz querier.
func (q *querier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, _ rbac.PreparedAuthorized) ([]database.ListAIBridgeInterceptionsRow, error) {
// TODO: Delete this function, all ListAIBridgeInterceptions should be authorized. For now just call ListAIBridgeInterceptions on the authz querier.
// This cannot be deleted for now because it's included in the
// database.Store interface, so dbauthz needs to implement it.
return q.ListAIBridgeModels(ctx, arg)
return q.ListAIBridgeInterceptions(ctx, arg)
}
func (q *querier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams, _ rbac.PreparedAuthorized) (int64, error) {
// TODO: Delete this function, all CountAIBridgeInterceptions should be authorized. For now just call CountAIBridgeInterceptions on the authz querier.
// This cannot be deleted for now because it's included in the
// database.Store interface, so dbauthz needs to implement it.
return q.CountAIBridgeInterceptions(ctx, arg)
}
+3 -376
View File
@@ -170,7 +170,6 @@ func TestDBAuthzRecursive(t *testing.T) {
Groups: []string{},
Scope: rbac.ScopeAll,
}
preparedAuthorizedType := reflect.TypeOf((*rbac.PreparedAuthorized)(nil)).Elem()
for i := 0; i < reflect.TypeOf(q).NumMethod(); i++ {
var ins []reflect.Value
ctx := dbauthz.As(context.Background(), actor)
@@ -178,13 +177,7 @@ func TestDBAuthzRecursive(t *testing.T) {
ins = append(ins, reflect.ValueOf(ctx))
method := reflect.TypeOf(q).Method(i)
for i := 2; i < method.Type.NumIn(); i++ {
inType := method.Type.In(i)
if inType.Implements(preparedAuthorizedType) {
ins = append(ins, reflect.ValueOf(emptyPreparedAuthorized{}))
continue
}
ins = append(ins, reflect.New(inType).Elem())
ins = append(ins, reflect.New(method.Type.In(i)).Elem())
}
if method.Name == "InTx" ||
method.Name == "Ping" ||
@@ -244,8 +237,8 @@ func (s *MethodTestSuite) TestAPIKey() {
s.Run("GetAPIKeysByLoginType", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
a := testutil.Fake(s.T(), faker, database.APIKey{LoginType: database.LoginTypePassword})
b := testutil.Fake(s.T(), faker, database.APIKey{LoginType: database.LoginTypePassword})
dbm.EXPECT().GetAPIKeysByLoginType(gomock.Any(), database.GetAPIKeysByLoginTypeParams{LoginType: database.LoginTypePassword}).Return([]database.APIKey{a, b}, nil).AnyTimes()
check.Args(database.GetAPIKeysByLoginTypeParams{LoginType: database.LoginTypePassword}).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
dbm.EXPECT().GetAPIKeysByLoginType(gomock.Any(), database.LoginTypePassword).Return([]database.APIKey{a, b}, nil).AnyTimes()
check.Args(database.LoginTypePassword).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
}))
s.Run("GetAPIKeysByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u1 := testutil.Fake(s.T(), faker, database.User{})
@@ -371,344 +364,6 @@ func (s *MethodTestSuite) TestConnectionLogs() {
}))
}
func (s *MethodTestSuite) TestChats() {
s.Run("AcquireChat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.AcquireChatParams{
StartedAt: dbtime.Now(),
WorkerID: uuid.New(),
}
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().AcquireChat(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns(chat)
}))
s.Run("DeleteAllChatQueuedMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().DeleteAllChatQueuedMessages(gomock.Any(), chat.ID).Return(nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("DeleteChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().DeleteChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionDelete).Returns()
}))
s.Run("DeleteChatMessagesByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().DeleteChatMessagesByChatID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionDelete).Returns()
}))
s.Run("DeleteChatModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
id := uuid.New()
dbm.EXPECT().DeleteChatModelConfigByID(gomock.Any(), id).Return(nil).AnyTimes()
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("DeleteChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
id := uuid.New()
dbm.EXPECT().DeleteChatProviderByID(gomock.Any(), id).Return(nil).AnyTimes()
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("DeleteChatQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
args := database.DeleteChatQueuedMessageParams{ID: 123, ChatID: chat.ID}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().DeleteChatQueuedMessage(gomock.Any(), args).Return(nil).AnyTimes()
check.Args(args).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("GetChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(chat)
}))
s.Run("GetChatByIDForUpdate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByIDForUpdate(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(chat)
}))
s.Run("GetChatDiffStatusByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetChatDiffStatusByChatID(gomock.Any(), chat.ID).Return(diffStatus, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(diffStatus)
}))
s.Run("GetChatDiffStatusesByChatIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chatA := testutil.Fake(s.T(), faker, database.Chat{})
chatB := testutil.Fake(s.T(), faker, database.Chat{})
ids := []uuid.UUID{chatA.ID, chatB.ID}
diffStatusA := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chatA.ID})
diffStatusB := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chatB.ID})
dbm.EXPECT().GetChatByID(gomock.Any(), chatA.ID).Return(chatA, nil).AnyTimes()
dbm.EXPECT().GetChatByID(gomock.Any(), chatB.ID).Return(chatB, nil).AnyTimes()
dbm.EXPECT().GetChatDiffStatusesByChatIDs(gomock.Any(), ids).Return([]database.ChatDiffStatus{diffStatusA, diffStatusB}, nil).AnyTimes()
check.Args(ids).
Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).
Returns([]database.ChatDiffStatus{diffStatusA, diffStatusB})
}))
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
dbm.EXPECT().GetChatMessageByID(gomock.Any(), msg.ID).Return(msg, nil).AnyTimes()
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
check.Args(msg.ID).Asserts(chat, policy.ActionRead).Returns(msg)
}))
s.Run("GetChatMessagesByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), chat.ID).Return(msgs, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(msgs)
}))
s.Run("GetChatMessagesForPromptByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetChatMessagesForPromptByChatID(gomock.Any(), chat.ID).Return(msgs, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(msgs)
}))
s.Run("GetChatModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
dbm.EXPECT().GetChatModelConfigByID(gomock.Any(), config.ID).Return(config, nil).AnyTimes()
check.Args(config.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetDefaultChatModelConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
dbm.EXPECT().GetDefaultChatModelConfig(gomock.Any()).Return(config, nil).AnyTimes()
check.Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetChatModelConfigByProviderAndModel", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
args := database.GetChatModelConfigByProviderAndModelParams{
Provider: config.Provider,
Model: config.Model,
}
dbm.EXPECT().GetChatModelConfigByProviderAndModel(gomock.Any(), args).Return(config, nil).AnyTimes()
check.Args(args).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
dbm.EXPECT().GetChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB})
}))
s.Run("GetChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
provider := testutil.Fake(s.T(), faker, database.ChatProvider{})
dbm.EXPECT().GetChatProviderByID(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes()
check.Args(provider.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
}))
s.Run("GetChatProviderByProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
providerName := "test-provider"
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: providerName})
dbm.EXPECT().GetChatProviderByProvider(gomock.Any(), providerName).Return(provider, nil).AnyTimes()
check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
}))
s.Run("GetChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
providerA := testutil.Fake(s.T(), faker, database.ChatProvider{})
providerB := testutil.Fake(s.T(), faker, database.ChatProvider{})
dbm.EXPECT().GetChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
}))
s.Run("GetChatsByOwnerID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
c1 := testutil.Fake(s.T(), faker, database.Chat{})
c2 := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatsByOwnerID(gomock.Any(), c1.OwnerID).Return([]database.Chat{c1, c2}, nil).AnyTimes()
check.Args(c1.OwnerID).Asserts(c1, policy.ActionRead, c2, policy.ActionRead).Returns([]database.Chat{c1, c2})
}))
s.Run("GetChatQueuedMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
qms := []database.ChatQueuedMessage{testutil.Fake(s.T(), faker, database.ChatQueuedMessage{})}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetChatQueuedMessages(gomock.Any(), chat.ID).Return(qms, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(qms)
}))
s.Run("GetEnabledChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
dbm.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB})
}))
s.Run("GetEnabledChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
providerA := testutil.Fake(s.T(), faker, database.ChatProvider{})
providerB := testutil.Fake(s.T(), faker, database.ChatProvider{})
dbm.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
}))
s.Run("ListChatsByRootID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
rootChatID := uuid.New()
chatA := testutil.Fake(s.T(), faker, database.Chat{RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true}})
chatB := testutil.Fake(s.T(), faker, database.Chat{RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true}})
dbm.EXPECT().ListChatsByRootID(gomock.Any(), rootChatID).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
check.Args(rootChatID).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
}))
s.Run("ListChildChatsByParentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
parentChatID := uuid.New()
chatA := testutil.Fake(s.T(), faker, database.Chat{ParentChatID: uuid.NullUUID{UUID: parentChatID, Valid: true}})
chatB := testutil.Fake(s.T(), faker, database.Chat{ParentChatID: uuid.NullUUID{UUID: parentChatID, Valid: true}})
dbm.EXPECT().ListChildChatsByParentID(gomock.Any(), parentChatID).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
check.Args(parentChatID).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
}))
s.Run("GetStaleChats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
threshold := dbtime.Now()
chats := []database.Chat{testutil.Fake(s.T(), faker, database.Chat{})}
dbm.EXPECT().GetStaleChats(gomock.Any(), threshold).Return(chats, nil).AnyTimes()
check.Args(threshold).Asserts(rbac.ResourceChat, policy.ActionRead).Returns(chats)
}))
s.Run("InsertChat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := testutil.Fake(s.T(), faker, database.InsertChatParams{})
chat := testutil.Fake(s.T(), faker, database.Chat{OwnerID: arg.OwnerID})
dbm.EXPECT().InsertChat(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionCreate).Returns(chat)
}))
s.Run("InsertChatMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := testutil.Fake(s.T(), faker, database.InsertChatMessageParams{ChatID: chat.ID})
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().InsertChatMessage(gomock.Any(), arg).Return(msg, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(msg)
}))
s.Run("InsertChatQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := testutil.Fake(s.T(), faker, database.InsertChatQueuedMessageParams{ChatID: chat.ID})
qm := testutil.Fake(s.T(), faker, database.ChatQueuedMessage{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().InsertChatQueuedMessage(gomock.Any(), arg).Return(qm, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(qm)
}))
s.Run("InsertChatModelConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.InsertChatModelConfigParams{
Provider: "test-provider",
Model: "test-model",
DisplayName: "Test Model",
Enabled: true,
}
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{Provider: arg.Provider, Model: arg.Model, DisplayName: arg.DisplayName, Enabled: arg.Enabled})
dbm.EXPECT().InsertChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
}))
s.Run("InsertChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.InsertChatProviderParams{
Provider: "test-provider",
DisplayName: "Test Provider",
APIKey: "test-api-key",
Enabled: true,
}
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: arg.Provider, DisplayName: arg.DisplayName, APIKey: arg.APIKey, Enabled: arg.Enabled})
dbm.EXPECT().InsertChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
}))
s.Run("PopNextQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
qm := testutil.Fake(s.T(), faker, database.ChatQueuedMessage{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().PopNextQueuedMessage(gomock.Any(), chat.ID).Return(qm, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns(qm)
}))
s.Run("UpdateChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatByIDParams{
ID: chat.ID,
Title: "Updated title",
}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatByID(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
}))
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatHeartbeatParams{
ID: chat.ID,
WorkerID: uuid.New(),
}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatHeartbeat(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
}))
s.Run("UpdateChatModelConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
arg := database.UpdateChatModelConfigParams{
ID: config.ID,
Provider: "updated-provider",
Model: "updated-model",
DisplayName: "Updated Model",
Enabled: true,
}
dbm.EXPECT().UpdateChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
}))
s.Run("UpdateChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
provider := testutil.Fake(s.T(), faker, database.ChatProvider{})
arg := database.UpdateChatProviderParams{
ID: provider.ID,
DisplayName: "Updated Provider",
APIKey: "updated-api-key",
Enabled: true,
}
dbm.EXPECT().UpdateChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
}))
s.Run("UpdateChatStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatStatus(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
}))
s.Run("UpdateChatWorkspace", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatWorkspaceParams{
ID: chat.ID,
WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
WorkspaceAgentID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
}
updatedChat := testutil.Fake(s.T(), faker, database.Chat{ID: chat.ID})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatWorkspace(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updatedChat)
}))
s.Run("UnsetDefaultChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UnsetDefaultChatModelConfigs(gomock.Any()).Return(nil).AnyTimes()
check.Args().Asserts(rbac.ResourceSystem, policy.ActionUpdate)
}))
s.Run("UpsertChatDiffStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
now := dbtime.Now()
arg := database.UpsertChatDiffStatusParams{
ChatID: chat.ID,
Url: sql.NullString{String: "https://example.com/pr/123", Valid: true},
PullRequestState: sql.NullString{String: "open", Valid: true},
ChangesRequested: false,
Additions: 10,
Deletions: 5,
ChangedFiles: 2,
RefreshedAt: now,
StaleAt: now.Add(time.Hour),
}
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpsertChatDiffStatus(gomock.Any(), arg).Return(diffStatus, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(diffStatus)
}))
s.Run("UpsertChatDiffStatusReference", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpsertChatDiffStatusReferenceParams{
ChatID: chat.ID,
Url: sql.NullString{String: "https://example.com/pr/123", Valid: true},
GitBranch: "feature/test",
GitRemoteOrigin: "origin",
StaleAt: dbtime.Now().Add(time.Hour),
}
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), arg).Return(diffStatus, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(diffStatus)
}))
}
func (s *MethodTestSuite) TestFile() {
s.Run("GetFileByHashAndCreator", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
f := testutil.Fake(s.T(), faker, database.File{})
@@ -1671,11 +1326,6 @@ func (s *MethodTestSuite) TestTemplate() {
dbm.EXPECT().GetTemplateInsightsByTemplate(gomock.Any(), arg).Return([]database.GetTemplateInsightsByTemplateRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceTemplate, policy.ActionViewInsights)
}))
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetTelemetryTaskEventsParams{}
dbm.EXPECT().GetTelemetryTaskEvents(gomock.Any(), arg).Return([]database.GetTelemetryTaskEventsRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceTask.All(), policy.ActionRead)
}))
s.Run("GetTemplateAppInsights", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetTemplateAppInsightsParams{}
dbm.EXPECT().GetTemplateAppInsights(gomock.Any(), arg).Return([]database.GetTemplateAppInsightsRow{}, nil).AnyTimes()
@@ -2319,15 +1969,6 @@ func (s *MethodTestSuite) TestWorkspace() {
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
check.Args(build.ID).Asserts(ws, policy.ActionRead).Returns(build)
}))
s.Run("GetWorkspaceBuildProvisionerStateByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
row := database.GetWorkspaceBuildProvisionerStateByIDRow{
ProvisionerState: []byte("state"),
TemplateID: uuid.New(),
TemplateOrganizationID: uuid.New(),
}
dbm.EXPECT().GetWorkspaceBuildProvisionerStateByID(gomock.Any(), gomock.Any()).Return(row, nil).AnyTimes()
check.Args(uuid.New()).Asserts(row, policy.ActionUpdate).Returns(row)
}))
s.Run("GetWorkspaceBuildByJobID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ws := testutil.Fake(s.T(), faker, database.Workspace{})
build := testutil.Fake(s.T(), faker, database.WorkspaceBuild{WorkspaceID: ws.ID})
@@ -5105,20 +4746,6 @@ func (s *MethodTestSuite) TestAIBridge() {
check.Args(params, emptyPreparedAuthorized{}).Asserts()
}))
s.Run("ListAIBridgeModels", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
params := database.ListAIBridgeModelsParams{}
db.EXPECT().ListAuthorizedAIBridgeModels(gomock.Any(), params, gomock.Any()).Return([]string{}, nil).AnyTimes()
// No asserts here because SQLFilter.
check.Args(params).Asserts()
}))
s.Run("ListAuthorizedAIBridgeModels", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
params := database.ListAIBridgeModelsParams{}
db.EXPECT().ListAuthorizedAIBridgeModels(gomock.Any(), params, gomock.Any()).Return([]string{}, nil).AnyTimes()
// No asserts here because SQLFilter.
check.Args(params, emptyPreparedAuthorized{}).Asserts()
}))
s.Run("ListAIBridgeTokenUsagesByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ids := []uuid.UUID{{1}}
db.EXPECT().ListAIBridgeTokenUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeTokenUsage{}, nil).AnyTimes()
-19
View File
@@ -67,8 +67,6 @@ type WorkspaceBuildBuilder struct {
jobError string // Error message for failed jobs
jobErrorCode string // Error code for failed jobs
provisionerState []byte
}
// BuilderOption is a functional option for customizing job timestamps
@@ -140,15 +138,6 @@ func (b WorkspaceBuildBuilder) Seed(seed database.WorkspaceBuild) WorkspaceBuild
return b
}
// ProvisionerState sets the provisioner state for the workspace build.
// This is stored separately from the seed because ProvisionerState is
// not part of the WorkspaceBuild view struct.
func (b WorkspaceBuildBuilder) ProvisionerState(state []byte) WorkspaceBuildBuilder {
//nolint: revive // returns modified struct
b.provisionerState = state
return b
}
func (b WorkspaceBuildBuilder) Resource(resource ...*sdkproto.Resource) WorkspaceBuildBuilder {
//nolint: revive // returns modified struct
b.resources = append(b.resources, resource...)
@@ -475,14 +464,6 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
}
resp.Build = dbgen.WorkspaceBuild(b.t, b.db, b.seed)
if len(b.provisionerState) > 0 {
err = b.db.UpdateWorkspaceBuildProvisionerStateByID(ownerCtx, database.UpdateWorkspaceBuildProvisionerStateByIDParams{
ID: resp.Build.ID,
UpdatedAt: dbtime.Now(),
ProvisionerState: b.provisionerState,
})
require.NoError(b.t, err, "update provisioner state")
}
b.logger.Debug(context.Background(), "created workspace build",
slog.F("build_id", resp.Build.ID),
slog.F("workspace_id", resp.Workspace.ID),
+1 -3
View File
@@ -504,7 +504,7 @@ func WorkspaceBuild(t testing.TB, db database.Store, orig database.WorkspaceBuil
Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart),
InitiatorID: takeFirst(orig.InitiatorID, uuid.New()),
JobID: jobID,
ProvisionerState: []byte{},
ProvisionerState: takeFirstSlice(orig.ProvisionerState, []byte{}),
Deadline: takeFirst(orig.Deadline, dbtime.Now().Add(time.Hour)),
MaxDeadline: takeFirst(orig.MaxDeadline, time.Time{}),
Reason: takeFirst(orig.Reason, database.BuildReasonInitiator),
@@ -1373,8 +1373,6 @@ func OAuth2ProviderAppCode(t testing.TB, db database.Store, seed database.OAuth2
ResourceUri: seed.ResourceUri,
CodeChallenge: seed.CodeChallenge,
CodeChallengeMethod: seed.CodeChallengeMethod,
StateHash: seed.StateHash,
RedirectUri: seed.RedirectUri,
})
require.NoError(t, err, "insert oauth2 app code")
return code
+1 -378
View File
@@ -104,14 +104,6 @@ func (m queryMetricsStore) DeleteOrganization(ctx context.Context, id uuid.UUID)
return r0
}
func (m queryMetricsStore) AcquireChat(ctx context.Context, arg database.AcquireChatParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.AcquireChat(ctx, arg)
m.queryLatencies.WithLabelValues("AcquireChat").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "AcquireChat").Inc()
return r0, r1
}
func (m queryMetricsStore) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error {
start := time.Now()
r0 := m.s.AcquireLock(ctx, pgAdvisoryXactLock)
@@ -164,7 +156,6 @@ func (m queryMetricsStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context
start := time.Now()
r0 := m.s.BatchUpdateWorkspaceAgentMetadata(ctx, arg)
m.queryLatencies.WithLabelValues("BatchUpdateWorkspaceAgentMetadata").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BatchUpdateWorkspaceAgentMetadata").Inc()
return r0
}
@@ -320,14 +311,6 @@ func (m queryMetricsStore) DeleteAPIKeysByUserID(ctx context.Context, userID uui
return r0
}
func (m queryMetricsStore) DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteAllChatQueuedMessages(ctx, chatID)
m.queryLatencies.WithLabelValues("DeleteAllChatQueuedMessages").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteAllChatQueuedMessages").Inc()
return r0
}
func (m queryMetricsStore) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) error {
start := time.Now()
r0 := m.s.DeleteAllTailnetTunnels(ctx, arg)
@@ -352,46 +335,6 @@ func (m queryMetricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.C
return r0
}
func (m queryMetricsStore) DeleteChatByID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatByID(ctx, id)
m.queryLatencies.WithLabelValues("DeleteChatByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatByID").Inc()
return r0
}
func (m queryMetricsStore) DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatMessagesByChatID(ctx, chatID)
m.queryLatencies.WithLabelValues("DeleteChatMessagesByChatID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatMessagesByChatID").Inc()
return r0
}
func (m queryMetricsStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatModelConfigByID(ctx, id)
m.queryLatencies.WithLabelValues("DeleteChatModelConfigByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatModelConfigByID").Inc()
return r0
}
func (m queryMetricsStore) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatProviderByID(ctx, id)
m.queryLatencies.WithLabelValues("DeleteChatProviderByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatProviderByID").Inc()
return r0
}
func (m queryMetricsStore) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
start := time.Now()
r0 := m.s.DeleteChatQueuedMessage(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteChatQueuedMessage").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatQueuedMessage").Inc()
return r0
}
func (m queryMetricsStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
start := time.Now()
r0, r1 := m.s.DeleteCryptoKey(ctx, arg)
@@ -831,7 +774,7 @@ func (m queryMetricsStore) GetAPIKeyByName(ctx context.Context, arg database.Get
return r0, r1
}
func (m queryMetricsStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
func (m queryMetricsStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
start := time.Now()
r0, r1 := m.s.GetAPIKeysByLoginType(ctx, loginType)
m.queryLatencies.WithLabelValues("GetAPIKeysByLoginType").Observe(time.Since(start).Seconds())
@@ -959,126 +902,6 @@ func (m queryMetricsStore) GetAuthorizationUserRoles(ctx context.Context, userID
return r0, r1
}
func (m queryMetricsStore) GetChatByID(ctx context.Context, id uuid.UUID) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetChatByID(ctx, id)
m.queryLatencies.WithLabelValues("GetChatByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatByID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetChatByIDForUpdate(ctx, id)
m.queryLatencies.WithLabelValues("GetChatByIDForUpdate").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatByIDForUpdate").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
start := time.Now()
r0, r1 := m.s.GetChatDiffStatusByChatID(ctx, chatID)
m.queryLatencies.WithLabelValues("GetChatDiffStatusByChatID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDiffStatusByChatID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIDs []uuid.UUID) ([]database.ChatDiffStatus, error) {
start := time.Now()
r0, r1 := m.s.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
m.queryLatencies.WithLabelValues("GetChatDiffStatusesByChatIDs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDiffStatusesByChatIDs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessageByID(ctx, id)
m.queryLatencies.WithLabelValues("GetChatMessageByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessageByID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessagesByChatID(ctx, chatID)
m.queryLatencies.WithLabelValues("GetChatMessagesByChatID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesByChatID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessagesForPromptByChatID(ctx, chatID)
m.queryLatencies.WithLabelValues("GetChatMessagesForPromptByChatID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesForPromptByChatID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
start := time.Now()
r0, r1 := m.s.GetChatModelConfigByID(ctx, id)
m.queryLatencies.WithLabelValues("GetChatModelConfigByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigByID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatModelConfigByProviderAndModel(ctx context.Context, arg database.GetChatModelConfigByProviderAndModelParams) (database.ChatModelConfig, error) {
start := time.Now()
r0, r1 := m.s.GetChatModelConfigByProviderAndModel(ctx, arg)
m.queryLatencies.WithLabelValues("GetChatModelConfigByProviderAndModel").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigByProviderAndModel").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
start := time.Now()
r0, r1 := m.s.GetChatModelConfigs(ctx)
m.queryLatencies.WithLabelValues("GetChatModelConfigs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.GetChatProviderByID(ctx, id)
m.queryLatencies.WithLabelValues("GetChatProviderByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.GetChatProviderByProvider(ctx, provider)
m.queryLatencies.WithLabelValues("GetChatProviderByProvider").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByProvider").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.GetChatProviders(ctx)
m.queryLatencies.WithLabelValues("GetChatProviders").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviders").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatQueuedMessages(ctx, chatID)
m.queryLatencies.WithLabelValues("GetChatQueuedMessages").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatQueuedMessages").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetChatsByOwnerID(ctx, ownerID)
m.queryLatencies.WithLabelValues("GetChatsByOwnerID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsByOwnerID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
start := time.Now()
r0, r1 := m.s.GetConnectionLogsOffset(ctx, arg)
@@ -1135,14 +958,6 @@ func (m queryMetricsStore) GetDERPMeshKey(ctx context.Context) (string, error) {
return r0, r1
}
func (m queryMetricsStore) GetDefaultChatModelConfig(ctx context.Context) (database.ChatModelConfig, error) {
start := time.Now()
r0, r1 := m.s.GetDefaultChatModelConfig(ctx)
m.queryLatencies.WithLabelValues("GetDefaultChatModelConfig").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetDefaultChatModelConfig").Inc()
return r0, r1
}
func (m queryMetricsStore) GetDefaultOrganization(ctx context.Context) (database.Organization, error) {
start := time.Now()
r0, r1 := m.s.GetDefaultOrganization(ctx)
@@ -1207,22 +1022,6 @@ func (m queryMetricsStore) GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx
return r0, r1
}
func (m queryMetricsStore) GetEnabledChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
start := time.Now()
r0, r1 := m.s.GetEnabledChatModelConfigs(ctx)
m.queryLatencies.WithLabelValues("GetEnabledChatModelConfigs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledChatModelConfigs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.GetEnabledChatProviders(ctx)
m.queryLatencies.WithLabelValues("GetEnabledChatProviders").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledChatProviders").Inc()
return r0, r1
}
func (m queryMetricsStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
start := time.Now()
r0, r1 := m.s.GetExternalAuthLink(ctx, arg)
@@ -1919,14 +1718,6 @@ func (m queryMetricsStore) GetRuntimeConfig(ctx context.Context, key string) (st
return r0, r1
}
func (m queryMetricsStore) GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetStaleChats(ctx, staleThreshold)
m.queryLatencies.WithLabelValues("GetStaleChats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetStaleChats").Inc()
return r0, r1
}
func (m queryMetricsStore) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]database.TailnetPeer, error) {
start := time.Now()
r0, r1 := m.s.GetTailnetPeers(ctx, id)
@@ -1999,14 +1790,6 @@ func (m queryMetricsStore) GetTelemetryItems(ctx context.Context) ([]database.Te
return r0, r1
}
func (m queryMetricsStore) GetTelemetryTaskEvents(ctx context.Context, createdAfter database.GetTelemetryTaskEventsParams) ([]database.GetTelemetryTaskEventsRow, error) {
start := time.Now()
r0, r1 := m.s.GetTelemetryTaskEvents(ctx, createdAfter)
m.queryLatencies.WithLabelValues("GetTelemetryTaskEvents").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTelemetryTaskEvents").Inc()
return r0, r1
}
func (m queryMetricsStore) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) {
start := time.Now()
r0, r1 := m.s.GetTemplateAppInsights(ctx, arg)
@@ -2647,14 +2430,6 @@ func (m queryMetricsStore) GetWorkspaceBuildParametersByBuildIDs(ctx context.Con
return r0, r1
}
func (m queryMetricsStore) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
start := time.Now()
r0, r1 := m.s.GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuildID)
m.queryLatencies.WithLabelValues("GetWorkspaceBuildProvisionerStateByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetWorkspaceBuildProvisionerStateByID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
start := time.Now()
r0, r1 := m.s.GetWorkspaceBuildStatsByTemplates(ctx, since)
@@ -2919,46 +2694,6 @@ func (m queryMetricsStore) InsertAuditLog(ctx context.Context, arg database.Inse
return r0, r1
}
func (m queryMetricsStore) InsertChat(ctx context.Context, arg database.InsertChatParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.InsertChat(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChat").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChat").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.InsertChatMessage(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatMessage").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatMessage").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
start := time.Now()
r0, r1 := m.s.InsertChatModelConfig(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatModelConfig").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatModelConfig").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.InsertChatProvider(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatProvider").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatProvider").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
start := time.Now()
r0, r1 := m.s.InsertChatQueuedMessage(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatQueuedMessage").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatQueuedMessage").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) {
start := time.Now()
r0, r1 := m.s.InsertCryptoKey(ctx, arg)
@@ -3463,14 +3198,6 @@ func (m queryMetricsStore) ListAIBridgeInterceptionsTelemetrySummaries(ctx conte
return r0, r1
}
func (m queryMetricsStore) ListAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams) ([]string, error) {
start := time.Now()
r0, r1 := m.s.ListAIBridgeModels(ctx, arg)
m.queryLatencies.WithLabelValues("ListAIBridgeModels").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAIBridgeModels").Inc()
return r0, r1
}
func (m queryMetricsStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
start := time.Now()
r0, r1 := m.s.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIds)
@@ -3495,22 +3222,6 @@ func (m queryMetricsStore) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.
return r0, r1
}
func (m queryMetricsStore) ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.ListChatsByRootID(ctx, rootChatID)
m.queryLatencies.WithLabelValues("ListChatsByRootID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatsByRootID").Inc()
return r0, r1
}
func (m queryMetricsStore) ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.ListChildChatsByParentID(ctx, parentChatID)
m.queryLatencies.WithLabelValues("ListChildChatsByParentID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChildChatsByParentID").Inc()
return r0, r1
}
func (m queryMetricsStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
start := time.Now()
r0, r1 := m.s.ListProvisionerKeysByOrganization(ctx, organizationID)
@@ -3591,14 +3302,6 @@ func (m queryMetricsStore) PaginatedOrganizationMembers(ctx context.Context, arg
return r0, r1
}
func (m queryMetricsStore) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (database.ChatQueuedMessage, error) {
start := time.Now()
r0, r1 := m.s.PopNextQueuedMessage(ctx, chatID)
m.queryLatencies.WithLabelValues("PopNextQueuedMessage").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "PopNextQueuedMessage").Inc()
return r0, r1
}
func (m queryMetricsStore) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
start := time.Now()
r0 := m.s.ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx, templateID)
@@ -3671,14 +3374,6 @@ func (m queryMetricsStore) UnfavoriteWorkspace(ctx context.Context, id uuid.UUID
return r0
}
func (m queryMetricsStore) UnsetDefaultChatModelConfigs(ctx context.Context) error {
start := time.Now()
r0 := m.s.UnsetDefaultChatModelConfigs(ctx)
m.queryLatencies.WithLabelValues("UnsetDefaultChatModelConfigs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UnsetDefaultChatModelConfigs").Inc()
return r0
}
func (m queryMetricsStore) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
start := time.Now()
r0, r1 := m.s.UpdateAIBridgeInterceptionEnded(ctx, arg)
@@ -3695,54 +3390,6 @@ func (m queryMetricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.Up
return r0
}
func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatByID(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatByID").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatHeartbeat(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatHeartbeat").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeat").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatModelConfig(ctx context.Context, arg database.UpdateChatModelConfigParams) (database.ChatModelConfig, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatModelConfig(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatModelConfig").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatModelConfig").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatProvider(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatProvider").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatProvider").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatStatus(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatStatus").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatStatus").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatWorkspace(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatWorkspace").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatWorkspace").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
start := time.Now()
r0, r1 := m.s.UpdateCryptoKeyDeletesAt(ctx, arg)
@@ -4454,22 +4101,6 @@ func (m queryMetricsStore) UpsertBoundaryUsageStats(ctx context.Context, arg dat
return r0, r1
}
func (m queryMetricsStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
start := time.Now()
r0, r1 := m.s.UpsertChatDiffStatus(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertChatDiffStatus").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDiffStatus").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertChatDiffStatusReference(ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
start := time.Now()
r0, r1 := m.s.UpsertChatDiffStatusReference(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertChatDiffStatusReference").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDiffStatusReference").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
start := time.Now()
r0, r1 := m.s.UpsertConnectionLog(ctx, arg)
@@ -4781,11 +4412,3 @@ func (m queryMetricsStore) CountAuthorizedAIBridgeInterceptions(ctx context.Cont
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountAuthorizedAIBridgeInterceptions").Inc()
return r0, r1
}
func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error) {
start := time.Now()
r0, r1 := m.s.ListAuthorizedAIBridgeModels(ctx, arg, prepared)
m.queryLatencies.WithLabelValues("ListAuthorizedAIBridgeModels").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeModels").Inc()
return r0, r1
}

Some files were not shown because too many files have changed in this diff Show More