Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d15a6be1c7 | |||
| fe9202ab6f | |||
| e36944f9cc | |||
| 670801f310 | |||
| 3a826bcbb9 | |||
| 42b467acf7 | |||
| 98bc5f04fb | |||
| 02aeef5a91 | |||
| 1cbc8929ab |
@@ -4,7 +4,7 @@ description: |
|
||||
inputs:
|
||||
version:
|
||||
description: "The Go version to use."
|
||||
default: "1.25.7"
|
||||
default: "1.25.6"
|
||||
use-preinstalled-go:
|
||||
description: "Whether to use preinstalled Go."
|
||||
default: "false"
|
||||
|
||||
@@ -7,5 +7,5 @@ runs:
|
||||
- name: Install Terraform
|
||||
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # v3.1.2
|
||||
with:
|
||||
terraform_version: 1.14.5
|
||||
terraform_version: 1.14.1
|
||||
terraform_wrapper: false
|
||||
|
||||
@@ -489,14 +489,6 @@ jobs:
|
||||
# macOS will output "The default interactive shell is now zsh" intermittently in CI.
|
||||
touch ~/.bash_profile && echo "export BASH_SILENCE_DEPRECATION_WARNING=1" >> ~/.bash_profile
|
||||
|
||||
- name: Increase PTY limit (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
shell: bash
|
||||
run: |
|
||||
# Increase PTY limit to avoid exhaustion during tests.
|
||||
# Default is 511; 999 is the maximum value on CI runner.
|
||||
sudo sysctl -w kern.tty.ptmx_max=999
|
||||
|
||||
- name: Test with PostgreSQL Database (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
uses: ./.github/actions/test-go-pg
|
||||
|
||||
@@ -146,7 +146,7 @@ jobs:
|
||||
echo "image=$(cat "$image_job")" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@c1824fd6edce30d7ab345a9989de00bbd46ef284 # v0.34.0
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8
|
||||
with:
|
||||
image-ref: ${{ steps.build.outputs.image }}
|
||||
format: sarif
|
||||
|
||||
@@ -23,7 +23,7 @@ jobs:
|
||||
egress-policy: audit
|
||||
|
||||
- name: stale
|
||||
uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
|
||||
uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
|
||||
with:
|
||||
stale-issue-label: "stale"
|
||||
stale-pr-label: "stale"
|
||||
|
||||
@@ -98,6 +98,3 @@ AGENTS.local.md
|
||||
|
||||
# Ignore plans written by AI agents.
|
||||
PLAN.md
|
||||
|
||||
# Ignore any dev licenses
|
||||
license.txt
|
||||
|
||||
@@ -854,7 +854,7 @@ enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
|
||||
# -C sets the directory for the go run command
|
||||
go run -C ./scripts/apitypings main.go > $@
|
||||
./scripts/biome_format.sh src/api/typesGenerated.ts
|
||||
(cd site/ && pnpm exec biome format --write src/api/typesGenerated.ts)
|
||||
touch "$@"
|
||||
|
||||
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
|
||||
@@ -863,7 +863,7 @@ site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/prot
|
||||
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*)
|
||||
go run ./scripts/gensite/ -icons "$@"
|
||||
./scripts/biome_format.sh src/theme/icons.json
|
||||
(cd site/ && pnpm exec biome format --write src/theme/icons.json)
|
||||
touch "$@"
|
||||
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates)
|
||||
@@ -901,12 +901,12 @@ codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scope
|
||||
|
||||
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
|
||||
go run scripts/typegen/main.go rbac typescript > "$@"
|
||||
./scripts/biome_format.sh src/api/rbacresourcesGenerated.ts
|
||||
(cd site/ && pnpm exec biome format --write src/api/rbacresourcesGenerated.ts)
|
||||
touch "$@"
|
||||
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go
|
||||
go run scripts/typegen/main.go countries > "$@"
|
||||
./scripts/biome_format.sh src/api/countriesGenerated.ts
|
||||
(cd site/ && pnpm exec biome format --write src/api/countriesGenerated.ts)
|
||||
touch "$@"
|
||||
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES)
|
||||
@@ -950,11 +950,11 @@ coderd/apidoc/.gen: \
|
||||
touch "$@"
|
||||
|
||||
docs/manifest.json: site/node_modules/.installed coderd/apidoc/.gen docs/reference/cli/index.md
|
||||
./scripts/biome_format.sh ../docs/manifest.json
|
||||
(cd site/ && pnpm exec biome format --write ../docs/manifest.json)
|
||||
touch "$@"
|
||||
|
||||
coderd/apidoc/swagger.json: site/node_modules/.installed coderd/apidoc/.gen
|
||||
./scripts/biome_format.sh ../coderd/apidoc/swagger.json
|
||||
(cd site/ && pnpm exec biome format --write ../coderd/apidoc/swagger.json)
|
||||
touch "$@"
|
||||
|
||||
update-golden-files:
|
||||
@@ -999,19 +999,11 @@ enterprise/tailnet/testdata/.gen-golden: $(wildcard enterprise/tailnet/testdata/
|
||||
touch "$@"
|
||||
|
||||
helm/coder/tests/testdata/.gen-golden: $(wildcard helm/coder/tests/testdata/*.yaml) $(wildcard helm/coder/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/coder/tests/*_test.go)
|
||||
if command -v helm >/dev/null 2>&1; then
|
||||
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
|
||||
else
|
||||
echo "WARNING: helm not found; skipping helm/coder golden generation" >&2
|
||||
fi
|
||||
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
|
||||
touch "$@"
|
||||
|
||||
helm/provisioner/tests/testdata/.gen-golden: $(wildcard helm/provisioner/tests/testdata/*.yaml) $(wildcard helm/provisioner/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/provisioner/tests/*_test.go)
|
||||
if command -v helm >/dev/null 2>&1; then
|
||||
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
|
||||
else
|
||||
echo "WARNING: helm not found; skipping helm/provisioner golden generation" >&2
|
||||
fi
|
||||
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
|
||||
touch "$@"
|
||||
|
||||
coderd/.gen-golden: $(wildcard coderd/testdata/*/*.golden) $(GO_SRC_FILES) $(wildcard coderd/*_test.go)
|
||||
|
||||
+2
-10
@@ -111,12 +111,6 @@ type Client interface {
|
||||
ConnectRPC28(ctx context.Context) (
|
||||
proto.DRPCAgentClient28, tailnetproto.DRPCTailnetClient28, error,
|
||||
)
|
||||
// ConnectRPC28WithRole is like ConnectRPC28 but sends an explicit
|
||||
// role query parameter to the server. The workspace agent should
|
||||
// use role "agent" to enable connection monitoring.
|
||||
ConnectRPC28WithRole(ctx context.Context, role string) (
|
||||
proto.DRPCAgentClient28, tailnetproto.DRPCTailnetClient28, error,
|
||||
)
|
||||
tailnet.DERPMapRewriter
|
||||
agentsdk.RefreshableSessionTokenProvider
|
||||
}
|
||||
@@ -1003,10 +997,8 @@ func (a *agent) run() (retErr error) {
|
||||
return xerrors.Errorf("refresh token: %w", err)
|
||||
}
|
||||
|
||||
// ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs.
|
||||
// We pass role "agent" to enable connection monitoring on the server, which tracks
|
||||
// the agent's connectivity state (first_connected_at, last_connected_at, disconnected_at).
|
||||
aAPI, tAPI, err := a.client.ConnectRPC28WithRole(a.hardCtx, "agent")
|
||||
// ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs
|
||||
aAPI, tAPI, err := a.client.ConnectRPC28(a.hardCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,22 +1,37 @@
|
||||
package agentsocket_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("agentsocket is not supported on Windows")
|
||||
}
|
||||
|
||||
t.Run("StartStop", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
logger := slog.Make().Leveled(slog.LevelDebug)
|
||||
server, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
|
||||
require.NoError(t, err)
|
||||
@@ -26,7 +41,7 @@ func TestServer(t *testing.T) {
|
||||
t.Run("AlreadyStarted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
logger := slog.Make().Leveled(slog.LevelDebug)
|
||||
server1, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
|
||||
require.NoError(t, err)
|
||||
@@ -34,4 +49,90 @@ func TestServer(t *testing.T) {
|
||||
_, err = agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
|
||||
require.ErrorContains(t, err, "create socket")
|
||||
})
|
||||
|
||||
t.Run("AutoSocketPath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
logger := slog.Make().Leveled(slog.LevelDebug)
|
||||
server, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
}
|
||||
|
||||
func TestServerWindowsNotSupported(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS != "windows" {
|
||||
t.Skip("this test only runs on Windows")
|
||||
}
|
||||
|
||||
t.Run("NewServer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
logger := slog.Make().Leveled(slog.LevelDebug)
|
||||
_, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
|
||||
require.ErrorContains(t, err, "agentsocket is not supported on Windows")
|
||||
})
|
||||
|
||||
t.Run("NewClient", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := agentsocket.NewClient(context.Background(), agentsocket.WithPath("test.sock"))
|
||||
require.ErrorContains(t, err, "agentsocket is not supported on Windows")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAgentInitializesOnWindowsWithoutSocketServer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS != "windows" {
|
||||
t.Skip("this test only runs on Windows")
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := testutil.Logger(t).Named("agent")
|
||||
|
||||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||||
|
||||
coordinator := tailnet.NewCoordinator(logger)
|
||||
t.Cleanup(func() {
|
||||
_ = coordinator.Close()
|
||||
})
|
||||
|
||||
statsCh := make(chan *agentproto.Stats, 50)
|
||||
agentID := uuid.New()
|
||||
manifest := agentsdk.Manifest{
|
||||
AgentID: agentID,
|
||||
AgentName: "test-agent",
|
||||
WorkspaceName: "test-workspace",
|
||||
OwnerName: "test-user",
|
||||
WorkspaceID: uuid.New(),
|
||||
DERPMap: derpMap,
|
||||
}
|
||||
|
||||
client := agenttest.NewClient(t, logger.Named("agenttest"), agentID, manifest, statsCh, coordinator)
|
||||
t.Cleanup(client.Close)
|
||||
|
||||
options := agent.Options{
|
||||
Client: client,
|
||||
Filesystem: afero.NewMemMapFs(),
|
||||
Logger: logger.Named("agent"),
|
||||
ReconnectingPTYTimeout: testutil.WaitShort,
|
||||
EnvironmentVariables: map[string]string{},
|
||||
SocketPath: "",
|
||||
}
|
||||
|
||||
agnt := agent.New(options)
|
||||
t.Cleanup(func() {
|
||||
_ = agnt.Close()
|
||||
})
|
||||
|
||||
startup := testutil.TryReceive(ctx, t, client.GetStartup())
|
||||
require.NotNil(t, startup, "agent should send startup message")
|
||||
|
||||
err := agnt.Close()
|
||||
require.NoError(t, err, "agent should close cleanly")
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package agentsocket_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -28,10 +30,14 @@ func newSocketClient(ctx context.Context, t *testing.T, socketPath string) *agen
|
||||
func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("agentsocket is not supported on Windows")
|
||||
}
|
||||
|
||||
t.Run("Ping", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -51,7 +57,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
|
||||
t.Run("NewUnit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -73,7 +79,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("UnitAlreadyStarted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -103,7 +109,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("UnitAlreadyCompleted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -142,7 +148,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("UnitNotReady", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -172,7 +178,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("NewUnits", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -197,7 +203,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("DependencyAlreadyRegistered", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -232,7 +238,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("DependencyAddedAfterDependentStarted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -274,7 +280,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("UnregisteredUnit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -293,7 +299,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("UnitNotReady", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -317,7 +323,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("UnitReady", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
|
||||
@@ -4,60 +4,19 @@ package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/user"
|
||||
"strings"
|
||||
|
||||
"github.com/Microsoft/go-winio"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
const defaultSocketPath = `\\.\pipe\com.coder.agentsocket`
|
||||
|
||||
func createSocket(path string) (net.Listener, error) {
|
||||
if path == "" {
|
||||
path = defaultSocketPath
|
||||
}
|
||||
if !strings.HasPrefix(path, `\\.\pipe\`) {
|
||||
return nil, xerrors.Errorf("%q is not a valid local socket path", path)
|
||||
}
|
||||
|
||||
user, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to look up current user: %w", err)
|
||||
}
|
||||
sid := user.Uid
|
||||
|
||||
// SecurityDescriptor is in SDDL format. c.f.
|
||||
// https://learn.microsoft.com/en-us/windows/win32/secauthz/security-descriptor-string-format for full details.
|
||||
// D: indicates this is a Discretionary Access Control List (DACL), which is Windows-speak for ACLs that allow or
|
||||
// deny access (as opposed to SACL which controls audit logging).
|
||||
// P indicates that this DACL is "protected" from being modified thru inheritance
|
||||
// () delimit access control entries (ACEs), here we only have one, which, allows (A) generic all (GA) access to our
|
||||
// specific user's security ID (SID).
|
||||
//
|
||||
// Note that although Microsoft docs at https://learn.microsoft.com/en-us/windows/win32/ipc/named-pipes warns that
|
||||
// named pipes are accessible from remote machines in the general case, the `winio` package sets the flag
|
||||
// windows.FILE_PIPE_REJECT_REMOTE_CLIENTS when creating pipes, so connections from remote machines are always
|
||||
// denied. This is important because we sort of expect customers to run the Coder agent under a generic user
|
||||
// account unless they are very sophisticated. We don't want this socket to cross the boundary of the local machine.
|
||||
configuration := &winio.PipeConfig{
|
||||
SecurityDescriptor: fmt.Sprintf("D:P(A;;GA;;;%s)", sid),
|
||||
}
|
||||
|
||||
listener, err := winio.ListenPipe(path, configuration)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to open named pipe: %w", err)
|
||||
}
|
||||
return listener, nil
|
||||
func createSocket(_ string) (net.Listener, error) {
|
||||
return nil, xerrors.New("agentsocket is not supported on Windows")
|
||||
}
|
||||
|
||||
func cleanupSocket(path string) error {
|
||||
return os.Remove(path)
|
||||
func cleanupSocket(_ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func dialSocket(ctx context.Context, path string) (net.Conn, error) {
|
||||
return winio.DialPipeContext(ctx, path)
|
||||
func dialSocket(_ context.Context, _ string) (net.Conn, error) {
|
||||
return nil, xerrors.New("agentsocket is not supported on Windows")
|
||||
}
|
||||
|
||||
@@ -124,12 +124,6 @@ func (c *Client) Close() {
|
||||
c.derpMapOnce.Do(func() { close(c.derpMapUpdates) })
|
||||
}
|
||||
|
||||
func (c *Client) ConnectRPC28WithRole(ctx context.Context, _ string) (
|
||||
agentproto.DRPCAgentClient28, proto.DRPCTailnetClient28, error,
|
||||
) {
|
||||
return c.ConnectRPC28(ctx)
|
||||
}
|
||||
|
||||
func (c *Client) ConnectRPC28(ctx context.Context) (
|
||||
agentproto.DRPCAgentClient28, proto.DRPCTailnetClient28, error,
|
||||
) {
|
||||
@@ -235,10 +229,6 @@ type FakeAgentAPI struct {
|
||||
pushResourcesMonitoringUsageFunc func(*agentproto.PushResourcesMonitoringUsageRequest) (*agentproto.PushResourcesMonitoringUsageResponse, error)
|
||||
}
|
||||
|
||||
func (*FakeAgentAPI) UpdateAppStatus(context.Context, *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (f *FakeAgentAPI) GetManifest(context.Context, *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
|
||||
return f.manifest, nil
|
||||
}
|
||||
|
||||
+330
-544
File diff suppressed because it is too large
Load Diff
+1
-20
@@ -436,7 +436,7 @@ message CreateSubAgentRequest {
|
||||
}
|
||||
|
||||
repeated DisplayApp display_apps = 6;
|
||||
|
||||
|
||||
optional bytes id = 7;
|
||||
}
|
||||
|
||||
@@ -494,24 +494,6 @@ message ReportBoundaryLogsRequest {
|
||||
|
||||
message ReportBoundaryLogsResponse {}
|
||||
|
||||
// UpdateAppStatusRequest updates the given Workspace App's status. c.f. agentsdk.PatchAppStatus
|
||||
message UpdateAppStatusRequest {
|
||||
string slug = 1;
|
||||
|
||||
enum AppStatusState {
|
||||
WORKING = 0;
|
||||
IDLE = 1;
|
||||
COMPLETE = 2;
|
||||
FAILURE = 3;
|
||||
}
|
||||
AppStatusState state = 2;
|
||||
|
||||
string message = 3;
|
||||
string uri = 4;
|
||||
}
|
||||
|
||||
message UpdateAppStatusResponse {}
|
||||
|
||||
service Agent {
|
||||
rpc GetManifest(GetManifestRequest) returns (Manifest);
|
||||
rpc GetServiceBanner(GetServiceBannerRequest) returns (ServiceBanner);
|
||||
@@ -530,5 +512,4 @@ service Agent {
|
||||
rpc DeleteSubAgent(DeleteSubAgentRequest) returns (DeleteSubAgentResponse);
|
||||
rpc ListSubAgents(ListSubAgentsRequest) returns (ListSubAgentsResponse);
|
||||
rpc ReportBoundaryLogs(ReportBoundaryLogsRequest) returns (ReportBoundaryLogsResponse);
|
||||
rpc UpdateAppStatus(UpdateAppStatusRequest) returns (UpdateAppStatusResponse);
|
||||
}
|
||||
|
||||
@@ -56,7 +56,6 @@ type DRPCAgentClient interface {
|
||||
DeleteSubAgent(ctx context.Context, in *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
|
||||
ListSubAgents(ctx context.Context, in *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
|
||||
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
|
||||
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
|
||||
}
|
||||
|
||||
type drpcAgentClient struct {
|
||||
@@ -222,15 +221,6 @@ func (c *drpcAgentClient) ReportBoundaryLogs(ctx context.Context, in *ReportBoun
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentClient) UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
|
||||
out := new(UpdateAppStatusResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
type DRPCAgentServer interface {
|
||||
GetManifest(context.Context, *GetManifestRequest) (*Manifest, error)
|
||||
GetServiceBanner(context.Context, *GetServiceBannerRequest) (*ServiceBanner, error)
|
||||
@@ -249,7 +239,6 @@ type DRPCAgentServer interface {
|
||||
DeleteSubAgent(context.Context, *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
|
||||
ListSubAgents(context.Context, *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
|
||||
ReportBoundaryLogs(context.Context, *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
|
||||
UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
|
||||
}
|
||||
|
||||
type DRPCAgentUnimplementedServer struct{}
|
||||
@@ -322,13 +311,9 @@ func (s *DRPCAgentUnimplementedServer) ReportBoundaryLogs(context.Context, *Repo
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentUnimplementedServer) UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
type DRPCAgentDescription struct{}
|
||||
|
||||
func (DRPCAgentDescription) NumMethods() int { return 18 }
|
||||
func (DRPCAgentDescription) NumMethods() int { return 17 }
|
||||
|
||||
func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
|
||||
switch n {
|
||||
@@ -485,15 +470,6 @@ func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver,
|
||||
in1.(*ReportBoundaryLogsRequest),
|
||||
)
|
||||
}, DRPCAgentServer.ReportBoundaryLogs, true
|
||||
case 17:
|
||||
return "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentServer).
|
||||
UpdateAppStatus(
|
||||
ctx,
|
||||
in1.(*UpdateAppStatusRequest),
|
||||
)
|
||||
}, DRPCAgentServer.UpdateAppStatus, true
|
||||
default:
|
||||
return "", nil, nil, nil, false
|
||||
}
|
||||
@@ -774,19 +750,3 @@ func (x *drpcAgent_ReportBoundaryLogsStream) SendAndClose(m *ReportBoundaryLogsR
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgent_UpdateAppStatusStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*UpdateAppStatusResponse) error
|
||||
}
|
||||
|
||||
type drpcAgent_UpdateAppStatusStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgent_UpdateAppStatusStream) SendAndClose(m *UpdateAppStatusResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_proto_agent_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
@@ -73,13 +73,9 @@ type DRPCAgentClient27 interface {
|
||||
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
|
||||
}
|
||||
|
||||
// DRPCAgentClient28 is the Agent API at v2.8. It adds
|
||||
// - a SubagentId field to the WorkspaceAgentDevcontainer message
|
||||
// - an Id field to the CreateSubAgentRequest message.
|
||||
// - UpdateAppStatus RPC.
|
||||
//
|
||||
// Compatible with Coder v2.31+
|
||||
// DRPCAgentClient28 is the Agent API at v2.8. It adds a SubagentId field to the
|
||||
// WorkspaceAgentDevcontainer message, and a Id field to the CreateSubAgentRequest
|
||||
// message. Compatible with Coder v2.31+
|
||||
type DRPCAgentClient28 interface {
|
||||
DRPCAgentClient27
|
||||
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
|
||||
}
|
||||
|
||||
+45
-50
@@ -10,7 +10,6 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
@@ -24,7 +23,6 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/codersdk/toolsdk"
|
||||
"github.com/coder/retry"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
@@ -541,6 +539,7 @@ func (r *RootCmd) mcpServer() *serpent.Command {
|
||||
defer cancel()
|
||||
defer srv.queue.Close()
|
||||
|
||||
cliui.Infof(inv.Stderr, "Failed to watch screen events")
|
||||
// Start the reporter, watcher, and server. These are all tied to the
|
||||
// lifetime of the MCP server, which is itself tied to the lifetime of the
|
||||
// AI agent.
|
||||
@@ -614,51 +613,48 @@ func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation)
|
||||
}
|
||||
|
||||
func (s *mcpServer) startWatcher(ctx context.Context, inv *serpent.Invocation) {
|
||||
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
for retrier := retry.New(time.Second, 30*time.Second); retrier.Wait(ctx); {
|
||||
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
|
||||
if err == nil {
|
||||
retrier.Reset()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case event := <-eventsCh:
|
||||
switch ev := event.(type) {
|
||||
case agentapi.EventStatusChange:
|
||||
// If the screen is stable, report idle.
|
||||
state := codersdk.WorkspaceAppStatusStateWorking
|
||||
if ev.Status == agentapi.StatusStable {
|
||||
state = codersdk.WorkspaceAppStatusStateIdle
|
||||
}
|
||||
err := s.queue.Push(taskReport{
|
||||
state: state,
|
||||
})
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
|
||||
return
|
||||
case event := <-eventsCh:
|
||||
switch ev := event.(type) {
|
||||
case agentapi.EventStatusChange:
|
||||
state := codersdk.WorkspaceAppStatusStateWorking
|
||||
if ev.Status == agentapi.StatusStable {
|
||||
state = codersdk.WorkspaceAppStatusStateIdle
|
||||
}
|
||||
err := s.queue.Push(taskReport{
|
||||
state: state,
|
||||
})
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
|
||||
return
|
||||
}
|
||||
case agentapi.EventMessageUpdate:
|
||||
if ev.Role == agentapi.RoleUser {
|
||||
err := s.queue.Push(taskReport{
|
||||
messageID: &ev.Id,
|
||||
state: codersdk.WorkspaceAppStatusStateWorking,
|
||||
})
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
case agentapi.EventMessageUpdate:
|
||||
if ev.Role == agentapi.RoleUser {
|
||||
err := s.queue.Push(taskReport{
|
||||
messageID: &ev.Id,
|
||||
state: codersdk.WorkspaceAppStatusStateWorking,
|
||||
})
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
|
||||
return
|
||||
}
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
|
||||
}
|
||||
break loop
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -696,14 +692,13 @@ func (s *mcpServer) startServer(ctx context.Context, inv *serpent.Invocation, in
|
||||
// Add tool dependencies.
|
||||
toolOpts := []func(*toolsdk.Deps){
|
||||
toolsdk.WithTaskReporter(func(args toolsdk.ReportTaskArgs) error {
|
||||
state := codersdk.WorkspaceAppStatusState(args.State)
|
||||
// The agent does not reliably report idle, so when AgentAPI is
|
||||
// enabled we override idle to working and let the screen watcher
|
||||
// detect the real idle via StatusStable. Final states (failure,
|
||||
// complete) are trusted from the agent since the screen watcher
|
||||
// cannot produce them.
|
||||
if s.aiAgentAPIClient != nil && state == codersdk.WorkspaceAppStatusStateIdle {
|
||||
state = codersdk.WorkspaceAppStatusStateWorking
|
||||
// The agent does not reliably report its status correctly. If AgentAPI
|
||||
// is enabled, we will always set the status to "working" when we get an
|
||||
// MCP message, and rely on the screen watcher to eventually catch the
|
||||
// idle state.
|
||||
state := codersdk.WorkspaceAppStatusStateWorking
|
||||
if s.aiAgentAPIClient == nil {
|
||||
state = codersdk.WorkspaceAppStatusState(args.State)
|
||||
}
|
||||
return s.queue.Push(taskReport{
|
||||
link: args.Link,
|
||||
|
||||
+1
-185
@@ -921,7 +921,7 @@ func TestExpMcpReporter(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
// We override idle from the agent to working, but trust final states.
|
||||
// We ignore the state from the agent and assume "working".
|
||||
{
|
||||
name: "IgnoreAgentState",
|
||||
// AI agent reports that it is finished but the summary says it is doing
|
||||
@@ -953,46 +953,6 @@ func TestExpMcpReporter(t *testing.T) {
|
||||
Message: "finished",
|
||||
},
|
||||
},
|
||||
// Agent reports failure; trusted even with AgentAPI enabled.
|
||||
{
|
||||
state: codersdk.WorkspaceAppStatusStateFailure,
|
||||
summary: "something broke",
|
||||
expected: &codersdk.WorkspaceAppStatus{
|
||||
State: codersdk.WorkspaceAppStatusStateFailure,
|
||||
Message: "something broke",
|
||||
},
|
||||
},
|
||||
// After failure, watcher reports stable -> idle.
|
||||
{
|
||||
event: makeStatusEvent(agentapi.StatusStable),
|
||||
expected: &codersdk.WorkspaceAppStatus{
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "something broke",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Final states pass through with AgentAPI enabled.
|
||||
{
|
||||
name: "AllowFinalStates",
|
||||
tests: []test{
|
||||
{
|
||||
state: codersdk.WorkspaceAppStatusStateWorking,
|
||||
summary: "doing work",
|
||||
expected: &codersdk.WorkspaceAppStatus{
|
||||
State: codersdk.WorkspaceAppStatusStateWorking,
|
||||
Message: "doing work",
|
||||
},
|
||||
},
|
||||
// Agent reports complete; not overridden.
|
||||
{
|
||||
state: codersdk.WorkspaceAppStatusStateComplete,
|
||||
summary: "all done",
|
||||
expected: &codersdk.WorkspaceAppStatus{
|
||||
State: codersdk.WorkspaceAppStatusStateComplete,
|
||||
Message: "all done",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// When AgentAPI is not being used, we accept agent state updates as-is.
|
||||
@@ -1150,148 +1110,4 @@ func TestExpMcpReporter(t *testing.T) {
|
||||
<-cmdDone
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Reconnect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a test deployment and workspace.
|
||||
client, db := coderdtest.NewWithDatabase(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
|
||||
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user2.ID,
|
||||
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
|
||||
a[0].Apps = []*proto.App{
|
||||
{
|
||||
Slug: "vscode",
|
||||
},
|
||||
}
|
||||
return a
|
||||
}).Do()
|
||||
|
||||
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitLong))
|
||||
|
||||
// Watch the workspace for changes.
|
||||
watcher, err := client.WatchWorkspace(ctx, r.Workspace.ID)
|
||||
require.NoError(t, err)
|
||||
var lastAppStatus codersdk.WorkspaceAppStatus
|
||||
nextUpdate := func() codersdk.WorkspaceAppStatus {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
require.FailNow(t, "timed out waiting for status update")
|
||||
case w, ok := <-watcher:
|
||||
require.True(t, ok, "watch channel closed")
|
||||
if w.LatestAppStatus != nil && w.LatestAppStatus.ID != lastAppStatus.ID {
|
||||
t.Logf("Got status update: %s > %s", lastAppStatus.State, w.LatestAppStatus.State)
|
||||
lastAppStatus = *w.LatestAppStatus
|
||||
return lastAppStatus
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mock AI AgentAPI server that supports disconnect/reconnect.
|
||||
disconnect := make(chan struct{})
|
||||
listening := make(chan func(sse codersdk.ServerSentEvent) error)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Create a cancelable context so we can stop the SSE sender
|
||||
// goroutine on disconnect without waiting for the HTTP
|
||||
// serve loop to cancel r.Context().
|
||||
sseCtx, sseCancel := context.WithCancel(r.Context())
|
||||
defer sseCancel()
|
||||
r = r.WithContext(sseCtx)
|
||||
|
||||
send, closed, err := httpapi.ServerSentEventSender(w, r)
|
||||
if err != nil {
|
||||
httpapi.Write(sseCtx, w, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error setting up server-sent events.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// Send initial message so the watcher knows the agent is active.
|
||||
send(*makeMessageEvent(0, agentapi.RoleAgent))
|
||||
select {
|
||||
case listening <- send:
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-closed:
|
||||
case <-disconnect:
|
||||
sseCancel()
|
||||
<-closed
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
inv, _ := clitest.New(t,
|
||||
"exp", "mcp", "server",
|
||||
"--agent-url", client.URL.String(),
|
||||
"--agent-token", r.AgentToken,
|
||||
"--app-status-slug", "vscode",
|
||||
"--allowed-tools=coder_report_task",
|
||||
"--ai-agentapi-url", srv.URL,
|
||||
)
|
||||
inv = inv.WithContext(ctx)
|
||||
|
||||
pty := ptytest.New(t)
|
||||
inv.Stdin = pty.Input()
|
||||
inv.Stdout = pty.Output()
|
||||
stderr := ptytest.New(t)
|
||||
inv.Stderr = stderr.Output()
|
||||
|
||||
// Run the MCP server.
|
||||
clitest.Start(t, inv)
|
||||
|
||||
// Initialize.
|
||||
payload := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
|
||||
pty.WriteLine(payload)
|
||||
_ = pty.ReadLine(ctx) // ignore echo
|
||||
_ = pty.ReadLine(ctx) // ignore init response
|
||||
|
||||
// Get first sender from the initial SSE connection.
|
||||
sender := testutil.RequireReceive(ctx, t, listening)
|
||||
|
||||
// Self-report a working status via tool call.
|
||||
toolPayload := `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"doing work","link":""}}}`
|
||||
pty.WriteLine(toolPayload)
|
||||
_ = pty.ReadLine(ctx) // ignore echo
|
||||
_ = pty.ReadLine(ctx) // ignore response
|
||||
got := nextUpdate()
|
||||
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
|
||||
require.Equal(t, "doing work", got.Message)
|
||||
|
||||
// Watcher sends stable, verify idle is reported.
|
||||
err = sender(*makeStatusEvent(agentapi.StatusStable))
|
||||
require.NoError(t, err)
|
||||
got = nextUpdate()
|
||||
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
|
||||
|
||||
// Disconnect the SSE connection by signaling the handler to return.
|
||||
testutil.RequireSend(ctx, t, disconnect, struct{}{})
|
||||
|
||||
// Wait for the watcher to reconnect and get the new sender.
|
||||
sender = testutil.RequireReceive(ctx, t, listening)
|
||||
|
||||
// After reconnect, self-report a working status again.
|
||||
toolPayload = `{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"reconnected","link":""}}}`
|
||||
pty.WriteLine(toolPayload)
|
||||
_ = pty.ReadLine(ctx) // ignore echo
|
||||
_ = pty.ReadLine(ctx) // ignore response
|
||||
got = nextUpdate()
|
||||
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
|
||||
require.Equal(t, "reconnected", got.Message)
|
||||
|
||||
// Verify the watcher still processes events after reconnect.
|
||||
err = sender(*makeStatusEvent(agentapi.StatusStable))
|
||||
require.NoError(t, err)
|
||||
got = nextUpdate()
|
||||
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
|
||||
|
||||
cancel()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -29,7 +29,6 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
|
||||
templateVersionJobTimeout time.Duration
|
||||
prebuildWorkspaceTimeout time.Duration
|
||||
noCleanup bool
|
||||
provisionerTags []string
|
||||
|
||||
tracingFlags = &scaletestTracingFlags{}
|
||||
timeoutStrategy = &timeoutFlags{}
|
||||
@@ -112,16 +111,10 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
|
||||
|
||||
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
|
||||
|
||||
tags, err := ParseProvisionerTags(provisionerTags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range numTemplates {
|
||||
id := strconv.Itoa(int(i))
|
||||
cfg := prebuilds.Config{
|
||||
OrganizationID: me.OrganizationIDs[0],
|
||||
ProvisionerTags: tags,
|
||||
NumPresets: int(numPresets),
|
||||
NumPresetPrebuilds: int(numPresetPrebuilds),
|
||||
TemplateVersionJobTimeout: templateVersionJobTimeout,
|
||||
@@ -290,11 +283,6 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
|
||||
Description: "Skip cleanup (deletion test) and leave resources intact.",
|
||||
Value: serpent.BoolOf(&noCleanup),
|
||||
},
|
||||
{
|
||||
Flag: "provisioner-tag",
|
||||
Description: "Specify a set of tags to target provisioner daemons.",
|
||||
Value: serpent.StringArrayOf(&provisionerTags),
|
||||
},
|
||||
}
|
||||
|
||||
tracingFlags.attach(&cmd.Options)
|
||||
|
||||
+1
-45
@@ -4,9 +4,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
@@ -19,29 +16,6 @@ import (
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
// detectGitRef attempts to resolve the current git branch and remote
|
||||
// origin URL from the given working directory. These are sent to the
|
||||
// control plane so it can look up PR/diff status via the GitHub API
|
||||
// without SSHing into the workspace. Failures are silently ignored
|
||||
// since this is best-effort.
|
||||
func detectGitRef(workingDirectory string) (branch string, remoteOrigin string) {
|
||||
run := func(args ...string) string {
|
||||
//nolint:gosec
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
if workingDirectory != "" {
|
||||
cmd.Dir = workingDirectory
|
||||
}
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
branch = run("git", "rev-parse", "--abbrev-ref", "HEAD")
|
||||
remoteOrigin = run("git", "config", "--get", "remote.origin.url")
|
||||
return branch, remoteOrigin
|
||||
}
|
||||
|
||||
// gitAskpass is used by the Coder agent to automatically authenticate
|
||||
// with Git providers based on a hostname.
|
||||
func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
@@ -64,20 +38,8 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
|
||||
workingDirectory, err := os.Getwd()
|
||||
if err != nil {
|
||||
workingDirectory = ""
|
||||
}
|
||||
|
||||
// Detect the current git branch and remote origin so
|
||||
// the control plane can resolve diffs without needing
|
||||
// to SSH back into the workspace.
|
||||
gitBranch, gitRemoteOrigin := detectGitRef(workingDirectory)
|
||||
|
||||
token, err := client.ExternalAuth(ctx, agentsdk.ExternalAuthRequest{
|
||||
Match: host,
|
||||
GitBranch: gitBranch,
|
||||
GitRemoteOrigin: gitRemoteOrigin,
|
||||
Match: host,
|
||||
})
|
||||
if err != nil {
|
||||
var apiError *codersdk.Error
|
||||
@@ -96,12 +58,6 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
return xerrors.Errorf("get git token: %w", err)
|
||||
}
|
||||
if token.URL != "" {
|
||||
// This is to help the agent authenticate with Git.
|
||||
if inv.Environ.Get("CODER_CHAT_AGENT") == "true" {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, `You must notify the user to authenticate with Git.\n\nThe URL is: %s\n`, token.URL)
|
||||
return cliui.ErrCanceled
|
||||
}
|
||||
|
||||
if err := openURL(inv, token.URL); err == nil {
|
||||
cliui.Infof(inv.Stderr, "Your browser has been opened to authenticate with Git:\n%s", token.URL)
|
||||
} else {
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
@@ -111,58 +108,4 @@ func TestGitAskpass(t *testing.T) {
|
||||
})
|
||||
stdout.ExpectMatch("username")
|
||||
})
|
||||
|
||||
t.Run("ChatAgentAuthRequired", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
listenCalls := atomic.Int64{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Has("listen") {
|
||||
listenCalls.Add(1)
|
||||
}
|
||||
httpapi.Write(context.Background(), w, http.StatusOK, agentsdk.ExternalAuthResponse{
|
||||
URL: "https://coder.example.com/external-auth/github",
|
||||
Type: codersdk.EnhancedExternalAuthProviderGitHub.String(),
|
||||
})
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
inv, _ := clitest.New(
|
||||
t,
|
||||
"--agent-url",
|
||||
srv.URL,
|
||||
"--no-open",
|
||||
"Username for 'https://github.com':",
|
||||
)
|
||||
inv.Environ.Set("GIT_PREFIX", "/")
|
||||
inv.Environ.Set("CODER_AGENT_TOKEN", "fake-token")
|
||||
inv.Environ.Set("CODER_CHAT_AGENT", "true")
|
||||
|
||||
var stderr bytes.Buffer
|
||||
inv.Stderr = &stderr
|
||||
|
||||
err := inv.Run()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "exit code")
|
||||
require.Zero(t, listenCalls.Load())
|
||||
|
||||
output := stderr.String()
|
||||
require.Contains(t, output, "CODER_GITAUTH_REQUIRED:")
|
||||
require.NotContains(t, output, "Open the following URL to authenticate")
|
||||
require.NotContains(t, output, "Your browser has been opened")
|
||||
|
||||
_, markerRaw, found := strings.Cut(output, "CODER_GITAUTH_REQUIRED:")
|
||||
require.True(t, found)
|
||||
var marker struct {
|
||||
ProviderID string `json:"provider_id"`
|
||||
ProviderType string `json:"provider_type"`
|
||||
ProviderDisplayName string `json:"provider_display_name"`
|
||||
AuthenticateURL string `json:"authenticate_url"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimSpace(markerRaw)), &marker))
|
||||
require.Equal(t, "github", marker.ProviderID)
|
||||
require.Equal(t, codersdk.EnhancedExternalAuthProviderGitHub.String(), marker.ProviderType)
|
||||
require.Equal(t, "GitHub", marker.ProviderDisplayName)
|
||||
require.Equal(t, "https://coder.example.com/external-auth/github", marker.AuthenticateURL)
|
||||
})
|
||||
}
|
||||
|
||||
+5
-1
@@ -106,7 +106,11 @@ func TestList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
memberClient, member = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
sharedWorkspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
|
||||
@@ -297,7 +297,7 @@ func (pr *ParameterResolver) verifyConstraints(resolved []codersdk.WorkspaceBuil
|
||||
return xerrors.Errorf("ephemeral parameter %q can be used only with --prompt-ephemeral-parameters or --ephemeral-parameter flag", r.Name)
|
||||
}
|
||||
|
||||
if !tvp.Mutable && action != WorkspaceCreate && !pr.isFirstTimeUse(r.Name) {
|
||||
if !tvp.Mutable && action != WorkspaceCreate {
|
||||
return xerrors.Errorf("parameter %q is immutable and cannot be updated", r.Name)
|
||||
}
|
||||
}
|
||||
|
||||
+43
-120
@@ -137,15 +137,6 @@ func createOIDCConfig(ctx context.Context, logger slog.Logger, vals *codersdk.De
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse oidc oauth callback url: %w", err)
|
||||
}
|
||||
|
||||
if vals.OIDC.RedirectURL.String() != "" {
|
||||
redirectURL, err = vals.OIDC.RedirectURL.Value().Parse("/api/v2/users/oidc/callback")
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse oidc redirect url %q", err)
|
||||
}
|
||||
logger.Warn(ctx, "custom OIDC redirect URL used instead of 'access_url', ensure this matches the value configured in your OIDC provider")
|
||||
}
|
||||
|
||||
// If the scopes contain 'groups', we enable group support.
|
||||
// Do not override any custom value set by the user.
|
||||
if slice.Contains(vals.OIDC.Scopes, "groups") && vals.OIDC.GroupField == "" {
|
||||
@@ -617,8 +608,28 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
}
|
||||
}
|
||||
|
||||
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read external auth providers from env: %w", err)
|
||||
}
|
||||
|
||||
promRegistry := prometheus.NewRegistry()
|
||||
oauthInstrument := promoauth.NewFactory(promRegistry)
|
||||
vals.ExternalAuthConfigs.Value = append(vals.ExternalAuthConfigs.Value, extAuthEnv...)
|
||||
externalAuthConfigs, err := externalauth.ConvertConfig(
|
||||
oauthInstrument,
|
||||
vals.ExternalAuthConfigs.Value,
|
||||
vals.AccessURL.Value(),
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("convert external auth config: %w", err)
|
||||
}
|
||||
for _, c := range externalAuthConfigs {
|
||||
logger.Debug(
|
||||
ctx, "loaded external auth config",
|
||||
slog.F("id", c.ID),
|
||||
)
|
||||
}
|
||||
|
||||
realIPConfig, err := httpmw.ParseRealIPConfig(vals.ProxyTrustedHeaders, vals.ProxyTrustedOrigins)
|
||||
if err != nil {
|
||||
@@ -649,7 +660,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
Pubsub: nil,
|
||||
CacheDir: cacheDir,
|
||||
GoogleTokenValidator: googleTokenValidator,
|
||||
ExternalAuthConfigs: nil,
|
||||
ExternalAuthConfigs: externalAuthConfigs,
|
||||
RealIPConfig: realIPConfig,
|
||||
SSHKeygenAlgorithm: sshKeygenAlgorithm,
|
||||
TracerProvider: tracerProvider,
|
||||
@@ -809,40 +820,6 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
return xerrors.Errorf("set deployment id: %w", err)
|
||||
}
|
||||
|
||||
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read external auth providers from env: %w", err)
|
||||
}
|
||||
mergedExternalAuthProviders := append([]codersdk.ExternalAuthConfig{}, vals.ExternalAuthConfigs.Value...)
|
||||
mergedExternalAuthProviders = append(mergedExternalAuthProviders, extAuthEnv...)
|
||||
vals.ExternalAuthConfigs.Value = mergedExternalAuthProviders
|
||||
|
||||
mergedExternalAuthProviders, err = maybeAppendDefaultGithubExternalAuthProvider(
|
||||
ctx,
|
||||
options.Logger,
|
||||
options.Database,
|
||||
vals,
|
||||
mergedExternalAuthProviders,
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("maybe append default github external auth provider: %w", err)
|
||||
}
|
||||
|
||||
options.ExternalAuthConfigs, err = externalauth.ConvertConfig(
|
||||
oauthInstrument,
|
||||
mergedExternalAuthProviders,
|
||||
vals.AccessURL.Value(),
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("convert external auth config: %w", err)
|
||||
}
|
||||
for _, c := range options.ExternalAuthConfigs {
|
||||
logger.Debug(
|
||||
ctx, "loaded external auth config",
|
||||
slog.F("id", c.ID),
|
||||
)
|
||||
}
|
||||
|
||||
// Manage push notifications.
|
||||
experiments := coderd.ReadExperiments(options.Logger, options.DeploymentValues.Experiments.Value())
|
||||
if experiments.Enabled(codersdk.ExperimentWebPush) {
|
||||
@@ -1940,79 +1917,6 @@ type githubOAuth2ConfigParams struct {
|
||||
enterpriseBaseURL string
|
||||
}
|
||||
|
||||
func isDeploymentEligibleForGithubDefaultProvider(ctx context.Context, db database.Store) (bool, error) {
|
||||
// We want to enable the default provider only for new deployments, and avoid
|
||||
// enabling it if a deployment was upgraded from an older version.
|
||||
// nolint:gocritic // Requires system privileges
|
||||
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return false, xerrors.Errorf("get github default eligible: %w", err)
|
||||
}
|
||||
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
|
||||
|
||||
if defaultEligibleNotSet {
|
||||
// nolint:gocritic // User count requires system privileges
|
||||
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
|
||||
if err != nil {
|
||||
return false, xerrors.Errorf("get user count: %w", err)
|
||||
}
|
||||
// We check if a deployment is new by checking if it has any users.
|
||||
defaultEligible = userCount == 0
|
||||
// nolint:gocritic // Requires system privileges
|
||||
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
|
||||
return false, xerrors.Errorf("upsert github default eligible: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return defaultEligible, nil
|
||||
}
|
||||
|
||||
func maybeAppendDefaultGithubExternalAuthProvider(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
db database.Store,
|
||||
vals *codersdk.DeploymentValues,
|
||||
mergedExplicitProviders []codersdk.ExternalAuthConfig,
|
||||
) ([]codersdk.ExternalAuthConfig, error) {
|
||||
if !vals.ExternalAuthGithubDefaultProviderEnable.Value() {
|
||||
logger.Info(ctx, "default github external auth provider suppressed",
|
||||
slog.F("reason", "disabled by configuration"),
|
||||
slog.F("flag", "external-auth-github-default-provider-enable"),
|
||||
)
|
||||
return mergedExplicitProviders, nil
|
||||
}
|
||||
|
||||
if len(mergedExplicitProviders) > 0 {
|
||||
logger.Info(ctx, "default github external auth provider suppressed",
|
||||
slog.F("reason", "explicit external auth providers configured"),
|
||||
slog.F("provider_count", len(mergedExplicitProviders)),
|
||||
)
|
||||
return mergedExplicitProviders, nil
|
||||
}
|
||||
|
||||
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !defaultEligible {
|
||||
logger.Info(ctx, "default github external auth provider suppressed",
|
||||
slog.F("reason", "deployment is not eligible"),
|
||||
)
|
||||
return mergedExplicitProviders, nil
|
||||
}
|
||||
|
||||
logger.Info(ctx, "injecting default github external auth provider",
|
||||
slog.F("type", codersdk.EnhancedExternalAuthProviderGitHub.String()),
|
||||
slog.F("client_id", GithubOAuth2DefaultProviderClientID),
|
||||
slog.F("device_flow", GithubOAuth2DefaultProviderDeviceFlow),
|
||||
)
|
||||
return append(mergedExplicitProviders, codersdk.ExternalAuthConfig{
|
||||
Type: codersdk.EnhancedExternalAuthProviderGitHub.String(),
|
||||
ClientID: GithubOAuth2DefaultProviderClientID,
|
||||
DeviceFlow: GithubOAuth2DefaultProviderDeviceFlow,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *codersdk.DeploymentValues) (*githubOAuth2ConfigParams, error) {
|
||||
params := githubOAuth2ConfigParams{
|
||||
accessURL: vals.AccessURL.Value(),
|
||||
@@ -2037,9 +1941,28 @@ func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *c
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Check if the deployment is eligible for the default GitHub OAuth2 provider.
|
||||
// We want to enable it only for new deployments, and avoid enabling it
|
||||
// if a deployment was upgraded from an older version.
|
||||
// nolint:gocritic // Requires system privileges
|
||||
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get github default eligible: %w", err)
|
||||
}
|
||||
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
|
||||
|
||||
if defaultEligibleNotSet {
|
||||
// nolint:gocritic // User count requires system privileges
|
||||
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get user count: %w", err)
|
||||
}
|
||||
// We check if a deployment is new by checking if it has any users.
|
||||
defaultEligible = userCount == 0
|
||||
// nolint:gocritic // Requires system privileges
|
||||
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
|
||||
return nil, xerrors.Errorf("upsert github default eligible: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !defaultEligible {
|
||||
|
||||
@@ -53,7 +53,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/migrations"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
"github.com/coder/coder/v2/coderd/userpassword"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/cryptorand"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
@@ -1741,18 +1740,6 @@ func TestServer(t *testing.T) {
|
||||
|
||||
// Next, we instruct the same server to display the YAML config
|
||||
// and then save it.
|
||||
// Because this is literally the same invocation, DefaultFn sets the
|
||||
// value of 'Default'. Which triggers a mutually exclusive error
|
||||
// on the next parse.
|
||||
// Usually we only parse flags once, so this is not an issue
|
||||
for _, c := range inv.Command.Children {
|
||||
if c.Name() == "server" {
|
||||
for i := range c.Options {
|
||||
c.Options[i].DefaultFn = nil
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
inv = inv.WithContext(testutil.Context(t, testutil.WaitMedium))
|
||||
//nolint:gocritic
|
||||
inv.Args = append(args, "--write-config")
|
||||
@@ -1806,155 +1793,6 @@ func TestServer(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest // This test sets environment variables.
|
||||
func TestServer_ExternalAuthGitHubDefaultProvider(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
args []string
|
||||
env map[string]string
|
||||
createUserPreStart bool
|
||||
expectedProviders []string
|
||||
}
|
||||
|
||||
run := func(t *testing.T, tc testCase) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
unsetPrefixedEnv := func(prefix string) {
|
||||
t.Helper()
|
||||
for _, envVar := range os.Environ() {
|
||||
envKey, _, found := strings.Cut(envVar, "=")
|
||||
if !found || !strings.HasPrefix(envKey, prefix) {
|
||||
continue
|
||||
}
|
||||
value, had := os.LookupEnv(envKey)
|
||||
require.True(t, had)
|
||||
require.NoError(t, os.Unsetenv(envKey))
|
||||
keyCopy := envKey
|
||||
valueCopy := value
|
||||
t.Cleanup(func() {
|
||||
// This is for setting/unsetting a number of prefixed env vars.
|
||||
// t.Setenv doesn't cover this use case.
|
||||
// nolint:usetesting
|
||||
_ = os.Setenv(keyCopy, valueCopy)
|
||||
})
|
||||
}
|
||||
}
|
||||
unsetPrefixedEnv("CODER_EXTERNAL_AUTH_")
|
||||
unsetPrefixedEnv("CODER_GITAUTH_")
|
||||
|
||||
dbURL, err := dbtestutil.Open(t)
|
||||
require.NoError(t, err)
|
||||
db, _ := dbtestutil.NewDB(t, dbtestutil.WithURL(dbURL))
|
||||
|
||||
const (
|
||||
existingUserEmail = "existing-user@coder.com"
|
||||
existingUserUsername = "existing-user"
|
||||
existingUserPassword = "SomeSecurePassword!"
|
||||
)
|
||||
if tc.createUserPreStart {
|
||||
hashedPassword, err := userpassword.Hash(existingUserPassword)
|
||||
require.NoError(t, err)
|
||||
_ = dbgen.User(t, db, database.User{
|
||||
Email: existingUserEmail,
|
||||
Username: existingUserUsername,
|
||||
HashedPassword: []byte(hashedPassword),
|
||||
})
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"server",
|
||||
"--postgres-url", dbURL,
|
||||
"--http-address", ":0",
|
||||
"--access-url", "https://example.com",
|
||||
}
|
||||
args = append(args, tc.args...)
|
||||
|
||||
inv, cfg := clitest.New(t, args...)
|
||||
for envKey, value := range tc.env {
|
||||
t.Setenv(envKey, value)
|
||||
}
|
||||
clitest.Start(t, inv)
|
||||
|
||||
accessURL := waitAccessURL(t, cfg)
|
||||
client := codersdk.New(accessURL)
|
||||
|
||||
if tc.createUserPreStart {
|
||||
loginResp, err := client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{
|
||||
Email: existingUserEmail,
|
||||
Password: existingUserPassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
client.SetSessionToken(loginResp.SessionToken)
|
||||
} else {
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
}
|
||||
|
||||
externalAuthResp, err := client.ListExternalAuths(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
gotProviders := map[string]codersdk.ExternalAuthLinkProvider{}
|
||||
for _, provider := range externalAuthResp.Providers {
|
||||
gotProviders[provider.ID] = provider
|
||||
}
|
||||
require.Len(t, gotProviders, len(tc.expectedProviders))
|
||||
|
||||
for _, providerID := range tc.expectedProviders {
|
||||
provider, ok := gotProviders[providerID]
|
||||
require.Truef(t, ok, "expected provider %q to be configured", providerID)
|
||||
if providerID == codersdk.EnhancedExternalAuthProviderGitHub.String() {
|
||||
require.Equal(t, codersdk.EnhancedExternalAuthProviderGitHub.String(), provider.Type)
|
||||
require.True(t, provider.Device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range []testCase{
|
||||
{
|
||||
name: "NewDeployment_NoExplicitProviders_InjectsDefaultGithub",
|
||||
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitHub.String()},
|
||||
},
|
||||
{
|
||||
name: "ExistingDeployment_DoesNotInjectDefaultGithub",
|
||||
createUserPreStart: true,
|
||||
expectedProviders: nil,
|
||||
},
|
||||
{
|
||||
name: "DefaultProviderDisabled_DoesNotInjectDefaultGithub",
|
||||
args: []string{
|
||||
"--external-auth-github-default-provider-enable=false",
|
||||
},
|
||||
expectedProviders: nil,
|
||||
},
|
||||
{
|
||||
name: "ExplicitProviderViaConfig_DoesNotInjectDefaultGithub",
|
||||
args: []string{
|
||||
`--external-auth-providers=[{"type":"gitlab","client_id":"config-client-id"}]`,
|
||||
},
|
||||
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
|
||||
},
|
||||
{
|
||||
name: "ExplicitProviderViaEnv_DoesNotInjectDefaultGithub",
|
||||
env: map[string]string{
|
||||
"CODER_EXTERNAL_AUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
|
||||
"CODER_EXTERNAL_AUTH_0_CLIENT_ID": "env-client-id",
|
||||
},
|
||||
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
|
||||
},
|
||||
{
|
||||
name: "ExplicitProviderViaLegacyEnv_DoesNotInjectDefaultGithub",
|
||||
env: map[string]string{
|
||||
"CODER_GITAUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
|
||||
"CODER_GITAUTH_0_CLIENT_ID": "legacy-env-client-id",
|
||||
},
|
||||
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
run(t, tc)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest // This test sets environment variables.
|
||||
func TestServer_Logging_NoParallel(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
+31
-7
@@ -25,7 +25,11 @@ func TestSharingShare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -64,8 +68,12 @@ func TestSharingShare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -119,7 +127,11 @@ func TestSharingShare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -170,7 +182,11 @@ func TestSharingStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -214,7 +230,11 @@ func TestSharingRemove(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -271,7 +291,11 @@ func TestSharingRemove(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
|
||||
+1
-1
@@ -120,7 +120,7 @@ func (r *RootCmd) start() *serpent.Command {
|
||||
func buildWorkspaceStartRequest(inv *serpent.Invocation, client *codersdk.Client, workspace codersdk.Workspace, parameterFlags workspaceParameterFlags, buildFlags buildFlags, action WorkspaceCLIAction) (codersdk.CreateWorkspaceBuildRequest, error) {
|
||||
version := workspace.LatestBuild.TemplateVersionID
|
||||
|
||||
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || workspace.TemplateRequireActiveVersion || action == WorkspaceUpdate {
|
||||
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || action == WorkspaceUpdate {
|
||||
version = workspace.TemplateActiveVersionID
|
||||
if version != workspace.LatestBuild.TemplateVersionID {
|
||||
action = WorkspaceUpdate
|
||||
|
||||
+4
-4
@@ -33,7 +33,7 @@ func TestStatePull(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taUser.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
|
||||
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
|
||||
Do()
|
||||
statefilePath := filepath.Join(t.TempDir(), "state")
|
||||
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name, statefilePath)
|
||||
@@ -54,7 +54,7 @@ func TestStatePull(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taUser.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
|
||||
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
|
||||
Do()
|
||||
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name)
|
||||
var gotState bytes.Buffer
|
||||
@@ -74,7 +74,7 @@ func TestStatePull(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taUser.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
|
||||
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
|
||||
Do()
|
||||
inv, root := clitest.New(t, "state", "pull", taUser.Username+"/"+r.Workspace.Name,
|
||||
"--build", fmt.Sprintf("%d", r.Build.BuildNumber))
|
||||
@@ -170,7 +170,7 @@ func TestStatePush(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taUser.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{}).ProvisionerState(initialState).
|
||||
Seed(database.WorkspaceBuild{ProvisionerState: initialState}).
|
||||
Do()
|
||||
wantState := []byte("updated state")
|
||||
stateFile, err := os.CreateTemp(t.TempDir(), "")
|
||||
|
||||
+7
-9
@@ -1,3 +1,5 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
@@ -5,7 +7,6 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -24,15 +25,12 @@ func setupSocketServer(t *testing.T) (path string, cleanup func()) {
|
||||
t.Helper()
|
||||
|
||||
// Use a temporary socket path for each test
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
|
||||
// Create parent directory if needed. Not necessary on Windows because named pipes live in an abstract namespace
|
||||
// not tied to any real files.
|
||||
if runtime.GOOS != "windows" {
|
||||
parentDir := filepath.Dir(socketPath)
|
||||
err := os.MkdirAll(parentDir, 0o700)
|
||||
require.NoError(t, err, "create socket directory")
|
||||
}
|
||||
// Create parent directory if needed
|
||||
parentDir := filepath.Dir(socketPath)
|
||||
err := os.MkdirAll(parentDir, 0o700)
|
||||
require.NoError(t, err, "create socket directory")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
|
||||
@@ -18,7 +18,6 @@ func (r *RootCmd) tasksCommand() *serpent.Command {
|
||||
r.taskList(),
|
||||
r.taskLogs(),
|
||||
r.taskPause(),
|
||||
r.taskResume(),
|
||||
r.taskSend(),
|
||||
r.taskStatus(),
|
||||
},
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) taskResume() *serpent.Command {
|
||||
var noWait bool
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "resume <task>",
|
||||
Short: "Resume a task",
|
||||
Long: FormatExamples(
|
||||
Example{
|
||||
Description: "Resume a task by name",
|
||||
Command: "coder task resume my-task",
|
||||
},
|
||||
Example{
|
||||
Description: "Resume another user's task",
|
||||
Command: "coder task resume alice/my-task",
|
||||
},
|
||||
Example{
|
||||
Description: "Resume a task without confirmation",
|
||||
Command: "coder task resume my-task --yes",
|
||||
},
|
||||
),
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(1),
|
||||
),
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
Flag: "no-wait",
|
||||
Description: "Return immediately after resuming the task.",
|
||||
Value: serpent.BoolOf(&noWait),
|
||||
},
|
||||
cliui.SkipPromptOption(),
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
task, err := client.TaskByIdentifier(ctx, inv.Args[0])
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve task %q: %w", inv.Args[0], err)
|
||||
}
|
||||
|
||||
display := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
|
||||
|
||||
if task.Status == codersdk.TaskStatusError || task.Status == codersdk.TaskStatusUnknown {
|
||||
return xerrors.Errorf("task %q is in %s state and cannot be resumed; check the workspace build logs and agent status for details", display, task.Status)
|
||||
} else if task.Status != codersdk.TaskStatusPaused {
|
||||
return xerrors.Errorf("task %q cannot be resumed (current status: %s)", display, task.Status)
|
||||
}
|
||||
|
||||
_, err = cliui.Prompt(inv, cliui.PromptOptions{
|
||||
Text: fmt.Sprintf("Resume task %s?", pretty.Sprint(cliui.DefaultStyles.Code, display)),
|
||||
IsConfirm: true,
|
||||
Default: cliui.ConfirmNo,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.ResumeTask(ctx, task.OwnerName, task.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resume task %q: %w", display, err)
|
||||
} else if resp.WorkspaceBuild == nil {
|
||||
return xerrors.Errorf("resume task %q: no workspace build returned", display)
|
||||
}
|
||||
|
||||
if noWait {
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Resuming task %q in the background.\n", cliui.Keyword(display))
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = cliui.WorkspaceBuild(ctx, inv.Stdout, client, resp.WorkspaceBuild.ID); err != nil {
|
||||
return xerrors.Errorf("watch resume build for task %q: %w", display, err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "\nThe %s task has been resumed.\n", cliui.Keyword(display))
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
@@ -1,183 +0,0 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestExpTaskResume(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// pauseTask is a helper that pauses a task and waits for the stop
|
||||
// build to complete.
|
||||
pauseTask := func(ctx context.Context, t *testing.T, client *codersdk.Client, task codersdk.Task) {
|
||||
t.Helper()
|
||||
|
||||
pauseResp, err := client.PauseTask(ctx, task.OwnerName, task.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pauseResp.WorkspaceBuild)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, pauseResp.WorkspaceBuild.ID)
|
||||
}
|
||||
|
||||
t.Run("WithYesFlag", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
|
||||
// When: We attempt to resume the task
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// Then: We expect the task to be resumed
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been resumed")
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
|
||||
// OtherUserTask verifies that an admin can resume a task owned by
|
||||
// another user using the "owner/name" identifier format.
|
||||
t.Run("OtherUserTask", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A different user's paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
|
||||
// When: We attempt to resume their task
|
||||
identifier := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
|
||||
inv, root := clitest.New(t, "task", "resume", identifier, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, adminClient, root)
|
||||
|
||||
// Then: We expect the task to be resumed
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been resumed")
|
||||
|
||||
updated, err := adminClient.TaskByIdentifier(ctx, identifier)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("NoWait", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
|
||||
// When: We attempt to resume the task (and specify no wait)
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes", "--no-wait")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// Then: We expect the task to be resumed in the background
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "in the background")
|
||||
|
||||
// And: The task to eventually be resumed
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
ws := coderdtest.MustWorkspace(t, userClient, task.WorkspaceID.UUID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, ws.LatestBuild.ID)
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("PromptConfirm", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
|
||||
// When: We attempt to resume the task
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// And: We confirm we want to resume the task
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
inv = inv.WithContext(ctx)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
pty.ExpectMatchContext(ctx, "Resume task")
|
||||
pty.WriteLine("yes")
|
||||
|
||||
// Then: We expect the task to be resumed
|
||||
pty.ExpectMatchContext(ctx, "has been resumed")
|
||||
require.NoError(t, w.Wait())
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("PromptDecline", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
|
||||
// When: We attempt to resume the task
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// But: Say no at the confirmation screen
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
inv = inv.WithContext(ctx)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
pty.ExpectMatchContext(ctx, "Resume task")
|
||||
pty.WriteLine("no")
|
||||
require.Error(t, w.Wait())
|
||||
|
||||
// Then: We expect the task to still be paused
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("TaskNotPaused", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to resume the task that is not paused
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// Then: We expect to get an error that the task is not paused
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.ErrorContains(t, err, "cannot be resumed")
|
||||
})
|
||||
}
|
||||
@@ -137,23 +137,6 @@ func Test_Tasks(t *testing.T) {
|
||||
require.Equal(t, codersdk.TaskStatusPaused, task.Status, "task should be paused")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "resume task",
|
||||
cmdArgs: []string{"task", "resume", taskName, "--yes"},
|
||||
assertFn: func(stdout string, userClient *codersdk.Client) {
|
||||
require.Contains(t, stdout, "has been resumed", "resume output should confirm task was resumed")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get task status after resume",
|
||||
cmdArgs: []string{"task", "status", taskName, "--output", "json"},
|
||||
assertFn: func(stdout string, userClient *codersdk.Client) {
|
||||
var task codersdk.Task
|
||||
require.NoError(t, json.NewDecoder(strings.NewReader(stdout)).Decode(&task), "should unmarshal task status")
|
||||
require.Equal(t, taskName, task.Name, "task name should match")
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, task.Status, "task should be initializing after resume")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete task",
|
||||
cmdArgs: []string{"task", "delete", taskName, "--yes"},
|
||||
|
||||
@@ -139,10 +139,8 @@ func (r *RootCmd) templateVersionsList() *serpent.Command {
|
||||
type templateVersionRow struct {
|
||||
// For json format:
|
||||
TemplateVersion codersdk.TemplateVersion `table:"-"`
|
||||
ActiveJSON bool `json:"active" table:"-"`
|
||||
|
||||
// For table format:
|
||||
ID string `json:"-" table:"id"`
|
||||
Name string `json:"-" table:"name,default_sort"`
|
||||
CreatedAt time.Time `json:"-" table:"created at"`
|
||||
CreatedBy string `json:"-" table:"created by"`
|
||||
@@ -168,8 +166,6 @@ func templateVersionsToRows(activeVersionID uuid.UUID, templateVersions ...coder
|
||||
|
||||
rows[i] = templateVersionRow{
|
||||
TemplateVersion: templateVersion,
|
||||
ActiveJSON: templateVersion.ID == activeVersionID,
|
||||
ID: templateVersion.ID.String(),
|
||||
Name: templateVersion.Name,
|
||||
CreatedAt: templateVersion.CreatedAt,
|
||||
CreatedBy: templateVersion.CreatedBy.Username,
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -42,33 +40,6 @@ func TestTemplateVersions(t *testing.T) {
|
||||
pty.ExpectMatch(version.CreatedBy.Username)
|
||||
pty.ExpectMatch("Active")
|
||||
})
|
||||
|
||||
t.Run("ListVersionsJSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
|
||||
_ = coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
|
||||
|
||||
inv, root := clitest.New(t, "templates", "versions", "list", template.Name, "--output", "json")
|
||||
clitest.SetupConfig(t, member, root)
|
||||
|
||||
var stdout bytes.Buffer
|
||||
inv.Stdout = &stdout
|
||||
|
||||
require.NoError(t, inv.Run())
|
||||
|
||||
var rows []struct {
|
||||
TemplateVersion codersdk.TemplateVersion `json:"TemplateVersion"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(stdout.Bytes(), &rows))
|
||||
require.Len(t, rows, 1)
|
||||
assert.Equal(t, version.ID, rows[0].TemplateVersion.ID)
|
||||
assert.True(t, rows[0].Active)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTemplateVersionsPromote(t *testing.T) {
|
||||
|
||||
+5
-11
@@ -49,9 +49,10 @@ OPTIONS:
|
||||
security purposes if a --wildcard-access-url is configured.
|
||||
|
||||
--disable-workspace-sharing bool, $CODER_DISABLE_WORKSPACE_SHARING
|
||||
Disable workspace sharing. Workspace ACL checking is disabled and only
|
||||
owners can have ssh, apps and terminal access to workspaces. Access
|
||||
based on the 'owner' role is also allowed unless disabled via
|
||||
Disable workspace sharing (requires the "workspace-sharing" experiment
|
||||
to be enabled). Workspace ACL checking is disabled and only owners can
|
||||
have ssh, apps and terminal access to workspaces. Access based on the
|
||||
'owner' role is also allowed unless disabled via
|
||||
--disable-owner-workspace-access.
|
||||
|
||||
--swagger-enable bool, $CODER_SWAGGER_ENABLE
|
||||
@@ -62,9 +63,6 @@ OPTIONS:
|
||||
Separate multiple experiments with commas, or enter '*' to opt-in to
|
||||
all available experiments.
|
||||
|
||||
--external-auth-github-default-provider-enable bool, $CODER_EXTERNAL_AUTH_GITHUB_DEFAULT_PROVIDER_ENABLE (default: true)
|
||||
Enable the default GitHub external auth provider managed by Coder.
|
||||
|
||||
--postgres-auth password|awsiamrds, $CODER_PG_AUTH (default: password)
|
||||
Type of auth to use when connecting to postgres. For AWS RDS, using
|
||||
IAM authentication (awsiamrds) is recommended.
|
||||
@@ -385,17 +383,13 @@ NETWORKING OPTIONS:
|
||||
--samesite-auth-cookie lax|none, $CODER_SAMESITE_AUTH_COOKIE (default: lax)
|
||||
Controls the 'SameSite' property is set on browser session cookies.
|
||||
|
||||
--secure-auth-cookie bool, $CODER_SECURE_AUTH_COOKIE (default: false)
|
||||
--secure-auth-cookie bool, $CODER_SECURE_AUTH_COOKIE
|
||||
Controls if the 'Secure' property is set on browser session cookies.
|
||||
|
||||
--wildcard-access-url string, $CODER_WILDCARD_ACCESS_URL
|
||||
Specifies the wildcard hostname to use for workspace applications in
|
||||
the form "*.example.com".
|
||||
|
||||
--host-prefix-cookie bool, $CODER_HOST_PREFIX_COOKIE (default: false)
|
||||
Recommended to be enabled. Enables `__Host-` prefix for cookies to
|
||||
guarantee they are only set by the right domain.
|
||||
|
||||
NETWORKING / DERP OPTIONS:
|
||||
Most Coder deployments never have to think about DERP because all connections
|
||||
between workspaces and users are peer-to-peer. However, when Coder cannot
|
||||
|
||||
-1
@@ -13,7 +13,6 @@ SUBCOMMANDS:
|
||||
list List tasks
|
||||
logs Show a task's logs
|
||||
pause Pause a task
|
||||
resume Resume a task
|
||||
send Send input to a task
|
||||
status Show the status of a task.
|
||||
|
||||
|
||||
-28
@@ -1,28 +0,0 @@
|
||||
coder v0.0.0-devel
|
||||
|
||||
USAGE:
|
||||
coder task resume [flags] <task>
|
||||
|
||||
Resume a task
|
||||
|
||||
- Resume a task by name:
|
||||
|
||||
$ coder task resume my-task
|
||||
|
||||
- Resume another user's task:
|
||||
|
||||
$ coder task resume alice/my-task
|
||||
|
||||
- Resume a task without confirmation:
|
||||
|
||||
$ coder task resume my-task --yes
|
||||
|
||||
OPTIONS:
|
||||
--no-wait bool
|
||||
Return immediately after resuming the task.
|
||||
|
||||
-y, --yes bool
|
||||
Bypass confirmation prompts.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
@@ -9,7 +9,7 @@ OPTIONS:
|
||||
-O, --org string, $CODER_ORGANIZATION
|
||||
Select which organization (uuid or name) to use.
|
||||
|
||||
-c, --column [id|name|created at|created by|status|active|archived] (default: name,created at,created by,status,active)
|
||||
-c, --column [name|created at|created by|status|active|archived] (default: name,created at,created by,status,active)
|
||||
Columns to display in table output.
|
||||
|
||||
--include-archived bool
|
||||
|
||||
+1
-1
@@ -27,7 +27,7 @@ USAGE:
|
||||
SUBCOMMANDS:
|
||||
create Create a token
|
||||
list List tokens
|
||||
remove Expire or delete a token
|
||||
remove Delete a token
|
||||
view Display detailed information about a token
|
||||
|
||||
———
|
||||
|
||||
@@ -15,10 +15,6 @@ OPTIONS:
|
||||
-c, --column [id|name|scopes|allow list|last used|expires at|created at|owner] (default: id,name,scopes,allow list,last used,expires at,created at)
|
||||
Columns to display in table output.
|
||||
|
||||
--include-expired bool
|
||||
Include expired tokens in the output. By default, expired tokens are
|
||||
hidden.
|
||||
|
||||
-o, --output table|json (default: table)
|
||||
Output format.
|
||||
|
||||
|
||||
+2
-10
@@ -1,19 +1,11 @@
|
||||
coder v0.0.0-devel
|
||||
|
||||
USAGE:
|
||||
coder tokens remove [flags] <name|id|token>
|
||||
coder tokens remove <name|id|token>
|
||||
|
||||
Expire or delete a token
|
||||
Delete a token
|
||||
|
||||
Aliases: delete, rm
|
||||
|
||||
Remove a token by expiring it. Use --delete to permanently hard-delete the
|
||||
token instead.
|
||||
|
||||
OPTIONS:
|
||||
--delete bool
|
||||
Permanently delete the token instead of expiring it. This removes the
|
||||
audit trail.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
|
||||
+5
-17
@@ -176,15 +176,11 @@ networking:
|
||||
# (default: <unset>, type: string-array)
|
||||
proxyTrustedOrigins: []
|
||||
# Controls if the 'Secure' property is set on browser session cookies.
|
||||
# (default: false, type: bool)
|
||||
# (default: <unset>, type: bool)
|
||||
secureAuthCookie: false
|
||||
# Controls the 'SameSite' property is set on browser session cookies.
|
||||
# (default: lax, type: enum[lax\|none])
|
||||
sameSiteAuthCookie: lax
|
||||
# Recommended to be enabled. Enables `__Host-` prefix for cookies to guarantee
|
||||
# they are only set by the right domain.
|
||||
# (default: false, type: bool)
|
||||
hostPrefixCookie: false
|
||||
# Whether Coder only allows connections to workspaces via the browser.
|
||||
# (default: <unset>, type: bool)
|
||||
browserOnly: false
|
||||
@@ -421,11 +417,6 @@ oidc:
|
||||
# an insecure OIDC configuration. It is not recommended to use this flag.
|
||||
# (default: <unset>, type: bool)
|
||||
dangerousSkipIssuerChecks: false
|
||||
# Optional override of the default redirect url which uses the deployment's access
|
||||
# url. Useful in situations where a deployment has more than 1 domain. Using this
|
||||
# setting can also break OIDC, so use with caution.
|
||||
# (default: <unset>, type: url)
|
||||
oidc-redirect-url:
|
||||
# Telemetry is critical to our ability to improve Coder. We strip all personal
|
||||
# information before sending data to our servers. Please only disable telemetry
|
||||
# when required by your organization's security policy.
|
||||
@@ -523,10 +514,10 @@ disablePathApps: false
|
||||
# workspaces.
|
||||
# (default: <unset>, type: bool)
|
||||
disableOwnerWorkspaceAccess: false
|
||||
# Disable workspace sharing. Workspace ACL checking is disabled and only owners
|
||||
# can have ssh, apps and terminal access to workspaces. Access based on the
|
||||
# 'owner' role is also allowed unless disabled via
|
||||
# --disable-owner-workspace-access.
|
||||
# Disable workspace sharing (requires the "workspace-sharing" experiment to be
|
||||
# enabled). Workspace ACL checking is disabled and only owners can have ssh, apps
|
||||
# and terminal access to workspaces. Access based on the 'owner' role is also
|
||||
# allowed unless disabled via --disable-owner-workspace-access.
|
||||
# (default: <unset>, type: bool)
|
||||
disableWorkspaceSharing: false
|
||||
# These options change the behavior of how clients interact with the Coder.
|
||||
@@ -563,9 +554,6 @@ supportLinks: []
|
||||
# External Authentication providers.
|
||||
# (default: <unset>, type: struct[[]codersdk.ExternalAuthConfig])
|
||||
externalAuthProviders: []
|
||||
# Enable the default GitHub external auth provider managed by Coder.
|
||||
# (default: true, type: bool)
|
||||
externalAuthGithubDefaultProviderEnable: true
|
||||
# Hostname of HTTPS server that runs https://github.com/coder/wgtunnel. By
|
||||
# default, this will pick the best available wgtunnel server hosted by Coder. e.g.
|
||||
# "tunnel.example.com".
|
||||
|
||||
+14
-37
@@ -218,10 +218,9 @@ func (r *RootCmd) listTokens() *serpent.Command {
|
||||
}
|
||||
|
||||
var (
|
||||
all bool
|
||||
includeExpired bool
|
||||
displayTokens []tokenListRow
|
||||
formatter = cliui.NewOutputFormatter(
|
||||
all bool
|
||||
displayTokens []tokenListRow
|
||||
formatter = cliui.NewOutputFormatter(
|
||||
cliui.TableFormat([]tokenListRow{}, defaultCols),
|
||||
cliui.JSONFormat(),
|
||||
)
|
||||
@@ -241,8 +240,7 @@ func (r *RootCmd) listTokens() *serpent.Command {
|
||||
}
|
||||
|
||||
tokens, err := client.Tokens(inv.Context(), codersdk.Me, codersdk.TokensFilter{
|
||||
IncludeAll: all,
|
||||
IncludeExpired: includeExpired,
|
||||
IncludeAll: all,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("list tokens: %w", err)
|
||||
@@ -276,12 +274,6 @@ func (r *RootCmd) listTokens() *serpent.Command {
|
||||
Description: "Specifies whether all users' tokens will be listed or not (must have Owner role to see all tokens).",
|
||||
Value: serpent.BoolOf(&all),
|
||||
},
|
||||
{
|
||||
Name: "include-expired",
|
||||
Flag: "include-expired",
|
||||
Description: "Include expired tokens in the output. By default, expired tokens are hidden.",
|
||||
Value: serpent.BoolOf(&includeExpired),
|
||||
},
|
||||
}
|
||||
|
||||
formatter.AttachOptions(&cmd.Options)
|
||||
@@ -331,13 +323,10 @@ func (r *RootCmd) viewToken() *serpent.Command {
|
||||
}
|
||||
|
||||
func (r *RootCmd) removeToken() *serpent.Command {
|
||||
var deleteToken bool
|
||||
cmd := &serpent.Command{
|
||||
Use: "remove <name|id|token>",
|
||||
Aliases: []string{"delete"},
|
||||
Short: "Expire or delete a token",
|
||||
Long: "Remove a token by expiring it. Use --delete to permanently hard-" +
|
||||
"delete the token instead.",
|
||||
Short: "Delete a token",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(1),
|
||||
),
|
||||
@@ -349,7 +338,7 @@ func (r *RootCmd) removeToken() *serpent.Command {
|
||||
|
||||
token, err := client.APIKeyByName(inv.Context(), codersdk.Me, inv.Args[0])
|
||||
if err != nil {
|
||||
// If it's a token, we need to extract the ID.
|
||||
// If it's a token, we need to extract the ID
|
||||
maybeID := strings.Split(inv.Args[0], "-")[0]
|
||||
token, err = client.APIKeyByID(inv.Context(), codersdk.Me, maybeID)
|
||||
if err != nil {
|
||||
@@ -357,29 +346,17 @@ func (r *RootCmd) removeToken() *serpent.Command {
|
||||
}
|
||||
}
|
||||
|
||||
if deleteToken {
|
||||
err = client.DeleteAPIKey(inv.Context(), codersdk.Me, token.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("delete api key: %w", err)
|
||||
}
|
||||
cliui.Infof(inv.Stdout, "Token has been deleted.")
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.ExpireAPIKey(inv.Context(), codersdk.Me, token.ID)
|
||||
err = client.DeleteAPIKey(inv.Context(), codersdk.Me, token.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("expire api key: %w", err)
|
||||
return xerrors.Errorf("delete api key: %w", err)
|
||||
}
|
||||
cliui.Infof(inv.Stdout, "Token has been expired.")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = serpent.OptionSet{
|
||||
{
|
||||
Flag: "delete",
|
||||
Description: "Permanently delete the token instead of expiring it. This removes the audit trail.",
|
||||
Value: serpent.BoolOf(&deleteToken),
|
||||
cliui.Infof(
|
||||
inv.Stdout,
|
||||
"Token has been deleted.",
|
||||
)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
+17
-153
@@ -6,16 +6,12 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
@@ -26,7 +22,7 @@ func TestTokens(t *testing.T) {
|
||||
adminUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
secondUserClient, secondUser := coderdtest.CreateAnotherUser(t, client, adminUser.OrganizationID)
|
||||
thirdUserClient, thirdUser := coderdtest.CreateAnotherUser(t, client, adminUser.OrganizationID)
|
||||
_, thirdUser := coderdtest.CreateAnotherUser(t, client, adminUser.OrganizationID)
|
||||
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancelFunc()
|
||||
@@ -159,7 +155,7 @@ func TestTokens(t *testing.T) {
|
||||
require.Len(t, scopedToken.AllowList, 1)
|
||||
require.Equal(t, allowSpec, scopedToken.AllowList[0].String())
|
||||
|
||||
// Delete by name (default behavior is now expire)
|
||||
// Delete by name
|
||||
inv, root = clitest.New(t, "tokens", "rm", "token-one")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
@@ -168,42 +164,21 @@ func TestTokens(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
res = buf.String()
|
||||
require.NotEmpty(t, res)
|
||||
require.Contains(t, res, "expired")
|
||||
|
||||
// Regular users cannot expire other users' tokens (expire is default now).
|
||||
inv, root = clitest.New(t, "tokens", "rm", secondTokenID)
|
||||
clitest.SetupConfig(t, thirdUserClient, root)
|
||||
buf = new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not found")
|
||||
|
||||
// Only admin users can expire other users' tokens (expire is default now).
|
||||
inv, root = clitest.New(t, "tokens", "rm", secondTokenID)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
// Validate that token was expired
|
||||
if token, err := client.APIKeyByName(ctx, secondUser.ID.String(), "token-two"); assert.NoError(t, err) {
|
||||
require.True(t, token.ExpiresAt.Before(time.Now()))
|
||||
}
|
||||
|
||||
// Delete by ID (explicit delete flag)
|
||||
inv, root = clitest.New(t, "tokens", "rm", "--delete", secondTokenID)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
res = buf.String()
|
||||
require.NotEmpty(t, res)
|
||||
require.Contains(t, res, "deleted")
|
||||
|
||||
// Delete scoped token by ID (explicit delete flag)
|
||||
inv, root = clitest.New(t, "tokens", "rm", "--delete", scopedTokenID)
|
||||
// Delete by ID
|
||||
inv, root = clitest.New(t, "tokens", "rm", secondTokenID)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
res = buf.String()
|
||||
require.NotEmpty(t, res)
|
||||
require.Contains(t, res, "deleted")
|
||||
|
||||
// Delete scoped token by ID
|
||||
inv, root = clitest.New(t, "tokens", "rm", scopedTokenID)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
@@ -224,8 +199,8 @@ func TestTokens(t *testing.T) {
|
||||
require.NotEmpty(t, res)
|
||||
fourthToken := res
|
||||
|
||||
// Delete by token (explicit delete flag)
|
||||
inv, root = clitest.New(t, "tokens", "rm", "--delete", fourthToken)
|
||||
// Delete by token
|
||||
inv, root = clitest.New(t, "tokens", "rm", fourthToken)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
@@ -235,114 +210,3 @@ func TestTokens(t *testing.T) {
|
||||
require.NotEmpty(t, res)
|
||||
require.Contains(t, res, "deleted")
|
||||
}
|
||||
|
||||
func TestTokensListExpiredFiltering(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, _, api := coderdtest.NewWithAPI(t, nil)
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create a valid (non-expired) token
|
||||
validToken, _ := dbgen.APIKey(t, api.Database, database.APIKey{
|
||||
UserID: owner.UserID,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
LoginType: database.LoginTypeToken,
|
||||
TokenName: "valid-token",
|
||||
})
|
||||
|
||||
// Create an expired token
|
||||
expiredToken, _ := dbgen.APIKey(t, api.Database, database.APIKey{
|
||||
UserID: owner.UserID,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
LoginType: database.LoginTypeToken,
|
||||
TokenName: "expired-token",
|
||||
})
|
||||
|
||||
t.Run("HidesExpiredByDefault", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
inv, root := clitest.New(t, "tokens", "ls")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf := new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
res := buf.String()
|
||||
require.Contains(t, res, validToken.ID)
|
||||
require.Contains(t, res, "valid-token")
|
||||
require.NotContains(t, res, expiredToken.ID)
|
||||
require.NotContains(t, res, "expired-token")
|
||||
})
|
||||
|
||||
t.Run("ShowsExpiredWithFlag", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
inv, root := clitest.New(t, "tokens", "ls", "--include-expired")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf := new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
res := buf.String()
|
||||
require.Contains(t, res, validToken.ID)
|
||||
require.Contains(t, res, "valid-token")
|
||||
require.Contains(t, res, expiredToken.ID)
|
||||
require.Contains(t, res, "expired-token")
|
||||
})
|
||||
|
||||
t.Run("JSONOutputRespectsFilter", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
// Default (no expired)
|
||||
inv, root := clitest.New(t, "tokens", "ls", "--output=json")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf := new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
res := buf.String()
|
||||
require.Contains(t, res, "valid-token")
|
||||
require.NotContains(t, res, "expired-token")
|
||||
|
||||
// With --include-expired
|
||||
inv, root = clitest.New(t, "tokens", "ls", "--output=json", "--include-expired")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
res = buf.String()
|
||||
require.Contains(t, res, "valid-token")
|
||||
require.Contains(t, res, "expired-token")
|
||||
})
|
||||
|
||||
t.Run("AllUsersWithIncludeExpired", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
inv, root := clitest.New(t, "tokens", "ls", "--all", "--include-expired")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf := new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
res := buf.String()
|
||||
// Should show both valid and expired tokens
|
||||
require.Contains(t, res, validToken.ID)
|
||||
require.Contains(t, res, "valid-token")
|
||||
require.Contains(t, res, expiredToken.ID)
|
||||
require.Contains(t, res, "expired-token")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -990,74 +990,4 @@ func TestUpdateValidateRichParameters(t *testing.T) {
|
||||
|
||||
_ = testutil.TryReceive(ctx, t, doneChan)
|
||||
})
|
||||
|
||||
t.Run("NewImmutableParameterViaFlag", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create template and workspace with only a mutable parameter.
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
templateParameters := []*proto.RichParameter{
|
||||
{Name: stringParameterName, Type: "string", Mutable: true, Required: true, Options: []*proto.RichParameterOption{
|
||||
{Name: "First option", Description: "This is first option", Value: "1st"},
|
||||
{Name: "Second option", Description: "This is second option", Value: "2nd"},
|
||||
}},
|
||||
}
|
||||
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses(templateParameters))
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
|
||||
|
||||
inv, root := clitest.New(t, "create", "my-workspace", "--yes", "--template", template.Name, "--parameter", fmt.Sprintf("%s=%s", stringParameterName, "1st"))
|
||||
clitest.SetupConfig(t, member, root)
|
||||
err := inv.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update template: add a new immutable parameter.
|
||||
updatedTemplateParameters := []*proto.RichParameter{
|
||||
templateParameters[0],
|
||||
{Name: immutableParameterName, Type: "string", Mutable: false, Required: true, Options: []*proto.RichParameterOption{
|
||||
{Name: "fir", Description: "First option for immutable parameter", Value: "I"},
|
||||
{Name: "sec", Description: "Second option for immutable parameter", Value: "II"},
|
||||
}},
|
||||
}
|
||||
|
||||
updatedVersion := coderdtest.UpdateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses(updatedTemplateParameters), template.ID)
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, updatedVersion.ID)
|
||||
err = client.UpdateActiveTemplateVersion(context.Background(), template.ID, codersdk.UpdateActiveTemplateVersion{
|
||||
ID: updatedVersion.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update workspace, supplying the new immutable parameter via
|
||||
// the --parameter flag. This should succeed because it's the
|
||||
// first time this parameter is being set.
|
||||
inv, root = clitest.New(t, "update", "my-workspace",
|
||||
"--parameter", fmt.Sprintf("%s=%s", immutableParameterName, "II"))
|
||||
clitest.SetupConfig(t, member, root)
|
||||
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
doneChan := make(chan struct{})
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
err := inv.Run()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
pty.ExpectMatch("Planning workspace")
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
_ = testutil.TryReceive(ctx, t, doneChan)
|
||||
|
||||
// Verify the immutable parameter was set correctly.
|
||||
workspace, err := client.WorkspaceByOwnerAndName(ctx, memberUser.ID.String(), "my-workspace", codersdk.WorkspaceOptions{})
|
||||
require.NoError(t, err)
|
||||
actualParameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, actualParameters, codersdk.WorkspaceBuildParameter{
|
||||
Name: immutableParameterName,
|
||||
Value: "II",
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -179,8 +179,6 @@ func New(opts Options, workspace database.Workspace) *API {
|
||||
Database: opts.Database,
|
||||
Log: opts.Log,
|
||||
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
|
||||
Clock: opts.Clock,
|
||||
NotificationsEnqueuer: opts.NotificationsEnqueuer,
|
||||
}
|
||||
|
||||
api.MetadataAPI = &MetadataAPI{
|
||||
|
||||
@@ -2,10 +2,6 @@ package agentapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
@@ -13,14 +9,7 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
strutil "github.com/coder/coder/v2/coderd/util/strings"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
type AppsAPI struct {
|
||||
@@ -28,8 +17,6 @@ type AppsAPI struct {
|
||||
Database database.Store
|
||||
Log slog.Logger
|
||||
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
|
||||
NotificationsEnqueuer notifications.Enqueuer
|
||||
Clock quartz.Clock
|
||||
}
|
||||
|
||||
func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) {
|
||||
@@ -117,230 +104,3 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
|
||||
}
|
||||
return &agentproto.BatchUpdateAppHealthResponse{}, nil
|
||||
}
|
||||
|
||||
func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
|
||||
if len(req.Message) > 160 {
|
||||
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Message is too long.",
|
||||
Detail: "Message must be less than 160 characters.",
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "message", Detail: "Message must be less than 160 characters."},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
var dbState database.WorkspaceAppStatusState
|
||||
switch req.State {
|
||||
case agentproto.UpdateAppStatusRequest_COMPLETE:
|
||||
dbState = database.WorkspaceAppStatusStateComplete
|
||||
case agentproto.UpdateAppStatusRequest_FAILURE:
|
||||
dbState = database.WorkspaceAppStatusStateFailure
|
||||
case agentproto.UpdateAppStatusRequest_WORKING:
|
||||
dbState = database.WorkspaceAppStatusStateWorking
|
||||
case agentproto.UpdateAppStatusRequest_IDLE:
|
||||
dbState = database.WorkspaceAppStatusStateIdle
|
||||
default:
|
||||
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid state provided.",
|
||||
Detail: fmt.Sprintf("invalid state: %q", req.State),
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "state", Detail: "State must be one of: complete, failure, working, idle."},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
workspaceAgent, err := a.AgentFn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
app, err := a.Database.GetWorkspaceAppByAgentIDAndSlug(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{
|
||||
AgentID: workspaceAgent.ID,
|
||||
Slug: req.Slug,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to get workspace app.",
|
||||
Detail: fmt.Sprintf("No app found with slug %q", req.Slug),
|
||||
})
|
||||
}
|
||||
|
||||
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to get workspace.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Treat the message as untrusted input.
|
||||
cleaned := strutil.UISanitize(req.Message)
|
||||
|
||||
// Get the latest status for the workspace app to detect no-op updates
|
||||
// nolint:gocritic // This is a system restricted operation.
|
||||
latestAppStatus, err := a.Database.GetLatestWorkspaceAppStatusByAppID(dbauthz.AsSystemRestricted(ctx), app.ID)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get latest workspace app status.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
// If no rows found, latestAppStatus will be a zero-value struct (ID == uuid.Nil)
|
||||
|
||||
// nolint:gocritic // This is a system restricted operation.
|
||||
_, err = a.Database.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
|
||||
ID: uuid.New(),
|
||||
CreatedAt: dbtime.Now(),
|
||||
WorkspaceID: workspace.ID,
|
||||
AgentID: workspaceAgent.ID,
|
||||
AppID: app.ID,
|
||||
State: dbState,
|
||||
Message: cleaned,
|
||||
Uri: sql.NullString{
|
||||
String: req.Uri,
|
||||
Valid: req.Uri != "",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to insert workspace app status.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
if a.PublishWorkspaceUpdateFn != nil {
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentAppStatusUpdate)
|
||||
if err != nil {
|
||||
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to publish workspace update.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Notify on state change to Working/Idle for AI tasks
|
||||
a.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, dbState, workspace, workspaceAgent)
|
||||
|
||||
if shouldBump(dbState, latestAppStatus) {
|
||||
// We pass time.Time{} for nextAutostart since we don't have access to
|
||||
// TemplateScheduleStore here. The activity bump logic handles this by
|
||||
// defaulting to the template's activity_bump duration (typically 1 hour).
|
||||
workspacestats.ActivityBumpWorkspace(ctx, a.Log, a.Database, workspace.ID, time.Time{})
|
||||
}
|
||||
// just return a blank response because it doesn't contain any settable fields at present.
|
||||
return new(agentproto.UpdateAppStatusResponse), nil
|
||||
}
|
||||
|
||||
func shouldBump(dbState database.WorkspaceAppStatusState, latestAppStatus database.WorkspaceAppStatus) bool {
|
||||
// Bump deadline when agent reports working or transitions away from working.
|
||||
// This prevents auto-pause during active work and gives users time to interact
|
||||
// after work completes.
|
||||
|
||||
// Bump if reporting working state.
|
||||
if dbState == database.WorkspaceAppStatusStateWorking {
|
||||
return true
|
||||
}
|
||||
|
||||
// Bump if transitioning away from working state.
|
||||
if latestAppStatus.ID != uuid.Nil {
|
||||
prevState := latestAppStatus.State
|
||||
if prevState == database.WorkspaceAppStatusStateWorking {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// enqueueAITaskStateNotification enqueues a notification when an AI task's app
|
||||
// transitions to Working or Idle.
|
||||
// No-op if:
|
||||
// - the workspace agent app isn't configured as an AI task,
|
||||
// - the new state equals the latest persisted state,
|
||||
// - the workspace agent is not ready (still starting up).
|
||||
func (a *AppsAPI) enqueueAITaskStateNotification(
|
||||
ctx context.Context,
|
||||
appID uuid.UUID,
|
||||
latestAppStatus database.WorkspaceAppStatus,
|
||||
newAppStatus database.WorkspaceAppStatusState,
|
||||
workspace database.Workspace,
|
||||
agent database.WorkspaceAgent,
|
||||
) {
|
||||
var notificationTemplate uuid.UUID
|
||||
switch newAppStatus {
|
||||
case database.WorkspaceAppStatusStateWorking:
|
||||
notificationTemplate = notifications.TemplateTaskWorking
|
||||
case database.WorkspaceAppStatusStateIdle:
|
||||
notificationTemplate = notifications.TemplateTaskIdle
|
||||
case database.WorkspaceAppStatusStateComplete:
|
||||
notificationTemplate = notifications.TemplateTaskCompleted
|
||||
case database.WorkspaceAppStatusStateFailure:
|
||||
notificationTemplate = notifications.TemplateTaskFailed
|
||||
default:
|
||||
// Not a notifiable state, do nothing
|
||||
return
|
||||
}
|
||||
|
||||
if !workspace.TaskID.Valid {
|
||||
// Workspace has no task ID, do nothing.
|
||||
return
|
||||
}
|
||||
|
||||
// Only send notifications when the agent is ready. We want to skip
|
||||
// any state transitions that occur whilst the workspace is starting
|
||||
// up as it doesn't make sense to receive them.
|
||||
if agent.LifecycleState != database.WorkspaceAgentLifecycleStateReady {
|
||||
a.Log.Debug(ctx, "skipping AI task notification because agent is not ready",
|
||||
slog.F("agent_id", agent.ID),
|
||||
slog.F("lifecycle_state", agent.LifecycleState),
|
||||
slog.F("new_app_status", newAppStatus),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
task, err := a.Database.GetTaskByID(ctx, workspace.TaskID.UUID)
|
||||
if err != nil {
|
||||
a.Log.Warn(ctx, "failed to get task", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
if !task.WorkspaceAppID.Valid || task.WorkspaceAppID.UUID != appID {
|
||||
// Non-task app, do nothing.
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if the latest persisted state equals the new state (no new transition)
|
||||
// Note: uuid.Nil check is valid here. If no previous status exists,
|
||||
// GetLatestWorkspaceAppStatusByAppID returns sql.ErrNoRows and we get a zero-value struct.
|
||||
if latestAppStatus.ID != uuid.Nil && latestAppStatus.State == newAppStatus {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip the initial "Working" notification when the task first starts.
|
||||
// This is obvious to the user since they just created the task.
|
||||
// We still notify on the first "Idle" status and all subsequent transitions.
|
||||
if latestAppStatus.ID == uuid.Nil && newAppStatus == database.WorkspaceAppStatusStateWorking {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := a.NotificationsEnqueuer.EnqueueWithData(
|
||||
// nolint:gocritic // Need notifier actor to enqueue notifications
|
||||
dbauthz.AsNotifier(ctx),
|
||||
workspace.OwnerID,
|
||||
notificationTemplate,
|
||||
map[string]string{
|
||||
"task": task.Name,
|
||||
"workspace": workspace.Name,
|
||||
},
|
||||
map[string]any{
|
||||
// Use a 1-minute bucketed timestamp to bypass per-day dedupe,
|
||||
// allowing identical content to resend within the same day
|
||||
// (but not more than once every 10s).
|
||||
"dedupe_bypass_ts": a.Clock.Now().UTC().Truncate(time.Minute),
|
||||
},
|
||||
"api-workspace-agent-app-status",
|
||||
// Associate this notification with related entities
|
||||
workspace.ID, workspace.OwnerID, workspace.OrganizationID, appID,
|
||||
); err != nil {
|
||||
a.Log.Warn(ctx, "failed to notify of task state", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
package agentapi
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
)
|
||||
|
||||
func TestShouldBump(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prevState *database.WorkspaceAppStatusState // nil means no previous state
|
||||
newState database.WorkspaceAppStatusState
|
||||
shouldBump bool
|
||||
}{
|
||||
{
|
||||
name: "FirstStatusBumps",
|
||||
prevState: nil,
|
||||
newState: database.WorkspaceAppStatusStateWorking,
|
||||
shouldBump: true,
|
||||
},
|
||||
{
|
||||
name: "WorkingToIdleBumps",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
|
||||
newState: database.WorkspaceAppStatusStateIdle,
|
||||
shouldBump: true,
|
||||
},
|
||||
{
|
||||
name: "WorkingToCompleteBumps",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
|
||||
newState: database.WorkspaceAppStatusStateComplete,
|
||||
shouldBump: true,
|
||||
},
|
||||
{
|
||||
name: "CompleteToIdleNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
|
||||
newState: database.WorkspaceAppStatusStateIdle,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "CompleteToCompleteNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
|
||||
newState: database.WorkspaceAppStatusStateComplete,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "FailureToIdleNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
|
||||
newState: database.WorkspaceAppStatusStateIdle,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "FailureToFailureNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
|
||||
newState: database.WorkspaceAppStatusStateFailure,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "CompleteToWorkingBumps",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
|
||||
newState: database.WorkspaceAppStatusStateWorking,
|
||||
shouldBump: true,
|
||||
},
|
||||
{
|
||||
name: "FailureToCompleteNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
|
||||
newState: database.WorkspaceAppStatusStateComplete,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "WorkingToFailureBumps",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
|
||||
newState: database.WorkspaceAppStatusStateFailure,
|
||||
shouldBump: true,
|
||||
},
|
||||
{
|
||||
name: "IdleToIdleNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateIdle),
|
||||
newState: database.WorkspaceAppStatusStateIdle,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "IdleToWorkingBumps",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateIdle),
|
||||
newState: database.WorkspaceAppStatusStateWorking,
|
||||
shouldBump: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var prevAppStatus database.WorkspaceAppStatus
|
||||
// If there's a previous state, report it first.
|
||||
if tt.prevState != nil {
|
||||
prevAppStatus.ID = uuid.UUID{1}
|
||||
prevAppStatus.State = *tt.prevState
|
||||
}
|
||||
|
||||
didBump := shouldBump(tt.newState, prevAppStatus)
|
||||
if tt.shouldBump {
|
||||
require.True(t, didBump, "wanted deadline to bump but it didn't")
|
||||
} else {
|
||||
require.False(t, didBump, "wanted deadline not to bump but it did")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,13 +2,9 @@ package agentapi_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
@@ -16,12 +12,8 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestBatchUpdateAppHealths(t *testing.T) {
|
||||
@@ -261,183 +253,3 @@ func TestBatchUpdateAppHealths(t *testing.T) {
|
||||
require.Nil(t, resp)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkspaceAgentAppStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
fEnq := ¬ificationstest.FakeEnqueuer{}
|
||||
mClock := quartz.NewMock(t)
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: uuid.UUID{2},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
}
|
||||
workspaceUpdates := make(chan wspubsub.WorkspaceEventKind, 100)
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: mDB,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: func(_ context.Context, agnt *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
assert.Equal(t, *agnt, agent)
|
||||
testutil.AssertSend(ctx, t, workspaceUpdates, kind)
|
||||
return nil
|
||||
},
|
||||
NotificationsEnqueuer: fEnq,
|
||||
Clock: mClock,
|
||||
}
|
||||
|
||||
app := database.WorkspaceApp{
|
||||
ID: uuid.UUID{8},
|
||||
}
|
||||
mDB.EXPECT().GetWorkspaceAppByAgentIDAndSlug(gomock.Any(), database.GetWorkspaceAppByAgentIDAndSlugParams{
|
||||
AgentID: agent.ID,
|
||||
Slug: "vscode",
|
||||
}).Times(1).Return(app, nil)
|
||||
task := database.Task{
|
||||
ID: uuid.UUID{7},
|
||||
WorkspaceAppID: uuid.NullUUID{
|
||||
Valid: true,
|
||||
UUID: app.ID,
|
||||
},
|
||||
}
|
||||
mDB.EXPECT().GetTaskByID(gomock.Any(), task.ID).Times(1).Return(task, nil)
|
||||
workspace := database.Workspace{
|
||||
ID: uuid.UUID{9},
|
||||
TaskID: uuid.NullUUID{
|
||||
Valid: true,
|
||||
UUID: task.ID,
|
||||
},
|
||||
}
|
||||
mDB.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Times(1).Return(workspace, nil)
|
||||
appStatus := database.WorkspaceAppStatus{
|
||||
ID: uuid.UUID{6},
|
||||
}
|
||||
mDB.EXPECT().GetLatestWorkspaceAppStatusByAppID(gomock.Any(), app.ID).Times(1).Return(appStatus, nil)
|
||||
mDB.EXPECT().InsertWorkspaceAppStatus(
|
||||
gomock.Any(),
|
||||
gomock.Cond(func(params database.InsertWorkspaceAppStatusParams) bool {
|
||||
if params.AgentID == agent.ID && params.AppID == app.ID {
|
||||
assert.Equal(t, "testing", params.Message)
|
||||
assert.Equal(t, database.WorkspaceAppStatusStateComplete, params.State)
|
||||
assert.True(t, params.Uri.Valid)
|
||||
assert.Equal(t, "https://example.com", params.Uri.String)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})).Times(1).Return(database.WorkspaceAppStatus{}, nil)
|
||||
|
||||
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
|
||||
Slug: "vscode",
|
||||
Message: "testing",
|
||||
Uri: "https://example.com",
|
||||
State: agentproto.UpdateAppStatusRequest_COMPLETE,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
kind := testutil.RequireReceive(ctx, t, workspaceUpdates)
|
||||
require.Equal(t, wspubsub.WorkspaceEventKindAgentAppStatusUpdate, kind)
|
||||
sent := fEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTaskCompleted))
|
||||
require.Len(t, sent, 1)
|
||||
})
|
||||
|
||||
t.Run("FailUnknownApp", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: uuid.UUID{2},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
}
|
||||
|
||||
mDB.EXPECT().GetWorkspaceAppByAgentIDAndSlug(gomock.Any(), gomock.Any()).
|
||||
Times(1).
|
||||
Return(database.WorkspaceApp{}, sql.ErrNoRows)
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: mDB,
|
||||
Log: testutil.Logger(t),
|
||||
}
|
||||
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
|
||||
Slug: "unknown",
|
||||
Message: "testing",
|
||||
Uri: "https://example.com",
|
||||
State: agentproto.UpdateAppStatusRequest_COMPLETE,
|
||||
})
|
||||
require.ErrorContains(t, err, "No app found with slug")
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("FailUnknownState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: uuid.UUID{2},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
}
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: mDB,
|
||||
Log: testutil.Logger(t),
|
||||
}
|
||||
|
||||
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
|
||||
Slug: "vscode",
|
||||
Message: "testing",
|
||||
Uri: "https://example.com",
|
||||
State: 77,
|
||||
})
|
||||
require.ErrorContains(t, err, "Invalid state")
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("FailTooLong", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: uuid.UUID{2},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
}
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: mDB,
|
||||
Log: testutil.Logger(t),
|
||||
}
|
||||
|
||||
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
|
||||
Slug: "vscode",
|
||||
Message: strings.Repeat("a", 161),
|
||||
Uri: "https://example.com",
|
||||
State: agentproto.UpdateAppStatusRequest_COMPLETE,
|
||||
})
|
||||
require.ErrorContains(t, err, "Message is too long")
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -128,7 +128,7 @@ func (a *SubAgentAPI) CreateSubAgent(ctx context.Context, req *agentproto.Create
|
||||
Name: agentName,
|
||||
ResourceID: parentAgent.ResourceID,
|
||||
AuthToken: uuid.New(),
|
||||
AuthInstanceID: sql.NullString{},
|
||||
AuthInstanceID: parentAgent.AuthInstanceID,
|
||||
Architecture: req.Architecture,
|
||||
EnvironmentVariables: pqtype.NullRawMessage{},
|
||||
OperatingSystem: req.OperatingSystem,
|
||||
|
||||
@@ -175,52 +175,6 @@ func TestSubAgentAPI(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Context: https://github.com/coder/coder/pull/22196
|
||||
t.Run("CreateSubAgentDoesNotInheritAuthInstanceID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
log = testutil.Logger(t)
|
||||
clock = quartz.NewMock(t)
|
||||
|
||||
db, org = newDatabaseWithOrg(t)
|
||||
user, agent = newUserWithWorkspaceAgent(t, db, org)
|
||||
)
|
||||
|
||||
// Given: The parent agent has an AuthInstanceID set
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
parentAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agent.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, parentAgent.AuthInstanceID.Valid, "parent agent should have an AuthInstanceID")
|
||||
require.NotEmpty(t, parentAgent.AuthInstanceID.String)
|
||||
|
||||
api := newAgentAPI(t, log, db, clock, user, org, agent)
|
||||
|
||||
// When: We create a sub agent
|
||||
createResp, err := api.CreateSubAgent(ctx, &proto.CreateSubAgentRequest{
|
||||
Name: "sub-agent",
|
||||
Directory: "/workspaces/test",
|
||||
Architecture: "amd64",
|
||||
OperatingSystem: "linux",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
subAgentID, err := uuid.FromBytes(createResp.Agent.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then: The sub-agent must NOT re-use the parent's AuthInstanceID.
|
||||
subAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), subAgentID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, subAgent.AuthInstanceID.Valid, "sub-agent should not have an AuthInstanceID")
|
||||
assert.Empty(t, subAgent.AuthInstanceID.String, "sub-agent AuthInstanceID string should be empty")
|
||||
|
||||
// Double-check: looking up by the parent's instance ID must
|
||||
// still return the parent, not the sub-agent.
|
||||
lookedUp, err := db.GetWorkspaceAgentByInstanceID(dbauthz.AsSystemRestricted(ctx), parentAgent.AuthInstanceID.String)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, parentAgent.ID, lookedUp.ID, "instance ID lookup should still return the parent agent")
|
||||
})
|
||||
|
||||
type expectedAppError struct {
|
||||
index int32
|
||||
field string
|
||||
@@ -1366,6 +1320,7 @@ func TestSubAgentAPI(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
+3
-37
@@ -21,12 +21,10 @@ import (
|
||||
agentapisdk "github.com/coder/agentapi-sdk-go"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpapi/httperror"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/searchquery"
|
||||
@@ -192,8 +190,7 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
defer commitAuditWS()
|
||||
|
||||
workspace, err := createWorkspace(ctx, aReqWS, apiKey.UserID, api, owner, createReq, &createWorkspaceOptions{
|
||||
remoteAddr: r.RemoteAddr,
|
||||
workspace, err := createWorkspace(ctx, aReqWS, apiKey.UserID, api, owner, createReq, r, &createWorkspaceOptions{
|
||||
// Before creating the workspace, ensure that this task can be created.
|
||||
preCreateInTX: func(ctx context.Context, tx database.Store) error {
|
||||
// Create task record in the database before creating the workspace so that
|
||||
@@ -467,6 +464,7 @@ func (api *API) convertTasks(ctx context.Context, requesterID uuid.UUID, dbTasks
|
||||
|
||||
apiWorkspaces, err := convertWorkspaces(
|
||||
ctx,
|
||||
api.Experiments,
|
||||
api.Logger,
|
||||
requesterID,
|
||||
workspaces,
|
||||
@@ -546,6 +544,7 @@ func (api *API) taskGet(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ws, err := convertWorkspace(
|
||||
ctx,
|
||||
api.Experiments,
|
||||
api.Logger,
|
||||
apiKey.UserID,
|
||||
workspace,
|
||||
@@ -1301,23 +1300,6 @@ func (api *API) pauseTask(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := api.NotificationsEnqueuer.Enqueue(
|
||||
// nolint:gocritic // Need notifier actor to enqueue notifications.
|
||||
dbauthz.AsNotifier(ctx),
|
||||
workspace.OwnerID,
|
||||
notifications.TemplateTaskPaused,
|
||||
map[string]string{
|
||||
"task": task.Name,
|
||||
"task_id": task.ID.String(),
|
||||
"workspace": workspace.Name,
|
||||
"pause_reason": "manual",
|
||||
},
|
||||
"api-task-pause",
|
||||
workspace.ID, workspace.OwnerID, workspace.OrganizationID,
|
||||
); err != nil {
|
||||
api.Logger.Warn(ctx, "failed to notify of task paused", slog.Error(err), slog.F("task_id", task.ID), slog.F("workspace_id", workspace.ID))
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusAccepted, codersdk.PauseTaskResponse{
|
||||
WorkspaceBuild: &build,
|
||||
})
|
||||
@@ -1405,22 +1387,6 @@ func (api *API) resumeTask(rw http.ResponseWriter, r *http.Request) {
|
||||
httperror.WriteWorkspaceBuildError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
if _, err := api.NotificationsEnqueuer.Enqueue(
|
||||
// nolint:gocritic // Need notifier actor to enqueue notifications.
|
||||
dbauthz.AsNotifier(ctx),
|
||||
workspace.OwnerID,
|
||||
notifications.TemplateTaskResumed,
|
||||
map[string]string{
|
||||
"task": task.Name,
|
||||
"task_id": task.ID.String(),
|
||||
"workspace": workspace.Name,
|
||||
},
|
||||
"api-task-resume",
|
||||
workspace.ID, workspace.OwnerID, workspace.OrganizationID,
|
||||
); err != nil {
|
||||
api.Logger.Warn(ctx, "failed to notify of task resumed", slog.Error(err), slog.F("task_id", task.ID), slog.F("workspace_id", workspace.ID))
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusAccepted, codersdk.ResumeTaskResponse{
|
||||
WorkspaceBuild: &build,
|
||||
})
|
||||
|
||||
+39
-111
@@ -45,10 +45,10 @@ import (
|
||||
)
|
||||
|
||||
// createTaskInState is a helper to create a task in the desired state.
|
||||
// It returns a function that takes context, test, and status, and returns the task.
|
||||
// It returns a function that takes context, test, and status, and returns the task ID.
|
||||
// The caller is responsible for setting up the database, owner, and user.
|
||||
func createTaskInState(db database.Store, ownerSubject rbac.Subject, ownerOrgID, userID uuid.UUID) func(context.Context, *testing.T, database.TaskStatus) database.Task {
|
||||
return func(ctx context.Context, t *testing.T, status database.TaskStatus) database.Task {
|
||||
func createTaskInState(db database.Store, ownerSubject rbac.Subject, ownerOrgID, userID uuid.UUID) func(context.Context, *testing.T, database.TaskStatus) uuid.UUID {
|
||||
return func(ctx context.Context, t *testing.T, status database.TaskStatus) uuid.UUID {
|
||||
ctx = dbauthz.As(ctx, ownerSubject)
|
||||
|
||||
builder := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -65,9 +65,6 @@ func createTaskInState(db database.Store, ownerSubject rbac.Subject, ownerOrgID,
|
||||
builder = builder.Pending()
|
||||
case database.TaskStatusInitializing:
|
||||
builder = builder.Starting()
|
||||
case database.TaskStatusActive:
|
||||
// Default builder produces a succeeded start build.
|
||||
// Post-processing below sets agent and app to active.
|
||||
case database.TaskStatusPaused:
|
||||
builder = builder.Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStop,
|
||||
@@ -79,32 +76,31 @@ func createTaskInState(db database.Store, ownerSubject rbac.Subject, ownerOrgID,
|
||||
}
|
||||
|
||||
resp := builder.Do()
|
||||
taskID := resp.Task.ID
|
||||
|
||||
// Post-process by manipulating agent and app state.
|
||||
if status == database.TaskStatusActive || status == database.TaskStatusError {
|
||||
// Set agent to ready state so agent_status returns 'active'.
|
||||
if status == database.TaskStatusError {
|
||||
// First, set agent to ready state so agent_status returns 'active'.
|
||||
// This ensures the cascade reaches app_status.
|
||||
err := db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
|
||||
ID: resp.Agents[0].ID,
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then set workspace app health to unhealthy to trigger error state.
|
||||
apps, err := db.GetWorkspaceAppsByAgentID(ctx, resp.Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, apps, 1, "expected exactly one app for task")
|
||||
|
||||
appHealth := database.WorkspaceAppHealthHealthy
|
||||
if status == database.TaskStatusError {
|
||||
appHealth = database.WorkspaceAppHealthUnhealthy
|
||||
}
|
||||
err = db.UpdateWorkspaceAppHealthByID(ctx, database.UpdateWorkspaceAppHealthByIDParams{
|
||||
ID: apps[0].ID,
|
||||
Health: appHealth,
|
||||
Health: database.WorkspaceAppHealthUnhealthy,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
return resp.Task
|
||||
return taskID
|
||||
}
|
||||
}
|
||||
|
||||
@@ -832,7 +828,7 @@ func TestTasks(t *testing.T) {
|
||||
t.Run("SendToNonActiveStates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{})
|
||||
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
@@ -849,9 +845,9 @@ func TestTasks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
task := createTask(ctx, t, database.TaskStatusPaused)
|
||||
taskID := createTask(ctx, t, database.TaskStatusPaused)
|
||||
|
||||
err := client.TaskSend(ctx, "me", task.ID, codersdk.TaskSendRequest{
|
||||
err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
|
||||
Input: "Hello",
|
||||
})
|
||||
|
||||
@@ -867,9 +863,9 @@ func TestTasks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
task := createTask(ctx, t, database.TaskStatusInitializing)
|
||||
taskID := createTask(ctx, t, database.TaskStatusInitializing)
|
||||
|
||||
err := client.TaskSend(ctx, "me", task.ID, codersdk.TaskSendRequest{
|
||||
err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
|
||||
Input: "Hello",
|
||||
})
|
||||
|
||||
@@ -885,9 +881,9 @@ func TestTasks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
task := createTask(ctx, t, database.TaskStatusPending)
|
||||
taskID := createTask(ctx, t, database.TaskStatusPending)
|
||||
|
||||
err := client.TaskSend(ctx, "me", task.ID, codersdk.TaskSendRequest{
|
||||
err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
|
||||
Input: "Hello",
|
||||
})
|
||||
|
||||
@@ -903,9 +899,9 @@ func TestTasks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
task := createTask(ctx, t, database.TaskStatusError)
|
||||
taskID := createTask(ctx, t, database.TaskStatusError)
|
||||
|
||||
err := client.TaskSend(ctx, "me", task.ID, codersdk.TaskSendRequest{
|
||||
err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
|
||||
Input: "Hello",
|
||||
})
|
||||
|
||||
@@ -1124,16 +1120,16 @@ func TestTasks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
task := createTask(ctx, t, database.TaskStatusPending)
|
||||
taskID := createTask(ctx, t, database.TaskStatusPending)
|
||||
|
||||
err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
|
||||
TaskID: task.ID,
|
||||
TaskID: taskID,
|
||||
LogSnapshot: json.RawMessage(snapshotJSON),
|
||||
LogSnapshotCreatedAt: snapshotTime,
|
||||
})
|
||||
require.NoError(t, err, "upserting task snapshot")
|
||||
|
||||
logsResp, err := client.TaskLogs(ctx, "me", task.ID)
|
||||
logsResp, err := client.TaskLogs(ctx, "me", taskID)
|
||||
require.NoError(t, err, "fetching task logs")
|
||||
verifySnapshotLogs(t, logsResp)
|
||||
})
|
||||
@@ -1142,16 +1138,16 @@ func TestTasks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
task := createTask(ctx, t, database.TaskStatusInitializing)
|
||||
taskID := createTask(ctx, t, database.TaskStatusInitializing)
|
||||
|
||||
err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
|
||||
TaskID: task.ID,
|
||||
TaskID: taskID,
|
||||
LogSnapshot: json.RawMessage(snapshotJSON),
|
||||
LogSnapshotCreatedAt: snapshotTime,
|
||||
})
|
||||
require.NoError(t, err, "upserting task snapshot")
|
||||
|
||||
logsResp, err := client.TaskLogs(ctx, "me", task.ID)
|
||||
logsResp, err := client.TaskLogs(ctx, "me", taskID)
|
||||
require.NoError(t, err, "fetching task logs")
|
||||
verifySnapshotLogs(t, logsResp)
|
||||
})
|
||||
@@ -1160,16 +1156,16 @@ func TestTasks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
task := createTask(ctx, t, database.TaskStatusPaused)
|
||||
taskID := createTask(ctx, t, database.TaskStatusPaused)
|
||||
|
||||
err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
|
||||
TaskID: task.ID,
|
||||
TaskID: taskID,
|
||||
LogSnapshot: json.RawMessage(snapshotJSON),
|
||||
LogSnapshotCreatedAt: snapshotTime,
|
||||
})
|
||||
require.NoError(t, err, "upserting task snapshot")
|
||||
|
||||
logsResp, err := client.TaskLogs(ctx, "me", task.ID)
|
||||
logsResp, err := client.TaskLogs(ctx, "me", taskID)
|
||||
require.NoError(t, err, "fetching task logs")
|
||||
verifySnapshotLogs(t, logsResp)
|
||||
})
|
||||
@@ -1178,9 +1174,9 @@ func TestTasks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
task := createTask(ctx, t, database.TaskStatusPending)
|
||||
taskID := createTask(ctx, t, database.TaskStatusPending)
|
||||
|
||||
logsResp, err := client.TaskLogs(ctx, "me", task.ID)
|
||||
logsResp, err := client.TaskLogs(ctx, "me", taskID)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, logsResp.Snapshot)
|
||||
@@ -1192,7 +1188,7 @@ func TestTasks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
task := createTask(ctx, t, database.TaskStatusPending)
|
||||
taskID := createTask(ctx, t, database.TaskStatusPending)
|
||||
|
||||
invalidEnvelope := coderd.TaskLogSnapshotEnvelope{
|
||||
Format: "unknown-format",
|
||||
@@ -1202,13 +1198,13 @@ func TestTasks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
|
||||
TaskID: task.ID,
|
||||
TaskID: taskID,
|
||||
LogSnapshot: json.RawMessage(invalidJSON),
|
||||
LogSnapshotCreatedAt: snapshotTime,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.TaskLogs(ctx, "me", task.ID)
|
||||
_, err = client.TaskLogs(ctx, "me", taskID)
|
||||
require.Error(t, err)
|
||||
|
||||
var sdkErr *codersdk.Error
|
||||
@@ -1221,16 +1217,16 @@ func TestTasks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
task := createTask(ctx, t, database.TaskStatusPending)
|
||||
taskID := createTask(ctx, t, database.TaskStatusPending)
|
||||
|
||||
err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
|
||||
TaskID: task.ID,
|
||||
TaskID: taskID,
|
||||
LogSnapshot: json.RawMessage(`{"format":"agentapi","data":"not an object"}`),
|
||||
LogSnapshotCreatedAt: snapshotTime,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.TaskLogs(ctx, "me", task.ID)
|
||||
_, err = client.TaskLogs(ctx, "me", taskID)
|
||||
require.Error(t, err)
|
||||
|
||||
var sdkErr *codersdk.Error
|
||||
@@ -1242,9 +1238,9 @@ func TestTasks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
task := createTask(ctx, t, database.TaskStatusError)
|
||||
taskID := createTask(ctx, t, database.TaskStatusError)
|
||||
|
||||
_, err := client.TaskLogs(ctx, "me", task.ID)
|
||||
_, err := client.TaskLogs(ctx, "me", taskID)
|
||||
require.Error(t, err)
|
||||
|
||||
var sdkErr *codersdk.Error
|
||||
@@ -2567,6 +2563,7 @@ func TestPauseTask(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
task, _ := setupWorkspaceTask(t, db, owner)
|
||||
userClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID, tc.roles...)
|
||||
@@ -2790,41 +2787,6 @@ func TestPauseTask(t *testing.T) {
|
||||
require.ErrorAs(t, err, &apiErr)
|
||||
require.Equal(t, http.StatusInternalServerError, apiErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("Notification", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
notifyEnq = ¬ificationstest.FakeEnqueuer{}
|
||||
ownerClient, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{NotificationsEnqueuer: notifyEnq})
|
||||
owner = coderdtest.CreateFirstUser(t, ownerClient)
|
||||
)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
ownerUser, err := ownerClient.User(ctx, owner.UserID.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
createTask := createTaskInState(db, coderdtest.AuthzUserSubject(ownerUser), owner.OrganizationID, owner.UserID)
|
||||
|
||||
// Given: A task in an active state
|
||||
task := createTask(ctx, t, database.TaskStatusActive)
|
||||
|
||||
workspace, err := ownerClient.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// When: We pause the task
|
||||
_, err = ownerClient.PauseTask(ctx, codersdk.Me, task.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then: A notification should be sent
|
||||
sent := notifyEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTaskPaused))
|
||||
require.Len(t, sent, 1)
|
||||
require.Equal(t, owner.UserID, sent[0].UserID)
|
||||
require.Equal(t, task.Name, sent[0].Labels["task"])
|
||||
require.Equal(t, task.ID.String(), sent[0].Labels["task_id"])
|
||||
require.Equal(t, workspace.Name, sent[0].Labels["workspace"])
|
||||
require.Equal(t, "manual", sent[0].Labels["pause_reason"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestResumeTask(t *testing.T) {
|
||||
@@ -3154,38 +3116,4 @@ func TestResumeTask(t *testing.T) {
|
||||
require.ErrorAs(t, err, &apiErr)
|
||||
require.Equal(t, http.StatusInternalServerError, apiErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("Notification", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
notifyEnq = ¬ificationstest.FakeEnqueuer{}
|
||||
ownerClient, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{NotificationsEnqueuer: notifyEnq})
|
||||
owner = coderdtest.CreateFirstUser(t, ownerClient)
|
||||
)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
ownerUser, err := ownerClient.User(ctx, owner.UserID.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
createTask := createTaskInState(db, coderdtest.AuthzUserSubject(ownerUser), owner.OrganizationID, owner.UserID)
|
||||
|
||||
// Given: A task in a paused state
|
||||
task := createTask(ctx, t, database.TaskStatusPaused)
|
||||
|
||||
workspace, err := ownerClient.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// When: We resume the task
|
||||
_, err = ownerClient.ResumeTask(ctx, codersdk.Me, task.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then: A notification should be sent
|
||||
sent := notifyEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTaskResumed))
|
||||
require.Len(t, sent, 1)
|
||||
require.Equal(t, owner.UserID, sent[0].UserID)
|
||||
require.Equal(t, task.Name, sent[0].Labels["task"])
|
||||
require.Equal(t, task.ID.String(), sent[0].Labels["task_id"])
|
||||
require.Equal(t, workspace.Name, sent[0].Labels["workspace"])
|
||||
})
|
||||
}
|
||||
|
||||
Generated
+21
-1069
File diff suppressed because it is too large
Load Diff
Generated
+21
-993
File diff suppressed because it is too large
Load Diff
+8
-77
@@ -307,26 +307,20 @@ func (api *API) apiKeyByName(rw http.ResponseWriter, r *http.Request) {
|
||||
// @Tags Users
|
||||
// @Param user path string true "User ID, name, or me"
|
||||
// @Success 200 {array} codersdk.APIKey
|
||||
// @Param include_expired query bool false "Include expired tokens in the list"
|
||||
// @Router /users/{user}/keys/tokens [get]
|
||||
func (api *API) tokens(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
user = httpmw.UserParam(r)
|
||||
keys []database.APIKey
|
||||
err error
|
||||
queryStr = r.URL.Query().Get("include_all")
|
||||
includeAll, _ = strconv.ParseBool(queryStr)
|
||||
expiredStr = r.URL.Query().Get("include_expired")
|
||||
includeExpired, _ = strconv.ParseBool(expiredStr)
|
||||
ctx = r.Context()
|
||||
user = httpmw.UserParam(r)
|
||||
keys []database.APIKey
|
||||
err error
|
||||
queryStr = r.URL.Query().Get("include_all")
|
||||
includeAll, _ = strconv.ParseBool(queryStr)
|
||||
)
|
||||
|
||||
if includeAll {
|
||||
// get tokens for all users
|
||||
keys, err = api.Database.GetAPIKeysByLoginType(ctx, database.GetAPIKeysByLoginTypeParams{
|
||||
LoginType: database.LoginTypeToken,
|
||||
IncludeExpired: includeExpired,
|
||||
})
|
||||
keys, err = api.Database.GetAPIKeysByLoginType(ctx, database.LoginTypeToken)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching API keys.",
|
||||
@@ -336,7 +330,7 @@ func (api *API) tokens(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
} else {
|
||||
// get user's tokens only
|
||||
keys, err = api.Database.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: user.ID, IncludeExpired: includeExpired})
|
||||
keys, err = api.Database.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: user.ID})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching API keys.",
|
||||
@@ -427,69 +421,6 @@ func (api *API) deleteAPIKey(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// @Summary Expire API key
|
||||
// @ID expire-api-key
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Users
|
||||
// @Param user path string true "User ID, name, or me"
|
||||
// @Param keyid path string true "Key ID" format(string)
|
||||
// @Success 204
|
||||
// @Failure 404 {object} codersdk.Response
|
||||
// @Failure 500 {object} codersdk.Response
|
||||
// @Router /users/{user}/keys/{keyid}/expire [put]
|
||||
func (api *API) expireAPIKey(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
keyID = chi.URLParam(r, "keyid")
|
||||
auditor = api.Auditor.Load()
|
||||
aReq, commitAudit = audit.InitRequest[database.APIKey](rw, &audit.RequestParams{
|
||||
Audit: *auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
})
|
||||
)
|
||||
defer commitAudit()
|
||||
|
||||
if err := api.Database.InTx(func(db database.Store) error {
|
||||
key, err := db.GetAPIKeyByID(ctx, keyID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("fetch API key: %w", err)
|
||||
}
|
||||
if !key.ExpiresAt.After(api.Clock.Now()) {
|
||||
return nil // Already expired
|
||||
}
|
||||
aReq.Old = key
|
||||
if err := db.UpdateAPIKeyByID(ctx, database.UpdateAPIKeyByIDParams{
|
||||
ID: key.ID,
|
||||
LastUsed: key.LastUsed,
|
||||
ExpiresAt: dbtime.Now(),
|
||||
IPAddress: key.IPAddress,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("expire API key: %w", err)
|
||||
}
|
||||
// Fetch the updated key for audit log.
|
||||
newKey, err := db.GetAPIKeyByID(ctx, keyID)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to fetch updated API key for audit log", slog.Error(err))
|
||||
} else {
|
||||
aReq.New = newKey
|
||||
}
|
||||
return nil
|
||||
}, nil); httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
} else if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error expiring API key.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// @Summary Get token config
|
||||
// @ID get-token-config
|
||||
// @Security CoderSessionToken
|
||||
|
||||
+3
-196
@@ -69,44 +69,6 @@ func TestTokenCRUD(t *testing.T) {
|
||||
require.Equal(t, database.AuditActionDelete, auditor.AuditLogs()[numLogs-1].Action)
|
||||
}
|
||||
|
||||
func TestTokensFilterExpired(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
adminClient := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, adminClient)
|
||||
|
||||
// Create a token.
|
||||
res, err := adminClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
|
||||
Lifetime: time.Hour * 24 * 7,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
keyID := strings.Split(res.Key, "-")[0]
|
||||
|
||||
// List tokens without including expired - should see the token.
|
||||
keys, err := adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
|
||||
// Expire the token.
|
||||
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// List tokens without including expired - should NOT see expired token.
|
||||
keys, err = adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, keys)
|
||||
|
||||
// List tokens WITH including expired - should see expired token.
|
||||
keys, err = adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{
|
||||
IncludeExpired: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, keyID, keys[0].ID)
|
||||
}
|
||||
|
||||
func TestTokenScoped(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -477,7 +439,7 @@ func TestAPIKey_PrebuildsNotAllowed(t *testing.T) {
|
||||
DeploymentValues: dc,
|
||||
})
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Given: an existing api token for the prebuilds user
|
||||
_, prebuildsToken := dbgen.APIKey(t, db, database.APIKey{
|
||||
@@ -486,167 +448,12 @@ func TestAPIKey_PrebuildsNotAllowed(t *testing.T) {
|
||||
client.SetSessionToken(prebuildsToken)
|
||||
|
||||
// When: the prebuilds user tries to create an API key
|
||||
_, err := client.CreateAPIKey(setupCtx, database.PrebuildsSystemUserID.String())
|
||||
_, err := client.CreateAPIKey(ctx, database.PrebuildsSystemUserID.String())
|
||||
// Then: denied.
|
||||
require.ErrorContains(t, err, httpapi.ResourceForbiddenResponse.Message)
|
||||
|
||||
// When: the prebuilds user tries to create a token
|
||||
_, err = client.CreateToken(setupCtx, database.PrebuildsSystemUserID.String(), codersdk.CreateTokenRequest{})
|
||||
_, err = client.CreateToken(ctx, database.PrebuildsSystemUserID.String(), codersdk.CreateTokenRequest{})
|
||||
// Then: also denied.
|
||||
require.ErrorContains(t, err, httpapi.ResourceForbiddenResponse.Message)
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest // Subtests share the same coderdtest instance and auditor.
|
||||
func TestExpireAPIKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auditor := audit.NewMock()
|
||||
adminClient := coderdtest.New(t, &coderdtest.Options{Auditor: auditor})
|
||||
admin := coderdtest.CreateFirstUser(t, adminClient)
|
||||
memberClient, member := coderdtest.CreateAnotherUser(t, adminClient, admin.OrganizationID)
|
||||
|
||||
t.Run("OwnerCanExpireOwnToken", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Create a token.
|
||||
res, err := adminClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
|
||||
Lifetime: time.Hour * 24 * 7,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
keyID := strings.Split(res.Key, "-")[0]
|
||||
|
||||
// Verify the token is not expired.
|
||||
key, err := adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, key.ExpiresAt.After(time.Now()))
|
||||
|
||||
auditor.ResetLogs()
|
||||
|
||||
// Expire the token.
|
||||
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the token is expired.
|
||||
key, err = adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, key.ExpiresAt.Before(time.Now()))
|
||||
|
||||
// Verify audit log.
|
||||
als := auditor.AuditLogs()
|
||||
require.Len(t, als, 1)
|
||||
require.Equal(t, database.AuditActionWrite, als[0].Action)
|
||||
require.Equal(t, database.ResourceTypeApiKey, als[0].ResourceType)
|
||||
require.Equal(t, admin.UserID.String(), als[0].UserID.String())
|
||||
})
|
||||
|
||||
t.Run("AdminCanExpireOtherUsersToken", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Create a token for the member.
|
||||
res, err := memberClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
|
||||
Lifetime: time.Hour * 24 * 7,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
keyID := strings.Split(res.Key, "-")[0]
|
||||
|
||||
// Admin expires the member's token.
|
||||
err = adminClient.ExpireAPIKey(ctx, member.ID.String(), keyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the token is expired.
|
||||
key, err := memberClient.APIKeyByID(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, key.ExpiresAt.Before(time.Now()))
|
||||
})
|
||||
|
||||
t.Run("MemberCannotExpireOtherUsersToken", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Create a token for the admin.
|
||||
res, err := adminClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
|
||||
Lifetime: time.Hour * 24 * 7,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
keyID := strings.Split(res.Key, "-")[0]
|
||||
|
||||
// Member attempts to expire admin's token.
|
||||
err = memberClient.ExpireAPIKey(ctx, admin.UserID.String(), keyID)
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
// Members cannot read other users, so they get a 404 Not Found
|
||||
// from the authorization layer.
|
||||
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Try to expire a non-existent token.
|
||||
err := adminClient.ExpireAPIKey(ctx, codersdk.Me, "nonexistent")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ExpiringAlreadyExpiredTokenSucceeds", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Create and expire a token.
|
||||
res, err := adminClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
|
||||
Lifetime: time.Hour * 24 * 7,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
keyID := strings.Split(res.Key, "-")[0]
|
||||
|
||||
// Expire it once.
|
||||
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Invariant: make sure it's actually expired
|
||||
key, err := adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
require.LessOrEqual(t, key.ExpiresAt, time.Now(), "key should be expired")
|
||||
|
||||
// Expire it again - should succeed (idempotent).
|
||||
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Token should still be just as expired as before. No more, no less.
|
||||
keyAgain, err := adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, key.ExpiresAt, keyAgain.ExpiresAt, "expiration should be idempotent")
|
||||
})
|
||||
|
||||
t.Run("DeletingExpiredTokenSucceeds", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Create a token.
|
||||
res, err := adminClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
|
||||
Lifetime: time.Hour * 24 * 7,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
keyID := strings.Split(res.Key, "-")[0]
|
||||
|
||||
// Expire it first.
|
||||
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's expired.
|
||||
key, err := adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, key.ExpiresAt.Before(time.Now()))
|
||||
|
||||
// Delete the expired token - should succeed.
|
||||
err = adminClient.DeleteAPIKey(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's gone.
|
||||
_, err = adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
@@ -9,7 +8,6 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
@@ -93,36 +91,6 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action policy.Action, object
|
||||
return true
|
||||
}
|
||||
|
||||
// AuthorizeContext checks whether the RBAC subject on the context
|
||||
// is authorized to perform the given action. The subject must have
|
||||
// been set via dbauthz.As or the ExtractAPIKey middleware. Returns
|
||||
// false if the subject is missing or unauthorized.
|
||||
func (h *HTTPAuthorizer) AuthorizeContext(ctx context.Context, action policy.Action, object rbac.Objecter) bool {
|
||||
roles, ok := dbauthz.ActorFromContext(ctx)
|
||||
if !ok {
|
||||
h.Logger.Error(ctx, "no authorization actor in context")
|
||||
return false
|
||||
}
|
||||
err := h.Authorizer.Authorize(ctx, roles, action, object.RBACObject())
|
||||
if err != nil {
|
||||
internalError := new(rbac.UnauthorizedError)
|
||||
logger := h.Logger
|
||||
if xerrors.As(err, internalError) {
|
||||
logger = h.Logger.With(slog.F("internal_error", internalError.Internal()))
|
||||
}
|
||||
logger.Warn(ctx, "requester is not authorized to access the object",
|
||||
slog.F("roles", roles.SafeRoleNames()),
|
||||
slog.F("actor_id", roles.ID),
|
||||
slog.F("actor_name", roles),
|
||||
slog.F("scope", roles.SafeScopeName()),
|
||||
slog.F("action", action),
|
||||
slog.F("object", object),
|
||||
)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// AuthorizeSQLFilter returns an authorization filter that can used in a
|
||||
// SQL 'WHERE' clause. If the filter is used, the resulting rows returned
|
||||
// from postgres are already authorized, and the caller does not need to
|
||||
@@ -138,22 +106,6 @@ func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action policy.Actio
|
||||
return prepared, nil
|
||||
}
|
||||
|
||||
// AuthorizeSQLFilterContext is like AuthorizeSQLFilter but reads the
|
||||
// RBAC subject from the context directly rather than from an
|
||||
// *http.Request. The subject must have been set via dbauthz.As.
|
||||
func (h *HTTPAuthorizer) AuthorizeSQLFilterContext(ctx context.Context, action policy.Action, objectType string) (rbac.PreparedAuthorized, error) {
|
||||
roles, ok := dbauthz.ActorFromContext(ctx)
|
||||
if !ok {
|
||||
return nil, xerrors.New("no authorization actor in context")
|
||||
}
|
||||
prepared, err := h.Authorizer.Prepare(ctx, roles, action, objectType)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("prepare filter: %w", err)
|
||||
}
|
||||
|
||||
return prepared, nil
|
||||
}
|
||||
|
||||
// checkAuthorization returns if the current API key can use the given
|
||||
// permissions, factoring in the current user's roles and the API key scopes.
|
||||
//
|
||||
|
||||
@@ -231,7 +231,6 @@ func (e *Executor) runOnce(t time.Time) Stats {
|
||||
job *database.ProvisionerJob
|
||||
auditLog *auditParams
|
||||
shouldNotifyDormancy bool
|
||||
shouldNotifyTaskPause bool
|
||||
nextBuild *database.WorkspaceBuild
|
||||
activeTemplateVersion database.TemplateVersion
|
||||
ws database.Workspace
|
||||
@@ -317,10 +316,6 @@ func (e *Executor) runOnce(t time.Time) Stats {
|
||||
return nil
|
||||
}
|
||||
|
||||
if reason == database.BuildReasonTaskAutoPause {
|
||||
shouldNotifyTaskPause = true
|
||||
}
|
||||
|
||||
// Get the template version job to access tags
|
||||
templateVersionJob, err := tx.GetProvisionerJobByID(e.ctx, activeTemplateVersion.JobID)
|
||||
if err != nil {
|
||||
@@ -487,28 +482,6 @@ func (e *Executor) runOnce(t time.Time) Stats {
|
||||
log.Warn(e.ctx, "failed to notify of workspace marked as dormant", slog.Error(err), slog.F("workspace_id", ws.ID))
|
||||
}
|
||||
}
|
||||
if shouldNotifyTaskPause {
|
||||
task, err := e.db.GetTaskByID(e.ctx, ws.TaskID.UUID)
|
||||
if err != nil {
|
||||
log.Warn(e.ctx, "failed to get task for pause notification", slog.Error(err), slog.F("task_id", ws.TaskID.UUID), slog.F("workspace_id", ws.ID))
|
||||
} else {
|
||||
if _, err := e.notificationsEnqueuer.Enqueue(
|
||||
e.ctx,
|
||||
ws.OwnerID,
|
||||
notifications.TemplateTaskPaused,
|
||||
map[string]string{
|
||||
"task": task.Name,
|
||||
"task_id": task.ID.String(),
|
||||
"workspace": ws.Name,
|
||||
"pause_reason": "idle timeout",
|
||||
},
|
||||
"lifecycle_executor",
|
||||
ws.ID, ws.OwnerID, ws.OrganizationID,
|
||||
); err != nil {
|
||||
log.Warn(e.ctx, "failed to notify of task paused", slog.Error(err), slog.F("task_id", ws.TaskID.UUID), slog.F("workspace_id", ws.ID))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
if err != nil && !xerrors.Is(err, context.Canceled) {
|
||||
@@ -552,18 +525,10 @@ func getNextTransition(
|
||||
) {
|
||||
switch {
|
||||
case isEligibleForAutostop(user, ws, latestBuild, latestJob, currentTick):
|
||||
// Use task-specific reason for AI task workspaces.
|
||||
if ws.TaskID.Valid {
|
||||
return database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil
|
||||
}
|
||||
return database.WorkspaceTransitionStop, database.BuildReasonAutostop, nil
|
||||
case isEligibleForAutostart(user, ws, latestBuild, latestJob, templateSchedule, currentTick):
|
||||
return database.WorkspaceTransitionStart, database.BuildReasonAutostart, nil
|
||||
case isEligibleForFailedStop(latestBuild, latestJob, templateSchedule, currentTick):
|
||||
// Use task-specific reason for AI task workspaces.
|
||||
if ws.TaskID.Valid {
|
||||
return database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil
|
||||
}
|
||||
return database.WorkspaceTransitionStop, database.BuildReasonAutostop, nil
|
||||
case isEligibleForDormantStop(ws, templateSchedule, currentTick):
|
||||
// Only stop started workspaces.
|
||||
|
||||
@@ -5,113 +5,12 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/schedule"
|
||||
)
|
||||
|
||||
func Test_getNextTransition_TaskAutoPause(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Set up a workspace that is eligible for autostop (past deadline).
|
||||
now := time.Now()
|
||||
pastDeadline := now.Add(-time.Hour)
|
||||
|
||||
okUser := database.User{Status: database.UserStatusActive}
|
||||
okBuild := database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
Deadline: pastDeadline,
|
||||
}
|
||||
okJob := database.ProvisionerJob{
|
||||
JobStatus: database.ProvisionerJobStatusSucceeded,
|
||||
}
|
||||
okTemplateSchedule := schedule.TemplateScheduleOptions{}
|
||||
|
||||
// Failed build setup for failedstop tests.
|
||||
failedBuild := database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
}
|
||||
failedJob := database.ProvisionerJob{
|
||||
JobStatus: database.ProvisionerJobStatusFailed,
|
||||
CompletedAt: sql.NullTime{Time: now.Add(-time.Hour), Valid: true},
|
||||
}
|
||||
failedTemplateSchedule := schedule.TemplateScheduleOptions{
|
||||
FailureTTL: time.Minute, // TTL already elapsed since job completed an hour ago.
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
Name string
|
||||
Workspace database.Workspace
|
||||
Build database.WorkspaceBuild
|
||||
Job database.ProvisionerJob
|
||||
TemplateSchedule schedule.TemplateScheduleOptions
|
||||
ExpectedReason database.BuildReason
|
||||
}{
|
||||
{
|
||||
Name: "RegularWorkspace_Autostop",
|
||||
Workspace: database.Workspace{
|
||||
DormantAt: sql.NullTime{Valid: false},
|
||||
},
|
||||
Build: okBuild,
|
||||
Job: okJob,
|
||||
TemplateSchedule: okTemplateSchedule,
|
||||
ExpectedReason: database.BuildReasonAutostop,
|
||||
},
|
||||
{
|
||||
Name: "TaskWorkspace_Autostop_UsesTaskAutoPause",
|
||||
Workspace: database.Workspace{
|
||||
DormantAt: sql.NullTime{Valid: false},
|
||||
TaskID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
},
|
||||
Build: okBuild,
|
||||
Job: okJob,
|
||||
TemplateSchedule: okTemplateSchedule,
|
||||
ExpectedReason: database.BuildReasonTaskAutoPause,
|
||||
},
|
||||
{
|
||||
Name: "RegularWorkspace_FailedStop",
|
||||
Workspace: database.Workspace{
|
||||
DormantAt: sql.NullTime{Valid: false},
|
||||
},
|
||||
Build: failedBuild,
|
||||
Job: failedJob,
|
||||
TemplateSchedule: failedTemplateSchedule,
|
||||
ExpectedReason: database.BuildReasonAutostop,
|
||||
},
|
||||
{
|
||||
Name: "TaskWorkspace_FailedStop_UsesTaskAutoPause",
|
||||
Workspace: database.Workspace{
|
||||
DormantAt: sql.NullTime{Valid: false},
|
||||
TaskID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
},
|
||||
Build: failedBuild,
|
||||
Job: failedJob,
|
||||
TemplateSchedule: failedTemplateSchedule,
|
||||
ExpectedReason: database.BuildReasonTaskAutoPause,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
transition, reason, err := getNextTransition(
|
||||
okUser,
|
||||
tc.Workspace,
|
||||
tc.Build,
|
||||
tc.Job,
|
||||
tc.TemplateSchedule,
|
||||
now,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.WorkspaceTransitionStop, transition)
|
||||
require.Equal(t, tc.ExpectedReason, reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isEligibleForAutostart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -2019,69 +2019,5 @@ func TestExecutorTaskWorkspace(t *testing.T) {
|
||||
assert.Contains(t, stats.Transitions, workspace.ID, "task workspace should be in transitions")
|
||||
assert.Equal(t, database.WorkspaceTransitionStop, stats.Transitions[workspace.ID], "should autostop the workspace")
|
||||
require.Empty(t, stats.Errors, "should have no errors when managing task workspaces")
|
||||
|
||||
// Then: The build reason should be TaskAutoPause (not regular Autostop)
|
||||
workspace = coderdtest.MustWorkspace(t, client, workspace.ID)
|
||||
_ = coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
workspace = coderdtest.MustWorkspace(t, client, workspace.ID)
|
||||
assert.Equal(t, codersdk.BuildReasonTaskAutoPause, workspace.LatestBuild.Reason, "task workspace should use TaskAutoPause build reason")
|
||||
})
|
||||
|
||||
t.Run("AutostopNotification", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
tickCh = make(chan time.Time)
|
||||
statsCh = make(chan autobuild.Stats)
|
||||
notifyEnq = notificationstest.FakeEnqueuer{}
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
AutobuildTicker: tickCh,
|
||||
IncludeProvisionerDaemon: true,
|
||||
AutobuildStats: statsCh,
|
||||
NotificationsEnqueuer: ¬ifyEnq,
|
||||
})
|
||||
admin = coderdtest.CreateFirstUser(t, client)
|
||||
)
|
||||
|
||||
// Given: A task workspace with an 8 hour deadline
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
template := createTaskTemplate(t, client, admin.OrganizationID, ctx, 8*time.Hour)
|
||||
workspace := createTaskWorkspace(t, client, template, ctx, "test task for autostop notification")
|
||||
|
||||
// Given: The workspace is currently running
|
||||
workspace = coderdtest.MustWorkspace(t, client, workspace.ID)
|
||||
require.Equal(t, codersdk.WorkspaceTransitionStart, workspace.LatestBuild.Transition)
|
||||
require.NotZero(t, workspace.LatestBuild.Deadline, "workspace should have a deadline for autostop")
|
||||
|
||||
p, err := coderdtest.GetProvisionerForTags(db, time.Now(), workspace.OrganizationID, map[string]string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// When: the autobuild executor ticks after the deadline
|
||||
go func() {
|
||||
tickTime := workspace.LatestBuild.Deadline.Time.Add(time.Minute)
|
||||
coderdtest.UpdateProvisionerLastSeenAt(t, db, p.ID, tickTime)
|
||||
tickCh <- tickTime
|
||||
close(tickCh)
|
||||
}()
|
||||
|
||||
// Then: We expect to see a stop transition
|
||||
stats := <-statsCh
|
||||
require.Len(t, stats.Transitions, 1, "lifecycle executor should transition the task workspace")
|
||||
assert.Contains(t, stats.Transitions, workspace.ID, "task workspace should be in transitions")
|
||||
assert.Equal(t, database.WorkspaceTransitionStop, stats.Transitions[workspace.ID], "should autostop the workspace")
|
||||
require.Empty(t, stats.Errors, "should have no errors when managing task workspaces")
|
||||
|
||||
// Then: A task paused notification was sent with "idle timeout" reason
|
||||
require.True(t, workspace.TaskID.Valid, "workspace should have a task ID")
|
||||
task, err := db.GetTaskByID(dbauthz.AsSystemRestricted(ctx), workspace.TaskID.UUID)
|
||||
require.NoError(t, err)
|
||||
|
||||
sent := notifyEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTaskPaused))
|
||||
require.Len(t, sent, 1)
|
||||
require.Equal(t, workspace.OwnerID, sent[0].UserID)
|
||||
require.Equal(t, task.Name, sent[0].Labels["task"])
|
||||
require.Equal(t, task.ID.String(), sent[0].Labels["task_id"])
|
||||
require.Equal(t, workspace.Name, sent[0].Labels["workspace"])
|
||||
require.Equal(t, "idle timeout", sent[0].Labels["pause_reason"])
|
||||
})
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,305 +0,0 @@
|
||||
package chatd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestInterruptChatBroadcastsStatusAcrossInstances(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replicaA := newTestServer(t, db, ps, uuid.New())
|
||||
replicaB := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := replicaA.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "interrupt-me",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
runningWorker := uuid.New()
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkerID: uuid.NullUUID{UUID: runningWorker, Valid: true},
|
||||
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, events, cancel, ok := replicaB.Subscribe(ctx, chat.ID, nil)
|
||||
require.True(t, ok)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
updated := replicaA.InterruptChat(ctx, chat)
|
||||
require.Equal(t, database.ChatStatusWaiting, updated.Status)
|
||||
require.False(t, updated.WorkerID.Valid)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case event := <-events:
|
||||
if event.Type != codersdk.ChatStreamEventTypeStatus || event.Status == nil {
|
||||
return false
|
||||
}
|
||||
return event.Status.Status == codersdk.ChatStatusWaiting
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func TestInterruptChatClearsWorkerInDatabase(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "db-transition",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated := replica.InterruptChat(ctx, chat)
|
||||
require.Equal(t, database.ChatStatusWaiting, updated.Status)
|
||||
require.False(t, updated.WorkerID.Valid)
|
||||
|
||||
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
|
||||
require.False(t, fromDB.WorkerID.Valid)
|
||||
}
|
||||
|
||||
func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "heartbeat-ownership",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
workerID := uuid.New()
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
||||
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
WorkerID: uuid.New(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), rows)
|
||||
|
||||
rows, err = db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
WorkerID: workerID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), rows)
|
||||
}
|
||||
|
||||
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "queue-when-busy",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
workerID := uuid.New()
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
||||
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: "queued"}},
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Queued)
|
||||
require.NotNil(t, result.QueuedMessage)
|
||||
require.Equal(t, database.ChatStatusRunning, result.Chat.Status)
|
||||
require.Equal(t, workerID, result.Chat.WorkerID.UUID)
|
||||
require.True(t, result.Chat.WorkerID.Valid)
|
||||
|
||||
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queued, 1)
|
||||
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 1)
|
||||
}
|
||||
|
||||
func TestSendMessageInterruptBehaviorSendsImmediatelyWhenBusy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "interrupt-when-busy",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: "interrupt"}},
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.Queued)
|
||||
require.Equal(t, database.ChatStatusPending, result.Chat.Status)
|
||||
require.False(t, result.Chat.WorkerID.Valid)
|
||||
|
||||
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusPending, fromDB.Status)
|
||||
require.False(t, fromDB.WorkerID.Valid)
|
||||
|
||||
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queued, 0)
|
||||
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 2)
|
||||
require.Equal(t, int64(messages[len(messages)-1].ID), result.Message.ID)
|
||||
}
|
||||
|
||||
func newTestServer(
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
ps dbpubsub.Pubsub,
|
||||
replicaID uuid.UUID,
|
||||
) *chatd.Server {
|
||||
t.Helper()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := chatd.New(chatd.Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: replicaID,
|
||||
Pubsub: ps,
|
||||
PendingChatAcquireInterval: testutil.WaitSuperLong,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
return server
|
||||
}
|
||||
|
||||
func seedChatDependencies(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
) (database.User, database.ChatModelConfig) {
|
||||
t.Helper()
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
BaseUrl: "",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return user, model
|
||||
}
|
||||
@@ -1,676 +0,0 @@
|
||||
package chatloop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
const (
|
||||
interruptedToolResultErrorMessage = "tool call was interrupted before it produced a result"
|
||||
)
|
||||
|
||||
var ErrInterrupted = xerrors.New("chat interrupted")
|
||||
|
||||
// PersistedStep contains the full content of a completed or
|
||||
// interrupted agent step. Content includes both assistant blocks
|
||||
// (text, reasoning, tool calls) and tool result blocks, mirroring
|
||||
// what fantasy provides in StepResult.Content. The persistence
|
||||
// layer is responsible for splitting these into separate database
|
||||
// messages by role.
|
||||
type PersistedStep struct {
|
||||
Content []fantasy.Content
|
||||
Usage fantasy.Usage
|
||||
ContextLimit sql.NullInt64
|
||||
}
|
||||
|
||||
// RunOptions configures a single streaming chat loop run.
|
||||
type RunOptions struct {
|
||||
Model fantasy.LanguageModel
|
||||
Messages []fantasy.Message
|
||||
Tools []fantasy.AgentTool
|
||||
StreamCall fantasy.AgentStreamCall
|
||||
MaxSteps int
|
||||
|
||||
ActiveTools []string
|
||||
ContextLimitFallback int64
|
||||
|
||||
PersistStep func(context.Context, PersistedStep) error
|
||||
PublishMessagePart func(
|
||||
role fantasy.MessageRole,
|
||||
part codersdk.ChatMessagePart,
|
||||
)
|
||||
Compaction *CompactionOptions
|
||||
|
||||
OnInterruptedPersistError func(error)
|
||||
}
|
||||
|
||||
// Run executes the chat step-stream loop and delegates persistence/publishing to callbacks.
|
||||
func Run(ctx context.Context, opts RunOptions) (*fantasy.AgentResult, error) {
|
||||
if opts.Model == nil {
|
||||
return nil, xerrors.New("chat model is required")
|
||||
}
|
||||
if opts.PersistStep == nil {
|
||||
return nil, xerrors.New("persist step callback is required")
|
||||
}
|
||||
if opts.MaxSteps <= 0 {
|
||||
opts.MaxSteps = 1
|
||||
}
|
||||
|
||||
publishMessagePart := func(role fantasy.MessageRole, part codersdk.ChatMessagePart) {
|
||||
if opts.PublishMessagePart == nil {
|
||||
return
|
||||
}
|
||||
opts.PublishMessagePart(role, part)
|
||||
}
|
||||
|
||||
var (
|
||||
stepStateMu sync.Mutex
|
||||
streamToolNames map[string]string
|
||||
streamReasoningTitles map[string]string
|
||||
streamReasoningText map[string]string
|
||||
// stepToolResultContents tracks tool results received during
|
||||
// streaming. These are needed for the interrupted-step path
|
||||
// where OnStepFinish never fires.
|
||||
stepToolResultContents []fantasy.ToolResultContent
|
||||
stepAssistantDraft []fantasy.Content
|
||||
stepToolCallIndexByID map[string]int
|
||||
)
|
||||
|
||||
resetStepState := func() {
|
||||
stepStateMu.Lock()
|
||||
streamToolNames = make(map[string]string)
|
||||
streamReasoningTitles = make(map[string]string)
|
||||
streamReasoningText = make(map[string]string)
|
||||
stepToolResultContents = nil
|
||||
stepAssistantDraft = nil
|
||||
stepToolCallIndexByID = make(map[string]int)
|
||||
stepStateMu.Unlock()
|
||||
}
|
||||
|
||||
setReasoningTitleFromText := func(id string, text string) {
|
||||
if id == "" || strings.TrimSpace(text) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
stepStateMu.Lock()
|
||||
defer stepStateMu.Unlock()
|
||||
|
||||
if streamReasoningTitles[id] != "" {
|
||||
return
|
||||
}
|
||||
|
||||
streamReasoningText[id] += text
|
||||
if !strings.ContainsAny(streamReasoningText[id], "\r\n") {
|
||||
return
|
||||
}
|
||||
title := chatprompt.ReasoningTitleFromFirstLine(streamReasoningText[id])
|
||||
if title == "" {
|
||||
return
|
||||
}
|
||||
|
||||
streamReasoningTitles[id] = title
|
||||
}
|
||||
|
||||
appendDraftText := func(text string) {
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
|
||||
stepStateMu.Lock()
|
||||
defer stepStateMu.Unlock()
|
||||
|
||||
if len(stepAssistantDraft) > 0 {
|
||||
lastIndex := len(stepAssistantDraft) - 1
|
||||
switch last := stepAssistantDraft[lastIndex].(type) {
|
||||
case fantasy.TextContent:
|
||||
last.Text += text
|
||||
stepAssistantDraft[lastIndex] = last
|
||||
return
|
||||
case *fantasy.TextContent:
|
||||
last.Text += text
|
||||
stepAssistantDraft[lastIndex] = fantasy.TextContent{Text: last.Text}
|
||||
return
|
||||
}
|
||||
}
|
||||
stepAssistantDraft = append(stepAssistantDraft, fantasy.TextContent{Text: text})
|
||||
}
|
||||
|
||||
appendDraftReasoning := func(text string) {
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
|
||||
stepStateMu.Lock()
|
||||
defer stepStateMu.Unlock()
|
||||
|
||||
if len(stepAssistantDraft) > 0 {
|
||||
lastIndex := len(stepAssistantDraft) - 1
|
||||
switch last := stepAssistantDraft[lastIndex].(type) {
|
||||
case fantasy.ReasoningContent:
|
||||
last.Text += text
|
||||
stepAssistantDraft[lastIndex] = last
|
||||
return
|
||||
case *fantasy.ReasoningContent:
|
||||
last.Text += text
|
||||
stepAssistantDraft[lastIndex] = fantasy.ReasoningContent{Text: last.Text}
|
||||
return
|
||||
}
|
||||
}
|
||||
stepAssistantDraft = append(stepAssistantDraft, fantasy.ReasoningContent{Text: text})
|
||||
}
|
||||
|
||||
upsertDraftToolCall := func(toolCallID, toolName, input string, appendInput bool) {
|
||||
if toolCallID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
stepStateMu.Lock()
|
||||
defer stepStateMu.Unlock()
|
||||
|
||||
if strings.TrimSpace(toolName) != "" {
|
||||
streamToolNames[toolCallID] = toolName
|
||||
}
|
||||
|
||||
index, exists := stepToolCallIndexByID[toolCallID]
|
||||
if !exists {
|
||||
stepToolCallIndexByID[toolCallID] = len(stepAssistantDraft)
|
||||
stepAssistantDraft = append(stepAssistantDraft, fantasy.ToolCallContent{
|
||||
ToolCallID: toolCallID,
|
||||
ToolName: toolName,
|
||||
Input: input,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if index < 0 || index >= len(stepAssistantDraft) {
|
||||
stepToolCallIndexByID[toolCallID] = len(stepAssistantDraft)
|
||||
stepAssistantDraft = append(stepAssistantDraft, fantasy.ToolCallContent{
|
||||
ToolCallID: toolCallID,
|
||||
ToolName: toolName,
|
||||
Input: input,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
existingCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](stepAssistantDraft[index])
|
||||
if !ok {
|
||||
if ptrCall, ptrOK := fantasy.AsContentType[*fantasy.ToolCallContent](stepAssistantDraft[index]); ptrOK && ptrCall != nil {
|
||||
existingCall = *ptrCall
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
stepToolCallIndexByID[toolCallID] = len(stepAssistantDraft)
|
||||
stepAssistantDraft = append(stepAssistantDraft, fantasy.ToolCallContent{
|
||||
ToolCallID: toolCallID,
|
||||
ToolName: toolName,
|
||||
Input: input,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(toolName) != "" {
|
||||
existingCall.ToolName = toolName
|
||||
}
|
||||
if appendInput {
|
||||
existingCall.Input += input
|
||||
} else if input != "" || existingCall.Input == "" {
|
||||
existingCall.Input = input
|
||||
}
|
||||
stepAssistantDraft[index] = existingCall
|
||||
}
|
||||
|
||||
appendDraftSource := func(source fantasy.SourceContent) {
|
||||
stepStateMu.Lock()
|
||||
stepAssistantDraft = append(stepAssistantDraft, source)
|
||||
stepStateMu.Unlock()
|
||||
}
|
||||
|
||||
persistInterruptedStep := func() error {
|
||||
stepStateMu.Lock()
|
||||
draft := append([]fantasy.Content(nil), stepAssistantDraft...)
|
||||
toolResults := append([]fantasy.ToolResultContent(nil), stepToolResultContents...)
|
||||
toolNameByCallID := make(map[string]string, len(streamToolNames))
|
||||
for id, name := range streamToolNames {
|
||||
toolNameByCallID[id] = name
|
||||
}
|
||||
stepStateMu.Unlock()
|
||||
|
||||
if len(draft) == 0 && len(toolResults) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Track which tool calls already have results.
|
||||
answeredToolCalls := make(map[string]struct{}, len(toolResults))
|
||||
for _, tr := range toolResults {
|
||||
if tr.ToolCallID != "" {
|
||||
answeredToolCalls[tr.ToolCallID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the combined content: draft + received tool results
|
||||
// + synthetic interrupted results for unanswered tool calls.
|
||||
content := make([]fantasy.Content, 0, len(draft)+len(toolResults))
|
||||
content = append(content, draft...)
|
||||
for _, tr := range toolResults {
|
||||
content = append(content, tr)
|
||||
}
|
||||
|
||||
for _, block := range draft {
|
||||
toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](block)
|
||||
if !ok {
|
||||
if ptrCall, ptrOK := fantasy.AsContentType[*fantasy.ToolCallContent](block); ptrOK && ptrCall != nil {
|
||||
toolCall = *ptrCall
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
if !ok || toolCall.ToolCallID == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := answeredToolCalls[toolCall.ToolCallID]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
toolName := strings.TrimSpace(toolCall.ToolName)
|
||||
if toolName == "" {
|
||||
toolName = strings.TrimSpace(toolNameByCallID[toolCall.ToolCallID])
|
||||
}
|
||||
|
||||
content = append(content, fantasy.ToolResultContent{
|
||||
ToolCallID: toolCall.ToolCallID,
|
||||
ToolName: toolName,
|
||||
Result: fantasy.ToolResultOutputContentError{
|
||||
Error: xerrors.New(interruptedToolResultErrorMessage),
|
||||
},
|
||||
})
|
||||
answeredToolCalls[toolCall.ToolCallID] = struct{}{}
|
||||
}
|
||||
|
||||
persistCtx := context.WithoutCancel(ctx)
|
||||
return opts.PersistStep(persistCtx, PersistedStep{
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
resetStepState()
|
||||
|
||||
agent := fantasy.NewAgent(
|
||||
opts.Model,
|
||||
fantasy.WithTools(opts.Tools...),
|
||||
fantasy.WithStopConditions(fantasy.StepCountIs(opts.MaxSteps)),
|
||||
)
|
||||
applyAnthropicCaching := shouldApplyAnthropicPromptCaching(opts.Model)
|
||||
// Fantasy's AgentStreamCall currently requires a non-empty Prompt and always
|
||||
// appends it as a user message. chatd already supplies the full history in
|
||||
// Messages, so we pass and then strip a sentinel user message in PrepareStep.
|
||||
sentinelPrompt := "__chatd_agent_prompt_sentinel_" + uuid.NewString()
|
||||
|
||||
streamCall := opts.StreamCall
|
||||
streamCall.Prompt = sentinelPrompt
|
||||
streamCall.Messages = opts.Messages
|
||||
streamCall.PrepareStep = func(
|
||||
stepCtx context.Context,
|
||||
options fantasy.PrepareStepFunctionOptions,
|
||||
) (context.Context, fantasy.PrepareStepResult, error) {
|
||||
return stepCtx, prepareStepResult(
|
||||
options.Messages,
|
||||
sentinelPrompt,
|
||||
opts.ActiveTools,
|
||||
applyAnthropicCaching,
|
||||
), nil
|
||||
}
|
||||
streamCall.OnStepStart = func(_ int) error {
|
||||
resetStepState()
|
||||
return nil
|
||||
}
|
||||
streamCall.OnTextDelta = func(_ string, text string) error {
|
||||
appendDraftText(text)
|
||||
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: text,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
streamCall.OnReasoningDelta = func(id string, text string) error {
|
||||
appendDraftReasoning(text)
|
||||
setReasoningTitleFromText(id, text)
|
||||
stepStateMu.Lock()
|
||||
title := streamReasoningTitles[id]
|
||||
stepStateMu.Unlock()
|
||||
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Text: text,
|
||||
Title: title,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
streamCall.OnReasoningEnd = func(id string, _ fantasy.ReasoningContent) error {
|
||||
stepStateMu.Lock()
|
||||
if streamReasoningTitles[id] == "" {
|
||||
// At the end of reasoning we have the full text, so we can
|
||||
// safely evaluate first-line title format even if no newline
|
||||
// ever arrived in deltas.
|
||||
streamReasoningTitles[id] = chatprompt.ReasoningTitleFromFirstLine(
|
||||
streamReasoningText[id],
|
||||
)
|
||||
}
|
||||
title := streamReasoningTitles[id]
|
||||
stepStateMu.Unlock()
|
||||
if title != "" {
|
||||
// Publish a title-only reasoning part so clients can update the
|
||||
// reasoning header when metadata arrives at the end of streaming.
|
||||
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Title: title,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
streamCall.OnToolInputStart = func(id, toolName string) error {
|
||||
upsertDraftToolCall(id, toolName, "", false)
|
||||
return nil
|
||||
}
|
||||
streamCall.OnToolInputDelta = func(id, delta string) error {
|
||||
stepStateMu.Lock()
|
||||
toolName := streamToolNames[id]
|
||||
stepStateMu.Unlock()
|
||||
upsertDraftToolCall(id, toolName, delta, true)
|
||||
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: id,
|
||||
ToolName: toolName,
|
||||
ArgsDelta: delta,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
streamCall.OnToolCall = func(toolCall fantasy.ToolCallContent) error {
|
||||
upsertDraftToolCall(toolCall.ToolCallID, toolCall.ToolName, toolCall.Input, false)
|
||||
publishMessagePart(
|
||||
fantasy.MessageRoleAssistant,
|
||||
chatprompt.PartFromContent(toolCall),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
streamCall.OnSource = func(source fantasy.SourceContent) error {
|
||||
appendDraftSource(source)
|
||||
publishMessagePart(
|
||||
fantasy.MessageRoleAssistant,
|
||||
chatprompt.PartFromContent(source),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
streamCall.OnToolResult = func(result fantasy.ToolResultContent) error {
|
||||
publishMessagePart(
|
||||
fantasy.MessageRoleTool,
|
||||
chatprompt.PartFromContent(result),
|
||||
)
|
||||
|
||||
stepStateMu.Lock()
|
||||
if result.ToolCallID != "" && strings.TrimSpace(result.ToolName) != "" {
|
||||
streamToolNames[result.ToolCallID] = result.ToolName
|
||||
}
|
||||
stepToolResultContents = append(stepToolResultContents, result)
|
||||
stepStateMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
streamCall.OnStepFinish = func(stepResult fantasy.StepResult) error {
|
||||
contextLimit := extractContextLimit(stepResult.ProviderMetadata)
|
||||
if !contextLimit.Valid && opts.ContextLimitFallback > 0 {
|
||||
contextLimit = sql.NullInt64{
|
||||
Int64: opts.ContextLimitFallback,
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
|
||||
return opts.PersistStep(ctx, PersistedStep{
|
||||
Content: stepResult.Content,
|
||||
Usage: stepResult.Usage,
|
||||
ContextLimit: contextLimit,
|
||||
})
|
||||
}
|
||||
|
||||
result, err := agent.Stream(ctx, streamCall)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) &&
|
||||
errors.Is(context.Cause(ctx), ErrInterrupted) {
|
||||
if persistErr := persistInterruptedStep(); persistErr != nil {
|
||||
if opts.OnInterruptedPersistError != nil {
|
||||
opts.OnInterruptedPersistError(persistErr)
|
||||
}
|
||||
}
|
||||
return nil, ErrInterrupted
|
||||
}
|
||||
return nil, xerrors.Errorf("stream response: %w", err)
|
||||
}
|
||||
if opts.Compaction != nil {
|
||||
if err := maybeCompact(ctx, opts, result); err != nil {
|
||||
if opts.Compaction.OnError != nil {
|
||||
opts.Compaction.OnError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
//nolint:revive // Boolean controls Anthropic-specific caching behavior.
|
||||
func prepareStepResult(
|
||||
messages []fantasy.Message,
|
||||
sentinel string,
|
||||
activeTools []string,
|
||||
anthropicCaching bool,
|
||||
) fantasy.PrepareStepResult {
|
||||
filtered := make([]fantasy.Message, 0, len(messages))
|
||||
removed := false
|
||||
for _, message := range messages {
|
||||
if !removed &&
|
||||
message.Role == fantasy.MessageRoleUser &&
|
||||
len(message.Content) == 1 {
|
||||
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](message.Content[0])
|
||||
if ok && textPart.Text == sentinel {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, message)
|
||||
}
|
||||
|
||||
result := fantasy.PrepareStepResult{
|
||||
Messages: filtered,
|
||||
}
|
||||
if anthropicCaching {
|
||||
result.Messages = addAnthropicPromptCaching(result.Messages)
|
||||
}
|
||||
if len(activeTools) > 0 {
|
||||
result.ActiveTools = append([]string(nil), activeTools...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func shouldApplyAnthropicPromptCaching(model fantasy.LanguageModel) bool {
|
||||
if model == nil {
|
||||
return false
|
||||
}
|
||||
return model.Provider() == fantasyanthropic.Name
|
||||
}
|
||||
|
||||
func addAnthropicPromptCaching(messages []fantasy.Message) []fantasy.Message {
|
||||
for i := range messages {
|
||||
messages[i].ProviderOptions = nil
|
||||
}
|
||||
|
||||
providerOption := fantasy.ProviderOptions{
|
||||
fantasyanthropic.Name: &fantasyanthropic.ProviderCacheControlOptions{
|
||||
CacheControl: fantasyanthropic.CacheControl{Type: "ephemeral"},
|
||||
},
|
||||
}
|
||||
|
||||
lastSystemRoleIdx := -1
|
||||
systemMessageUpdated := false
|
||||
for i, msg := range messages {
|
||||
if msg.Role == fantasy.MessageRoleSystem {
|
||||
lastSystemRoleIdx = i
|
||||
} else if !systemMessageUpdated && lastSystemRoleIdx >= 0 {
|
||||
messages[lastSystemRoleIdx].ProviderOptions = providerOption
|
||||
systemMessageUpdated = true
|
||||
}
|
||||
if i > len(messages)-3 {
|
||||
messages[i].ProviderOptions = providerOption
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func extractContextLimit(metadata fantasy.ProviderMetadata) sql.NullInt64 {
|
||||
if len(metadata) == 0 {
|
||||
return sql.NullInt64{}
|
||||
}
|
||||
|
||||
encoded, err := json.Marshal(metadata)
|
||||
if err != nil || len(encoded) == 0 {
|
||||
return sql.NullInt64{}
|
||||
}
|
||||
|
||||
var payload any
|
||||
if err := json.Unmarshal(encoded, &payload); err != nil {
|
||||
return sql.NullInt64{}
|
||||
}
|
||||
|
||||
limit, ok := findContextLimitValue(payload)
|
||||
if !ok {
|
||||
return sql.NullInt64{}
|
||||
}
|
||||
|
||||
return sql.NullInt64{
|
||||
Int64: limit,
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
|
||||
func findContextLimitValue(value any) (int64, bool) {
|
||||
var (
|
||||
limit int64
|
||||
found bool
|
||||
)
|
||||
|
||||
collectContextLimitValues(value, func(candidate int64) {
|
||||
if !found || candidate > limit {
|
||||
limit = candidate
|
||||
found = true
|
||||
}
|
||||
})
|
||||
|
||||
return limit, found
|
||||
}
|
||||
|
||||
func collectContextLimitValues(value any, onValue func(int64)) {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
for key, child := range typed {
|
||||
if isContextLimitKey(key) {
|
||||
if numeric, ok := numericContextLimitValue(child); ok {
|
||||
onValue(numeric)
|
||||
}
|
||||
}
|
||||
collectContextLimitValues(child, onValue)
|
||||
}
|
||||
case []any:
|
||||
for _, child := range typed {
|
||||
collectContextLimitValues(child, onValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isContextLimitKey(key string) bool {
|
||||
normalized := normalizeMetadataKey(key)
|
||||
if normalized == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
switch normalized {
|
||||
case
|
||||
"contextlimit",
|
||||
"contextwindow",
|
||||
"contextlength",
|
||||
"maxcontext",
|
||||
"maxcontexttokens",
|
||||
"maxinputtokens",
|
||||
"maxinputtoken",
|
||||
"inputtokenlimit":
|
||||
return true
|
||||
}
|
||||
|
||||
return strings.Contains(normalized, "context") &&
|
||||
(strings.Contains(normalized, "limit") ||
|
||||
strings.Contains(normalized, "window") ||
|
||||
strings.Contains(normalized, "length") ||
|
||||
strings.HasPrefix(normalized, "max"))
|
||||
}
|
||||
|
||||
func normalizeMetadataKey(key string) string {
|
||||
var b strings.Builder
|
||||
b.Grow(len(key))
|
||||
|
||||
for _, r := range key {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
_, _ = b.WriteRune(r)
|
||||
case r >= 'A' && r <= 'Z':
|
||||
_, _ = b.WriteRune(r + ('a' - 'A'))
|
||||
case r >= '0' && r <= '9':
|
||||
_, _ = b.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func numericContextLimitValue(value any) (int64, bool) {
|
||||
switch typed := value.(type) {
|
||||
case int64:
|
||||
return positiveInt64(typed)
|
||||
case int32:
|
||||
return positiveInt64(int64(typed))
|
||||
case int:
|
||||
return positiveInt64(int64(typed))
|
||||
case float64:
|
||||
casted := int64(typed)
|
||||
if typed > 0 && float64(casted) == typed {
|
||||
return casted, true
|
||||
}
|
||||
case string:
|
||||
parsed, err := strconv.ParseInt(strings.TrimSpace(typed), 10, 64)
|
||||
if err == nil {
|
||||
return positiveInt64(parsed)
|
||||
}
|
||||
case json.Number:
|
||||
parsed, err := typed.Int64()
|
||||
if err == nil {
|
||||
return positiveInt64(parsed)
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func positiveInt64(value int64) (int64, bool) {
|
||||
if value <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return value, true
|
||||
}
|
||||
@@ -1,289 +0,0 @@
|
||||
package chatloop //nolint:testpackage // Uses internal symbols.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"iter"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
const activeToolName = "read_file"
|
||||
|
||||
func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var capturedCall fantasy.Call
|
||||
model := &loopTestModel{
|
||||
provider: fantasyanthropic.Name,
|
||||
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
capturedCall = call
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
|
||||
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
|
||||
}), nil
|
||||
},
|
||||
}
|
||||
|
||||
persistStepCalls := 0
|
||||
var persistedStep PersistedStep
|
||||
|
||||
_, err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleSystem, "sys-1"),
|
||||
textMessage(fantasy.MessageRoleSystem, "sys-2"),
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
textMessage(fantasy.MessageRoleAssistant, "working"),
|
||||
textMessage(fantasy.MessageRoleUser, "continue"),
|
||||
},
|
||||
Tools: []fantasy.AgentTool{
|
||||
newNoopTool(activeToolName),
|
||||
newNoopTool("write_file"),
|
||||
},
|
||||
MaxSteps: 3,
|
||||
ActiveTools: []string{activeToolName},
|
||||
ContextLimitFallback: 4096,
|
||||
PersistStep: func(_ context.Context, step PersistedStep) error {
|
||||
persistStepCalls++
|
||||
persistedStep = step
|
||||
return nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, 1, persistStepCalls)
|
||||
require.True(t, persistedStep.ContextLimit.Valid)
|
||||
require.Equal(t, int64(4096), persistedStep.ContextLimit.Int64)
|
||||
|
||||
require.NotEmpty(t, capturedCall.Prompt)
|
||||
require.False(t, containsPromptSentinel(capturedCall.Prompt))
|
||||
require.Len(t, capturedCall.Tools, 1)
|
||||
require.Equal(t, activeToolName, capturedCall.Tools[0].GetName())
|
||||
|
||||
require.Len(t, capturedCall.Prompt, 5)
|
||||
require.False(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[0]))
|
||||
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[1]))
|
||||
require.False(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[2]))
|
||||
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[3]))
|
||||
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[4]))
|
||||
}
|
||||
|
||||
func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
started := make(chan struct{})
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
|
||||
parts := []fantasy.StreamPart{
|
||||
{
|
||||
Type: fantasy.StreamPartTypeToolInputStart,
|
||||
ID: "interrupt-tool-1",
|
||||
ToolCallName: "read_file",
|
||||
},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeToolInputDelta,
|
||||
ID: "interrupt-tool-1",
|
||||
ToolCallName: "read_file",
|
||||
Delta: `{"path":"main.go"`,
|
||||
},
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "partial assistant output"},
|
||||
}
|
||||
for _, part := range parts {
|
||||
if !yield(part) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
default:
|
||||
close(started)
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
_ = yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeError,
|
||||
Error: ctx.Err(),
|
||||
})
|
||||
}), nil
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
defer cancel(nil)
|
||||
|
||||
go func() {
|
||||
<-started
|
||||
cancel(ErrInterrupted)
|
||||
}()
|
||||
|
||||
persistedAssistantCtxErr := xerrors.New("unset")
|
||||
var persistedContent []fantasy.Content
|
||||
|
||||
_, err := Run(ctx, RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
},
|
||||
Tools: []fantasy.AgentTool{
|
||||
newNoopTool("read_file"),
|
||||
},
|
||||
MaxSteps: 3,
|
||||
PersistStep: func(persistCtx context.Context, step PersistedStep) error {
|
||||
persistedAssistantCtxErr = persistCtx.Err()
|
||||
persistedContent = append([]fantasy.Content(nil), step.Content...)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
require.ErrorIs(t, err, ErrInterrupted)
|
||||
require.NoError(t, persistedAssistantCtxErr)
|
||||
|
||||
require.NotEmpty(t, persistedContent)
|
||||
var (
|
||||
foundText bool
|
||||
foundToolCall bool
|
||||
foundToolResult bool
|
||||
)
|
||||
for _, block := range persistedContent {
|
||||
if text, ok := fantasy.AsContentType[fantasy.TextContent](block); ok {
|
||||
if strings.Contains(text.Text, "partial assistant output") {
|
||||
foundText = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](block); ok {
|
||||
if toolCall.ToolCallID == "interrupt-tool-1" &&
|
||||
toolCall.ToolName == "read_file" &&
|
||||
strings.Contains(toolCall.Input, `"path":"main.go"`) {
|
||||
foundToolCall = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block); ok {
|
||||
if toolResult.ToolCallID == "interrupt-tool-1" &&
|
||||
toolResult.ToolName == "read_file" {
|
||||
_, isErr := toolResult.Result.(fantasy.ToolResultOutputContentError)
|
||||
require.True(t, isErr, "interrupted tool result should be an error")
|
||||
foundToolResult = true
|
||||
}
|
||||
}
|
||||
}
|
||||
require.True(t, foundText)
|
||||
require.True(t, foundToolCall)
|
||||
require.True(t, foundToolResult)
|
||||
}
|
||||
|
||||
type loopTestModel struct {
|
||||
provider string
|
||||
model string
|
||||
generateFn func(context.Context, fantasy.Call) (*fantasy.Response, error)
|
||||
streamFn func(context.Context, fantasy.Call) (fantasy.StreamResponse, error)
|
||||
}
|
||||
|
||||
func (m *loopTestModel) Provider() string {
|
||||
if m.provider != "" {
|
||||
return m.provider
|
||||
}
|
||||
return "fake"
|
||||
}
|
||||
|
||||
func (m *loopTestModel) Model() string {
|
||||
if m.model != "" {
|
||||
return m.model
|
||||
}
|
||||
return "fake"
|
||||
}
|
||||
|
||||
func (m *loopTestModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
|
||||
if m.generateFn != nil {
|
||||
return m.generateFn(ctx, call)
|
||||
}
|
||||
return &fantasy.Response{}, nil
|
||||
}
|
||||
|
||||
func (m *loopTestModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
if m.streamFn != nil {
|
||||
return m.streamFn(ctx, call)
|
||||
}
|
||||
return streamFromParts([]fantasy.StreamPart{{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
}}), nil
|
||||
}
|
||||
|
||||
func (*loopTestModel) GenerateObject(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
return nil, xerrors.New("not implemented")
|
||||
}
|
||||
|
||||
func (*loopTestModel) StreamObject(context.Context, fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
|
||||
return nil, xerrors.New("not implemented")
|
||||
}
|
||||
|
||||
func streamFromParts(parts []fantasy.StreamPart) fantasy.StreamResponse {
|
||||
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
|
||||
for _, part := range parts {
|
||||
if !yield(part) {
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func newNoopTool(name string) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
name,
|
||||
"test noop tool",
|
||||
func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return fantasy.ToolResponse{}, nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func textMessage(role fantasy.MessageRole, text string) fantasy.Message {
|
||||
return fantasy.Message{
|
||||
Role: role,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: text},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func containsPromptSentinel(prompt []fantasy.Message) bool {
|
||||
for _, message := range prompt {
|
||||
if message.Role != fantasy.MessageRoleUser || len(message.Content) != 1 {
|
||||
continue
|
||||
}
|
||||
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](message.Content[0])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(textPart.Text, "__chatd_agent_prompt_sentinel_") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasAnthropicEphemeralCacheControl(message fantasy.Message) bool {
|
||||
if len(message.ProviderOptions) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
options, ok := message.ProviderOptions[fantasyanthropic.Name]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
cacheOptions, ok := options.(*fantasyanthropic.ProviderCacheControlOptions)
|
||||
return ok && cacheOptions.CacheControl.Type == "ephemeral"
|
||||
}
|
||||
@@ -1,209 +0,0 @@
|
||||
package chatloop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCompactionThresholdPercent = int32(70)
|
||||
minCompactionThresholdPercent = int32(0)
|
||||
maxCompactionThresholdPercent = int32(100)
|
||||
|
||||
defaultCompactionSummaryPrompt = "Summarize the current chat so a " +
|
||||
"new assistant can continue seamlessly. Include the user's goals, " +
|
||||
"decisions made, concrete technical details (files, commands, APIs), " +
|
||||
"errors encountered and fixes, and open questions. Be dense and factual. " +
|
||||
"Omit pleasantries and next-step suggestions."
|
||||
defaultCompactionSystemSummaryPrefix = "Summary of earlier chat context:"
|
||||
defaultCompactionTimeout = 90 * time.Second
|
||||
)
|
||||
|
||||
type CompactionOptions struct {
|
||||
ThresholdPercent int32
|
||||
ContextLimit int64
|
||||
SummaryPrompt string
|
||||
SystemSummaryPrefix string
|
||||
Timeout time.Duration
|
||||
Persist func(context.Context, CompactionResult) error
|
||||
OnError func(error)
|
||||
}
|
||||
|
||||
type CompactionResult struct {
|
||||
SystemSummary string
|
||||
SummaryReport string
|
||||
ThresholdPercent int32
|
||||
UsagePercent float64
|
||||
ContextTokens int64
|
||||
ContextLimit int64
|
||||
}
|
||||
|
||||
func maybeCompact(
|
||||
ctx context.Context,
|
||||
runOpts RunOptions,
|
||||
runResult *fantasy.AgentResult,
|
||||
) error {
|
||||
if runResult == nil || runOpts.Compaction == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := *runOpts.Compaction
|
||||
if config.Persist == nil {
|
||||
return xerrors.New("compaction persist callback is required")
|
||||
}
|
||||
if strings.TrimSpace(config.SummaryPrompt) == "" {
|
||||
config.SummaryPrompt = defaultCompactionSummaryPrompt
|
||||
}
|
||||
if strings.TrimSpace(config.SystemSummaryPrefix) == "" {
|
||||
config.SystemSummaryPrefix = defaultCompactionSystemSummaryPrefix
|
||||
}
|
||||
if config.Timeout <= 0 {
|
||||
config.Timeout = defaultCompactionTimeout
|
||||
}
|
||||
if config.ThresholdPercent < minCompactionThresholdPercent ||
|
||||
config.ThresholdPercent > maxCompactionThresholdPercent {
|
||||
config.ThresholdPercent = defaultCompactionThresholdPercent
|
||||
}
|
||||
|
||||
if config.ThresholdPercent >= maxCompactionThresholdPercent {
|
||||
return nil
|
||||
}
|
||||
if runOpts.MaxSteps > 0 && len(runResult.Steps) >= runOpts.MaxSteps {
|
||||
lastStep := runResult.Steps[len(runResult.Steps)-1]
|
||||
if lastStep.FinishReason == fantasy.FinishReasonToolCalls &&
|
||||
len(lastStep.Content.ToolCalls()) > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
contextTokens := int64(0)
|
||||
contextLimitFromMetadata := int64(0)
|
||||
for i := len(runResult.Steps) - 1; i >= 0; i-- {
|
||||
usage := runResult.Steps[i].Usage
|
||||
total := int64(0)
|
||||
hasContextTokens := false
|
||||
|
||||
if usage.InputTokens > 0 {
|
||||
total += usage.InputTokens
|
||||
hasContextTokens = true
|
||||
}
|
||||
if usage.CacheReadTokens > 0 {
|
||||
total += usage.CacheReadTokens
|
||||
hasContextTokens = true
|
||||
}
|
||||
if usage.CacheCreationTokens > 0 {
|
||||
total += usage.CacheCreationTokens
|
||||
hasContextTokens = true
|
||||
}
|
||||
if !hasContextTokens && usage.TotalTokens > 0 {
|
||||
total = usage.TotalTokens
|
||||
hasContextTokens = true
|
||||
}
|
||||
if !hasContextTokens || total <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
contextTokens = total
|
||||
metadataLimit := extractContextLimit(runResult.Steps[i].ProviderMetadata)
|
||||
if metadataLimit.Valid && metadataLimit.Int64 > 0 {
|
||||
contextLimitFromMetadata = metadataLimit.Int64
|
||||
}
|
||||
break
|
||||
}
|
||||
if contextTokens <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
contextLimit := contextLimitFromMetadata
|
||||
if contextLimit <= 0 && config.ContextLimit > 0 {
|
||||
contextLimit = config.ContextLimit
|
||||
}
|
||||
if contextLimit <= 0 && runOpts.ContextLimitFallback > 0 {
|
||||
contextLimit = runOpts.ContextLimitFallback
|
||||
}
|
||||
if contextLimit <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
usagePercent := (float64(contextTokens) / float64(contextLimit)) * 100
|
||||
if usagePercent < float64(config.ThresholdPercent) {
|
||||
return nil
|
||||
}
|
||||
|
||||
summary, err := generateCompactionSummary(
|
||||
ctx,
|
||||
runOpts.Model,
|
||||
runOpts.Messages,
|
||||
runResult.Steps,
|
||||
config,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if summary == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
systemSummary := strings.TrimSpace(
|
||||
config.SystemSummaryPrefix + "\n\n" + summary,
|
||||
)
|
||||
|
||||
return config.Persist(ctx, CompactionResult{
|
||||
SystemSummary: systemSummary,
|
||||
SummaryReport: summary,
|
||||
ThresholdPercent: config.ThresholdPercent,
|
||||
UsagePercent: usagePercent,
|
||||
ContextTokens: contextTokens,
|
||||
ContextLimit: contextLimit,
|
||||
})
|
||||
}
|
||||
|
||||
func generateCompactionSummary(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
messages []fantasy.Message,
|
||||
steps []fantasy.StepResult,
|
||||
options CompactionOptions,
|
||||
) (string, error) {
|
||||
summaryPrompt := make([]fantasy.Message, 0, len(messages)+len(steps)+1)
|
||||
summaryPrompt = append(summaryPrompt, messages...)
|
||||
for _, step := range steps {
|
||||
summaryPrompt = append(summaryPrompt, step.Messages...)
|
||||
}
|
||||
summaryPrompt = append(summaryPrompt, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: options.SummaryPrompt},
|
||||
},
|
||||
})
|
||||
toolChoice := fantasy.ToolChoiceNone
|
||||
|
||||
summaryCtx, cancel := context.WithTimeout(ctx, options.Timeout)
|
||||
defer cancel()
|
||||
|
||||
response, err := model.Generate(summaryCtx, fantasy.Call{
|
||||
Prompt: summaryPrompt,
|
||||
ToolChoice: &toolChoice,
|
||||
})
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("generate summary text: %w", err)
|
||||
}
|
||||
|
||||
parts := make([]string, 0, len(response.Content))
|
||||
for _, block := range response.Content {
|
||||
textBlock, ok := fantasy.AsContentType[fantasy.TextContent](block)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
text := strings.TrimSpace(textBlock.Text)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, text)
|
||||
}
|
||||
return strings.TrimSpace(strings.Join(parts, " ")), nil
|
||||
}
|
||||
@@ -1,132 +0,0 @@
|
||||
package chatloop //nolint:testpackage // Uses internal symbols.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
func TestRun_Compaction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("PersistsWhenThresholdReached", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
persistCompactionCalls := 0
|
||||
var persistedCompaction CompactionResult
|
||||
const summaryText = "summary text for compaction"
|
||||
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
|
||||
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
Usage: fantasy.Usage{
|
||||
InputTokens: 80,
|
||||
TotalTokens: 85,
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
},
|
||||
generateFn: func(_ context.Context, call fantasy.Call) (*fantasy.Response, error) {
|
||||
require.NotEmpty(t, call.Prompt)
|
||||
lastPrompt := call.Prompt[len(call.Prompt)-1]
|
||||
require.Equal(t, fantasy.MessageRoleUser, lastPrompt.Role)
|
||||
require.Len(t, lastPrompt.Content, 1)
|
||||
|
||||
instruction, ok := fantasy.AsMessagePart[fantasy.TextPart](lastPrompt.Content[0])
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "summarize now", instruction.Text)
|
||||
|
||||
return &fantasy.Response{
|
||||
Content: []fantasy.Content{
|
||||
fantasy.TextContent{Text: summaryText},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
_, err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
},
|
||||
MaxSteps: 1,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
ContextLimitFallback: 100,
|
||||
Compaction: &CompactionOptions{
|
||||
ThresholdPercent: 70,
|
||||
SummaryPrompt: "summarize now",
|
||||
Persist: func(_ context.Context, result CompactionResult) error {
|
||||
persistCompactionCalls++
|
||||
persistedCompaction = result
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, persistCompactionCalls)
|
||||
require.Contains(t, persistedCompaction.SystemSummary, summaryText)
|
||||
require.Equal(t, summaryText, persistedCompaction.SummaryReport)
|
||||
require.Equal(t, int64(80), persistedCompaction.ContextTokens)
|
||||
require.Equal(t, int64(100), persistedCompaction.ContextLimit)
|
||||
require.InDelta(t, 80.0, persistedCompaction.UsagePercent, 0.0001)
|
||||
})
|
||||
|
||||
t.Run("ErrorsAreReported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
Usage: fantasy.Usage{
|
||||
InputTokens: 80,
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
},
|
||||
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return nil, xerrors.New("generate failed")
|
||||
},
|
||||
}
|
||||
|
||||
compactionErr := xerrors.New("unset")
|
||||
_, err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
},
|
||||
MaxSteps: 1,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
ContextLimitFallback: 100,
|
||||
Compaction: &CompactionOptions{
|
||||
ThresholdPercent: 70,
|
||||
Persist: func(_ context.Context, _ CompactionResult) error {
|
||||
return nil
|
||||
},
|
||||
OnError: func(err error) {
|
||||
compactionErr = err
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Error(t, compactionErr)
|
||||
require.ErrorContains(t, compactionErr, "generate summary text")
|
||||
})
|
||||
}
|
||||
@@ -1,982 +0,0 @@
|
||||
package chatprompt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyopenai "charm.land/fantasy/providers/openai"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
var toolCallIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
|
||||
|
||||
func ConvertMessages(
|
||||
messages []database.ChatMessage,
|
||||
) ([]fantasy.Message, error) {
|
||||
prompt := make([]fantasy.Message, 0, len(messages))
|
||||
toolNameByCallID := make(map[string]string)
|
||||
for _, message := range messages {
|
||||
visibility := message.Visibility
|
||||
if visibility == "" {
|
||||
visibility = database.ChatMessageVisibilityBoth
|
||||
}
|
||||
if visibility != database.ChatMessageVisibilityModel &&
|
||||
visibility != database.ChatMessageVisibilityBoth {
|
||||
continue
|
||||
}
|
||||
|
||||
switch message.Role {
|
||||
case string(fantasy.MessageRoleSystem):
|
||||
content, err := parseSystemContent(message.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(content) == "" {
|
||||
continue
|
||||
}
|
||||
prompt = append(prompt, fantasy.Message{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: content},
|
||||
},
|
||||
})
|
||||
case string(fantasy.MessageRoleUser):
|
||||
content, err := ParseContent(string(fantasy.MessageRoleUser), message.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prompt = append(prompt, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: ToMessageParts(content),
|
||||
})
|
||||
case string(fantasy.MessageRoleAssistant):
|
||||
content, err := ParseContent(string(fantasy.MessageRoleAssistant), message.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts := normalizeAssistantToolCallInputs(ToMessageParts(content))
|
||||
for _, toolCall := range ExtractToolCalls(parts) {
|
||||
if toolCall.ToolCallID == "" || strings.TrimSpace(toolCall.ToolName) == "" {
|
||||
continue
|
||||
}
|
||||
toolNameByCallID[sanitizeToolCallID(toolCall.ToolCallID)] = toolCall.ToolName
|
||||
}
|
||||
prompt = append(prompt, fantasy.Message{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: parts,
|
||||
})
|
||||
case string(fantasy.MessageRoleTool):
|
||||
rows, err := parseToolResultRows(message.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts := make([]fantasy.MessagePart, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
if row.ToolCallID != "" && row.ToolName != "" {
|
||||
toolNameByCallID[sanitizeToolCallID(row.ToolCallID)] = row.ToolName
|
||||
}
|
||||
parts = append(parts, row.toToolResultPart())
|
||||
}
|
||||
prompt = append(prompt, fantasy.Message{
|
||||
Role: fantasy.MessageRoleTool,
|
||||
Content: parts,
|
||||
})
|
||||
default:
|
||||
return nil, xerrors.Errorf("unsupported chat message role %q", message.Role)
|
||||
}
|
||||
}
|
||||
prompt = injectMissingToolResults(prompt)
|
||||
prompt = injectMissingToolUses(
|
||||
prompt,
|
||||
toolNameByCallID,
|
||||
)
|
||||
return prompt, nil
|
||||
}
|
||||
|
||||
// PrependSystem prepends a system message unless an existing system
|
||||
// message already mentions create_workspace guidance.
|
||||
func PrependSystem(prompt []fantasy.Message, instruction string) []fantasy.Message {
|
||||
instruction = strings.TrimSpace(instruction)
|
||||
if instruction == "" {
|
||||
return prompt
|
||||
}
|
||||
for _, message := range prompt {
|
||||
if message.Role != fantasy.MessageRoleSystem {
|
||||
continue
|
||||
}
|
||||
for _, part := range message.Content {
|
||||
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(strings.ToLower(textPart.Text), "create_workspace") {
|
||||
return prompt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]fantasy.Message, 0, len(prompt)+1)
|
||||
out = append(out, fantasy.Message{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: instruction},
|
||||
},
|
||||
})
|
||||
out = append(out, prompt...)
|
||||
return out
|
||||
}
|
||||
|
||||
// InsertSystem inserts a system message after the existing system
|
||||
// block and before the first non-system message.
|
||||
func InsertSystem(prompt []fantasy.Message, instruction string) []fantasy.Message {
|
||||
instruction = strings.TrimSpace(instruction)
|
||||
if instruction == "" {
|
||||
return prompt
|
||||
}
|
||||
|
||||
systemMessage := fantasy.Message{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: instruction},
|
||||
},
|
||||
}
|
||||
|
||||
out := make([]fantasy.Message, 0, len(prompt)+1)
|
||||
inserted := false
|
||||
for _, message := range prompt {
|
||||
if !inserted && message.Role != fantasy.MessageRoleSystem {
|
||||
out = append(out, systemMessage)
|
||||
inserted = true
|
||||
}
|
||||
out = append(out, message)
|
||||
}
|
||||
if !inserted {
|
||||
out = append(out, systemMessage)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// AppendUser appends an instruction as a user message at the end of
|
||||
// the prompt.
|
||||
func AppendUser(prompt []fantasy.Message, instruction string) []fantasy.Message {
|
||||
instruction = strings.TrimSpace(instruction)
|
||||
if instruction == "" {
|
||||
return prompt
|
||||
}
|
||||
out := make([]fantasy.Message, 0, len(prompt)+1)
|
||||
out = append(out, prompt...)
|
||||
out = append(out, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: instruction},
|
||||
},
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
// ParseContent decodes persisted chat message content blocks.
|
||||
func ParseContent(role string, raw pqtype.NullRawMessage) ([]fantasy.Content, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var text string
|
||||
if err := json.Unmarshal(raw.RawMessage, &text); err == nil {
|
||||
return []fantasy.Content{fantasy.TextContent{Text: text}}, nil
|
||||
}
|
||||
|
||||
var rawBlocks []json.RawMessage
|
||||
if err := json.Unmarshal(raw.RawMessage, &rawBlocks); err != nil {
|
||||
return nil, xerrors.Errorf("parse %s content: %w", role, err)
|
||||
}
|
||||
|
||||
content := make([]fantasy.Content, 0, len(rawBlocks))
|
||||
for i, rawBlock := range rawBlocks {
|
||||
block, err := fantasy.UnmarshalContent(rawBlock)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse %s content block %d: %w", role, i, err)
|
||||
}
|
||||
content = append(content, block)
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// toolResultRaw is an untyped representation of a persisted tool
|
||||
// result row. We intentionally avoid a strict Go struct so that
|
||||
// historical shapes are never rejected.
|
||||
type toolResultRaw struct {
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Result json.RawMessage `json:"result"`
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
// parseToolResultRows decodes persisted tool result rows.
|
||||
func parseToolResultRows(raw pqtype.NullRawMessage) ([]toolResultRaw, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var rows []toolResultRaw
|
||||
if err := json.Unmarshal(raw.RawMessage, &rows); err != nil {
|
||||
return nil, xerrors.Errorf("parse tool content: %w", err)
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (r toolResultRaw) toToolResultPart() fantasy.ToolResultPart {
|
||||
toolCallID := sanitizeToolCallID(r.ToolCallID)
|
||||
resultText := string(r.Result)
|
||||
if resultText == "" || resultText == "null" {
|
||||
resultText = "{}"
|
||||
}
|
||||
|
||||
if r.IsError {
|
||||
message := strings.TrimSpace(resultText)
|
||||
if extracted := extractErrorString(r.Result); extracted != "" {
|
||||
message = extracted
|
||||
}
|
||||
return fantasy.ToolResultPart{
|
||||
ToolCallID: toolCallID,
|
||||
Output: fantasy.ToolResultOutputContentError{
|
||||
Error: xerrors.New(message),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return fantasy.ToolResultPart{
|
||||
ToolCallID: toolCallID,
|
||||
Output: fantasy.ToolResultOutputContentText{
|
||||
Text: resultText,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// extractErrorString pulls the "error" field from a JSON object if
|
||||
// present, returning it as a string. Returns "" if the field is
|
||||
// missing or the input is not an object.
|
||||
func extractErrorString(raw json.RawMessage) string {
|
||||
var fields map[string]json.RawMessage
|
||||
if err := json.Unmarshal(raw, &fields); err != nil {
|
||||
return ""
|
||||
}
|
||||
errField, ok := fields["error"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(errField, &s); err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
// ToMessageParts converts fantasy content blocks into message parts.
|
||||
func ToMessageParts(content []fantasy.Content) []fantasy.MessagePart {
|
||||
parts := make([]fantasy.MessagePart, 0, len(content))
|
||||
for _, block := range content {
|
||||
switch value := block.(type) {
|
||||
case fantasy.TextContent:
|
||||
parts = append(parts, fantasy.TextPart{
|
||||
Text: value.Text,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case *fantasy.TextContent:
|
||||
parts = append(parts, fantasy.TextPart{
|
||||
Text: value.Text,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case fantasy.ReasoningContent:
|
||||
parts = append(parts, fantasy.ReasoningPart{
|
||||
Text: value.Text,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case *fantasy.ReasoningContent:
|
||||
parts = append(parts, fantasy.ReasoningPart{
|
||||
Text: value.Text,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case fantasy.ToolCallContent:
|
||||
parts = append(parts, fantasy.ToolCallPart{
|
||||
ToolCallID: sanitizeToolCallID(value.ToolCallID),
|
||||
ToolName: value.ToolName,
|
||||
Input: value.Input,
|
||||
ProviderExecuted: value.ProviderExecuted,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case *fantasy.ToolCallContent:
|
||||
parts = append(parts, fantasy.ToolCallPart{
|
||||
ToolCallID: sanitizeToolCallID(value.ToolCallID),
|
||||
ToolName: value.ToolName,
|
||||
Input: value.Input,
|
||||
ProviderExecuted: value.ProviderExecuted,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case fantasy.FileContent:
|
||||
parts = append(parts, fantasy.FilePart{
|
||||
Data: value.Data,
|
||||
MediaType: value.MediaType,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case *fantasy.FileContent:
|
||||
parts = append(parts, fantasy.FilePart{
|
||||
Data: value.Data,
|
||||
MediaType: value.MediaType,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case fantasy.ToolResultContent:
|
||||
parts = append(parts, fantasy.ToolResultPart{
|
||||
ToolCallID: sanitizeToolCallID(value.ToolCallID),
|
||||
Output: value.Result,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
case *fantasy.ToolResultContent:
|
||||
parts = append(parts, fantasy.ToolResultPart{
|
||||
ToolCallID: sanitizeToolCallID(value.ToolCallID),
|
||||
Output: value.Result,
|
||||
ProviderOptions: fantasy.ProviderOptions(value.ProviderMetadata),
|
||||
})
|
||||
}
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func normalizeAssistantToolCallInputs(
|
||||
parts []fantasy.MessagePart,
|
||||
) []fantasy.MessagePart {
|
||||
normalized := make([]fantasy.MessagePart, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
|
||||
if !ok {
|
||||
normalized = append(normalized, part)
|
||||
continue
|
||||
}
|
||||
|
||||
toolCall.Input = normalizeToolCallInput(toolCall.Input)
|
||||
normalized = append(normalized, toolCall)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
// normalizeToolCallInput guarantees tool call input is a JSON object string.
|
||||
// Anthropic drops assistant tool calls with malformed input, which can leave
|
||||
// following tool results orphaned.
|
||||
func normalizeToolCallInput(input string) string {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return "{}"
|
||||
}
|
||||
|
||||
var object map[string]any
|
||||
if err := json.Unmarshal([]byte(input), &object); err != nil || object == nil {
|
||||
return "{}"
|
||||
}
|
||||
|
||||
return input
|
||||
}
|
||||
|
||||
// ExtractToolCalls returns all tool call parts as content blocks.
|
||||
func ExtractToolCalls(parts []fantasy.MessagePart) []fantasy.ToolCallContent {
|
||||
toolCalls := make([]fantasy.ToolCallContent, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
toolCalls = append(toolCalls, fantasy.ToolCallContent{
|
||||
ToolCallID: toolCall.ToolCallID,
|
||||
ToolName: toolCall.ToolName,
|
||||
Input: toolCall.Input,
|
||||
ProviderExecuted: toolCall.ProviderExecuted,
|
||||
})
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
// MarshalContent encodes message content blocks for persistence.
|
||||
func MarshalContent(blocks []fantasy.Content) (pqtype.NullRawMessage, error) {
|
||||
if len(blocks) == 0 {
|
||||
return pqtype.NullRawMessage{}, nil
|
||||
}
|
||||
|
||||
encodedBlocks := make([]json.RawMessage, 0, len(blocks))
|
||||
for i, block := range blocks {
|
||||
encoded, err := marshalContentBlock(block)
|
||||
if err != nil {
|
||||
return pqtype.NullRawMessage{}, xerrors.Errorf(
|
||||
"encode content block %d: %w",
|
||||
i,
|
||||
err,
|
||||
)
|
||||
}
|
||||
encodedBlocks = append(encodedBlocks, encoded)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(encodedBlocks)
|
||||
if err != nil {
|
||||
return pqtype.NullRawMessage{}, xerrors.Errorf("encode content blocks: %w", err)
|
||||
}
|
||||
return pqtype.NullRawMessage{RawMessage: data, Valid: true}, nil
|
||||
}
|
||||
|
||||
// MarshalToolResult encodes a single tool result for persistence as
|
||||
// an opaque JSON blob. The stored shape is
|
||||
// [{"tool_call_id":…,"tool_name":…,"result":…,"is_error":…}].
|
||||
func MarshalToolResult(toolCallID, toolName string, result json.RawMessage, isError bool) (pqtype.NullRawMessage, error) {
|
||||
row := toolResultRaw{
|
||||
ToolCallID: toolCallID,
|
||||
ToolName: toolName,
|
||||
Result: result,
|
||||
IsError: isError,
|
||||
}
|
||||
data, err := json.Marshal([]toolResultRaw{row})
|
||||
if err != nil {
|
||||
return pqtype.NullRawMessage{}, xerrors.Errorf("encode tool result: %w", err)
|
||||
}
|
||||
return pqtype.NullRawMessage{RawMessage: data, Valid: true}, nil
|
||||
}
|
||||
|
||||
// MarshalToolResultContent encodes a fantasy tool result content
|
||||
// block for persistence. It extracts the raw fields and delegates
|
||||
// to MarshalToolResult.
|
||||
func MarshalToolResultContent(content fantasy.ToolResultContent) (pqtype.NullRawMessage, error) {
|
||||
var result json.RawMessage
|
||||
var isError bool
|
||||
|
||||
switch output := content.Result.(type) {
|
||||
case fantasy.ToolResultOutputContentError:
|
||||
isError = true
|
||||
if output.Error != nil {
|
||||
result, _ = json.Marshal(map[string]any{"error": output.Error.Error()})
|
||||
} else {
|
||||
result = []byte(`{"error":""}`)
|
||||
}
|
||||
case fantasy.ToolResultOutputContentText:
|
||||
result = json.RawMessage(output.Text)
|
||||
if !json.Valid(result) {
|
||||
result, _ = json.Marshal(map[string]any{"output": output.Text})
|
||||
}
|
||||
case fantasy.ToolResultOutputContentMedia:
|
||||
result, _ = json.Marshal(map[string]any{
|
||||
"data": output.Data,
|
||||
"mime_type": output.MediaType,
|
||||
"text": output.Text,
|
||||
})
|
||||
default:
|
||||
result = []byte(`{}`)
|
||||
}
|
||||
|
||||
return MarshalToolResult(content.ToolCallID, content.ToolName, result, isError)
|
||||
}
|
||||
|
||||
// PartFromContent converts fantasy content into a SDK chat message part.
|
||||
func PartFromContent(block fantasy.Content) codersdk.ChatMessagePart {
|
||||
switch value := block.(type) {
|
||||
case fantasy.TextContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: value.Text,
|
||||
}
|
||||
case *fantasy.TextContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: value.Text,
|
||||
}
|
||||
case fantasy.ReasoningContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Text: value.Text,
|
||||
Title: reasoningSummaryTitle(value.ProviderMetadata),
|
||||
}
|
||||
case *fantasy.ReasoningContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Text: value.Text,
|
||||
Title: reasoningSummaryTitle(value.ProviderMetadata),
|
||||
}
|
||||
case fantasy.ToolCallContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: value.ToolCallID,
|
||||
ToolName: value.ToolName,
|
||||
Args: []byte(value.Input),
|
||||
}
|
||||
case *fantasy.ToolCallContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: value.ToolCallID,
|
||||
ToolName: value.ToolName,
|
||||
Args: []byte(value.Input),
|
||||
}
|
||||
case fantasy.SourceContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeSource,
|
||||
SourceID: value.ID,
|
||||
URL: value.URL,
|
||||
Title: value.Title,
|
||||
}
|
||||
case *fantasy.SourceContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeSource,
|
||||
SourceID: value.ID,
|
||||
URL: value.URL,
|
||||
Title: value.Title,
|
||||
}
|
||||
case fantasy.FileContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeFile,
|
||||
MediaType: value.MediaType,
|
||||
Data: value.Data,
|
||||
}
|
||||
case *fantasy.FileContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeFile,
|
||||
MediaType: value.MediaType,
|
||||
Data: value.Data,
|
||||
}
|
||||
case fantasy.ToolResultContent:
|
||||
return toolResultContentToPart(value)
|
||||
case *fantasy.ToolResultContent:
|
||||
return toolResultContentToPart(*value)
|
||||
default:
|
||||
return codersdk.ChatMessagePart{}
|
||||
}
|
||||
}
|
||||
|
||||
// ToolResultToPart converts a tool call ID, raw result, and error
|
||||
// flag into a ChatMessagePart. This is the minimal conversion used
|
||||
// both during streaming and when reading from the database.
|
||||
func ToolResultToPart(toolCallID, toolName string, result json.RawMessage, isError bool) codersdk.ChatMessagePart {
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolResult,
|
||||
ToolCallID: toolCallID,
|
||||
ToolName: toolName,
|
||||
Result: result,
|
||||
IsError: isError,
|
||||
}
|
||||
}
|
||||
|
||||
// toolResultContentToPart converts a fantasy ToolResultContent
|
||||
// directly into a ChatMessagePart without an intermediate struct.
|
||||
func toolResultContentToPart(content fantasy.ToolResultContent) codersdk.ChatMessagePart {
|
||||
var result json.RawMessage
|
||||
var isError bool
|
||||
|
||||
switch output := content.Result.(type) {
|
||||
case fantasy.ToolResultOutputContentError:
|
||||
isError = true
|
||||
if output.Error != nil {
|
||||
result, _ = json.Marshal(map[string]any{"error": output.Error.Error()})
|
||||
} else {
|
||||
result = []byte(`{"error":""}`)
|
||||
}
|
||||
case fantasy.ToolResultOutputContentText:
|
||||
result = json.RawMessage(output.Text)
|
||||
// Ensure valid JSON; wrap in an object if not.
|
||||
if !json.Valid(result) {
|
||||
result, _ = json.Marshal(map[string]any{"output": output.Text})
|
||||
}
|
||||
case fantasy.ToolResultOutputContentMedia:
|
||||
result, _ = json.Marshal(map[string]any{
|
||||
"data": output.Data,
|
||||
"mime_type": output.MediaType,
|
||||
"text": output.Text,
|
||||
})
|
||||
default:
|
||||
result = []byte(`{}`)
|
||||
}
|
||||
|
||||
return ToolResultToPart(content.ToolCallID, content.ToolName, result, isError)
|
||||
}
|
||||
|
||||
// ReasoningTitleFromFirstLine extracts a compact markdown title.
|
||||
func ReasoningTitleFromFirstLine(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
firstLine := text
|
||||
if idx := strings.IndexAny(firstLine, "\r\n"); idx >= 0 {
|
||||
firstLine = firstLine[:idx]
|
||||
}
|
||||
firstLine = strings.TrimSpace(firstLine)
|
||||
if firstLine == "" || !strings.HasPrefix(firstLine, "**") {
|
||||
return ""
|
||||
}
|
||||
|
||||
rest := firstLine[2:]
|
||||
end := strings.Index(rest, "**")
|
||||
if end < 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
title := strings.TrimSpace(rest[:end])
|
||||
if title == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Require the first line to be exactly "**title**" (ignoring
|
||||
// surrounding whitespace) so providers without this format don't
|
||||
// accidentally emit a title.
|
||||
if strings.TrimSpace(rest[end+2:]) != "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return compactReasoningSummaryTitle(title)
|
||||
}
|
||||
|
||||
func injectMissingToolResults(prompt []fantasy.Message) []fantasy.Message {
|
||||
result := make([]fantasy.Message, 0, len(prompt))
|
||||
for i := 0; i < len(prompt); i++ {
|
||||
msg := prompt[i]
|
||||
result = append(result, msg)
|
||||
|
||||
if msg.Role != fantasy.MessageRoleAssistant {
|
||||
continue
|
||||
}
|
||||
toolCalls := ExtractToolCalls(msg.Content)
|
||||
if len(toolCalls) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Collect the tool call IDs that have results in the
|
||||
// following tool message(s).
|
||||
answered := make(map[string]struct{})
|
||||
j := i + 1
|
||||
for ; j < len(prompt); j++ {
|
||||
if prompt[j].Role != fantasy.MessageRoleTool {
|
||||
break
|
||||
}
|
||||
for _, part := range prompt[j].Content {
|
||||
tr, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
answered[tr.ToolCallID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if i+1 < j {
|
||||
// Preserve persisted tool result ordering and inject any
|
||||
// synthetic results after the existing contiguous tool messages.
|
||||
result = append(result, prompt[i+1:j]...)
|
||||
i = j - 1
|
||||
}
|
||||
|
||||
// Build synthetic results for any unanswered tool calls.
|
||||
var missing []fantasy.MessagePart
|
||||
for _, tc := range toolCalls {
|
||||
if _, ok := answered[tc.ToolCallID]; !ok {
|
||||
missing = append(missing, fantasy.ToolResultPart{
|
||||
ToolCallID: tc.ToolCallID,
|
||||
Output: fantasy.ToolResultOutputContentError{
|
||||
Error: xerrors.New("tool call was interrupted and did not receive a result"),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
result = append(result, fantasy.Message{
|
||||
Role: fantasy.MessageRoleTool,
|
||||
Content: missing,
|
||||
})
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func injectMissingToolUses(
|
||||
prompt []fantasy.Message,
|
||||
toolNameByCallID map[string]string,
|
||||
) []fantasy.Message {
|
||||
result := make([]fantasy.Message, 0, len(prompt))
|
||||
for _, msg := range prompt {
|
||||
if msg.Role != fantasy.MessageRoleTool {
|
||||
result = append(result, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
toolResults := make([]fantasy.ToolResultPart, 0, len(msg.Content))
|
||||
for _, part := range msg.Content {
|
||||
toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
toolResults = append(toolResults, toolResult)
|
||||
}
|
||||
if len(toolResults) == 0 {
|
||||
result = append(result, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
// Walk backwards through the result to find the nearest
|
||||
// preceding assistant message (skipping over other tool
|
||||
// messages that belong to the same batch of results).
|
||||
answeredByPrevious := make(map[string]struct{})
|
||||
for k := len(result) - 1; k >= 0; k-- {
|
||||
if result[k].Role == fantasy.MessageRoleAssistant {
|
||||
for _, toolCall := range ExtractToolCalls(result[k].Content) {
|
||||
toolCallID := sanitizeToolCallID(toolCall.ToolCallID)
|
||||
if toolCallID == "" {
|
||||
continue
|
||||
}
|
||||
answeredByPrevious[toolCallID] = struct{}{}
|
||||
}
|
||||
break
|
||||
}
|
||||
if result[k].Role != fantasy.MessageRoleTool {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
matchingResults := make([]fantasy.ToolResultPart, 0, len(toolResults))
|
||||
orphanResults := make([]fantasy.ToolResultPart, 0, len(toolResults))
|
||||
for _, toolResult := range toolResults {
|
||||
toolCallID := sanitizeToolCallID(toolResult.ToolCallID)
|
||||
if _, ok := answeredByPrevious[toolCallID]; ok {
|
||||
matchingResults = append(matchingResults, toolResult)
|
||||
continue
|
||||
}
|
||||
orphanResults = append(orphanResults, toolResult)
|
||||
}
|
||||
|
||||
if len(orphanResults) == 0 {
|
||||
result = append(result, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
syntheticToolUse := syntheticToolUseMessage(
|
||||
orphanResults,
|
||||
toolNameByCallID,
|
||||
)
|
||||
if len(syntheticToolUse.Content) == 0 {
|
||||
result = append(result, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(matchingResults) > 0 {
|
||||
result = append(result, toolMessageFromToolResultParts(matchingResults))
|
||||
}
|
||||
result = append(result, syntheticToolUse)
|
||||
result = append(result, toolMessageFromToolResultParts(orphanResults))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func toolMessageFromToolResultParts(results []fantasy.ToolResultPart) fantasy.Message {
|
||||
parts := make([]fantasy.MessagePart, 0, len(results))
|
||||
for _, result := range results {
|
||||
parts = append(parts, result)
|
||||
}
|
||||
return fantasy.Message{
|
||||
Role: fantasy.MessageRoleTool,
|
||||
Content: parts,
|
||||
}
|
||||
}
|
||||
|
||||
func syntheticToolUseMessage(
|
||||
toolResults []fantasy.ToolResultPart,
|
||||
toolNameByCallID map[string]string,
|
||||
) fantasy.Message {
|
||||
parts := make([]fantasy.MessagePart, 0, len(toolResults))
|
||||
seen := make(map[string]struct{}, len(toolResults))
|
||||
|
||||
for _, toolResult := range toolResults {
|
||||
toolCallID := sanitizeToolCallID(toolResult.ToolCallID)
|
||||
if toolCallID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[toolCallID]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
toolName := strings.TrimSpace(toolNameByCallID[toolCallID])
|
||||
if toolName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
seen[toolCallID] = struct{}{}
|
||||
parts = append(parts, fantasy.ToolCallPart{
|
||||
ToolCallID: toolCallID,
|
||||
ToolName: toolName,
|
||||
Input: "{}",
|
||||
})
|
||||
}
|
||||
|
||||
return fantasy.Message{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: parts,
|
||||
}
|
||||
}
|
||||
|
||||
func parseSystemContent(raw pqtype.NullRawMessage) (string, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var content string
|
||||
if err := json.Unmarshal(raw.RawMessage, &content); err != nil {
|
||||
return "", xerrors.Errorf("parse system message content: %w", err)
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func sanitizeToolCallID(id string) string {
|
||||
if id == "" {
|
||||
return ""
|
||||
}
|
||||
return toolCallIDSanitizer.ReplaceAllString(id, "_")
|
||||
}
|
||||
|
||||
func marshalContentBlock(block fantasy.Content) (json.RawMessage, error) {
|
||||
encoded, err := json.Marshal(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
title, ok := reasoningTitleFromContent(block)
|
||||
if !ok || title == "" {
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
Data map[string]any `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(encoded, &envelope); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !strings.EqualFold(envelope.Type, string(fantasy.ContentTypeReasoning)) {
|
||||
return encoded, nil
|
||||
}
|
||||
if envelope.Data == nil {
|
||||
envelope.Data = map[string]any{}
|
||||
}
|
||||
envelope.Data["title"] = title
|
||||
|
||||
encodedWithTitle, err := json.Marshal(envelope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encodedWithTitle, nil
|
||||
}
|
||||
|
||||
func reasoningTitleFromContent(block fantasy.Content) (string, bool) {
|
||||
switch value := block.(type) {
|
||||
case fantasy.ReasoningContent:
|
||||
return ReasoningTitleFromFirstLine(value.Text), true
|
||||
case *fantasy.ReasoningContent:
|
||||
if value == nil {
|
||||
return "", false
|
||||
}
|
||||
return ReasoningTitleFromFirstLine(value.Text), true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func reasoningSummaryTitle(metadata fantasy.ProviderMetadata) string {
|
||||
if len(metadata) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
reasoningMetadata := fantasyopenai.GetReasoningMetadata(
|
||||
fantasy.ProviderOptions(metadata),
|
||||
)
|
||||
if reasoningMetadata == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, summary := range reasoningMetadata.Summary {
|
||||
if title := compactReasoningSummaryTitle(summary); title != "" {
|
||||
return title
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func compactReasoningSummaryTitle(summary string) string {
|
||||
const maxWords = 8
|
||||
const maxRunes = 80
|
||||
|
||||
summary = strings.TrimSpace(summary)
|
||||
if summary == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
summary = strings.Trim(summary, "\"'`")
|
||||
summary = reasoningSummaryHeadline(summary)
|
||||
words := strings.Fields(summary)
|
||||
if len(words) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
truncated := false
|
||||
if len(words) > maxWords {
|
||||
words = words[:maxWords]
|
||||
truncated = true
|
||||
}
|
||||
|
||||
title := strings.Join(words, " ")
|
||||
if truncated {
|
||||
title += "…"
|
||||
}
|
||||
return truncateRunes(title, maxRunes)
|
||||
}
|
||||
|
||||
func reasoningSummaryHeadline(summary string) string {
|
||||
summary = strings.TrimSpace(summary)
|
||||
if summary == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// OpenAI summary_text may be markdown like:
|
||||
// "**Title**\n\nLonger explanation ...".
|
||||
// Keep only the heading segment for UI titles.
|
||||
if idx := strings.Index(summary, "\n\n"); idx >= 0 {
|
||||
summary = summary[:idx]
|
||||
}
|
||||
|
||||
if idx := strings.IndexAny(summary, "\r\n"); idx >= 0 {
|
||||
summary = summary[:idx]
|
||||
}
|
||||
|
||||
summary = strings.TrimSpace(summary)
|
||||
if summary == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if strings.HasPrefix(summary, "**") {
|
||||
rest := summary[2:]
|
||||
if end := strings.Index(rest, "**"); end >= 0 {
|
||||
bold := strings.TrimSpace(rest[:end])
|
||||
if bold != "" {
|
||||
summary = bold
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(strings.Trim(summary, "\"'`"))
|
||||
}
|
||||
|
||||
func truncateRunes(value string, maxLen int) string {
|
||||
if maxLen <= 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
runes := []rune(value)
|
||||
if len(runes) <= maxLen {
|
||||
return value
|
||||
}
|
||||
|
||||
return string(runes[:maxLen])
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
package chatprompt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
func TestConvertMessages_NormalizesAssistantToolCallInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
expected: "{}",
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
input: "{\"command\":",
|
||||
expected: "{}",
|
||||
},
|
||||
{
|
||||
name: "non-object json",
|
||||
input: "[]",
|
||||
expected: "{}",
|
||||
},
|
||||
{
|
||||
name: "valid object json",
|
||||
input: "{\"command\":\"ls\"}",
|
||||
expected: "{\"command\":\"ls\"}",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assistantContent, err := MarshalContent([]fantasy.Content{
|
||||
fantasy.ToolCallContent{
|
||||
ToolCallID: "toolu_01C4PqN6F2493pi7Ebag8Vg7",
|
||||
ToolName: "execute",
|
||||
Input: tc.input,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
toolContent, err := MarshalToolResult(
|
||||
"toolu_01C4PqN6F2493pi7Ebag8Vg7",
|
||||
"execute",
|
||||
json.RawMessage(`{"error":"tool call was interrupted before it produced a result"}`),
|
||||
true,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
prompt, err := ConvertMessages([]database.ChatMessage{
|
||||
{
|
||||
Role: string(fantasy.MessageRoleAssistant),
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
Content: assistantContent,
|
||||
},
|
||||
{
|
||||
Role: string(fantasy.MessageRoleTool),
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
Content: toolContent,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, prompt, 2)
|
||||
|
||||
require.Equal(t, fantasy.MessageRoleAssistant, prompt[0].Role)
|
||||
toolCalls := ExtractToolCalls(prompt[0].Content)
|
||||
require.Len(t, toolCalls, 1)
|
||||
require.Equal(t, tc.expected, toolCalls[0].Input)
|
||||
require.Equal(t, "execute", toolCalls[0].ToolName)
|
||||
require.Equal(t, "toolu_01C4PqN6F2493pi7Ebag8Vg7", toolCalls[0].ToolCallID)
|
||||
|
||||
require.Equal(t, fantasy.MessageRoleTool, prompt[1].Role)
|
||||
})
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,191 +0,0 @@
|
||||
package chatprovider_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
fantasyopenai "charm.land/fantasy/providers/openai"
|
||||
fantasyopenrouter "charm.land/fantasy/providers/openrouter"
|
||||
fantasyvercel "charm.land/fantasy/providers/vercel"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestReasoningEffortFromChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
provider string
|
||||
input *string
|
||||
want *string
|
||||
}{
|
||||
{
|
||||
name: "OpenAICaseInsensitive",
|
||||
provider: "openai",
|
||||
input: stringPtr(" HIGH "),
|
||||
want: stringPtr(string(fantasyopenai.ReasoningEffortHigh)),
|
||||
},
|
||||
{
|
||||
name: "AnthropicEffort",
|
||||
provider: "anthropic",
|
||||
input: stringPtr("max"),
|
||||
want: stringPtr(string(fantasyanthropic.EffortMax)),
|
||||
},
|
||||
{
|
||||
name: "OpenRouterEffort",
|
||||
provider: "openrouter",
|
||||
input: stringPtr("medium"),
|
||||
want: stringPtr(string(fantasyopenrouter.ReasoningEffortMedium)),
|
||||
},
|
||||
{
|
||||
name: "VercelEffort",
|
||||
provider: "vercel",
|
||||
input: stringPtr("xhigh"),
|
||||
want: stringPtr(string(fantasyvercel.ReasoningEffortXHigh)),
|
||||
},
|
||||
{
|
||||
name: "InvalidEffortReturnsNil",
|
||||
provider: "openai",
|
||||
input: stringPtr("unknown"),
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "UnsupportedProviderReturnsNil",
|
||||
provider: "bedrock",
|
||||
input: stringPtr("high"),
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "NilInputReturnsNil",
|
||||
provider: "openai",
|
||||
input: nil,
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := chatprovider.ReasoningEffortFromChat(tt.provider, tt.input)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
options := &codersdk.ChatModelProviderOptions{
|
||||
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
|
||||
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
|
||||
Enabled: boolPtr(true),
|
||||
},
|
||||
Provider: &codersdk.ChatModelOpenRouterProvider{
|
||||
Order: []string{"openai"},
|
||||
},
|
||||
},
|
||||
}
|
||||
defaults := &codersdk.ChatModelProviderOptions{
|
||||
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
|
||||
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
|
||||
Enabled: boolPtr(false),
|
||||
Exclude: boolPtr(true),
|
||||
MaxTokens: int64Ptr(123),
|
||||
Effort: stringPtr("high"),
|
||||
},
|
||||
IncludeUsage: boolPtr(true),
|
||||
Provider: &codersdk.ChatModelOpenRouterProvider{
|
||||
Order: []string{"anthropic"},
|
||||
AllowFallbacks: boolPtr(true),
|
||||
RequireParameters: boolPtr(false),
|
||||
DataCollection: stringPtr("allow"),
|
||||
Only: []string{"openai"},
|
||||
Ignore: []string{"foo"},
|
||||
Quantizations: []string{"int8"},
|
||||
Sort: stringPtr("latency"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chatprovider.MergeMissingProviderOptions(&options, defaults)
|
||||
|
||||
require.NotNil(t, options)
|
||||
require.NotNil(t, options.OpenRouter)
|
||||
require.NotNil(t, options.OpenRouter.Reasoning)
|
||||
require.True(t, *options.OpenRouter.Reasoning.Enabled)
|
||||
require.Equal(t, true, *options.OpenRouter.Reasoning.Exclude)
|
||||
require.EqualValues(t, 123, *options.OpenRouter.Reasoning.MaxTokens)
|
||||
require.Equal(t, "high", *options.OpenRouter.Reasoning.Effort)
|
||||
require.NotNil(t, options.OpenRouter.IncludeUsage)
|
||||
require.True(t, *options.OpenRouter.IncludeUsage)
|
||||
|
||||
require.NotNil(t, options.OpenRouter.Provider)
|
||||
require.Equal(t, []string{"openai"}, options.OpenRouter.Provider.Order)
|
||||
require.NotNil(t, options.OpenRouter.Provider.AllowFallbacks)
|
||||
require.True(t, *options.OpenRouter.Provider.AllowFallbacks)
|
||||
require.NotNil(t, options.OpenRouter.Provider.RequireParameters)
|
||||
require.False(t, *options.OpenRouter.Provider.RequireParameters)
|
||||
require.Equal(t, "allow", *options.OpenRouter.Provider.DataCollection)
|
||||
require.Equal(t, []string{"openai"}, options.OpenRouter.Provider.Only)
|
||||
require.Equal(t, []string{"foo"}, options.OpenRouter.Provider.Ignore)
|
||||
require.Equal(t, []string{"int8"}, options.OpenRouter.Provider.Quantizations)
|
||||
require.Equal(t, "latency", *options.OpenRouter.Provider.Sort)
|
||||
}
|
||||
|
||||
func TestMergeMissingCallConfig_FillsUnsetFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dst := codersdk.ChatModelCallConfig{
|
||||
Temperature: float64Ptr(0.2),
|
||||
ProviderOptions: &codersdk.ChatModelProviderOptions{
|
||||
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
|
||||
User: stringPtr("alice"),
|
||||
},
|
||||
},
|
||||
}
|
||||
defaults := codersdk.ChatModelCallConfig{
|
||||
MaxOutputTokens: int64Ptr(512),
|
||||
Temperature: float64Ptr(0.9),
|
||||
TopP: float64Ptr(0.8),
|
||||
ProviderOptions: &codersdk.ChatModelProviderOptions{
|
||||
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
|
||||
User: stringPtr("bob"),
|
||||
ReasoningEffort: stringPtr("medium"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chatprovider.MergeMissingCallConfig(&dst, defaults)
|
||||
|
||||
require.NotNil(t, dst.MaxOutputTokens)
|
||||
require.EqualValues(t, 512, *dst.MaxOutputTokens)
|
||||
require.NotNil(t, dst.Temperature)
|
||||
require.Equal(t, 0.2, *dst.Temperature)
|
||||
require.NotNil(t, dst.TopP)
|
||||
require.Equal(t, 0.8, *dst.TopP)
|
||||
require.NotNil(t, dst.ProviderOptions)
|
||||
require.NotNil(t, dst.ProviderOptions.OpenAI)
|
||||
require.Equal(t, "alice", *dst.ProviderOptions.OpenAI.User)
|
||||
require.Equal(t, "medium", *dst.ProviderOptions.OpenAI.ReasoningEffort)
|
||||
}
|
||||
|
||||
func stringPtr(value string) *string {
|
||||
return &value
|
||||
}
|
||||
|
||||
func boolPtr(value bool) *bool {
|
||||
return &value
|
||||
}
|
||||
|
||||
func int64Ptr(value int64) *int64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
func float64Ptr(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
@@ -1,402 +0,0 @@
|
||||
package chattest
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// AnthropicHandler handles Anthropic API requests and returns a response.
|
||||
type AnthropicHandler func(req *AnthropicRequest) AnthropicResponse
|
||||
|
||||
// AnthropicResponse represents a response to an Anthropic request.
|
||||
// Either StreamingChunks or Response should be set, not both.
|
||||
type AnthropicResponse struct {
|
||||
StreamingChunks <-chan AnthropicChunk
|
||||
Response *AnthropicMessage
|
||||
}
|
||||
|
||||
// AnthropicRequest represents an Anthropic messages request.
|
||||
type AnthropicRequest struct {
|
||||
*http.Request // Embed http.Request
|
||||
Model string `json:"model"`
|
||||
Messages []AnthropicRequestMessage `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Options map[string]interface{} `json:",inline"`
|
||||
}
|
||||
|
||||
// AnthropicRequestMessage represents a message in an Anthropic request.
|
||||
// Content may be either a string or a structured content array.
|
||||
type AnthropicRequestMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
|
||||
// AnthropicMessage represents a message in an Anthropic response.
|
||||
type AnthropicMessage struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
Usage AnthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicUsage represents usage information in an Anthropic response.
|
||||
type AnthropicUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// AnthropicChunk represents a streaming chunk from Anthropic.
|
||||
type AnthropicChunk struct {
|
||||
Type string `json:"type"`
|
||||
Index int `json:"index,omitempty"`
|
||||
Message AnthropicChunkMessage `json:"message,omitempty"`
|
||||
ContentBlock AnthropicContentBlock `json:"content_block,omitempty"`
|
||||
Delta AnthropicDeltaBlock `json:"delta,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence *string `json:"stop_sequence,omitempty"`
|
||||
Usage AnthropicUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicChunkMessage represents message metadata in a chunk.
|
||||
type AnthropicChunkMessage struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// AnthropicContentBlock represents a content block in a chunk.
|
||||
type AnthropicContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicDeltaBlock represents a delta block in a chunk.
|
||||
type AnthropicDeltaBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
}
|
||||
|
||||
// anthropicServer is a test server that mocks the Anthropic API.
|
||||
type anthropicServer struct {
|
||||
mu sync.Mutex
|
||||
server *httptest.Server
|
||||
handler AnthropicHandler
|
||||
request *AnthropicRequest
|
||||
}
|
||||
|
||||
// NewAnthropic creates a new Anthropic test server with a handler function.
|
||||
// The handler is called for each request and should return either a streaming
|
||||
// response (via channel) or a non-streaming response.
|
||||
// Returns the base URL of the server.
|
||||
func NewAnthropic(t testing.TB, handler AnthropicHandler) string {
|
||||
t.Helper()
|
||||
|
||||
s := &anthropicServer{
|
||||
handler: handler,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("POST /v1/messages", s.handleMessages)
|
||||
|
||||
s.server = httptest.NewServer(mux)
|
||||
|
||||
t.Cleanup(func() {
|
||||
s.server.Close()
|
||||
})
|
||||
|
||||
return s.server.URL
|
||||
}
|
||||
|
||||
func (s *anthropicServer) handleMessages(w http.ResponseWriter, r *http.Request) {
|
||||
var req AnthropicRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
// Return a more detailed error for debugging
|
||||
http.Error(w, fmt.Sprintf("decode request: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
req.Request = r // Embed the original http.Request
|
||||
|
||||
s.mu.Lock()
|
||||
s.request = &req
|
||||
s.mu.Unlock()
|
||||
|
||||
resp := s.handler(&req)
|
||||
s.writeResponse(w, &req, resp)
|
||||
}
|
||||
|
||||
func (s *anthropicServer) writeResponse(w http.ResponseWriter, req *AnthropicRequest, resp AnthropicResponse) {
|
||||
hasStreaming := resp.StreamingChunks != nil
|
||||
hasNonStreaming := resp.Response != nil
|
||||
|
||||
switch {
|
||||
case hasStreaming && hasNonStreaming:
|
||||
http.Error(w, "handler returned both streaming and non-streaming responses", http.StatusInternalServerError)
|
||||
return
|
||||
case !hasStreaming && !hasNonStreaming:
|
||||
http.Error(w, "handler returned empty response", http.StatusInternalServerError)
|
||||
return
|
||||
case req.Stream && !hasStreaming:
|
||||
http.Error(w, "handler returned non-streaming response for streaming request", http.StatusInternalServerError)
|
||||
return
|
||||
case !req.Stream && !hasNonStreaming:
|
||||
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
|
||||
return
|
||||
case hasStreaming:
|
||||
s.writeStreamingResponse(w, resp.StreamingChunks)
|
||||
default:
|
||||
s.writeNonStreamingResponse(w, resp.Response)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *anthropicServer) writeStreamingResponse(w http.ResponseWriter, chunks <-chan AnthropicChunk) {
|
||||
_ = s // receiver unused but kept for consistency
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("anthropic-version", "2023-06-01")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
for chunk := range chunks {
|
||||
chunkData := make(map[string]interface{})
|
||||
chunkData["type"] = chunk.Type
|
||||
|
||||
switch chunk.Type {
|
||||
case "message_start":
|
||||
chunkData["message"] = chunk.Message
|
||||
case "content_block_start":
|
||||
chunkData["index"] = chunk.Index
|
||||
chunkData["content_block"] = chunk.ContentBlock
|
||||
case "content_block_delta":
|
||||
chunkData["index"] = chunk.Index
|
||||
chunkData["delta"] = chunk.Delta
|
||||
case "content_block_stop":
|
||||
chunkData["index"] = chunk.Index
|
||||
case "message_delta":
|
||||
chunkData["delta"] = map[string]interface{}{
|
||||
"stop_reason": chunk.StopReason,
|
||||
"stop_sequence": chunk.StopSequence,
|
||||
}
|
||||
chunkData["usage"] = chunk.Usage
|
||||
case "message_stop":
|
||||
// No additional fields
|
||||
}
|
||||
|
||||
chunkBytes, err := json.Marshal(chunkData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Send both event and data lines to match Anthropic API format
|
||||
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", chunk.Type, chunkBytes); err != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *anthropicServer) writeNonStreamingResponse(w http.ResponseWriter, resp *AnthropicMessage) {
|
||||
_ = s // receiver unused but kept for consistency
|
||||
response := map[string]interface{}{
|
||||
"id": resp.ID,
|
||||
"type": resp.Type,
|
||||
"role": resp.Role,
|
||||
"model": resp.Model,
|
||||
"content": []map[string]interface{}{
|
||||
{
|
||||
"type": "text",
|
||||
"text": resp.Content,
|
||||
},
|
||||
},
|
||||
"stop_reason": resp.StopReason,
|
||||
"usage": resp.Usage,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("anthropic-version", "2023-06-01")
|
||||
_ = json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// AnthropicStreamingResponse creates a streaming response from chunks.
|
||||
func AnthropicStreamingResponse(chunks ...AnthropicChunk) AnthropicResponse {
|
||||
ch := make(chan AnthropicChunk, len(chunks))
|
||||
go func() {
|
||||
for _, chunk := range chunks {
|
||||
ch <- chunk
|
||||
}
|
||||
close(ch)
|
||||
}()
|
||||
return AnthropicResponse{StreamingChunks: ch}
|
||||
}
|
||||
|
||||
// AnthropicNonStreamingResponse creates a non-streaming response with the given text.
|
||||
func AnthropicNonStreamingResponse(text string) AnthropicResponse {
|
||||
return AnthropicResponse{
|
||||
Response: &AnthropicMessage{
|
||||
ID: fmt.Sprintf("msg-%s", uuid.New().String()[:8]),
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Content: text,
|
||||
Model: "claude-3-opus-20240229",
|
||||
StopReason: "end_turn",
|
||||
Usage: AnthropicUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// AnthropicTextChunks creates a complete streaming response with text deltas.
|
||||
// Takes text deltas and creates all required chunks (message_start,
|
||||
// content_block_start, content_block_delta for each delta,
|
||||
// content_block_stop, message_delta, message_stop).
|
||||
func AnthropicTextChunks(deltas ...string) []AnthropicChunk {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
messageID := fmt.Sprintf("msg-%s", uuid.New().String()[:8])
|
||||
model := "claude-3-opus-20240229"
|
||||
|
||||
chunks := []AnthropicChunk{
|
||||
{
|
||||
Type: "message_start",
|
||||
Message: AnthropicChunkMessage{
|
||||
ID: messageID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: model,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "content_block_start",
|
||||
Index: 0,
|
||||
ContentBlock: AnthropicContentBlock{
|
||||
Type: "text",
|
||||
Text: "", // According to Anthropic API spec, text should be empty in content_block_start
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add a delta chunk for each delta
|
||||
for _, delta := range deltas {
|
||||
chunks = append(chunks, AnthropicChunk{
|
||||
Type: "content_block_delta",
|
||||
Index: 0,
|
||||
Delta: AnthropicDeltaBlock{
|
||||
Type: "text_delta",
|
||||
Text: delta,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
chunks = append(chunks,
|
||||
AnthropicChunk{
|
||||
Type: "content_block_stop",
|
||||
Index: 0,
|
||||
},
|
||||
AnthropicChunk{
|
||||
Type: "message_delta",
|
||||
StopReason: "end_turn",
|
||||
Usage: AnthropicUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
},
|
||||
},
|
||||
AnthropicChunk{
|
||||
Type: "message_stop",
|
||||
},
|
||||
)
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
// AnthropicToolCallChunks creates a complete streaming response for a tool call.
|
||||
// Input JSON can be split across multiple deltas, matching Anthropic's
|
||||
// input_json_delta streaming behavior.
|
||||
func AnthropicToolCallChunks(toolName string, inputJSONDeltas ...string) []AnthropicChunk {
|
||||
if len(inputJSONDeltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
if toolName == "" {
|
||||
toolName = "tool"
|
||||
}
|
||||
|
||||
messageID := fmt.Sprintf("msg-%s", uuid.New().String()[:8])
|
||||
model := "claude-3-opus-20240229"
|
||||
toolCallID := fmt.Sprintf("toolu_%s", uuid.New().String()[:8])
|
||||
|
||||
chunks := []AnthropicChunk{
|
||||
{
|
||||
Type: "message_start",
|
||||
Message: AnthropicChunkMessage{
|
||||
ID: messageID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: model,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "content_block_start",
|
||||
Index: 0,
|
||||
ContentBlock: AnthropicContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: toolCallID,
|
||||
Name: toolName,
|
||||
Input: json.RawMessage("{}"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, delta := range inputJSONDeltas {
|
||||
chunks = append(chunks, AnthropicChunk{
|
||||
Type: "content_block_delta",
|
||||
Index: 0,
|
||||
Delta: AnthropicDeltaBlock{
|
||||
Type: "input_json_delta",
|
||||
PartialJSON: delta,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
chunks = append(chunks,
|
||||
AnthropicChunk{
|
||||
Type: "content_block_stop",
|
||||
Index: 0,
|
||||
},
|
||||
AnthropicChunk{
|
||||
Type: "message_delta",
|
||||
StopReason: "tool_use",
|
||||
Usage: AnthropicUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
},
|
||||
},
|
||||
AnthropicChunk{
|
||||
Type: "message_stop",
|
||||
},
|
||||
)
|
||||
|
||||
return chunks
|
||||
}
|
||||
@@ -1,221 +0,0 @@
|
||||
package chattest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattest"
|
||||
)
|
||||
|
||||
func TestAnthropic_Streaming(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
|
||||
return chattest.AnthropicStreamingResponse(
|
||||
chattest.AnthropicTextChunks("Hello", " world", "!")...,
|
||||
)
|
||||
})
|
||||
|
||||
// Create fantasy client pointing to our test server
|
||||
client, err := fantasyanthropic.New(
|
||||
fantasyanthropic.WithAPIKey("test-key"),
|
||||
fantasyanthropic.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
model, err := client.LanguageModel(ctx, "claude-3-opus-20240229")
|
||||
require.NoError(t, err)
|
||||
|
||||
call := fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "Say hello"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
stream, err := model.Stream(ctx, call)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedDeltas := []string{"Hello", " world", "!"}
|
||||
deltaIndex := 0
|
||||
|
||||
var allParts []fantasy.StreamPart
|
||||
for part := range stream {
|
||||
allParts = append(allParts, part)
|
||||
if part.Type == fantasy.StreamPartTypeTextDelta {
|
||||
require.Less(t, deltaIndex, len(expectedDeltas), "Received more deltas than expected")
|
||||
require.Equal(t, expectedDeltas[deltaIndex], part.Delta,
|
||||
"Delta at index %d should be %q, got %q", deltaIndex, expectedDeltas[deltaIndex], part.Delta)
|
||||
deltaIndex++
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, len(expectedDeltas), deltaIndex, "Expected %d deltas, got %d. Total parts received: %d", len(expectedDeltas), deltaIndex, len(allParts))
|
||||
}
|
||||
|
||||
func TestAnthropic_ToolCalls(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var requestCount atomic.Int32
|
||||
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
|
||||
switch requestCount.Add(1) {
|
||||
case 1:
|
||||
return chattest.AnthropicStreamingResponse(
|
||||
chattest.AnthropicToolCallChunks("get_weather", `{"location":"San Francisco"}`)...,
|
||||
)
|
||||
default:
|
||||
return chattest.AnthropicStreamingResponse(
|
||||
chattest.AnthropicTextChunks("The weather in San Francisco is 72F.")...,
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
client, err := fantasyanthropic.New(
|
||||
fantasyanthropic.WithAPIKey("test-key"),
|
||||
fantasyanthropic.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
|
||||
require.NoError(t, err)
|
||||
|
||||
type weatherInput struct {
|
||||
Location string `json:"location"`
|
||||
}
|
||||
var toolCallCount atomic.Int32
|
||||
weatherTool := fantasy.NewAgentTool(
|
||||
"get_weather",
|
||||
"Get weather for a location.",
|
||||
func(ctx context.Context, input weatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
toolCallCount.Add(1)
|
||||
require.Equal(t, "San Francisco", input.Location)
|
||||
return fantasy.NewTextResponse("72F"), nil
|
||||
},
|
||||
)
|
||||
|
||||
agent := fantasy.NewAgent(
|
||||
model,
|
||||
fantasy.WithSystemPrompt("You are a helpful assistant."),
|
||||
fantasy.WithTools(weatherTool),
|
||||
)
|
||||
|
||||
result, err := agent.Stream(context.Background(), fantasy.AgentStreamCall{
|
||||
Prompt: "What's the weather in San Francisco?",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
require.Equal(t, int32(1), toolCallCount.Load(), "expected exactly one tool execution")
|
||||
require.GreaterOrEqual(t, requestCount.Load(), int32(2), "expected follow-up model call after tool execution")
|
||||
}
|
||||
|
||||
func TestAnthropic_NonStreaming(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
|
||||
return chattest.AnthropicNonStreamingResponse("Response text")
|
||||
})
|
||||
|
||||
// Create fantasy client pointing to our test server
|
||||
client, err := fantasyanthropic.New(
|
||||
fantasyanthropic.WithAPIKey("test-key"),
|
||||
fantasyanthropic.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
model, err := client.LanguageModel(ctx, "claude-3-opus-20240229")
|
||||
require.NoError(t, err)
|
||||
|
||||
call := fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "Test message"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
response, err := model.Generate(ctx, call)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
}
|
||||
|
||||
func TestAnthropic_Streaming_MismatchReturnsErrorPart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
|
||||
return chattest.AnthropicNonStreamingResponse("wrong response type")
|
||||
})
|
||||
|
||||
client, err := fantasyanthropic.New(
|
||||
fantasyanthropic.WithAPIKey("test-key"),
|
||||
fantasyanthropic.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
|
||||
require.NoError(t, err)
|
||||
|
||||
stream, err := model.Stream(context.Background(), fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var streamErr error
|
||||
for part := range stream {
|
||||
if part.Type == fantasy.StreamPartTypeError {
|
||||
streamErr = part.Error
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Error(t, streamErr)
|
||||
require.Contains(t, streamErr.Error(), "500 Internal Server Error")
|
||||
}
|
||||
|
||||
func TestAnthropic_NonStreaming_MismatchReturnsError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
|
||||
return chattest.AnthropicStreamingResponse(
|
||||
chattest.AnthropicTextChunks("wrong", " response")...,
|
||||
)
|
||||
})
|
||||
|
||||
client, err := fantasyanthropic.New(
|
||||
fantasyanthropic.WithAPIKey("test-key"),
|
||||
fantasyanthropic.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = model.Generate(context.Background(), fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "500 Internal Server Error")
|
||||
}
|
||||
@@ -1,457 +0,0 @@
|
||||
package chattest
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// OpenAIHandler handles OpenAI API requests and returns a response.
|
||||
type OpenAIHandler func(req *OpenAIRequest) OpenAIResponse
|
||||
|
||||
// OpenAIResponse represents a response to an OpenAI request.
|
||||
// Either StreamingChunks or Response should be set, not both.
|
||||
type OpenAIResponse struct {
|
||||
StreamingChunks <-chan OpenAIChunk
|
||||
Response *OpenAICompletion
|
||||
}
|
||||
|
||||
// OpenAIRequest represents an OpenAI chat completion request.
|
||||
type OpenAIRequest struct {
|
||||
*http.Request
|
||||
Model string `json:"model"`
|
||||
Messages []OpenAIMessage `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Prompt []interface{} `json:"prompt,omitempty"` // For responses API
|
||||
Options map[string]interface{} `json:",inline"`
|
||||
}
|
||||
|
||||
// OpenAIMessage represents a message in an OpenAI request.
|
||||
type OpenAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// OpenAIToolCallFunction represents the function details in a tool call.
|
||||
type OpenAIToolCallFunction struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIToolCall represents a tool call in a streaming chunk or completion.
|
||||
type OpenAIToolCall struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function OpenAIToolCallFunction `json:"function,omitempty"`
|
||||
Index int `json:"index,omitempty"` // For streaming deltas
|
||||
}
|
||||
|
||||
// OpenAIChunkChoice represents a choice in a streaming chunk.
|
||||
type OpenAIChunkChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIChunk represents a streaming chunk from OpenAI.
|
||||
type OpenAIChunk struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []OpenAIChunkChoice `json:"choices"`
|
||||
}
|
||||
|
||||
// OpenAICompletionChoice represents a choice in a completion response.
|
||||
type OpenAICompletionChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message OpenAIMessage `json:"message"`
|
||||
ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
// OpenAICompletionUsage represents usage information in a completion response.
|
||||
type OpenAICompletionUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// OpenAICompletion represents a non-streaming OpenAI completion response.
|
||||
type OpenAICompletion struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []OpenAICompletionChoice `json:"choices"`
|
||||
Usage OpenAICompletionUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// openAIServer is a test server that mocks the OpenAI API.
|
||||
type openAIServer struct {
|
||||
mu sync.Mutex
|
||||
server *httptest.Server
|
||||
handler OpenAIHandler
|
||||
request *OpenAIRequest
|
||||
}
|
||||
|
||||
// NewOpenAI creates a new OpenAI test server with a handler function.
|
||||
// The handler is called for each request and should return either a streaming
|
||||
// response (via channel) or a non-streaming response.
|
||||
// Returns the base URL of the server.
|
||||
func NewOpenAI(t testing.TB, handler OpenAIHandler) string {
|
||||
t.Helper()
|
||||
|
||||
s := &openAIServer{
|
||||
handler: handler,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("POST /chat/completions", s.handleChatCompletions)
|
||||
mux.HandleFunc("POST /responses", s.handleResponses)
|
||||
|
||||
s.server = httptest.NewServer(mux)
|
||||
|
||||
t.Cleanup(func() {
|
||||
s.server.Close()
|
||||
})
|
||||
|
||||
return s.server.URL
|
||||
}
|
||||
|
||||
func (s *openAIServer) handleChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
var req OpenAIRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
req.Request = r
|
||||
|
||||
s.mu.Lock()
|
||||
s.request = &req
|
||||
s.mu.Unlock()
|
||||
|
||||
resp := s.handler(&req)
|
||||
s.writeChatCompletionsResponse(w, &req, resp)
|
||||
}
|
||||
|
||||
func (s *openAIServer) handleResponses(w http.ResponseWriter, r *http.Request) {
|
||||
var req OpenAIRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
req.Request = r
|
||||
|
||||
s.mu.Lock()
|
||||
s.request = &req
|
||||
s.mu.Unlock()
|
||||
|
||||
resp := s.handler(&req)
|
||||
s.writeResponsesAPIResponse(w, &req, resp)
|
||||
}
|
||||
|
||||
func (s *openAIServer) writeChatCompletionsResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
|
||||
hasStreaming := resp.StreamingChunks != nil
|
||||
hasNonStreaming := resp.Response != nil
|
||||
|
||||
switch {
|
||||
case hasStreaming && hasNonStreaming:
|
||||
http.Error(w, "handler returned both streaming and non-streaming responses", http.StatusInternalServerError)
|
||||
return
|
||||
case !hasStreaming && !hasNonStreaming:
|
||||
http.Error(w, "handler returned empty response", http.StatusInternalServerError)
|
||||
return
|
||||
case req.Stream && !hasStreaming:
|
||||
http.Error(w, "handler returned non-streaming response for streaming request", http.StatusInternalServerError)
|
||||
return
|
||||
case !req.Stream && !hasNonStreaming:
|
||||
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
|
||||
return
|
||||
case hasStreaming:
|
||||
s.writeChatCompletionsStreaming(w, resp.StreamingChunks)
|
||||
default:
|
||||
s.writeChatCompletionsNonStreaming(w, resp.Response)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *openAIServer) writeResponsesAPIResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
|
||||
hasStreaming := resp.StreamingChunks != nil
|
||||
hasNonStreaming := resp.Response != nil
|
||||
|
||||
switch {
|
||||
case hasStreaming && hasNonStreaming:
|
||||
http.Error(w, "handler returned both streaming and non-streaming responses", http.StatusInternalServerError)
|
||||
return
|
||||
case !hasStreaming && !hasNonStreaming:
|
||||
http.Error(w, "handler returned empty response", http.StatusInternalServerError)
|
||||
return
|
||||
case req.Stream && !hasStreaming:
|
||||
http.Error(w, "handler returned non-streaming response for streaming request", http.StatusInternalServerError)
|
||||
return
|
||||
case !req.Stream && !hasNonStreaming:
|
||||
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
|
||||
return
|
||||
case hasStreaming:
|
||||
s.writeResponsesAPIStreaming(w, resp.StreamingChunks)
|
||||
default:
|
||||
s.writeResponsesAPINonStreaming(w, resp.Response)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *openAIServer) writeChatCompletionsStreaming(w http.ResponseWriter, chunks <-chan OpenAIChunk) {
|
||||
_ = s // receiver unused but kept for consistency
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
for chunk := range chunks {
|
||||
choicesData := make([]map[string]interface{}, len(chunk.Choices))
|
||||
for i, choice := range chunk.Choices {
|
||||
choiceData := map[string]interface{}{
|
||||
"index": choice.Index,
|
||||
}
|
||||
if choice.Delta != "" {
|
||||
choiceData["delta"] = map[string]interface{}{
|
||||
"content": choice.Delta,
|
||||
}
|
||||
}
|
||||
if len(choice.ToolCalls) > 0 {
|
||||
// Tool calls come in the delta
|
||||
if choiceData["delta"] == nil {
|
||||
choiceData["delta"] = make(map[string]interface{})
|
||||
}
|
||||
delta, ok := choiceData["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
delta = make(map[string]interface{})
|
||||
choiceData["delta"] = delta
|
||||
}
|
||||
delta["tool_calls"] = choice.ToolCalls
|
||||
}
|
||||
if choice.FinishReason != "" {
|
||||
choiceData["finish_reason"] = choice.FinishReason
|
||||
}
|
||||
choicesData[i] = choiceData
|
||||
}
|
||||
|
||||
chunkData := map[string]interface{}{
|
||||
"id": chunk.ID,
|
||||
"object": chunk.Object,
|
||||
"created": chunk.Created,
|
||||
"model": chunk.Model,
|
||||
"choices": choicesData,
|
||||
}
|
||||
|
||||
chunkBytes, err := json.Marshal(chunkData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, "data: %s\n\n", chunkBytes); err != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
func (s *openAIServer) writeResponsesAPIStreaming(w http.ResponseWriter, chunks <-chan OpenAIChunk) {
|
||||
_ = s // receiver unused but kept for consistency
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
itemIDs := make(map[int]string)
|
||||
|
||||
for chunk := range chunks {
|
||||
// Responses API sends one event per choice
|
||||
for outputIndex, choice := range chunk.Choices {
|
||||
if choice.Index != 0 {
|
||||
outputIndex = choice.Index
|
||||
}
|
||||
itemID, found := itemIDs[outputIndex]
|
||||
if !found {
|
||||
itemID = fmt.Sprintf("msg_%s", uuid.New().String()[:8])
|
||||
itemIDs[outputIndex] = itemID
|
||||
}
|
||||
|
||||
chunkData := map[string]interface{}{
|
||||
"type": "response.output_text.delta",
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"created": chunk.Created,
|
||||
"model": chunk.Model,
|
||||
"content_index": 0,
|
||||
"delta": choice.Delta,
|
||||
}
|
||||
|
||||
chunkBytes, err := json.Marshal(chunkData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, "data: %s\n\n", chunkBytes); err != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
func (s *openAIServer) writeChatCompletionsNonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {
|
||||
_ = s // receiver unused but kept for consistency
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func (s *openAIServer) writeResponsesAPINonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {
|
||||
_ = s // receiver unused but kept for consistency
|
||||
// Convert all choices to output format
|
||||
outputs := make([]map[string]interface{}, len(resp.Choices))
|
||||
for i, choice := range resp.Choices {
|
||||
outputs[i] = map[string]interface{}{
|
||||
"id": uuid.New().String(),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []map[string]interface{}{
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": choice.Message.Content,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": resp.ID,
|
||||
"object": "response",
|
||||
"created": resp.Created,
|
||||
"model": resp.Model,
|
||||
"output": outputs,
|
||||
"usage": resp.Usage,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// OpenAIStreamingResponse creates a streaming response from chunks.
|
||||
func OpenAIStreamingResponse(chunks ...OpenAIChunk) OpenAIResponse {
|
||||
ch := make(chan OpenAIChunk, len(chunks))
|
||||
go func() {
|
||||
for _, chunk := range chunks {
|
||||
ch <- chunk
|
||||
}
|
||||
close(ch)
|
||||
}()
|
||||
return OpenAIResponse{StreamingChunks: ch}
|
||||
}
|
||||
|
||||
// OpenAINonStreamingResponse creates a non-streaming response with the given text.
|
||||
func OpenAINonStreamingResponse(text string) OpenAIResponse {
|
||||
return OpenAIResponse{
|
||||
Response: &OpenAICompletion{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8]),
|
||||
Object: "chat.completion",
|
||||
Created: time.Now().Unix(),
|
||||
Model: "gpt-4",
|
||||
Choices: []OpenAICompletionChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: OpenAIMessage{
|
||||
Role: "assistant",
|
||||
Content: text,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
Usage: OpenAICompletionUsage{
|
||||
PromptTokens: 10,
|
||||
CompletionTokens: 5,
|
||||
TotalTokens: 15,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAITextChunks creates streaming chunks with text deltas.
|
||||
// Each delta string becomes a separate chunk with a single choice.
|
||||
// Returns a slice of chunks, one per delta, with each choice having its index (0, 1, 2, ...).
|
||||
func OpenAITextChunks(deltas ...string) []OpenAIChunk {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
chunkID := fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8])
|
||||
now := time.Now().Unix()
|
||||
chunks := make([]OpenAIChunk, len(deltas))
|
||||
|
||||
for i, delta := range deltas {
|
||||
chunks[i] = OpenAIChunk{
|
||||
ID: chunkID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: now,
|
||||
Model: "gpt-4",
|
||||
Choices: []OpenAIChunkChoice{
|
||||
{
|
||||
Index: i,
|
||||
Delta: delta,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
// OpenAIToolCallChunk creates a streaming chunk with a tool call.
|
||||
// Takes the tool name and arguments JSON string, creates a tool call for choice index 0.
|
||||
func OpenAIToolCallChunk(toolName, arguments string) OpenAIChunk {
|
||||
return OpenAIChunk{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8]),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: time.Now().Unix(),
|
||||
Model: "gpt-4",
|
||||
Choices: []OpenAIChunkChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ToolCalls: []OpenAIToolCall{
|
||||
{
|
||||
Index: 0,
|
||||
ID: fmt.Sprintf("call_%s", uuid.New().String()[:8]),
|
||||
Type: "function",
|
||||
Function: OpenAIToolCallFunction{
|
||||
Name: toolName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,367 +0,0 @@
|
||||
package chattest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyopenai "charm.land/fantasy/providers/openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattest"
|
||||
)
|
||||
|
||||
func TestOpenAI_Streaming(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
append(
|
||||
append(
|
||||
chattest.OpenAITextChunks("Hello", "Hi"),
|
||||
chattest.OpenAITextChunks(" world", " there")...,
|
||||
),
|
||||
chattest.OpenAITextChunks("!", "!")...,
|
||||
)...,
|
||||
)
|
||||
})
|
||||
|
||||
// Create fantasy client pointing to our test server
|
||||
client, err := fantasyopenai.New(
|
||||
fantasyopenai.WithAPIKey("test-key"),
|
||||
fantasyopenai.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
model, err := client.LanguageModel(ctx, "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
call := fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "Say hello"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
stream, err := model.Stream(ctx, call)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We expect chunks in order: one choice per chunk
|
||||
// So we get: "Hello" (choice 0), "Hi" (choice 1), " world" (choice 0), " there" (choice 1), "!" (choice 0), "!" (choice 1)
|
||||
expectedDeltas := []string{"Hello", "Hi", " world", " there", "!", "!"}
|
||||
deltaIndex := 0
|
||||
|
||||
for part := range stream {
|
||||
if part.Type == fantasy.StreamPartTypeTextDelta {
|
||||
// Verify we're getting deltas in the expected order
|
||||
require.Less(t, deltaIndex, len(expectedDeltas), "Received more deltas than expected")
|
||||
require.Equal(t, expectedDeltas[deltaIndex], part.Delta,
|
||||
"Delta at index %d should be %q, got %q", deltaIndex, expectedDeltas[deltaIndex], part.Delta)
|
||||
deltaIndex++
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we received all expected deltas
|
||||
require.Equal(t, len(expectedDeltas), deltaIndex, "Expected %d deltas, got %d", len(expectedDeltas), deltaIndex)
|
||||
}
|
||||
|
||||
func TestOpenAI_Streaming_ResponsesAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
append(
|
||||
append(
|
||||
chattest.OpenAITextChunks("First", "Second"),
|
||||
chattest.OpenAITextChunks(" output", " output")...,
|
||||
),
|
||||
chattest.OpenAITextChunks("!", "!")...,
|
||||
)...,
|
||||
)
|
||||
})
|
||||
|
||||
// Create fantasy client pointing to our test server (responses API)
|
||||
client, err := fantasyopenai.New(
|
||||
fantasyopenai.WithAPIKey("test-key"),
|
||||
fantasyopenai.WithBaseURL(serverURL),
|
||||
fantasyopenai.WithUseResponsesAPI(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
model, err := client.LanguageModel(ctx, "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
call := fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "Say hello"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
stream, err := model.Stream(ctx, call)
|
||||
require.NoError(t, err)
|
||||
|
||||
var parts []fantasy.StreamPart
|
||||
for part := range stream {
|
||||
parts = append(parts, part)
|
||||
}
|
||||
|
||||
// Verify we received the chunks in order
|
||||
require.Greater(t, len(parts), 0)
|
||||
|
||||
// Extract text deltas from parts and verify they match expected chunks in order
|
||||
// We expect: "First", " output", "!" for choice 0, and "Second", " output", "!" for choice 1
|
||||
var allDeltas []string
|
||||
for _, part := range parts {
|
||||
if part.Type == fantasy.StreamPartTypeTextDelta {
|
||||
allDeltas = append(allDeltas, part.Delta)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we received deltas (responses API may handle multiple choices differently)
|
||||
// If we got text deltas, verify the content
|
||||
if len(allDeltas) > 0 {
|
||||
allText := ""
|
||||
for _, delta := range allDeltas {
|
||||
allText += delta
|
||||
}
|
||||
require.Contains(t, allText, "First")
|
||||
require.Contains(t, allText, "Second")
|
||||
require.Contains(t, allText, "output")
|
||||
require.Contains(t, allText, "!")
|
||||
} else {
|
||||
// If no text deltas, at least verify we got some parts (may be different format)
|
||||
require.Greater(t, len(parts), 0, "Expected at least one stream part")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAI_NonStreaming_CompletionsAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
return chattest.OpenAINonStreamingResponse("First response")
|
||||
})
|
||||
|
||||
// Create fantasy client pointing to our test server (completions API)
|
||||
client, err := fantasyopenai.New(
|
||||
fantasyopenai.WithAPIKey("test-key"),
|
||||
fantasyopenai.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
model, err := client.LanguageModel(ctx, "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
call := fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "Test message"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
response, err := model.Generate(ctx, call)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
}
|
||||
|
||||
func TestOpenAI_ToolCalls(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var requestCount atomic.Int32
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
switch requestCount.Add(1) {
|
||||
case 1:
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAIToolCallChunk("get_weather", `{"location":"San Francisco"}`),
|
||||
)
|
||||
default:
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("The weather in San Francisco is 72F.")...,
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
// Create fantasy client pointing to our test server
|
||||
client, err := fantasyopenai.New(
|
||||
fantasyopenai.WithAPIKey("test-key"),
|
||||
fantasyopenai.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
model, err := client.LanguageModel(ctx, "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
type weatherInput struct {
|
||||
Location string `json:"location"`
|
||||
}
|
||||
var toolCallCount atomic.Int32
|
||||
weatherTool := fantasy.NewAgentTool(
|
||||
"get_weather",
|
||||
"Get weather for a location.",
|
||||
func(ctx context.Context, input weatherInput, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
toolCallCount.Add(1)
|
||||
require.Equal(t, "San Francisco", input.Location)
|
||||
return fantasy.NewTextResponse("72F"), nil
|
||||
},
|
||||
)
|
||||
|
||||
agent := fantasy.NewAgent(
|
||||
model,
|
||||
fantasy.WithSystemPrompt("You are a helpful assistant."),
|
||||
fantasy.WithTools(weatherTool),
|
||||
)
|
||||
|
||||
result, err := agent.Stream(ctx, fantasy.AgentStreamCall{
|
||||
Prompt: "What's the weather in San Francisco?",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, int32(1), toolCallCount.Load(), "expected exactly one tool execution")
|
||||
require.GreaterOrEqual(t, requestCount.Load(), int32(2), "expected follow-up model call after tool execution")
|
||||
}
|
||||
|
||||
func TestOpenAI_NonStreaming_ResponsesAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
return chattest.OpenAINonStreamingResponse("First output")
|
||||
})
|
||||
|
||||
// Create fantasy client pointing to our test server (responses API)
|
||||
client, err := fantasyopenai.New(
|
||||
fantasyopenai.WithAPIKey("test-key"),
|
||||
fantasyopenai.WithBaseURL(serverURL),
|
||||
fantasyopenai.WithUseResponsesAPI(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
model, err := client.LanguageModel(ctx, "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
call := fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "Test message"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
response, err := model.Generate(ctx, call)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
}
|
||||
|
||||
func TestOpenAI_Streaming_MismatchReturnsErrorPart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
return chattest.OpenAINonStreamingResponse("wrong response type")
|
||||
})
|
||||
|
||||
client, err := fantasyopenai.New(
|
||||
fantasyopenai.WithAPIKey("test-key"),
|
||||
fantasyopenai.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := client.LanguageModel(context.Background(), "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
stream, err := model.Stream(context.Background(), fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var streamErr error
|
||||
for part := range stream {
|
||||
if part.Type == fantasy.StreamPartTypeError {
|
||||
streamErr = part.Error
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Error(t, streamErr)
|
||||
require.Contains(t, streamErr.Error(), "non-streaming response for streaming request")
|
||||
}
|
||||
|
||||
func TestOpenAI_NonStreaming_MismatchReturnsError_CompletionsAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("wrong response type")...)
|
||||
})
|
||||
|
||||
client, err := fantasyopenai.New(
|
||||
fantasyopenai.WithAPIKey("test-key"),
|
||||
fantasyopenai.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := client.LanguageModel(context.Background(), "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = model.Generate(context.Background(), fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "streaming response for non-streaming request")
|
||||
}
|
||||
|
||||
func TestOpenAI_NonStreaming_MismatchReturnsError_ResponsesAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("wrong response type")...)
|
||||
})
|
||||
|
||||
client, err := fantasyopenai.New(
|
||||
fantasyopenai.WithAPIKey("test-key"),
|
||||
fantasyopenai.WithBaseURL(serverURL),
|
||||
fantasyopenai.WithUseResponsesAPI(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := client.LanguageModel(context.Background(), "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = model.Generate(context.Background(), fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "streaming response for non-streaming request")
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// toolResponse builds a fantasy.ToolResponse from a JSON-serializable
|
||||
// result payload.
|
||||
func toolResponse(result map[string]any) fantasy.ToolResponse {
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return fantasy.NewTextResponse("{}")
|
||||
}
|
||||
return fantasy.NewTextResponse(string(data))
|
||||
}
|
||||
|
||||
// parseOwnerID parses a UUID string into a uuid.UUID, returning
|
||||
// an error if the string is empty or not a valid UUID.
|
||||
func parseOwnerID(raw string) (uuid.UUID, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return uuid.Nil, xerrors.New("owner ID is empty")
|
||||
}
|
||||
id, err := uuid.Parse(raw)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("invalid owner ID %q: %w", raw, err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func truncateRunes(value string, maxLen int) string {
|
||||
if maxLen <= 0 || value == "" {
|
||||
return ""
|
||||
}
|
||||
if utf8.RuneCountInString(value) <= maxLen {
|
||||
return value
|
||||
}
|
||||
|
||||
runes := []rune(value)
|
||||
if maxLen > len(runes) {
|
||||
maxLen = len(runes)
|
||||
}
|
||||
return string(runes[:maxLen])
|
||||
}
|
||||
@@ -1,426 +0,0 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/util/namesgenerator"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
const (
|
||||
// buildPollInterval is how often we check if the workspace
|
||||
// build has completed.
|
||||
buildPollInterval = 2 * time.Second
|
||||
// buildTimeout is the maximum time to wait for a workspace
|
||||
// build to complete before giving up.
|
||||
buildTimeout = 10 * time.Minute
|
||||
// agentConnectTimeout is the maximum time to wait for the
|
||||
// workspace agent to become reachable after a successful build.
|
||||
agentConnectTimeout = 2 * time.Minute
|
||||
// agentRetryInterval is how often we retry connecting to the
|
||||
// workspace agent.
|
||||
agentRetryInterval = 2 * time.Second
|
||||
// agentAttemptTimeout is the timeout for a single connection
|
||||
// attempt to the workspace agent during the retry loop.
|
||||
agentAttemptTimeout = 5 * time.Second
|
||||
// agentPingTimeout is the timeout for a single agent ping
|
||||
// when checking whether an existing workspace is alive.
|
||||
agentPingTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// CreateWorkspaceFn creates a workspace for the given owner.
|
||||
type CreateWorkspaceFn func(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
req codersdk.CreateWorkspaceRequest,
|
||||
) (codersdk.Workspace, error)
|
||||
|
||||
// AgentConnFunc provides access to workspace agent connections.
|
||||
type AgentConnFunc func(
|
||||
ctx context.Context,
|
||||
agentID uuid.UUID,
|
||||
) (workspacesdk.AgentConn, func(), error)
|
||||
|
||||
// CreateWorkspaceOptions configures the create_workspace tool.
|
||||
type CreateWorkspaceOptions struct {
|
||||
DB database.Store
|
||||
OwnerID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
CreateFn CreateWorkspaceFn
|
||||
AgentConnFn AgentConnFunc
|
||||
WorkspaceMu *sync.Mutex
|
||||
}
|
||||
|
||||
type createWorkspaceArgs struct {
|
||||
TemplateID string `json:"template_id"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Parameters map[string]string `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
// CreateWorkspace returns a tool that creates a new workspace from a
|
||||
// template. The tool is idempotent: if the chat already has a
|
||||
// workspace that is building or running, it returns the existing
|
||||
// workspace instead of creating a new one. A mutex prevents parallel
|
||||
// calls from creating duplicate workspaces.
|
||||
func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"create_workspace",
|
||||
"Create a new workspace from a template. Requires a "+
|
||||
"template_id (from list_templates). Optionally provide "+
|
||||
"a name and parameter values (from read_template). "+
|
||||
"If no name is given, one will be generated. "+
|
||||
"This tool is idempotent — if the chat already has a "+
|
||||
"workspace that is building or running, the existing "+
|
||||
"workspace is returned.",
|
||||
func(ctx context.Context, args createWorkspaceArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.CreateFn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace creator is not configured"), nil
|
||||
}
|
||||
|
||||
templateIDStr := strings.TrimSpace(args.TemplateID)
|
||||
if templateIDStr == "" {
|
||||
return fantasy.NewTextErrorResponse("template_id is required; use list_templates to find one"), nil
|
||||
}
|
||||
templateID, err := uuid.Parse(templateIDStr)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
xerrors.Errorf("invalid template_id: %w", err).Error(),
|
||||
), nil
|
||||
}
|
||||
|
||||
// Serialize workspace creation to prevent parallel
|
||||
// tool calls from creating duplicate workspaces.
|
||||
if options.WorkspaceMu != nil {
|
||||
options.WorkspaceMu.Lock()
|
||||
defer options.WorkspaceMu.Unlock()
|
||||
}
|
||||
|
||||
// Check for an existing workspace on the chat.
|
||||
if options.DB != nil && options.ChatID != uuid.Nil {
|
||||
existing, done, existErr := checkExistingWorkspace(
|
||||
ctx, options.DB, options.ChatID,
|
||||
options.AgentConnFn,
|
||||
)
|
||||
if existErr != nil {
|
||||
return fantasy.NewTextErrorResponse(existErr.Error()), nil
|
||||
}
|
||||
if done {
|
||||
return toolResponse(existing), nil
|
||||
}
|
||||
}
|
||||
|
||||
ownerID := options.OwnerID
|
||||
|
||||
// Set up dbauthz context for DB lookups.
|
||||
if options.DB != nil {
|
||||
ownerCtx, ownerErr := asOwner(ctx, options.DB, ownerID)
|
||||
if ownerErr != nil {
|
||||
return fantasy.NewTextErrorResponse(ownerErr.Error()), nil
|
||||
}
|
||||
ctx = ownerCtx
|
||||
}
|
||||
|
||||
createReq := codersdk.CreateWorkspaceRequest{
|
||||
TemplateID: templateID,
|
||||
}
|
||||
|
||||
// Resolve workspace name.
|
||||
name := strings.TrimSpace(args.Name)
|
||||
if name == "" {
|
||||
seed := "workspace"
|
||||
if options.DB != nil {
|
||||
if t, lookupErr := options.DB.GetTemplateByID(ctx, templateID); lookupErr == nil {
|
||||
seed = t.Name
|
||||
}
|
||||
}
|
||||
name = generatedWorkspaceName(seed)
|
||||
} else if err := codersdk.NameValid(name); err != nil {
|
||||
name = generatedWorkspaceName(name)
|
||||
}
|
||||
createReq.Name = name
|
||||
|
||||
// Map parameters.
|
||||
for k, v := range args.Parameters {
|
||||
createReq.RichParameterValues = append(
|
||||
createReq.RichParameterValues,
|
||||
codersdk.WorkspaceBuildParameter{Name: k, Value: v},
|
||||
)
|
||||
}
|
||||
|
||||
workspace, err := options.CreateFn(ctx, ownerID, createReq)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
// Wait for the build to complete and the agent to
|
||||
// come online so subsequent tools can use the
|
||||
// workspace immediately.
|
||||
if options.DB != nil {
|
||||
if err := waitForBuild(ctx, options.DB, workspace.ID); err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
xerrors.Errorf("workspace build failed: %w", err).Error(),
|
||||
), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Look up the first agent so we can link it to the chat.
|
||||
workspaceAgentID := uuid.Nil
|
||||
if options.DB != nil {
|
||||
agents, agentErr := options.DB.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspace.ID)
|
||||
if agentErr == nil && len(agents) > 0 {
|
||||
workspaceAgentID = agents[0].ID
|
||||
}
|
||||
}
|
||||
|
||||
// Persist workspace + agent association on the chat.
|
||||
if options.DB != nil && options.ChatID != uuid.Nil {
|
||||
_, _ = options.DB.UpdateChatWorkspace(ctx, database.UpdateChatWorkspaceParams{
|
||||
ID: options.ChatID,
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspace.ID,
|
||||
Valid: true,
|
||||
},
|
||||
WorkspaceAgentID: uuid.NullUUID{
|
||||
UUID: workspaceAgentID,
|
||||
Valid: workspaceAgentID != uuid.Nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Wait for the agent to come online.
|
||||
if workspaceAgentID != uuid.Nil && options.AgentConnFn != nil {
|
||||
if err := waitForAgent(ctx, options.AgentConnFn, workspaceAgentID); err != nil {
|
||||
// Non-fatal: the workspace was created
|
||||
// successfully, the agent just isn't ready
|
||||
// yet. The model can retry.
|
||||
return toolResponse(map[string]any{
|
||||
"created": true,
|
||||
"workspace_name": workspace.FullName(),
|
||||
"agent_status": "not_ready",
|
||||
"agent_error": err.Error(),
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
|
||||
return toolResponse(map[string]any{
|
||||
"created": true,
|
||||
"workspace_name": workspace.FullName(),
|
||||
}), nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// checkExistingWorkspace checks whether the chat already has a usable
|
||||
// workspace. Returns the result map and true if the caller should
|
||||
// return early (workspace exists and is alive or building). Returns
|
||||
// false if the caller should proceed with creation (workspace is dead
|
||||
// or missing).
|
||||
func checkExistingWorkspace(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
agentConnFn AgentConnFunc,
|
||||
) (map[string]any, bool, error) {
|
||||
chat, err := db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return nil, false, xerrors.Errorf("load chat: %w", err)
|
||||
}
|
||||
if !chat.WorkspaceID.Valid {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// Check if workspace still exists.
|
||||
ws, err := db.GetWorkspaceByID(ctx, chat.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
// Workspace was deleted — allow creation.
|
||||
return nil, false, nil
|
||||
}
|
||||
return nil, false, xerrors.Errorf("load workspace: %w", err)
|
||||
}
|
||||
|
||||
// Check the latest build status.
|
||||
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID)
|
||||
if err != nil {
|
||||
// Can't determine status — allow creation.
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
job, err := db.GetProvisionerJobByID(ctx, build.JobID)
|
||||
if err != nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
switch job.JobStatus {
|
||||
case database.ProvisionerJobStatusPending,
|
||||
database.ProvisionerJobStatusRunning:
|
||||
// Build is in progress — wait for it instead of
|
||||
// creating a new workspace.
|
||||
if err := waitForBuild(ctx, db, ws.ID); err != nil {
|
||||
return nil, false, xerrors.Errorf(
|
||||
"existing workspace build failed: %w", err,
|
||||
)
|
||||
}
|
||||
return map[string]any{
|
||||
"created": false,
|
||||
"workspace_name": ws.Name,
|
||||
"status": "already_exists",
|
||||
"message": "workspace was already being built and is now ready",
|
||||
}, true, nil
|
||||
|
||||
case database.ProvisionerJobStatusSucceeded:
|
||||
// Build succeeded — check if agent is reachable.
|
||||
if chat.WorkspaceAgentID.Valid && agentConnFn != nil {
|
||||
pingCtx, cancel := context.WithTimeout(
|
||||
ctx, agentPingTimeout,
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
conn, release, connErr := agentConnFn(
|
||||
pingCtx, chat.WorkspaceAgentID.UUID,
|
||||
)
|
||||
if connErr == nil {
|
||||
release()
|
||||
_ = conn
|
||||
return map[string]any{
|
||||
"created": false,
|
||||
"workspace_name": ws.Name,
|
||||
"status": "already_exists",
|
||||
"message": "workspace is already running and reachable",
|
||||
}, true, nil
|
||||
}
|
||||
// Agent unreachable — workspace is dead, allow
|
||||
// creation.
|
||||
}
|
||||
// No agent ID or no conn func — allow creation.
|
||||
return nil, false, nil
|
||||
|
||||
default:
|
||||
// Failed, canceled, etc — allow creation.
|
||||
return nil, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// waitForBuild polls the workspace's latest build until it
|
||||
// completes or the context expires.
|
||||
func waitForBuild(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
workspaceID uuid.UUID,
|
||||
) error {
|
||||
buildCtx, cancel := context.WithTimeout(ctx, buildTimeout)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(buildPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(
|
||||
buildCtx, workspaceID,
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get latest build: %w", err)
|
||||
}
|
||||
|
||||
job, err := db.GetProvisionerJobByID(buildCtx, build.JobID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get provisioner job: %w", err)
|
||||
}
|
||||
|
||||
switch job.JobStatus {
|
||||
case database.ProvisionerJobStatusSucceeded:
|
||||
return nil
|
||||
case database.ProvisionerJobStatusFailed:
|
||||
errMsg := "build failed"
|
||||
if job.Error.Valid {
|
||||
errMsg = job.Error.String
|
||||
}
|
||||
return xerrors.New(errMsg)
|
||||
case database.ProvisionerJobStatusCanceled:
|
||||
return xerrors.New("build was canceled")
|
||||
case database.ProvisionerJobStatusPending,
|
||||
database.ProvisionerJobStatusRunning,
|
||||
database.ProvisionerJobStatusCanceling:
|
||||
// Still in progress — keep waiting.
|
||||
default:
|
||||
return xerrors.Errorf("unexpected job status: %s", job.JobStatus)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-buildCtx.Done():
|
||||
return xerrors.Errorf(
|
||||
"timed out waiting for workspace build: %w",
|
||||
buildCtx.Err(),
|
||||
)
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// waitForAgent retries connecting to the workspace agent until it
|
||||
// succeeds or the timeout expires.
|
||||
func waitForAgent(
|
||||
ctx context.Context,
|
||||
agentConnFn AgentConnFunc,
|
||||
agentID uuid.UUID,
|
||||
) error {
|
||||
agentCtx, cancel := context.WithTimeout(ctx, agentConnectTimeout)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(agentRetryInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastErr error
|
||||
for {
|
||||
attemptCtx, attemptCancel := context.WithTimeout(agentCtx, agentAttemptTimeout)
|
||||
conn, release, err := agentConnFn(attemptCtx, agentID)
|
||||
attemptCancel()
|
||||
if err == nil {
|
||||
release()
|
||||
_ = conn
|
||||
return nil
|
||||
}
|
||||
lastErr = err
|
||||
|
||||
select {
|
||||
case <-agentCtx.Done():
|
||||
return xerrors.Errorf(
|
||||
"timed out waiting for workspace agent: %w",
|
||||
lastErr,
|
||||
)
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func generatedWorkspaceName(seed string) string {
|
||||
base := codersdk.UsernameFrom(strings.TrimSpace(strings.ToLower(seed)))
|
||||
if strings.TrimSpace(base) == "" {
|
||||
base = "workspace"
|
||||
}
|
||||
|
||||
suffix := strings.ReplaceAll(uuid.NewString(), "-", "")[:4]
|
||||
if len(base) > 27 {
|
||||
base = strings.Trim(base[:27], "-")
|
||||
}
|
||||
if base == "" {
|
||||
base = "workspace"
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("%s-%s", base, suffix)
|
||||
if err := codersdk.NameValid(name); err == nil {
|
||||
return name
|
||||
}
|
||||
return namesgenerator.NameDigitWith("-")
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
type EditFilesOptions struct {
|
||||
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
|
||||
}
|
||||
|
||||
type EditFilesArgs struct {
|
||||
Files []workspacesdk.FileEdits `json:"files"`
|
||||
}
|
||||
|
||||
func EditFiles(options EditFilesOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"edit_files",
|
||||
"Perform search-and-replace edits on one or more files in the workspace."+
|
||||
" Each file can have multiple edits applied atomically.",
|
||||
func(ctx context.Context, args EditFilesArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
}
|
||||
conn, err := options.GetWorkspaceConn(ctx)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return executeEditFilesTool(ctx, conn, args)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func executeEditFilesTool(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
args EditFilesArgs,
|
||||
) (fantasy.ToolResponse, error) {
|
||||
if len(args.Files) == 0 {
|
||||
return fantasy.NewTextErrorResponse("files is required"), nil
|
||||
}
|
||||
|
||||
if err := conn.EditFiles(ctx, workspacesdk.FileEditRequest{Files: args.Files}); err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return toolResponse(map[string]any{"ok": true}), nil
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultExecuteTimeout = 60 * time.Second
|
||||
chatAgentEnvVar = "CODER_CHAT_AGENT"
|
||||
gitAuthRequiredPrefix = "CODER_GITAUTH_REQUIRED:"
|
||||
authRequiredResultReason = "authentication_required"
|
||||
)
|
||||
|
||||
type ExecuteOptions struct {
|
||||
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
|
||||
DefaultTimeout time.Duration
|
||||
}
|
||||
|
||||
type ExecuteArgs struct {
|
||||
Command string `json:"command"`
|
||||
TimeoutSeconds *int `json:"timeout_seconds,omitempty"`
|
||||
}
|
||||
|
||||
func Execute(options ExecuteOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"execute",
|
||||
"Execute a shell command in the workspace.",
|
||||
func(ctx context.Context, args ExecuteArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
}
|
||||
conn, err := options.GetWorkspaceConn(ctx)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return executeTool(ctx, conn, args, options.DefaultTimeout), nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func executeTool(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
args ExecuteArgs,
|
||||
defaultTimeout time.Duration,
|
||||
) fantasy.ToolResponse {
|
||||
if args.Command == "" {
|
||||
return fantasy.NewTextErrorResponse("command is required")
|
||||
}
|
||||
|
||||
timeout := defaultTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = defaultExecuteTimeout
|
||||
}
|
||||
if args.TimeoutSeconds != nil {
|
||||
timeout = time.Duration(*args.TimeoutSeconds) * time.Second
|
||||
}
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
output, exitCode, err := runCommand(cmdCtx, conn, args.Command)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error())
|
||||
}
|
||||
return toolResponse(map[string]any{
|
||||
"output": output,
|
||||
"exit_code": exitCode,
|
||||
})
|
||||
}
|
||||
|
||||
func runCommand(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
command string,
|
||||
) (string, int, error) {
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
defer sshClient.Close()
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
defer session.Close()
|
||||
if err := session.Setenv(chatAgentEnvVar, "true"); err != nil {
|
||||
return "", 0, xerrors.Errorf("set %s: %w", chatAgentEnvVar, err)
|
||||
}
|
||||
|
||||
resultCh := make(chan struct {
|
||||
output string
|
||||
exitCode int
|
||||
err error
|
||||
}, 1)
|
||||
|
||||
go func() {
|
||||
output, err := session.CombinedOutput(command)
|
||||
exitCode := 0
|
||||
if err != nil {
|
||||
var exitErr *ssh.ExitError
|
||||
if xerrors.As(err, &exitErr) {
|
||||
exitCode = exitErr.ExitStatus()
|
||||
} else {
|
||||
exitCode = 1
|
||||
}
|
||||
}
|
||||
resultCh <- struct {
|
||||
output string
|
||||
exitCode int
|
||||
err error
|
||||
}{
|
||||
output: string(output),
|
||||
exitCode: exitCode,
|
||||
err: err,
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = session.Close()
|
||||
return "", 0, ctx.Err()
|
||||
case result := <-resultCh:
|
||||
return result.output, result.exitCode, result.err
|
||||
}
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
)
|
||||
|
||||
// ListTemplatesOptions configures the list_templates tool.
|
||||
type ListTemplatesOptions struct {
|
||||
DB database.Store
|
||||
OwnerID uuid.UUID
|
||||
}
|
||||
|
||||
type listTemplatesArgs struct {
|
||||
Query string `json:"query,omitempty"`
|
||||
}
|
||||
|
||||
// ListTemplates returns a tool that lists available workspace templates.
|
||||
// The agent uses this to discover templates before creating a workspace.
|
||||
func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"list_templates",
|
||||
"List available workspace templates. Optionally filter by a "+
|
||||
"search query matching template name or description. "+
|
||||
"Use this to find a template before creating a workspace.",
|
||||
func(ctx context.Context, args listTemplatesArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.DB == nil {
|
||||
return fantasy.NewTextErrorResponse("database is not configured"), nil
|
||||
}
|
||||
|
||||
ctx, err := asOwner(ctx, options.DB, options.OwnerID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
filterParams := database.GetTemplatesWithFilterParams{
|
||||
Deleted: false,
|
||||
Deprecated: sql.NullBool{
|
||||
Bool: false,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
query := strings.TrimSpace(args.Query)
|
||||
if query != "" {
|
||||
filterParams.FuzzyName = query
|
||||
}
|
||||
|
||||
templates, err := options.DB.GetTemplatesWithFilter(ctx, filterParams)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
items := make([]map[string]any, 0, len(templates))
|
||||
for _, t := range templates {
|
||||
item := map[string]any{
|
||||
"id": t.ID.String(),
|
||||
"name": t.Name,
|
||||
}
|
||||
if display := strings.TrimSpace(t.DisplayName); display != "" {
|
||||
item["display_name"] = display
|
||||
}
|
||||
if desc := strings.TrimSpace(t.Description); desc != "" {
|
||||
item["description"] = truncateRunes(desc, 200)
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
return toolResponse(map[string]any{
|
||||
"templates": items,
|
||||
"count": len(items),
|
||||
}), nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// asOwner sets up a dbauthz context for the given owner so that
|
||||
// subsequent database calls are scoped to what that user can access.
|
||||
func asOwner(ctx context.Context, db database.Store, ownerID uuid.UUID) (context.Context, error) {
|
||||
actor, _, err := httpmw.UserRBACSubject(ctx, db, ownerID, rbac.ScopeAll)
|
||||
if err != nil {
|
||||
return ctx, xerrors.Errorf("load user authorization: %w", err)
|
||||
}
|
||||
return dbauthz.As(ctx, actor), nil
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
type ReadFileOptions struct {
|
||||
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
|
||||
}
|
||||
|
||||
type ReadFileArgs struct {
|
||||
Path string `json:"path"`
|
||||
Offset *int64 `json:"offset,omitempty"`
|
||||
Limit *int64 `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
func ReadFile(options ReadFileOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"read_file",
|
||||
"Read a file from the workspace.",
|
||||
func(ctx context.Context, args ReadFileArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
}
|
||||
conn, err := options.GetWorkspaceConn(ctx)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return executeReadFileTool(ctx, conn, args)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func executeReadFileTool(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
args ReadFileArgs,
|
||||
) (fantasy.ToolResponse, error) {
|
||||
if args.Path == "" {
|
||||
return fantasy.NewTextErrorResponse("path is required"), nil
|
||||
}
|
||||
|
||||
offset := int64(0)
|
||||
limit := int64(0)
|
||||
if args.Offset != nil {
|
||||
offset = *args.Offset
|
||||
}
|
||||
if args.Limit != nil {
|
||||
limit = *args.Limit
|
||||
}
|
||||
|
||||
reader, mimeType, err := conn.ReadFile(ctx, args.Path, offset, limit)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
data, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
return toolResponse(map[string]any{
|
||||
"content": string(data),
|
||||
"mime_type": mimeType,
|
||||
}), nil
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
// ReadTemplateOptions configures the read_template tool.
|
||||
type ReadTemplateOptions struct {
|
||||
DB database.Store
|
||||
OwnerID uuid.UUID
|
||||
}
|
||||
|
||||
type readTemplateArgs struct {
|
||||
TemplateID string `json:"template_id"`
|
||||
}
|
||||
|
||||
// ReadTemplate returns a tool that retrieves details about a specific
|
||||
// template, including its configurable rich parameters. The agent
|
||||
// uses this after list_templates and before create_workspace.
|
||||
func ReadTemplate(options ReadTemplateOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"read_template",
|
||||
"Get details about a workspace template, including its "+
|
||||
"configurable parameters. Use this after finding a "+
|
||||
"template with list_templates and before creating a "+
|
||||
"workspace with create_workspace.",
|
||||
func(ctx context.Context, args readTemplateArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.DB == nil {
|
||||
return fantasy.NewTextErrorResponse("database is not configured"), nil
|
||||
}
|
||||
|
||||
templateIDStr := strings.TrimSpace(args.TemplateID)
|
||||
if templateIDStr == "" {
|
||||
return fantasy.NewTextErrorResponse("template_id is required"), nil
|
||||
}
|
||||
templateID, err := uuid.Parse(templateIDStr)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
xerrors.Errorf("invalid template_id: %w", err).Error(),
|
||||
), nil
|
||||
}
|
||||
|
||||
ctx, err = asOwner(ctx, options.DB, options.OwnerID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
template, err := options.DB.GetTemplateByID(ctx, templateID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("template not found"), nil
|
||||
}
|
||||
|
||||
params, err := options.DB.GetTemplateVersionParameters(ctx, template.ActiveVersionID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
xerrors.Errorf("failed to get template parameters: %w", err).Error(),
|
||||
), nil
|
||||
}
|
||||
|
||||
templateInfo := map[string]any{
|
||||
"id": template.ID.String(),
|
||||
"name": template.Name,
|
||||
"active_version_id": template.ActiveVersionID.String(),
|
||||
}
|
||||
if display := strings.TrimSpace(template.DisplayName); display != "" {
|
||||
templateInfo["display_name"] = display
|
||||
}
|
||||
if desc := strings.TrimSpace(template.Description); desc != "" {
|
||||
templateInfo["description"] = desc
|
||||
}
|
||||
|
||||
paramList := make([]map[string]any, 0, len(params))
|
||||
for _, p := range params {
|
||||
param := map[string]any{
|
||||
"name": p.Name,
|
||||
"type": p.Type,
|
||||
"required": p.Required,
|
||||
}
|
||||
if display := strings.TrimSpace(p.DisplayName); display != "" {
|
||||
param["display_name"] = display
|
||||
}
|
||||
if desc := strings.TrimSpace(p.Description); desc != "" {
|
||||
param["description"] = truncateRunes(desc, 300)
|
||||
}
|
||||
if p.DefaultValue != "" {
|
||||
param["default"] = p.DefaultValue
|
||||
}
|
||||
if p.Mutable {
|
||||
param["mutable"] = true
|
||||
}
|
||||
if p.Ephemeral {
|
||||
param["ephemeral"] = true
|
||||
}
|
||||
if p.FormType != "" {
|
||||
param["form_type"] = string(p.FormType)
|
||||
}
|
||||
if len(p.Options) > 0 && string(p.Options) != "null" && string(p.Options) != "[]" {
|
||||
var opts []map[string]any
|
||||
if err := json.Unmarshal(p.Options, &opts); err == nil && len(opts) > 0 {
|
||||
param["options"] = opts
|
||||
}
|
||||
}
|
||||
if p.ValidationRegex != "" {
|
||||
param["validation_regex"] = p.ValidationRegex
|
||||
}
|
||||
if p.ValidationMin.Valid {
|
||||
param["validation_min"] = p.ValidationMin.Int32
|
||||
}
|
||||
if p.ValidationMax.Valid {
|
||||
param["validation_max"] = p.ValidationMax.Int32
|
||||
}
|
||||
|
||||
paramList = append(paramList, param)
|
||||
}
|
||||
|
||||
return toolResponse(map[string]any{
|
||||
"template": templateInfo,
|
||||
"parameters": paramList,
|
||||
}), nil
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
type WriteFileOptions struct {
|
||||
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
|
||||
}
|
||||
|
||||
type WriteFileArgs struct {
|
||||
Path string `json:"path"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func WriteFile(options WriteFileOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"write_file",
|
||||
"Write a file to the workspace.",
|
||||
func(ctx context.Context, args WriteFileArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
}
|
||||
conn, err := options.GetWorkspaceConn(ctx)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return executeWriteFileTool(ctx, conn, args)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func executeWriteFileTool(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
args WriteFileArgs,
|
||||
) (fantasy.ToolResponse, error) {
|
||||
if args.Path == "" {
|
||||
return fantasy.NewTextErrorResponse("path is required"), nil
|
||||
}
|
||||
|
||||
if err := conn.WriteFile(ctx, args.Path, strings.NewReader(args.Content)); err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return toolResponse(map[string]any{"ok": true}), nil
|
||||
}
|
||||
@@ -1,126 +0,0 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
const (
|
||||
coderHomeInstructionDir = ".coder"
|
||||
coderHomeInstructionFile = "AGENTS.md"
|
||||
maxInstructionFileBytes = 64 * 1024
|
||||
)
|
||||
|
||||
var markdownCommentPattern = regexp.MustCompile(`<!--[\s\S]*?-->`)
|
||||
|
||||
func readHomeInstructionFile(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
) (content string, sourcePath string, truncated bool, err error) {
|
||||
if conn == nil {
|
||||
return "", "", false, nil
|
||||
}
|
||||
|
||||
coderDir, err := conn.LS(ctx, "", workspacesdk.LSRequest{
|
||||
Path: []string{coderHomeInstructionDir},
|
||||
Relativity: workspacesdk.LSRelativityHome,
|
||||
})
|
||||
if err != nil {
|
||||
if isCodersdkStatusCode(err, http.StatusNotFound) {
|
||||
return "", "", false, nil
|
||||
}
|
||||
return "", "", false, xerrors.Errorf("list home instruction directory: %w", err)
|
||||
}
|
||||
|
||||
var filePath string
|
||||
for _, entry := range coderDir.Contents {
|
||||
if entry.IsDir {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(entry.Name), coderHomeInstructionFile) {
|
||||
filePath = strings.TrimSpace(entry.AbsolutePathString)
|
||||
break
|
||||
}
|
||||
}
|
||||
if filePath == "" {
|
||||
return "", "", false, nil
|
||||
}
|
||||
|
||||
reader, _, err := conn.ReadFile(
|
||||
ctx,
|
||||
filePath,
|
||||
0,
|
||||
maxInstructionFileBytes+1,
|
||||
)
|
||||
if err != nil {
|
||||
if isCodersdkStatusCode(err, http.StatusNotFound) {
|
||||
return "", "", false, nil
|
||||
}
|
||||
return "", "", false, xerrors.Errorf("read home instruction file: %w", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
raw, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return "", "", false, xerrors.Errorf("read home instruction bytes: %w", err)
|
||||
}
|
||||
|
||||
truncated = int64(len(raw)) > maxInstructionFileBytes
|
||||
if truncated {
|
||||
raw = raw[:maxInstructionFileBytes]
|
||||
}
|
||||
|
||||
content = sanitizeInstructionMarkdown(string(raw))
|
||||
if content == "" {
|
||||
return "", "", truncated, nil
|
||||
}
|
||||
|
||||
return content, filePath, truncated, nil
|
||||
}
|
||||
|
||||
func sanitizeInstructionMarkdown(content string) string {
|
||||
content = strings.ReplaceAll(content, "\r\n", "\n")
|
||||
content = strings.ReplaceAll(content, "\r", "\n")
|
||||
content = markdownCommentPattern.ReplaceAllString(content, "")
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
//nolint:revive // Boolean indicates content was truncated.
|
||||
func formatHomeInstruction(content string, sourcePath string, truncated bool) string {
|
||||
content = strings.TrimSpace(content)
|
||||
if content == "" {
|
||||
return ""
|
||||
}
|
||||
sourcePath = strings.TrimSpace(sourcePath)
|
||||
if sourcePath == "" {
|
||||
sourcePath = "~/.coder/AGENTS.md"
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
_, _ = b.WriteString("<coder-home-instructions>\n")
|
||||
_, _ = b.WriteString("Source: ")
|
||||
_, _ = b.WriteString(sourcePath)
|
||||
if truncated {
|
||||
_, _ = b.WriteString(" (truncated to 64KiB)")
|
||||
}
|
||||
_, _ = b.WriteString("\n\n")
|
||||
_, _ = b.WriteString(content)
|
||||
_, _ = b.WriteString("\n</coder-home-instructions>")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func isCodersdkStatusCode(err error, statusCode int) bool {
|
||||
var sdkErr *codersdk.Error
|
||||
if !xerrors.As(err, &sdkErr) {
|
||||
return false
|
||||
}
|
||||
return sdkErr.StatusCode() == statusCode
|
||||
}
|
||||
@@ -1,134 +0,0 @@
|
||||
package chatd //nolint:testpackage // Uses internal symbols.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
)
|
||||
|
||||
func TestSanitizeInstructionMarkdown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := "line 1\r\n<!-- hidden -->\r\nline 2\r\n"
|
||||
require.Equal(t, "line 1\n\nline 2", sanitizeInstructionMarkdown(input))
|
||||
}
|
||||
|
||||
func TestReadHomeInstructionFileNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).DoAndReturn(
|
||||
func(context.Context, string, workspacesdk.LSRequest) (workspacesdk.LSResponse, error) {
|
||||
return workspacesdk.LSResponse{}, codersdk.NewTestError(404, "POST", "/api/v0/list-directory")
|
||||
},
|
||||
)
|
||||
|
||||
content, sourcePath, truncated, err := readHomeInstructionFile(context.Background(), conn)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, content)
|
||||
require.Empty(t, sourcePath)
|
||||
require.False(t, truncated)
|
||||
}
|
||||
|
||||
func TestReadHomeInstructionFileSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).DoAndReturn(
|
||||
func(context.Context, string, workspacesdk.LSRequest) (workspacesdk.LSResponse, error) {
|
||||
return workspacesdk.LSResponse{
|
||||
Contents: []workspacesdk.LSFile{{
|
||||
Name: "AGENTS.md",
|
||||
AbsolutePathString: "/home/coder/.coder/AGENTS.md",
|
||||
}},
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
conn.EXPECT().ReadFile(
|
||||
gomock.Any(),
|
||||
"/home/coder/.coder/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1),
|
||||
).Return(
|
||||
io.NopCloser(strings.NewReader("base\n<!-- hidden -->\nlocal")),
|
||||
"text/markdown",
|
||||
nil,
|
||||
)
|
||||
|
||||
content, sourcePath, truncated, err := readHomeInstructionFile(context.Background(), conn)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "base\n\nlocal", content)
|
||||
require.Equal(t, "/home/coder/.coder/AGENTS.md", sourcePath)
|
||||
require.False(t, truncated)
|
||||
}
|
||||
|
||||
func TestReadHomeInstructionFileTruncates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
content := strings.Repeat("a", maxInstructionFileBytes+8)
|
||||
|
||||
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).Return(
|
||||
workspacesdk.LSResponse{
|
||||
Contents: []workspacesdk.LSFile{{
|
||||
Name: "AGENTS.md",
|
||||
AbsolutePathString: "/home/coder/.coder/AGENTS.md",
|
||||
}},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
conn.EXPECT().ReadFile(
|
||||
gomock.Any(),
|
||||
"/home/coder/.coder/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1),
|
||||
).Return(io.NopCloser(strings.NewReader(content)), "text/markdown", nil)
|
||||
|
||||
got, _, truncated, err := readHomeInstructionFile(context.Background(), conn)
|
||||
require.NoError(t, err)
|
||||
require.True(t, truncated)
|
||||
require.Len(t, got, maxInstructionFileBytes)
|
||||
}
|
||||
|
||||
func TestInsertSystemInstructionAfterSystemMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
prompt := []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "base"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "hello"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := chatprompt.InsertSystem(prompt, "project rules")
|
||||
require.Len(t, got, 3)
|
||||
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
|
||||
require.Equal(t, fantasy.MessageRoleSystem, got[1].Role)
|
||||
require.Equal(t, fantasy.MessageRoleUser, got[2].Role)
|
||||
|
||||
part, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "project rules", part.Text)
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
package chatd
|
||||
|
||||
// DefaultSystemPrompt is used for new chats when no deployment override is
|
||||
// configured.
|
||||
const DefaultSystemPrompt = `You are the Coder agent — an interactive chat tool that helps users with software-engineering tasks inside of the Coder product.
|
||||
Use the instructions below and the tools available to you to assist User.
|
||||
|
||||
IMPORTANT — obey every rule in this prompt before anything else.
|
||||
Do EXACTLY what the User asked, never more, never less.
|
||||
|
||||
<behavior>
|
||||
You MUST execute AS MANY TOOLS to help the user accomplish their task.
|
||||
You are COMFORTABLE with vague tasks - using your tools to collect the most relevant answer possible.
|
||||
If a user asks how something works, no matter how vague, you MUST use your tools to collect the most relevant answer possible.
|
||||
DO NOT ask the user for clarification - just use your tools.
|
||||
</behavior>
|
||||
|
||||
<personality>
|
||||
Analytical — You break problems into measurable steps, relying on tool output and data rather than intuition.
|
||||
Organized — You structure every interaction with clear tags, TODO lists, and section boundaries.
|
||||
Precision-Oriented — You insist on exact formatting, package-manager choice, and rule adherence.
|
||||
Efficiency-Focused — You minimize chatter, run tasks in parallel, and favor small, complete answers.
|
||||
Clarity-Seeking — You ask for missing details instead of guessing, avoiding any ambiguity.
|
||||
</personality>
|
||||
|
||||
<communication>
|
||||
Be concise, direct, and to the point.
|
||||
NO emojis unless the User explicitly asks for them.
|
||||
If a task appears incomplete or ambiguous, **pause and ask the User** rather than guessing or marking "done".
|
||||
Prefer accuracy over reassurance; confirm facts with tool calls instead of assuming the User is right.
|
||||
If you face an architectural, tooling, or package-manager choice, **ask the User's preference first**.
|
||||
Default to the project's existing package manager / tooling; never substitute without confirmation.
|
||||
You MUST avoid text before/after your response, such as "The answer is" or "Short answer:", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
|
||||
Mimic the style of the User's messages.
|
||||
Do not remind the User you are happy to help.
|
||||
Do not inherently assume the User is correct; they may be making assumptions.
|
||||
If you are not confident in your answer, DO NOT provide an answer. Use your tools to collect more information, or ask the User for help.
|
||||
Do not act with sycophantic flattery or over-the-top enthusiasm.
|
||||
|
||||
Here are examples to demonstrate appropriate communication style and level of verbosity:
|
||||
|
||||
<example>
|
||||
user: find me a good issue to work on
|
||||
assistant: Issue [#1234](https://example) indicates a bug in the frontend, which you've contributed to in the past.
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: work on this issue <url>
|
||||
...assistant does work...
|
||||
assistant: I've put up this pull request: https://github.com/example/example/pull/1824. Please let me know your thoughts!
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: what is 2+2?
|
||||
assistant: 4
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: how does X work in <popular-repository-name>?
|
||||
assistant: Let me take a look at the code...
|
||||
[tool calls to investigate the repository]
|
||||
</example>
|
||||
</communication>
|
||||
|
||||
<collaboration>
|
||||
When a user asks for help with a task or there is ambiguity on the objective, always start by asking clarifying questions to understand:
|
||||
- What specific aspect they want to focus on
|
||||
- Their goals and vision for the changes
|
||||
- Their preferences for approach or style
|
||||
- What problems they're trying to solve
|
||||
|
||||
Don't assume what needs to be done - collaborate to define the scope together.
|
||||
</collaboration>`
|
||||
@@ -1,513 +0,0 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
var ErrSubagentNotDescendant = xerrors.New("target chat is not a descendant of current chat")
|
||||
|
||||
const (
|
||||
subagentAwaitPollInterval = 200 * time.Millisecond
|
||||
defaultSubagentWaitTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
type spawnAgentArgs struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Title string `json:"title,omitempty"`
|
||||
}
|
||||
|
||||
type waitAgentArgs struct {
|
||||
ChatID string `json:"chat_id"`
|
||||
TimeoutSeconds *int `json:"timeout_seconds,omitempty"`
|
||||
}
|
||||
|
||||
type messageAgentArgs struct {
|
||||
ChatID string `json:"chat_id"`
|
||||
Message string `json:"message"`
|
||||
Interrupt bool `json:"interrupt,omitempty"`
|
||||
}
|
||||
|
||||
type closeAgentArgs struct {
|
||||
ChatID string `json:"chat_id"`
|
||||
}
|
||||
|
||||
func (p *Server) subagentTools(currentChat func() database.Chat) []fantasy.AgentTool {
|
||||
return []fantasy.AgentTool{
|
||||
fantasy.NewAgentTool(
|
||||
"spawn_agent",
|
||||
"Spawn a delegated child agent chat from the root chat.",
|
||||
func(ctx context.Context, args spawnAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if currentChat == nil {
|
||||
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
|
||||
}
|
||||
|
||||
parent := currentChat()
|
||||
if parent.ParentChatID.Valid {
|
||||
return fantasy.NewTextErrorResponse("delegated chats cannot create child subagents"), nil
|
||||
}
|
||||
|
||||
parent, err := p.db.GetChatByID(ctx, parent.ID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
childChat, err := p.createChildSubagentChat(
|
||||
ctx,
|
||||
parent,
|
||||
args.Prompt,
|
||||
args.Title,
|
||||
)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
return toolJSONResponse(map[string]any{
|
||||
"chat_id": childChat.ID.String(),
|
||||
"title": childChat.Title,
|
||||
"status": string(childChat.Status),
|
||||
}), nil
|
||||
},
|
||||
),
|
||||
fantasy.NewAgentTool(
|
||||
"wait_agent",
|
||||
"Wait until a delegated descendant agent reaches a non-streaming status.",
|
||||
func(ctx context.Context, args waitAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if currentChat == nil {
|
||||
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
|
||||
}
|
||||
|
||||
targetChatID, err := parseSubagentToolChatID(args.ChatID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
timeout := defaultSubagentWaitTimeout
|
||||
if args.TimeoutSeconds != nil {
|
||||
timeout = time.Duration(*args.TimeoutSeconds) * time.Second
|
||||
}
|
||||
|
||||
parent := currentChat()
|
||||
targetChat, report, err := p.awaitSubagentCompletion(
|
||||
ctx,
|
||||
parent.ID,
|
||||
targetChatID,
|
||||
timeout,
|
||||
)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
return toolJSONResponse(map[string]any{
|
||||
"chat_id": targetChatID.String(),
|
||||
"title": targetChat.Title,
|
||||
"report": report,
|
||||
"status": string(targetChat.Status),
|
||||
}), nil
|
||||
},
|
||||
),
|
||||
fantasy.NewAgentTool(
|
||||
"message_agent",
|
||||
"Send a message to a delegated descendant agent. Use wait_agent to collect a response.",
|
||||
func(ctx context.Context, args messageAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if currentChat == nil {
|
||||
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
|
||||
}
|
||||
|
||||
targetChatID, err := parseSubagentToolChatID(args.ChatID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
parent := currentChat()
|
||||
targetChat, err := p.sendSubagentMessage(
|
||||
ctx,
|
||||
parent.ID,
|
||||
targetChatID,
|
||||
args.Message,
|
||||
args.Interrupt,
|
||||
)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
return toolJSONResponse(map[string]any{
|
||||
"chat_id": targetChatID.String(),
|
||||
"title": targetChat.Title,
|
||||
"status": string(targetChat.Status),
|
||||
"interrupted": args.Interrupt,
|
||||
}), nil
|
||||
},
|
||||
),
|
||||
fantasy.NewAgentTool(
|
||||
"close_agent",
|
||||
"Interrupt a delegated descendant agent immediately.",
|
||||
func(ctx context.Context, args closeAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if currentChat == nil {
|
||||
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
|
||||
}
|
||||
|
||||
targetChatID, err := parseSubagentToolChatID(args.ChatID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
parent := currentChat()
|
||||
targetChat, err := p.closeSubagent(
|
||||
ctx,
|
||||
parent.ID,
|
||||
targetChatID,
|
||||
)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
return toolJSONResponse(map[string]any{
|
||||
"chat_id": targetChatID.String(),
|
||||
"title": targetChat.Title,
|
||||
"terminated": true,
|
||||
"status": string(targetChat.Status),
|
||||
}), nil
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func parseSubagentToolChatID(raw string) (uuid.UUID, error) {
|
||||
chatID, err := uuid.Parse(strings.TrimSpace(raw))
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.New("chat_id must be a valid UUID")
|
||||
}
|
||||
return chatID, nil
|
||||
}
|
||||
|
||||
func (p *Server) createChildSubagentChat(
|
||||
ctx context.Context,
|
||||
parent database.Chat,
|
||||
prompt string,
|
||||
title string,
|
||||
) (database.Chat, error) {
|
||||
if parent.ParentChatID.Valid {
|
||||
return database.Chat{}, xerrors.New("delegated chats cannot create child subagents")
|
||||
}
|
||||
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
return database.Chat{}, xerrors.New("prompt is required")
|
||||
}
|
||||
|
||||
title = strings.TrimSpace(title)
|
||||
if title == "" {
|
||||
title = subagentFallbackChatTitle(prompt)
|
||||
}
|
||||
|
||||
rootChatID := parent.ID
|
||||
if parent.RootChatID.Valid {
|
||||
rootChatID = parent.RootChatID.UUID
|
||||
}
|
||||
if parent.LastModelConfigID == uuid.Nil {
|
||||
return database.Chat{}, xerrors.New("parent chat model config id is required")
|
||||
}
|
||||
|
||||
child, err := p.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
WorkspaceID: parent.WorkspaceID,
|
||||
WorkspaceAgentID: parent.WorkspaceAgentID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: rootChatID,
|
||||
Valid: true,
|
||||
},
|
||||
ModelConfigID: parent.LastModelConfigID,
|
||||
Title: title,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: prompt}},
|
||||
})
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("create child chat: %w", err)
|
||||
}
|
||||
|
||||
return child, nil
|
||||
}
|
||||
|
||||
func (p *Server) sendSubagentMessage(
|
||||
ctx context.Context,
|
||||
parentChatID uuid.UUID,
|
||||
targetChatID uuid.UUID,
|
||||
message string,
|
||||
interrupt bool,
|
||||
) (database.Chat, error) {
|
||||
message = strings.TrimSpace(message)
|
||||
if message == "" {
|
||||
return database.Chat{}, xerrors.New("message is required")
|
||||
}
|
||||
|
||||
isDescendant, err := isSubagentDescendant(ctx, p.db, parentChatID, targetChatID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if !isDescendant {
|
||||
return database.Chat{}, ErrSubagentNotDescendant
|
||||
}
|
||||
|
||||
busyBehavior := SendMessageBusyBehaviorQueue
|
||||
if interrupt {
|
||||
busyBehavior = SendMessageBusyBehaviorInterrupt
|
||||
}
|
||||
|
||||
sendResult, err := p.SendMessage(ctx, SendMessageOptions{
|
||||
ChatID: targetChatID,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: message}},
|
||||
BusyBehavior: busyBehavior,
|
||||
})
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
|
||||
return sendResult.Chat, nil
|
||||
}
|
||||
|
||||
func (p *Server) awaitSubagentCompletion(
|
||||
ctx context.Context,
|
||||
parentChatID uuid.UUID,
|
||||
targetChatID uuid.UUID,
|
||||
timeout time.Duration,
|
||||
) (database.Chat, string, error) {
|
||||
isDescendant, err := isSubagentDescendant(ctx, p.db, parentChatID, targetChatID)
|
||||
if err != nil {
|
||||
return database.Chat{}, "", err
|
||||
}
|
||||
if !isDescendant {
|
||||
return database.Chat{}, "", ErrSubagentNotDescendant
|
||||
}
|
||||
|
||||
if timeout <= 0 {
|
||||
timeout = defaultSubagentWaitTimeout
|
||||
}
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
ticker := time.NewTicker(subagentAwaitPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
targetChat, report, done, checkErr := p.checkSubagentCompletion(ctx, targetChatID)
|
||||
if checkErr != nil {
|
||||
return database.Chat{}, "", checkErr
|
||||
}
|
||||
if done {
|
||||
if targetChat.Status == database.ChatStatusError {
|
||||
reason := strings.TrimSpace(report)
|
||||
if reason == "" {
|
||||
reason = "agent reached error status"
|
||||
}
|
||||
return database.Chat{}, "", xerrors.New(reason)
|
||||
}
|
||||
return targetChat, report, nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-timer.C:
|
||||
return database.Chat{}, "", xerrors.New("timed out waiting for delegated subagent completion")
|
||||
case <-ctx.Done():
|
||||
return database.Chat{}, "", ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Server) closeSubagent(
|
||||
ctx context.Context,
|
||||
parentChatID uuid.UUID,
|
||||
targetChatID uuid.UUID,
|
||||
) (database.Chat, error) {
|
||||
isDescendant, err := isSubagentDescendant(ctx, p.db, parentChatID, targetChatID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if !isDescendant {
|
||||
return database.Chat{}, ErrSubagentNotDescendant
|
||||
}
|
||||
|
||||
targetChat, err := p.db.GetChatByID(ctx, targetChatID)
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("get target chat: %w", err)
|
||||
}
|
||||
|
||||
if targetChat.Status == database.ChatStatusWaiting {
|
||||
return targetChat, nil
|
||||
}
|
||||
|
||||
updatedChat := p.InterruptChat(ctx, targetChat)
|
||||
if updatedChat.Status != database.ChatStatusWaiting {
|
||||
return database.Chat{}, xerrors.New("set target chat waiting")
|
||||
}
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
func (p *Server) checkSubagentCompletion(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
) (database.Chat, string, bool, error) {
|
||||
chat, err := p.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return database.Chat{}, "", false, xerrors.Errorf("get chat: %w", err)
|
||||
}
|
||||
|
||||
if chat.Status == database.ChatStatusPending || chat.Status == database.ChatStatusRunning {
|
||||
return database.Chat{}, "", false, nil
|
||||
}
|
||||
|
||||
report, err := latestSubagentAssistantMessage(ctx, p.db, chatID)
|
||||
if err != nil {
|
||||
return database.Chat{}, "", false, err
|
||||
}
|
||||
|
||||
return chat, report, true, nil
|
||||
}
|
||||
|
||||
func latestSubagentAssistantMessage(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
chatID uuid.UUID,
|
||||
) (string, error) {
|
||||
messages, err := store.GetChatMessagesByChatID(ctx, chatID)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("get chat messages: %w", err)
|
||||
}
|
||||
|
||||
sort.Slice(messages, func(i, j int) bool {
|
||||
if messages[i].CreatedAt.Equal(messages[j].CreatedAt) {
|
||||
return messages[i].ID < messages[j].ID
|
||||
}
|
||||
return messages[i].CreatedAt.Before(messages[j].CreatedAt)
|
||||
})
|
||||
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
message := messages[i]
|
||||
if message.Role != string(fantasy.MessageRoleAssistant) ||
|
||||
message.Visibility == database.ChatMessageVisibilityModel {
|
||||
continue
|
||||
}
|
||||
|
||||
content, parseErr := chatprompt.ParseContent(message.Role, message.Content)
|
||||
if parseErr != nil {
|
||||
continue
|
||||
}
|
||||
text := strings.TrimSpace(contentBlocksToText(content))
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
return text, nil
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func isSubagentDescendant(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
ancestorChatID uuid.UUID,
|
||||
targetChatID uuid.UUID,
|
||||
) (bool, error) {
|
||||
if ancestorChatID == targetChatID {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
descendants, err := listSubagentDescendants(ctx, store, ancestorChatID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, descendant := range descendants {
|
||||
if descendant.ID == targetChatID {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func listSubagentDescendants(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
chatID uuid.UUID,
|
||||
) ([]database.Chat, error) {
|
||||
queue := []uuid.UUID{chatID}
|
||||
visited := map[uuid.UUID]struct{}{chatID: {}}
|
||||
|
||||
out := make([]database.Chat, 0)
|
||||
for len(queue) > 0 {
|
||||
parentChatID := queue[0]
|
||||
queue = queue[1:]
|
||||
|
||||
children, err := store.ListChildChatsByParentID(ctx, parentChatID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("list child chats for %s: %w", parentChatID, err)
|
||||
}
|
||||
|
||||
for _, child := range children {
|
||||
if _, ok := visited[child.ID]; ok {
|
||||
continue
|
||||
}
|
||||
visited[child.ID] = struct{}{}
|
||||
out = append(out, child)
|
||||
queue = append(queue, child.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func subagentFallbackChatTitle(message string) string {
|
||||
const maxWords = 6
|
||||
const maxRunes = 80
|
||||
|
||||
words := strings.Fields(message)
|
||||
if len(words) == 0 {
|
||||
return "New Chat"
|
||||
}
|
||||
|
||||
truncated := false
|
||||
if len(words) > maxWords {
|
||||
words = words[:maxWords]
|
||||
truncated = true
|
||||
}
|
||||
|
||||
title := strings.Join(words, " ")
|
||||
if truncated {
|
||||
title += "..."
|
||||
}
|
||||
|
||||
return subagentTruncateRunes(title, maxRunes)
|
||||
}
|
||||
|
||||
func subagentTruncateRunes(value string, max int) string {
|
||||
if max <= 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
runes := []rune(value)
|
||||
if len(runes) <= max {
|
||||
return value
|
||||
}
|
||||
|
||||
return string(runes[:max])
|
||||
}
|
||||
|
||||
func toolJSONResponse(result map[string]any) fantasy.ToolResponse {
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return fantasy.NewTextResponse("{}")
|
||||
}
|
||||
return fantasy.NewTextResponse(string(data))
|
||||
}
|
||||
@@ -1,216 +0,0 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
)
|
||||
|
||||
const titleGenerationPrompt = "Generate a concise title (max 8 words, under 128 characters) for " +
|
||||
"the user's first message. Return plain text only — no quotes, no emoji, " +
|
||||
"no markdown, no special characters."
|
||||
|
||||
// maybeGenerateChatTitle generates an AI title for the chat when
|
||||
// appropriate (first user message, no assistant reply yet, and the
|
||||
// current title is either empty or still the fallback truncation).
|
||||
// It is a best-effort operation that logs and swallows errors.
|
||||
func (p *Server) maybeGenerateChatTitle(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
messages []database.ChatMessage,
|
||||
model fantasy.LanguageModel,
|
||||
logger slog.Logger,
|
||||
) {
|
||||
input, ok := titleInput(chat, messages)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
titleCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
title, err := generateTitle(titleCtx, model, input)
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "failed to generate chat title",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if title == "" || title == chat.Title {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = p.db.UpdateChatByID(ctx, database.UpdateChatByIDParams{
|
||||
ID: chat.ID,
|
||||
Title: title,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to update generated chat title",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
chat.Title = title
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange)
|
||||
}
|
||||
|
||||
// generateTitle calls the model with a title-generation system prompt
|
||||
// and returns the normalized result.
|
||||
func generateTitle(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
input string,
|
||||
) (string, error) {
|
||||
prompt := []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: titleGenerationPrompt},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: input},
|
||||
},
|
||||
},
|
||||
}
|
||||
toolChoice := fantasy.ToolChoiceNone
|
||||
response, err := model.Generate(ctx, fantasy.Call{
|
||||
Prompt: prompt,
|
||||
ToolChoice: &toolChoice,
|
||||
})
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("generate title text: %w", err)
|
||||
}
|
||||
|
||||
title := normalizeTitleOutput(contentBlocksToText(response.Content))
|
||||
if title == "" {
|
||||
return "", xerrors.New("generated title was empty")
|
||||
}
|
||||
return title, nil
|
||||
}
|
||||
|
||||
// titleInput returns the first user message text and whether title
|
||||
// generation should proceed. It returns false when the chat already
|
||||
// has assistant/tool replies, has more than one visible user message,
|
||||
// or the current title doesn't look like a candidate for replacement.
|
||||
func titleInput(
|
||||
chat database.Chat,
|
||||
messages []database.ChatMessage,
|
||||
) (string, bool) {
|
||||
userCount := 0
|
||||
firstUserText := ""
|
||||
|
||||
for _, message := range messages {
|
||||
if message.Visibility == database.ChatMessageVisibilityModel {
|
||||
continue
|
||||
}
|
||||
|
||||
switch message.Role {
|
||||
case string(fantasy.MessageRoleAssistant), string(fantasy.MessageRoleTool):
|
||||
return "", false
|
||||
case string(fantasy.MessageRoleUser):
|
||||
userCount++
|
||||
if firstUserText == "" {
|
||||
parsed, err := chatprompt.ParseContent(
|
||||
string(fantasy.MessageRoleUser), message.Content,
|
||||
)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
firstUserText = strings.TrimSpace(
|
||||
contentBlocksToText(parsed),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if userCount != 1 || firstUserText == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
currentTitle := strings.TrimSpace(chat.Title)
|
||||
if currentTitle == "" {
|
||||
return firstUserText, true
|
||||
}
|
||||
|
||||
if currentTitle != fallbackChatTitle(firstUserText) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return firstUserText, true
|
||||
}
|
||||
|
||||
func normalizeTitleOutput(title string) string {
|
||||
title = strings.TrimSpace(title)
|
||||
if title == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
title = strings.Trim(title, "\"'`")
|
||||
title = strings.Join(strings.Fields(title), " ")
|
||||
return truncateRunes(title, 80)
|
||||
}
|
||||
|
||||
func fallbackChatTitle(message string) string {
|
||||
const maxWords = 6
|
||||
const maxRunes = 80
|
||||
|
||||
words := strings.Fields(message)
|
||||
if len(words) == 0 {
|
||||
return "New Chat"
|
||||
}
|
||||
|
||||
truncated := false
|
||||
if len(words) > maxWords {
|
||||
words = words[:maxWords]
|
||||
truncated = true
|
||||
}
|
||||
|
||||
title := strings.Join(words, " ")
|
||||
if truncated {
|
||||
title += "…"
|
||||
}
|
||||
|
||||
return truncateRunes(title, maxRunes)
|
||||
}
|
||||
|
||||
// contentBlocksToText concatenates the text parts of content blocks
|
||||
// into a single space-separated string.
|
||||
func contentBlocksToText(content []fantasy.Content) string {
|
||||
parts := make([]string, 0, len(content))
|
||||
for _, block := range content {
|
||||
textBlock, ok := fantasy.AsContentType[fantasy.TextContent](block)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
text := strings.TrimSpace(textBlock.Text)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, text)
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func truncateRunes(value string, maxLen int) string {
|
||||
if maxLen <= 0 {
|
||||
return ""
|
||||
}
|
||||
runes := []rune(value)
|
||||
if len(runes) <= maxLen {
|
||||
return value
|
||||
}
|
||||
return string(runes[:maxLen])
|
||||
}
|
||||
-3229
File diff suppressed because it is too large
Load Diff
@@ -1,336 +0,0 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("PostChats", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello from chats route tests",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEqual(t, uuid.Nil, chat.ID)
|
||||
require.Equal(t, user.UserID, chat.OwnerID)
|
||||
require.Equal(t, modelConfig.ID, chat.LastModelConfigID)
|
||||
require.Equal(t, "hello from chats route tests", chat.Title)
|
||||
require.Equal(t, codersdk.ChatStatusPending, chat.Status)
|
||||
require.NotZero(t, chat.CreatedAt)
|
||||
require.NotZero(t, chat.UpdatedAt)
|
||||
|
||||
require.Nil(t, chat.WorkspaceID)
|
||||
require.Nil(t, chat.WorkspaceAgentID)
|
||||
require.NotNil(t, chat.RootChatID)
|
||||
require.Equal(t, chat.ID, *chat.RootChatID)
|
||||
|
||||
chatWithMessages, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chat.ID, chatWithMessages.Chat.ID)
|
||||
|
||||
foundUserMessage := false
|
||||
for _, message := range chatWithMessages.Messages {
|
||||
if message.Role != "user" {
|
||||
continue
|
||||
}
|
||||
for _, part := range message.Content {
|
||||
if part.Type == codersdk.ChatMessagePartTypeText &&
|
||||
part.Text == "hello from chats route tests" {
|
||||
foundUserMessage = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
require.True(t, foundUserMessage)
|
||||
})
|
||||
|
||||
t.Run("HidesSystemPromptMessages", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "verify hidden system prompt",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chatWithMessages, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
for _, message := range chatWithMessages.Messages {
|
||||
require.NotEqual(t, "system", message.Role)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WorkspaceNotAccessible", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient, db := coderdtest.NewWithDatabase(t, nil)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient)
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
|
||||
|
||||
workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
OwnerID: firstUser.UserID,
|
||||
}).WithAgent().Do()
|
||||
|
||||
_, err := memberClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello",
|
||||
},
|
||||
},
|
||||
WorkspaceID: &workspaceBuild.Workspace.ID,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Workspace not found or you do not have access to this resource", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("WorkspaceNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello",
|
||||
},
|
||||
},
|
||||
WorkspaceID: &workspaceID,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Workspace not found or you do not have access to this resource", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("WorkspaceSelectsFirstAgent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := coderdtest.NewWithDatabase(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent().Do()
|
||||
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello",
|
||||
},
|
||||
},
|
||||
WorkspaceID: &workspaceBuild.Workspace.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, chat.WorkspaceID)
|
||||
require.Equal(t, workspaceBuild.Workspace.ID, *chat.WorkspaceID)
|
||||
require.NotNil(t, chat.WorkspaceAgentID)
|
||||
require.Equal(t, workspaceBuild.Agents[0].ID, *chat.WorkspaceAgentID)
|
||||
require.Equal(t, modelConfig.ID, chat.LastModelConfigID)
|
||||
})
|
||||
|
||||
t.Run("MissingDefaultModelConfig", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello",
|
||||
},
|
||||
},
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "No default chat model config is configured.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("EmptyContent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: nil,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Content is required.", sdkErr.Message)
|
||||
require.Equal(t, "Content cannot be empty.", sdkErr.Detail)
|
||||
})
|
||||
|
||||
t.Run("EmptyText", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: " ",
|
||||
},
|
||||
},
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid input part.", sdkErr.Message)
|
||||
require.Equal(t, "content[0].text cannot be empty.", sdkErr.Detail)
|
||||
})
|
||||
|
||||
t.Run("UnsupportedPartType", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartType("image"),
|
||||
Text: "hello",
|
||||
},
|
||||
},
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid input part.", sdkErr.Message)
|
||||
require.Equal(t, `content[0].type "image" is not supported.`, sdkErr.Detail)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("ListChatModels", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
models, err := client.ListChatModels(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
var openAIProvider *codersdk.ChatModelProvider
|
||||
for i := range models.Providers {
|
||||
if models.Providers[i].Provider == "openai" {
|
||||
openAIProvider = &models.Providers[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, openAIProvider)
|
||||
require.True(t, openAIProvider.Available)
|
||||
|
||||
foundModel := false
|
||||
for _, model := range openAIProvider.Models {
|
||||
if model.Provider == "openai" && model.Model == "gpt-4o-mini" {
|
||||
foundModel = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundModel)
|
||||
})
|
||||
|
||||
t.Run("ListChatProviders", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
providers, err := client.ListChatProviders(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
var openAIProvider *codersdk.ChatProviderConfig
|
||||
for i := range providers {
|
||||
if providers[i].Provider == "openai" {
|
||||
openAIProvider = &providers[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, openAIProvider)
|
||||
require.Equal(t, codersdk.ChatProviderConfigSourceDatabase, openAIProvider.Source)
|
||||
require.True(t, openAIProvider.Enabled)
|
||||
require.True(t, openAIProvider.HasAPIKey)
|
||||
})
|
||||
}
|
||||
|
||||
func createChatModelConfig(t *testing.T, client *codersdk.Client) codersdk.ChatModelConfig {
|
||||
t.Helper()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: "test-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return modelConfig
|
||||
}
|
||||
|
||||
func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error {
|
||||
t.Helper()
|
||||
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, expectedStatus, sdkErr.StatusCode())
|
||||
return sdkErr
|
||||
}
|
||||
+7
-63
@@ -49,7 +49,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/awsidentity"
|
||||
"github.com/coder/coder/v2/coderd/boundaryusage"
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/connectionlog"
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -239,9 +238,6 @@ type Options struct {
|
||||
SSHConfig codersdk.SSHConfigResponse
|
||||
|
||||
HTTPClient *http.Client
|
||||
// ChatRemotePartsProvider provides cross-replica message_part streaming.
|
||||
// Set by enterprise for HA deployments. Nil in AGPL single-replica.
|
||||
ChatRemotePartsProvider chatd.RemotePartsProvider
|
||||
|
||||
UpdateAgentMetrics func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric)
|
||||
StatsBatcher workspacestats.Batcher
|
||||
@@ -592,6 +588,7 @@ func New(options *Options) *API {
|
||||
var buildUsageChecker atomic.Pointer[wsbuilder.UsageChecker]
|
||||
var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{}
|
||||
buildUsageChecker.Store(&noopUsageChecker)
|
||||
|
||||
api := &API{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
@@ -757,17 +754,6 @@ func New(options *Options) *API {
|
||||
panic("failed to setup server tailnet: " + err.Error())
|
||||
}
|
||||
api.agentProvider = stn
|
||||
|
||||
api.chatDaemon = chatd.New(chatd.Config{
|
||||
Logger: options.Logger.Named("chats"),
|
||||
Database: options.Database,
|
||||
ReplicaID: api.ID,
|
||||
RemotePartsProvider: options.ChatRemotePartsProvider,
|
||||
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
Pubsub: options.Pubsub,
|
||||
})
|
||||
if options.DeploymentValues.Prometheus.Enable {
|
||||
options.PrometheusRegistry.MustRegister(stn)
|
||||
api.lifecycleMetrics = agentapi.NewLifecycleMetrics(options.PrometheusRegistry)
|
||||
@@ -914,7 +900,6 @@ func New(options *Options) *API {
|
||||
sharedhttpmw.Recover(api.Logger),
|
||||
httpmw.WithProfilingLabels,
|
||||
tracing.StatusWriterMiddleware,
|
||||
options.DeploymentValues.HTTPCookies.Middleware,
|
||||
tracing.Middleware(api.TracerProvider),
|
||||
httpmw.AttachRequestID,
|
||||
httpmw.ExtractRealIP(api.RealIPConfig),
|
||||
@@ -1174,43 +1159,6 @@ func New(options *Options) *API {
|
||||
r.Get("/", api.auditLogs)
|
||||
r.Post("/testgenerate", api.generateFakeAuditLog)
|
||||
})
|
||||
r.Route("/chats", func(r chi.Router) {
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Get("/", api.listChats)
|
||||
r.Post("/", api.postChats)
|
||||
r.Get("/models", api.listChatModels)
|
||||
r.Get("/watch", api.watchChats)
|
||||
r.Route("/providers", func(r chi.Router) {
|
||||
r.Get("/", api.listChatProviders)
|
||||
r.Post("/", api.createChatProvider)
|
||||
r.Route("/{providerConfig}", func(r chi.Router) {
|
||||
r.Patch("/", api.updateChatProvider)
|
||||
r.Delete("/", api.deleteChatProvider)
|
||||
})
|
||||
})
|
||||
r.Route("/model-configs", func(r chi.Router) {
|
||||
r.Get("/", api.listChatModelConfigs)
|
||||
r.Post("/", api.createChatModelConfig)
|
||||
r.Route("/{modelConfig}", func(r chi.Router) {
|
||||
r.Patch("/", api.updateChatModelConfig)
|
||||
r.Delete("/", api.deleteChatModelConfig)
|
||||
})
|
||||
})
|
||||
r.Route("/{chat}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractChatParam(options.Database))
|
||||
r.Get("/", api.getChat)
|
||||
r.Delete("/", api.deleteChat)
|
||||
r.Post("/messages", api.postChatMessages)
|
||||
r.Get("/stream", api.streamChat)
|
||||
r.Post("/interrupt", api.interruptChat)
|
||||
r.Get("/diff-status", api.getChatDiffStatus)
|
||||
r.Get("/diff", api.getChatDiffContents)
|
||||
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
|
||||
r.Delete("/", api.deleteChatQueuedMessage)
|
||||
r.Post("/promote", api.promoteChatQueuedMessage)
|
||||
})
|
||||
})
|
||||
})
|
||||
r.Route("/files", func(r chi.Router) {
|
||||
r.Use(
|
||||
apiKeyMiddleware,
|
||||
@@ -1284,10 +1232,7 @@ func New(options *Options) *API {
|
||||
r.Get("/", api.organizationMember)
|
||||
r.Delete("/", api.deleteOrganizationMember)
|
||||
r.Put("/roles", api.putMemberRoles)
|
||||
r.Route("/workspaces", func(r chi.Router) {
|
||||
r.Post("/", api.postWorkspacesByOrganization)
|
||||
r.Get("/available-users", api.workspaceAvailableUsers)
|
||||
})
|
||||
r.Post("/workspaces", api.postWorkspacesByOrganization)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1454,7 +1399,6 @@ func New(options *Options) *API {
|
||||
r.Route("/{keyid}", func(r chi.Router) {
|
||||
r.Get("/", api.apiKeyByID)
|
||||
r.Delete("/", api.deleteAPIKey)
|
||||
r.Put("/expire", api.expireAPIKey)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1577,6 +1521,10 @@ func New(options *Options) *API {
|
||||
})
|
||||
r.Get("/timings", api.workspaceTimings)
|
||||
r.Route("/acl", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.RequireExperiment(api.Experiments, codersdk.ExperimentWorkspaceSharing),
|
||||
)
|
||||
|
||||
r.Get("/", api.workspaceACL)
|
||||
r.Patch("/", api.patchWorkspaceACL)
|
||||
r.Delete("/", api.deleteWorkspaceACL)
|
||||
@@ -1953,8 +1901,6 @@ type API struct {
|
||||
// dbRolluper rolls up template usage stats from raw agent and app
|
||||
// stats. This is used to provide insights in the WebUI.
|
||||
dbRolluper *dbrollup.Rolluper
|
||||
// chatDaemon handles background processing of pending chats.
|
||||
chatDaemon *chatd.Server
|
||||
}
|
||||
|
||||
// Close waits for all WebSocket connections to drain before returning.
|
||||
@@ -1983,10 +1929,8 @@ func (api *API) Close() error {
|
||||
case <-timer.C:
|
||||
api.Logger.Warn(api.ctx, "websocket shutdown timed out after 10 seconds")
|
||||
}
|
||||
|
||||
api.dbRolluper.Close()
|
||||
if err := api.chatDaemon.Close(); err != nil {
|
||||
api.Logger.Warn(api.ctx, "close chat processor", slog.Error(err))
|
||||
}
|
||||
api.metricsCache.Close()
|
||||
if api.updateChecker != nil {
|
||||
api.updateChecker.Close()
|
||||
|
||||
@@ -6,20 +6,17 @@ type CheckConstraint string
|
||||
|
||||
// CheckConstraint enums.
|
||||
const (
|
||||
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
|
||||
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
|
||||
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
|
||||
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
|
||||
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
|
||||
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
|
||||
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
|
||||
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
|
||||
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
|
||||
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
|
||||
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
|
||||
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
|
||||
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
|
||||
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
|
||||
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
|
||||
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
|
||||
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
|
||||
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
|
||||
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
|
||||
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
|
||||
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
|
||||
)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
package db2sdk
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
@@ -12,7 +11,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/hcl/v2"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
@@ -20,7 +18,6 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
@@ -1053,332 +1050,3 @@ func jsonOrEmptyMap(rawMessage pqtype.NullRawMessage) map[string]any {
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func ChatMessage(m database.ChatMessage) codersdk.ChatMessage {
|
||||
modelConfigID := &m.ModelConfigID.UUID
|
||||
if !m.ModelConfigID.Valid {
|
||||
modelConfigID = nil
|
||||
}
|
||||
msg := codersdk.ChatMessage{
|
||||
ID: m.ID,
|
||||
ChatID: m.ChatID,
|
||||
ModelConfigID: modelConfigID,
|
||||
CreatedAt: m.CreatedAt,
|
||||
Role: m.Role,
|
||||
}
|
||||
if m.Content.Valid {
|
||||
parts, err := chatMessageParts(m.Role, m.Content)
|
||||
if err == nil {
|
||||
msg.Content = parts
|
||||
}
|
||||
}
|
||||
usage := chatMessageUsage(m)
|
||||
if usage != nil {
|
||||
msg.Usage = usage
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
// chatMessageUsage builds a ChatMessageUsage from the database row,
|
||||
// returning nil when no token fields are populated.
|
||||
func chatMessageUsage(m database.ChatMessage) *codersdk.ChatMessageUsage {
|
||||
inputTokens := nullInt64Ptr(m.InputTokens)
|
||||
outputTokens := nullInt64Ptr(m.OutputTokens)
|
||||
totalTokens := nullInt64Ptr(m.TotalTokens)
|
||||
reasoningTokens := nullInt64Ptr(m.ReasoningTokens)
|
||||
cacheCreationTokens := nullInt64Ptr(m.CacheCreationTokens)
|
||||
cacheReadTokens := nullInt64Ptr(m.CacheReadTokens)
|
||||
contextLimit := nullInt64Ptr(m.ContextLimit)
|
||||
|
||||
if inputTokens == nil && outputTokens == nil && totalTokens == nil &&
|
||||
reasoningTokens == nil && cacheCreationTokens == nil &&
|
||||
cacheReadTokens == nil && contextLimit == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &codersdk.ChatMessageUsage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
TotalTokens: totalTokens,
|
||||
ReasoningTokens: reasoningTokens,
|
||||
CacheCreationTokens: cacheCreationTokens,
|
||||
CacheReadTokens: cacheReadTokens,
|
||||
ContextLimit: contextLimit,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatQueuedMessages converts a slice of database queued messages
|
||||
// to their SDK representation.
|
||||
func ChatQueuedMessages(messages []database.ChatQueuedMessage) []codersdk.ChatQueuedMessage {
|
||||
out := make([]codersdk.ChatQueuedMessage, 0, len(messages))
|
||||
for _, message := range messages {
|
||||
out = append(out, codersdk.ChatQueuedMessage{
|
||||
ID: message.ID,
|
||||
ChatID: message.ChatID,
|
||||
Content: message.Content,
|
||||
CreatedAt: message.CreatedAt,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func chatMessageParts(role string, raw pqtype.NullRawMessage) ([]codersdk.ChatMessagePart, error) {
|
||||
switch role {
|
||||
case string(fantasy.MessageRoleSystem):
|
||||
content, err := parseSystemContent(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return []codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: content,
|
||||
}}, nil
|
||||
case string(fantasy.MessageRoleUser), string(fantasy.MessageRoleAssistant):
|
||||
content, err := parseContentBlocks(role, raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rawBlocks []json.RawMessage
|
||||
if role == string(fantasy.MessageRoleAssistant) {
|
||||
_ = json.Unmarshal(raw.RawMessage, &rawBlocks)
|
||||
}
|
||||
|
||||
parts := make([]codersdk.ChatMessagePart, 0, len(content))
|
||||
for i, block := range content {
|
||||
part := contentBlockToPart(block)
|
||||
if part.Type == "" {
|
||||
continue
|
||||
}
|
||||
if part.Type == codersdk.ChatMessagePartTypeReasoning {
|
||||
part.Title = ""
|
||||
if i < len(rawBlocks) {
|
||||
part.Title = reasoningStoredTitle(rawBlocks[i])
|
||||
}
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
return parts, nil
|
||||
case string(fantasy.MessageRoleTool):
|
||||
results, err := parseToolResults(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts := make([]codersdk.ChatMessagePart, 0, len(results))
|
||||
for _, result := range results {
|
||||
parts = append(parts, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolResult,
|
||||
ToolCallID: result.ToolCallID,
|
||||
ToolName: result.ToolName,
|
||||
Result: result.Result,
|
||||
IsError: result.IsError,
|
||||
})
|
||||
}
|
||||
return parts, nil
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseSystemContent(raw pqtype.NullRawMessage) (string, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
var content string
|
||||
if err := json.Unmarshal(raw.RawMessage, &content); err != nil {
|
||||
return "", xerrors.Errorf("parse system content: %w", err)
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func parseContentBlocks(role string, raw pqtype.NullRawMessage) ([]fantasy.Content, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if role == string(fantasy.MessageRoleUser) {
|
||||
var text string
|
||||
if err := json.Unmarshal(raw.RawMessage, &text); err == nil {
|
||||
return []fantasy.Content{
|
||||
fantasy.TextContent{Text: text},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
var blocks []json.RawMessage
|
||||
if err := json.Unmarshal(raw.RawMessage, &blocks); err != nil {
|
||||
return nil, xerrors.Errorf("parse content blocks: %w", err)
|
||||
}
|
||||
|
||||
content := make([]fantasy.Content, 0, len(blocks))
|
||||
for _, block := range blocks {
|
||||
decoded, err := fantasy.UnmarshalContent(block)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse content block: %w", err)
|
||||
}
|
||||
content = append(content, decoded)
|
||||
}
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// toolResultRow is used only for extracting top-level fields from
|
||||
// persisted tool result JSON. The result payload is kept as raw JSON.
|
||||
type toolResultRow struct {
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Result json.RawMessage `json:"result"`
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
func parseToolResults(raw pqtype.NullRawMessage) ([]toolResultRow, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var results []toolResultRow
|
||||
if err := json.Unmarshal(raw.RawMessage, &results); err != nil {
|
||||
return nil, xerrors.Errorf("parse tool results: %w", err)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func reasoningStoredTitle(raw json.RawMessage) string {
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
Data struct {
|
||||
Title string `json:"title"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &envelope); err != nil {
|
||||
return ""
|
||||
}
|
||||
if !strings.EqualFold(envelope.Type, string(fantasy.ContentTypeReasoning)) {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(envelope.Data.Title)
|
||||
}
|
||||
|
||||
func contentBlockToPart(block fantasy.Content) codersdk.ChatMessagePart {
|
||||
switch value := block.(type) {
|
||||
case fantasy.TextContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: value.Text,
|
||||
}
|
||||
case *fantasy.TextContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: value.Text,
|
||||
}
|
||||
case fantasy.ReasoningContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Text: value.Text,
|
||||
}
|
||||
case *fantasy.ReasoningContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Text: value.Text,
|
||||
}
|
||||
case fantasy.ToolCallContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: value.ToolCallID,
|
||||
ToolName: value.ToolName,
|
||||
Args: []byte(value.Input),
|
||||
}
|
||||
case *fantasy.ToolCallContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: value.ToolCallID,
|
||||
ToolName: value.ToolName,
|
||||
Args: []byte(value.Input),
|
||||
}
|
||||
case fantasy.SourceContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeSource,
|
||||
SourceID: value.ID,
|
||||
URL: value.URL,
|
||||
Title: value.Title,
|
||||
}
|
||||
case *fantasy.SourceContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeSource,
|
||||
SourceID: value.ID,
|
||||
URL: value.URL,
|
||||
Title: value.Title,
|
||||
}
|
||||
case fantasy.FileContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeFile,
|
||||
MediaType: value.MediaType,
|
||||
Data: value.Data,
|
||||
}
|
||||
case *fantasy.FileContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeFile,
|
||||
MediaType: value.MediaType,
|
||||
Data: value.Data,
|
||||
}
|
||||
case fantasy.ToolResultContent:
|
||||
return chatprompt.ToolResultToPart(
|
||||
value.ToolCallID,
|
||||
value.ToolName,
|
||||
toolResultOutputToRawJSON(value.Result),
|
||||
toolResultOutputIsError(value.Result),
|
||||
)
|
||||
case *fantasy.ToolResultContent:
|
||||
return chatprompt.ToolResultToPart(
|
||||
value.ToolCallID,
|
||||
value.ToolName,
|
||||
toolResultOutputToRawJSON(value.Result),
|
||||
toolResultOutputIsError(value.Result),
|
||||
)
|
||||
default:
|
||||
return codersdk.ChatMessagePart{}
|
||||
}
|
||||
}
|
||||
|
||||
func toolResultOutputToRawJSON(output fantasy.ToolResultOutputContent) json.RawMessage {
|
||||
switch v := output.(type) {
|
||||
case fantasy.ToolResultOutputContentError:
|
||||
if v.Error != nil {
|
||||
data, _ := json.Marshal(map[string]any{"error": v.Error.Error()})
|
||||
return data
|
||||
}
|
||||
return json.RawMessage(`{"error":""}`)
|
||||
case fantasy.ToolResultOutputContentText:
|
||||
raw := json.RawMessage(v.Text)
|
||||
if json.Valid(raw) {
|
||||
return raw
|
||||
}
|
||||
data, _ := json.Marshal(map[string]any{"output": v.Text})
|
||||
return data
|
||||
case fantasy.ToolResultOutputContentMedia:
|
||||
data, _ := json.Marshal(map[string]any{
|
||||
"data": v.Data,
|
||||
"mime_type": v.MediaType,
|
||||
"text": v.Text,
|
||||
})
|
||||
return data
|
||||
default:
|
||||
return json.RawMessage(`{}`)
|
||||
}
|
||||
}
|
||||
|
||||
func toolResultOutputIsError(output fantasy.ToolResultOutputContent) bool {
|
||||
_, ok := output.(fantasy.ToolResultOutputContentError)
|
||||
return ok
|
||||
}
|
||||
|
||||
func nullInt64Ptr(v sql.NullInt64) *int64 {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
value := v.Int64
|
||||
return &value
|
||||
}
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyopenai "charm.land/fantasy/providers/openai"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -437,79 +435,3 @@ func TestAIBridgeInterception(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatMessage_ReasoningPartWithoutPersistedTitleIsEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assistantContent, err := json.Marshal([]fantasy.Content{
|
||||
fantasy.ReasoningContent{
|
||||
Text: "Plan migration",
|
||||
ProviderMetadata: fantasy.ProviderMetadata{
|
||||
fantasyopenai.Name: &fantasyopenai.ResponsesReasoningMetadata{
|
||||
ItemID: "reasoning-1",
|
||||
Summary: []string{"Plan migration"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
message := db2sdk.ChatMessage(database.ChatMessage{
|
||||
ID: 1,
|
||||
ChatID: uuid.New(),
|
||||
CreatedAt: time.Now(),
|
||||
Role: string(fantasy.MessageRoleAssistant),
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: assistantContent,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
|
||||
require.Len(t, message.Content, 1)
|
||||
require.Equal(t, codersdk.ChatMessagePartTypeReasoning, message.Content[0].Type)
|
||||
require.Equal(t, "Plan migration", message.Content[0].Text)
|
||||
require.Empty(t, message.Content[0].Title)
|
||||
}
|
||||
|
||||
func TestChatMessage_ReasoningPartPrefersPersistedTitle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reasoningContent, err := json.Marshal(fantasy.ReasoningContent{
|
||||
Text: "Verify schema updates, then apply changes in order.",
|
||||
ProviderMetadata: fantasy.ProviderMetadata{
|
||||
fantasyopenai.Name: &fantasyopenai.ResponsesReasoningMetadata{
|
||||
ItemID: "reasoning-1",
|
||||
Summary: []string{
|
||||
"**Metadata-derived title**\n\nLonger explanation.",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var envelope map[string]any
|
||||
require.NoError(t, json.Unmarshal(reasoningContent, &envelope))
|
||||
dataValue, ok := envelope["data"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
dataValue["title"] = "Persisted stream title"
|
||||
|
||||
encodedReasoning, err := json.Marshal(envelope)
|
||||
require.NoError(t, err)
|
||||
assistantContent, err := json.Marshal([]json.RawMessage{encodedReasoning})
|
||||
require.NoError(t, err)
|
||||
|
||||
message := db2sdk.ChatMessage(database.ChatMessage{
|
||||
ID: 1,
|
||||
ChatID: uuid.New(),
|
||||
CreatedAt: time.Now(),
|
||||
Role: string(fantasy.MessageRoleAssistant),
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: assistantContent,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
|
||||
require.Len(t, message.Content, 1)
|
||||
require.Equal(t, codersdk.ChatMessagePartTypeReasoning, message.Content[0].Type)
|
||||
require.Equal(t, "Persisted stream title", message.Content[0].Title)
|
||||
}
|
||||
|
||||
@@ -453,7 +453,6 @@ var (
|
||||
rbac.ResourceProvisionerJobs.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreate},
|
||||
rbac.ResourceOauth2App.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
rbac.ResourceOauth2AppSecret.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
rbac.ResourceChat.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
}),
|
||||
User: []rbac.Permission{},
|
||||
ByOrgID: map[string]rbac.OrgPermissions{},
|
||||
@@ -669,31 +668,6 @@ var (
|
||||
}),
|
||||
Scope: rbac.ScopeAll,
|
||||
}.WithCachedASTValue()
|
||||
|
||||
subjectWorkspaceBuilder = rbac.Subject{
|
||||
Type: rbac.SubjectTypeWorkspaceBuilder,
|
||||
FriendlyName: "Workspace Builder",
|
||||
ID: uuid.Nil.String(),
|
||||
Roles: rbac.Roles([]rbac.Role{
|
||||
{
|
||||
Identifier: rbac.RoleIdentifier{Name: "workspace-builder"},
|
||||
DisplayName: "Workspace Builder",
|
||||
Site: rbac.Permissions(map[string][]policy.Action{
|
||||
// Reading provisioner daemons to check eligibility.
|
||||
rbac.ResourceProvisionerDaemon.Type: {policy.ActionRead},
|
||||
// Updating provisioner jobs (e.g. marking prebuild
|
||||
// jobs complete).
|
||||
rbac.ResourceProvisionerJobs.Type: {policy.ActionUpdate},
|
||||
// Reading provisioner state requires template update
|
||||
// permission.
|
||||
rbac.ResourceTemplate.Type: {policy.ActionUpdate},
|
||||
}),
|
||||
User: []rbac.Permission{},
|
||||
ByOrgID: map[string]rbac.OrgPermissions{},
|
||||
},
|
||||
}),
|
||||
Scope: rbac.ScopeAll,
|
||||
}.WithCachedASTValue()
|
||||
)
|
||||
|
||||
// AsProvisionerd returns a context with an actor that has permissions required
|
||||
@@ -800,14 +774,6 @@ func AsBoundaryUsageTracker(ctx context.Context) context.Context {
|
||||
return As(ctx, subjectBoundaryUsageTracker)
|
||||
}
|
||||
|
||||
// AsWorkspaceBuilder returns a context with an actor that has permissions
|
||||
// required for the workspace builder to prepare workspace builds. This
|
||||
// includes reading provisioner daemons, updating provisioner jobs, and
|
||||
// reading provisioner state (which requires template update permission).
|
||||
func AsWorkspaceBuilder(ctx context.Context) context.Context {
|
||||
return As(ctx, subjectWorkspaceBuilder)
|
||||
}
|
||||
|
||||
var AsRemoveActor = rbac.Subject{
|
||||
ID: "remove-actor",
|
||||
}
|
||||
@@ -1485,15 +1451,6 @@ func (q *querier) authorizeProvisionerJob(ctx context.Context, job database.Prov
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *querier) AcquireChat(ctx context.Context, arg database.AcquireChatParams) (database.Chat, error) {
|
||||
// AcquireChat is a system-level operation used by the chat processor.
|
||||
// Authorization is done at the system level, not per-user.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.AcquireChat(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) AcquireLock(ctx context.Context, id int64) error {
|
||||
return q.db.AcquireLock(ctx, id)
|
||||
}
|
||||
@@ -1722,17 +1679,6 @@ func (q *querier) DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) e
|
||||
return q.db.DeleteAPIKeysByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteAllChatQueuedMessages(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil {
|
||||
return err
|
||||
@@ -1757,54 +1703,6 @@ func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, u
|
||||
return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
// Authorize delete on the parent chat.
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatMessagesByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatModelConfigByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatProviderByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatQueuedMessage(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceCryptoKey); err != nil {
|
||||
return database.CryptoKey{}, err
|
||||
@@ -2263,12 +2161,12 @@ func (q *querier) GetAPIKeyByName(ctx context.Context, arg database.GetAPIKeyByN
|
||||
return fetch(q.log, q.auth, q.db.GetAPIKeyByName)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetAPIKeysByLoginType(ctx context.Context, loginType database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
|
||||
func (q *querier) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByLoginType)(ctx, loginType)
|
||||
}
|
||||
|
||||
func (q *querier) GetAPIKeysByUserID(ctx context.Context, params database.GetAPIKeysByUserIDParams) ([]database.APIKey, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByUserID)(ctx, params)
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByUserID)(ctx, database.GetAPIKeysByUserIDParams{LoginType: params.LoginType, UserID: params.UserID})
|
||||
}
|
||||
|
||||
func (q *querier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]database.APIKey, error) {
|
||||
@@ -2359,7 +2257,7 @@ func (q *querier) GetAuditLogsOffset(ctx context.Context, arg database.GetAuditL
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx context.Context, authToken uuid.UUID) (database.GetAuthenticatedWorkspaceAgentAndBuildByAuthTokenRow, error) {
|
||||
// This is a system function.
|
||||
// This is a system function
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return database.GetAuthenticatedWorkspaceAgentAndBuildByAuthTokenRow{}, err
|
||||
}
|
||||
@@ -2373,131 +2271,6 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI
|
||||
return q.db.GetAuthorizationUserRoles(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatByID(ctx context.Context, id uuid.UUID) (database.Chat, error) {
|
||||
return fetch(q.log, q.auth, q.db.GetChatByID)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (database.Chat, error) {
|
||||
return fetch(q.log, q.auth, q.db.GetChatByIDForUpdate)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return database.ChatDiffStatus{}, err
|
||||
}
|
||||
return q.db.GetChatDiffStatusByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIDs []uuid.UUID) ([]database.ChatDiffStatus, error) {
|
||||
if len(chatIDs) == 0 {
|
||||
return []database.ChatDiffStatus{}, nil
|
||||
}
|
||||
|
||||
actor, ok := ActorFromContext(ctx)
|
||||
if ok && actor.Type == rbac.SubjectTypeSystemRestricted {
|
||||
return q.db.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
|
||||
}
|
||||
|
||||
for _, chatID := range chatIDs {
|
||||
// Authorize read on each parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return q.db.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
|
||||
// ChatMessages are authorized through their parent Chat.
|
||||
// We need to fetch the message first to get its chat_id.
|
||||
msg, err := q.db.GetChatMessageByID(ctx, id)
|
||||
if err != nil {
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
// Authorize read on the parent chat.
|
||||
_, err = q.GetChatByID(ctx, msg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatMessagesByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatMessagesForPromptByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatModelConfig{}, err
|
||||
}
|
||||
return q.db.GetChatModelConfigByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatModelConfigByProviderAndModel(ctx context.Context, arg database.GetChatModelConfigByProviderAndModelParams) (database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatModelConfig{}, err
|
||||
}
|
||||
return q.db.GetChatModelConfigByProviderAndModel(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatModelConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.GetChatProviderByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.GetChatProviderByProvider(ctx, provider)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatProviders(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatQueuedMessages(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByOwnerID)(ctx, ownerID)
|
||||
}
|
||||
|
||||
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
// Just like with the audit logs query, shortcut if the user is an owner.
|
||||
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
|
||||
@@ -2555,13 +2328,6 @@ func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) {
|
||||
return q.db.GetDERPMeshKey(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetDefaultChatModelConfig(ctx context.Context) (database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatModelConfig{}, err
|
||||
}
|
||||
return q.db.GetDefaultChatModelConfig(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetDefaultOrganization(ctx context.Context) (database.Organization, error) {
|
||||
return fetch(q.log, q.auth, func(ctx context.Context, _ any) (database.Organization, error) {
|
||||
return q.db.GetDefaultOrganization(ctx)
|
||||
@@ -2602,20 +2368,6 @@ func (q *querier) GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx context.C
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetEligibleProvisionerDaemonsByProvisionerJobIDs)(ctx, provisionerJobIDs)
|
||||
}
|
||||
|
||||
func (q *querier) GetEnabledChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetEnabledChatModelConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetEnabledChatProviders(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
||||
return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetExternalAuthLink)(ctx, arg)
|
||||
}
|
||||
@@ -3315,14 +3067,6 @@ func (q *querier) GetRuntimeConfig(ctx context.Context, key string) (string, err
|
||||
return q.db.GetRuntimeConfig(ctx, key)
|
||||
}
|
||||
|
||||
func (q *querier) GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]database.Chat, error) {
|
||||
// GetStaleChats is a system-level operation used by the chat processor for recovery.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetStaleChats(ctx, staleThreshold)
|
||||
}
|
||||
|
||||
func (q *querier) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]database.TailnetPeer, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
|
||||
return nil, err
|
||||
@@ -3389,13 +3133,6 @@ func (q *querier) GetTelemetryItems(ctx context.Context) ([]database.TelemetryIt
|
||||
return q.db.GetTelemetryItems(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetTelemetryTaskEvents(ctx context.Context, arg database.GetTelemetryTaskEventsParams) ([]database.GetTelemetryTaskEventsRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTask.All()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetTelemetryTaskEvents(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) {
|
||||
if err := q.authorizeTemplateInsights(ctx, arg.TemplateIDs); err != nil {
|
||||
return nil, err
|
||||
@@ -4177,11 +3914,6 @@ func (q *querier) GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, wor
|
||||
return q.db.GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx, workspaceBuildIDs, prep)
|
||||
}
|
||||
|
||||
func (q *querier) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, buildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
|
||||
// Fetching the provisioner state requires Update permission on the template.
|
||||
return fetchWithAction(q.log, q.auth, policy.ActionUpdate, q.db.GetWorkspaceBuildProvisionerStateByID)(ctx, buildID)
|
||||
}
|
||||
|
||||
func (q *querier) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
@@ -4446,47 +4178,6 @@ func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLo
|
||||
return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChat(ctx context.Context, arg database.InsertChatParams) (database.Chat, error) {
|
||||
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()), q.db.InsertChat)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
// Authorize create on the parent chat (using update permission).
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
return q.db.InsertChatMessage(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatModelConfig{}, err
|
||||
}
|
||||
return q.db.InsertChatModelConfig(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.InsertChatProvider(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatQueuedMessage{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatQueuedMessage{}, err
|
||||
}
|
||||
return q.db.InsertChatQueuedMessage(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceCryptoKey); err != nil {
|
||||
return database.CryptoKey{}, err
|
||||
@@ -5055,14 +4746,6 @@ func (q *querier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Contex
|
||||
return q.db.ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams) ([]string, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
}
|
||||
return q.db.ListAuthorizedAIBridgeModels(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
// This function is a system function until we implement a join for aibridge interceptions.
|
||||
// Matches the behavior of the workspaces listing endpoint.
|
||||
@@ -5093,14 +4776,6 @@ func (q *querier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context,
|
||||
return q.db.ListAIBridgeUserPromptsByInterceptionIDs(ctx, interceptionIDs)
|
||||
}
|
||||
|
||||
func (q *querier) ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListChatsByRootID)(ctx, rootChatID)
|
||||
}
|
||||
|
||||
func (q *querier) ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListChildChatsByParentID)(ctx, parentChatID)
|
||||
}
|
||||
|
||||
func (q *querier) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListProvisionerKeysByOrganization)(ctx, organizationID)
|
||||
}
|
||||
@@ -5181,17 +4856,6 @@ func (q *querier) PaginatedOrganizationMembers(ctx context.Context, arg database
|
||||
return q.db.PaginatedOrganizationMembers(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (database.ChatQueuedMessage, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return database.ChatQueuedMessage{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatQueuedMessage{}, err
|
||||
}
|
||||
return q.db.PopNextQueuedMessage(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
|
||||
template, err := q.db.GetTemplateByID(ctx, templateID)
|
||||
if err != nil {
|
||||
@@ -5270,13 +4934,6 @@ func (q *querier) UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error {
|
||||
return update(q.log, q.auth, fetch, q.db.UnfavoriteWorkspace)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) UnsetDefaultChatModelConfigs(ctx context.Context) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UnsetDefaultChatModelConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateAIBridgeInterceptionEnded(ctx context.Context, params database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
|
||||
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, params.ID); err != nil {
|
||||
return database.AIBridgeInterception{}, err
|
||||
@@ -5291,75 +4948,6 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe
|
||||
return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.UpdateChatByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.UpdateChatHeartbeat(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatModelConfig(ctx context.Context, arg database.UpdateChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatModelConfig{}, err
|
||||
}
|
||||
return q.db.UpdateChatModelConfig(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.UpdateChatProvider(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
|
||||
// UpdateChatStatus is used by the chat processor to change chat status.
|
||||
// It should be called with system context.
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.UpdateChatStatus(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
|
||||
// UpdateChatWorkspace is manually implemented for chat tables and may not be
|
||||
// present on every wrapped store interface yet.
|
||||
chatWorkspaceUpdater, ok := q.db.(interface {
|
||||
UpdateChatWorkspace(context.Context, database.UpdateChatWorkspaceParams) (database.Chat, error)
|
||||
})
|
||||
if !ok {
|
||||
return database.Chat{}, xerrors.New("update chat workspace is not implemented by wrapped store")
|
||||
}
|
||||
return chatWorkspaceUpdater.UpdateChatWorkspace(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceCryptoKey); err != nil {
|
||||
return database.CryptoKey{}, err
|
||||
@@ -6415,30 +6003,6 @@ func (q *querier) UpsertBoundaryUsageStats(ctx context.Context, arg database.Ups
|
||||
return q.db.UpsertBoundaryUsageStats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
// Authorize update on the parent chat.
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatDiffStatus{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatDiffStatus{}, err
|
||||
}
|
||||
return q.db.UpsertChatDiffStatus(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDiffStatusReference(ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
||||
// Authorize update on the parent chat.
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatDiffStatus{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatDiffStatus{}, err
|
||||
}
|
||||
return q.db.UpsertChatDiffStatusReference(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
|
||||
return database.ConnectionLog{}, err
|
||||
@@ -6730,17 +6294,16 @@ func (q *querier) CountAuthorizedConnectionLogs(ctx context.Context, arg databas
|
||||
return q.CountConnectionLogs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeInterceptionsRow, error) {
|
||||
return q.db.ListAuthorizedAIBridgeInterceptions(ctx, arg, prepared)
|
||||
}
|
||||
|
||||
func (q *querier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
return q.db.CountAuthorizedAIBridgeInterceptions(ctx, arg, prepared)
|
||||
}
|
||||
|
||||
func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams, _ rbac.PreparedAuthorized) ([]string, error) {
|
||||
// TODO: Delete this function, all ListAIBridgeModels should be authorized. For now just call ListAIBridgeModels on the authz querier.
|
||||
func (q *querier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, _ rbac.PreparedAuthorized) ([]database.ListAIBridgeInterceptionsRow, error) {
|
||||
// TODO: Delete this function, all ListAIBridgeInterceptions should be authorized. For now just call ListAIBridgeInterceptions on the authz querier.
|
||||
// This cannot be deleted for now because it's included in the
|
||||
// database.Store interface, so dbauthz needs to implement it.
|
||||
return q.ListAIBridgeModels(ctx, arg)
|
||||
return q.ListAIBridgeInterceptions(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams, _ rbac.PreparedAuthorized) (int64, error) {
|
||||
// TODO: Delete this function, all CountAIBridgeInterceptions should be authorized. For now just call CountAIBridgeInterceptions on the authz querier.
|
||||
// This cannot be deleted for now because it's included in the
|
||||
// database.Store interface, so dbauthz needs to implement it.
|
||||
return q.CountAIBridgeInterceptions(ctx, arg)
|
||||
}
|
||||
|
||||
@@ -170,7 +170,6 @@ func TestDBAuthzRecursive(t *testing.T) {
|
||||
Groups: []string{},
|
||||
Scope: rbac.ScopeAll,
|
||||
}
|
||||
preparedAuthorizedType := reflect.TypeOf((*rbac.PreparedAuthorized)(nil)).Elem()
|
||||
for i := 0; i < reflect.TypeOf(q).NumMethod(); i++ {
|
||||
var ins []reflect.Value
|
||||
ctx := dbauthz.As(context.Background(), actor)
|
||||
@@ -178,13 +177,7 @@ func TestDBAuthzRecursive(t *testing.T) {
|
||||
ins = append(ins, reflect.ValueOf(ctx))
|
||||
method := reflect.TypeOf(q).Method(i)
|
||||
for i := 2; i < method.Type.NumIn(); i++ {
|
||||
inType := method.Type.In(i)
|
||||
if inType.Implements(preparedAuthorizedType) {
|
||||
ins = append(ins, reflect.ValueOf(emptyPreparedAuthorized{}))
|
||||
continue
|
||||
}
|
||||
|
||||
ins = append(ins, reflect.New(inType).Elem())
|
||||
ins = append(ins, reflect.New(method.Type.In(i)).Elem())
|
||||
}
|
||||
if method.Name == "InTx" ||
|
||||
method.Name == "Ping" ||
|
||||
@@ -244,8 +237,8 @@ func (s *MethodTestSuite) TestAPIKey() {
|
||||
s.Run("GetAPIKeysByLoginType", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
a := testutil.Fake(s.T(), faker, database.APIKey{LoginType: database.LoginTypePassword})
|
||||
b := testutil.Fake(s.T(), faker, database.APIKey{LoginType: database.LoginTypePassword})
|
||||
dbm.EXPECT().GetAPIKeysByLoginType(gomock.Any(), database.GetAPIKeysByLoginTypeParams{LoginType: database.LoginTypePassword}).Return([]database.APIKey{a, b}, nil).AnyTimes()
|
||||
check.Args(database.GetAPIKeysByLoginTypeParams{LoginType: database.LoginTypePassword}).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
|
||||
dbm.EXPECT().GetAPIKeysByLoginType(gomock.Any(), database.LoginTypePassword).Return([]database.APIKey{a, b}, nil).AnyTimes()
|
||||
check.Args(database.LoginTypePassword).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
|
||||
}))
|
||||
s.Run("GetAPIKeysByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u1 := testutil.Fake(s.T(), faker, database.User{})
|
||||
@@ -371,344 +364,6 @@ func (s *MethodTestSuite) TestConnectionLogs() {
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestChats() {
|
||||
s.Run("AcquireChat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.AcquireChatParams{
|
||||
StartedAt: dbtime.Now(),
|
||||
WorkerID: uuid.New(),
|
||||
}
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().AcquireChat(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("DeleteAllChatQueuedMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteAllChatQueuedMessages(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("DeleteChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionDelete).Returns()
|
||||
}))
|
||||
s.Run("DeleteChatMessagesByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteChatMessagesByChatID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionDelete).Returns()
|
||||
}))
|
||||
s.Run("DeleteChatModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
id := uuid.New()
|
||||
dbm.EXPECT().DeleteChatModelConfigByID(gomock.Any(), id).Return(nil).AnyTimes()
|
||||
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("DeleteChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
id := uuid.New()
|
||||
dbm.EXPECT().DeleteChatProviderByID(gomock.Any(), id).Return(nil).AnyTimes()
|
||||
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("DeleteChatQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
args := database.DeleteChatQueuedMessageParams{ID: 123, ChatID: chat.ID}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteChatQueuedMessage(gomock.Any(), args).Return(nil).AnyTimes()
|
||||
check.Args(args).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("GetChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(chat)
|
||||
}))
|
||||
s.Run("GetChatByIDForUpdate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByIDForUpdate(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(chat)
|
||||
}))
|
||||
s.Run("GetChatDiffStatusByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatDiffStatusByChatID(gomock.Any(), chat.ID).Return(diffStatus, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(diffStatus)
|
||||
}))
|
||||
s.Run("GetChatDiffStatusesByChatIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chatA := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
chatB := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
ids := []uuid.UUID{chatA.ID, chatB.ID}
|
||||
diffStatusA := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chatA.ID})
|
||||
diffStatusB := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chatB.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chatA.ID).Return(chatA, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chatB.ID).Return(chatB, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatDiffStatusesByChatIDs(gomock.Any(), ids).Return([]database.ChatDiffStatus{diffStatusA, diffStatusB}, nil).AnyTimes()
|
||||
check.Args(ids).
|
||||
Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).
|
||||
Returns([]database.ChatDiffStatus{diffStatusA, diffStatusB})
|
||||
}))
|
||||
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatMessageByID(gomock.Any(), msg.ID).Return(msg, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
check.Args(msg.ID).Asserts(chat, policy.ActionRead).Returns(msg)
|
||||
}))
|
||||
s.Run("GetChatMessagesByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), chat.ID).Return(msgs, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(msgs)
|
||||
}))
|
||||
s.Run("GetChatMessagesForPromptByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatMessagesForPromptByChatID(gomock.Any(), chat.ID).Return(msgs, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(msgs)
|
||||
}))
|
||||
s.Run("GetChatModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
dbm.EXPECT().GetChatModelConfigByID(gomock.Any(), config.ID).Return(config, nil).AnyTimes()
|
||||
check.Args(config.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
|
||||
}))
|
||||
s.Run("GetDefaultChatModelConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
dbm.EXPECT().GetDefaultChatModelConfig(gomock.Any()).Return(config, nil).AnyTimes()
|
||||
check.Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
|
||||
}))
|
||||
s.Run("GetChatModelConfigByProviderAndModel", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
args := database.GetChatModelConfigByProviderAndModelParams{
|
||||
Provider: config.Provider,
|
||||
Model: config.Model,
|
||||
}
|
||||
dbm.EXPECT().GetChatModelConfigByProviderAndModel(gomock.Any(), args).Return(config, nil).AnyTimes()
|
||||
check.Args(args).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
|
||||
}))
|
||||
s.Run("GetChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
dbm.EXPECT().GetChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB})
|
||||
}))
|
||||
s.Run("GetChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
dbm.EXPECT().GetChatProviderByID(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes()
|
||||
check.Args(provider.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
|
||||
}))
|
||||
s.Run("GetChatProviderByProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
providerName := "test-provider"
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: providerName})
|
||||
dbm.EXPECT().GetChatProviderByProvider(gomock.Any(), providerName).Return(provider, nil).AnyTimes()
|
||||
check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
|
||||
}))
|
||||
s.Run("GetChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
providerA := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
providerB := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
dbm.EXPECT().GetChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
|
||||
}))
|
||||
s.Run("GetChatsByOwnerID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
c1 := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
c2 := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatsByOwnerID(gomock.Any(), c1.OwnerID).Return([]database.Chat{c1, c2}, nil).AnyTimes()
|
||||
check.Args(c1.OwnerID).Asserts(c1, policy.ActionRead, c2, policy.ActionRead).Returns([]database.Chat{c1, c2})
|
||||
}))
|
||||
s.Run("GetChatQueuedMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
qms := []database.ChatQueuedMessage{testutil.Fake(s.T(), faker, database.ChatQueuedMessage{})}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatQueuedMessages(gomock.Any(), chat.ID).Return(qms, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(qms)
|
||||
}))
|
||||
s.Run("GetEnabledChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
dbm.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB})
|
||||
}))
|
||||
s.Run("GetEnabledChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
providerA := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
providerB := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
dbm.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
|
||||
}))
|
||||
s.Run("ListChatsByRootID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
rootChatID := uuid.New()
|
||||
chatA := testutil.Fake(s.T(), faker, database.Chat{RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true}})
|
||||
chatB := testutil.Fake(s.T(), faker, database.Chat{RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true}})
|
||||
dbm.EXPECT().ListChatsByRootID(gomock.Any(), rootChatID).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
|
||||
check.Args(rootChatID).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
|
||||
}))
|
||||
s.Run("ListChildChatsByParentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
parentChatID := uuid.New()
|
||||
chatA := testutil.Fake(s.T(), faker, database.Chat{ParentChatID: uuid.NullUUID{UUID: parentChatID, Valid: true}})
|
||||
chatB := testutil.Fake(s.T(), faker, database.Chat{ParentChatID: uuid.NullUUID{UUID: parentChatID, Valid: true}})
|
||||
dbm.EXPECT().ListChildChatsByParentID(gomock.Any(), parentChatID).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
|
||||
check.Args(parentChatID).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
|
||||
}))
|
||||
s.Run("GetStaleChats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
threshold := dbtime.Now()
|
||||
chats := []database.Chat{testutil.Fake(s.T(), faker, database.Chat{})}
|
||||
dbm.EXPECT().GetStaleChats(gomock.Any(), threshold).Return(chats, nil).AnyTimes()
|
||||
check.Args(threshold).Asserts(rbac.ResourceChat, policy.ActionRead).Returns(chats)
|
||||
}))
|
||||
s.Run("InsertChat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatParams{})
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{OwnerID: arg.OwnerID})
|
||||
dbm.EXPECT().InsertChat(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionCreate).Returns(chat)
|
||||
}))
|
||||
s.Run("InsertChatMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatMessageParams{ChatID: chat.ID})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().InsertChatMessage(gomock.Any(), arg).Return(msg, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(msg)
|
||||
}))
|
||||
s.Run("InsertChatQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatQueuedMessageParams{ChatID: chat.ID})
|
||||
qm := testutil.Fake(s.T(), faker, database.ChatQueuedMessage{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().InsertChatQueuedMessage(gomock.Any(), arg).Return(qm, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(qm)
|
||||
}))
|
||||
s.Run("InsertChatModelConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.InsertChatModelConfigParams{
|
||||
Provider: "test-provider",
|
||||
Model: "test-model",
|
||||
DisplayName: "Test Model",
|
||||
Enabled: true,
|
||||
}
|
||||
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{Provider: arg.Provider, Model: arg.Model, DisplayName: arg.DisplayName, Enabled: arg.Enabled})
|
||||
dbm.EXPECT().InsertChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
|
||||
}))
|
||||
s.Run("InsertChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.InsertChatProviderParams{
|
||||
Provider: "test-provider",
|
||||
DisplayName: "Test Provider",
|
||||
APIKey: "test-api-key",
|
||||
Enabled: true,
|
||||
}
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: arg.Provider, DisplayName: arg.DisplayName, APIKey: arg.APIKey, Enabled: arg.Enabled})
|
||||
dbm.EXPECT().InsertChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
|
||||
}))
|
||||
s.Run("PopNextQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
qm := testutil.Fake(s.T(), faker, database.ChatQueuedMessage{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().PopNextQueuedMessage(gomock.Any(), chat.ID).Return(qm, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns(qm)
|
||||
}))
|
||||
s.Run("UpdateChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatByIDParams{
|
||||
ID: chat.ID,
|
||||
Title: "Updated title",
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatByID(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
WorkerID: uuid.New(),
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatHeartbeat(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
|
||||
}))
|
||||
s.Run("UpdateChatModelConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
arg := database.UpdateChatModelConfigParams{
|
||||
ID: config.ID,
|
||||
Provider: "updated-provider",
|
||||
Model: "updated-model",
|
||||
DisplayName: "Updated Model",
|
||||
Enabled: true,
|
||||
}
|
||||
dbm.EXPECT().UpdateChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
|
||||
}))
|
||||
s.Run("UpdateChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
arg := database.UpdateChatProviderParams{
|
||||
ID: provider.ID,
|
||||
DisplayName: "Updated Provider",
|
||||
APIKey: "updated-api-key",
|
||||
Enabled: true,
|
||||
}
|
||||
dbm.EXPECT().UpdateChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
|
||||
}))
|
||||
s.Run("UpdateChatStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatStatus(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatWorkspace", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatWorkspaceParams{
|
||||
ID: chat.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
WorkspaceAgentID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
}
|
||||
updatedChat := testutil.Fake(s.T(), faker, database.Chat{ID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatWorkspace(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updatedChat)
|
||||
}))
|
||||
s.Run("UnsetDefaultChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UnsetDefaultChatModelConfigs(gomock.Any()).Return(nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceSystem, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("UpsertChatDiffStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
now := dbtime.Now()
|
||||
arg := database.UpsertChatDiffStatusParams{
|
||||
ChatID: chat.ID,
|
||||
Url: sql.NullString{String: "https://example.com/pr/123", Valid: true},
|
||||
PullRequestState: sql.NullString{String: "open", Valid: true},
|
||||
ChangesRequested: false,
|
||||
Additions: 10,
|
||||
Deletions: 5,
|
||||
ChangedFiles: 2,
|
||||
RefreshedAt: now,
|
||||
StaleAt: now.Add(time.Hour),
|
||||
}
|
||||
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpsertChatDiffStatus(gomock.Any(), arg).Return(diffStatus, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(diffStatus)
|
||||
}))
|
||||
s.Run("UpsertChatDiffStatusReference", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpsertChatDiffStatusReferenceParams{
|
||||
ChatID: chat.ID,
|
||||
Url: sql.NullString{String: "https://example.com/pr/123", Valid: true},
|
||||
GitBranch: "feature/test",
|
||||
GitRemoteOrigin: "origin",
|
||||
StaleAt: dbtime.Now().Add(time.Hour),
|
||||
}
|
||||
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), arg).Return(diffStatus, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(diffStatus)
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestFile() {
|
||||
s.Run("GetFileByHashAndCreator", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
f := testutil.Fake(s.T(), faker, database.File{})
|
||||
@@ -1671,11 +1326,6 @@ func (s *MethodTestSuite) TestTemplate() {
|
||||
dbm.EXPECT().GetTemplateInsightsByTemplate(gomock.Any(), arg).Return([]database.GetTemplateInsightsByTemplateRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceTemplate, policy.ActionViewInsights)
|
||||
}))
|
||||
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetTelemetryTaskEventsParams{}
|
||||
dbm.EXPECT().GetTelemetryTaskEvents(gomock.Any(), arg).Return([]database.GetTelemetryTaskEventsRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceTask.All(), policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetTemplateAppInsights", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetTemplateAppInsightsParams{}
|
||||
dbm.EXPECT().GetTemplateAppInsights(gomock.Any(), arg).Return([]database.GetTemplateAppInsightsRow{}, nil).AnyTimes()
|
||||
@@ -2319,15 +1969,6 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
|
||||
check.Args(build.ID).Asserts(ws, policy.ActionRead).Returns(build)
|
||||
}))
|
||||
s.Run("GetWorkspaceBuildProvisionerStateByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
row := database.GetWorkspaceBuildProvisionerStateByIDRow{
|
||||
ProvisionerState: []byte("state"),
|
||||
TemplateID: uuid.New(),
|
||||
TemplateOrganizationID: uuid.New(),
|
||||
}
|
||||
dbm.EXPECT().GetWorkspaceBuildProvisionerStateByID(gomock.Any(), gomock.Any()).Return(row, nil).AnyTimes()
|
||||
check.Args(uuid.New()).Asserts(row, policy.ActionUpdate).Returns(row)
|
||||
}))
|
||||
s.Run("GetWorkspaceBuildByJobID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ws := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
build := testutil.Fake(s.T(), faker, database.WorkspaceBuild{WorkspaceID: ws.ID})
|
||||
@@ -5105,20 +4746,6 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeModels", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.ListAIBridgeModelsParams{}
|
||||
db.EXPECT().ListAuthorizedAIBridgeModels(gomock.Any(), params, gomock.Any()).Return([]string{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAuthorizedAIBridgeModels", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.ListAIBridgeModelsParams{}
|
||||
db.EXPECT().ListAuthorizedAIBridgeModels(gomock.Any(), params, gomock.Any()).Return([]string{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeTokenUsagesByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ids := []uuid.UUID{{1}}
|
||||
db.EXPECT().ListAIBridgeTokenUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeTokenUsage{}, nil).AnyTimes()
|
||||
|
||||
@@ -67,8 +67,6 @@ type WorkspaceBuildBuilder struct {
|
||||
|
||||
jobError string // Error message for failed jobs
|
||||
jobErrorCode string // Error code for failed jobs
|
||||
|
||||
provisionerState []byte
|
||||
}
|
||||
|
||||
// BuilderOption is a functional option for customizing job timestamps
|
||||
@@ -140,15 +138,6 @@ func (b WorkspaceBuildBuilder) Seed(seed database.WorkspaceBuild) WorkspaceBuild
|
||||
return b
|
||||
}
|
||||
|
||||
// ProvisionerState sets the provisioner state for the workspace build.
|
||||
// This is stored separately from the seed because ProvisionerState is
|
||||
// not part of the WorkspaceBuild view struct.
|
||||
func (b WorkspaceBuildBuilder) ProvisionerState(state []byte) WorkspaceBuildBuilder {
|
||||
//nolint: revive // returns modified struct
|
||||
b.provisionerState = state
|
||||
return b
|
||||
}
|
||||
|
||||
func (b WorkspaceBuildBuilder) Resource(resource ...*sdkproto.Resource) WorkspaceBuildBuilder {
|
||||
//nolint: revive // returns modified struct
|
||||
b.resources = append(b.resources, resource...)
|
||||
@@ -475,14 +464,6 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
|
||||
}
|
||||
|
||||
resp.Build = dbgen.WorkspaceBuild(b.t, b.db, b.seed)
|
||||
if len(b.provisionerState) > 0 {
|
||||
err = b.db.UpdateWorkspaceBuildProvisionerStateByID(ownerCtx, database.UpdateWorkspaceBuildProvisionerStateByIDParams{
|
||||
ID: resp.Build.ID,
|
||||
UpdatedAt: dbtime.Now(),
|
||||
ProvisionerState: b.provisionerState,
|
||||
})
|
||||
require.NoError(b.t, err, "update provisioner state")
|
||||
}
|
||||
b.logger.Debug(context.Background(), "created workspace build",
|
||||
slog.F("build_id", resp.Build.ID),
|
||||
slog.F("workspace_id", resp.Workspace.ID),
|
||||
|
||||
@@ -504,7 +504,7 @@ func WorkspaceBuild(t testing.TB, db database.Store, orig database.WorkspaceBuil
|
||||
Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart),
|
||||
InitiatorID: takeFirst(orig.InitiatorID, uuid.New()),
|
||||
JobID: jobID,
|
||||
ProvisionerState: []byte{},
|
||||
ProvisionerState: takeFirstSlice(orig.ProvisionerState, []byte{}),
|
||||
Deadline: takeFirst(orig.Deadline, dbtime.Now().Add(time.Hour)),
|
||||
MaxDeadline: takeFirst(orig.MaxDeadline, time.Time{}),
|
||||
Reason: takeFirst(orig.Reason, database.BuildReasonInitiator),
|
||||
@@ -1373,8 +1373,6 @@ func OAuth2ProviderAppCode(t testing.TB, db database.Store, seed database.OAuth2
|
||||
ResourceUri: seed.ResourceUri,
|
||||
CodeChallenge: seed.CodeChallenge,
|
||||
CodeChallengeMethod: seed.CodeChallengeMethod,
|
||||
StateHash: seed.StateHash,
|
||||
RedirectUri: seed.RedirectUri,
|
||||
})
|
||||
require.NoError(t, err, "insert oauth2 app code")
|
||||
return code
|
||||
|
||||
@@ -104,14 +104,6 @@ func (m queryMetricsStore) DeleteOrganization(ctx context.Context, id uuid.UUID)
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) AcquireChat(ctx context.Context, arg database.AcquireChatParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.AcquireChat(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("AcquireChat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "AcquireChat").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.AcquireLock(ctx, pgAdvisoryXactLock)
|
||||
@@ -164,7 +156,6 @@ func (m queryMetricsStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context
|
||||
start := time.Now()
|
||||
r0 := m.s.BatchUpdateWorkspaceAgentMetadata(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("BatchUpdateWorkspaceAgentMetadata").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BatchUpdateWorkspaceAgentMetadata").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -320,14 +311,6 @@ func (m queryMetricsStore) DeleteAPIKeysByUserID(ctx context.Context, userID uui
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteAllChatQueuedMessages(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("DeleteAllChatQueuedMessages").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteAllChatQueuedMessages").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteAllTailnetTunnels(ctx, arg)
|
||||
@@ -352,46 +335,6 @@ func (m queryMetricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.C
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatMessagesByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatMessagesByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatMessagesByChatID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatModelConfigByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatModelConfigByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatModelConfigByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatProviderByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatProviderByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatProviderByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatQueuedMessage(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatQueuedMessage").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatQueuedMessage").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteCryptoKey(ctx, arg)
|
||||
@@ -831,7 +774,7 @@ func (m queryMetricsStore) GetAPIKeyByName(ctx context.Context, arg database.Get
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
|
||||
func (m queryMetricsStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAPIKeysByLoginType(ctx, loginType)
|
||||
m.queryLatencies.WithLabelValues("GetAPIKeysByLoginType").Observe(time.Since(start).Seconds())
|
||||
@@ -959,126 +902,6 @@ func (m queryMetricsStore) GetAuthorizationUserRoles(ctx context.Context, userID
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatByID(ctx context.Context, id uuid.UUID) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatByIDForUpdate(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatByIDForUpdate").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatByIDForUpdate").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDiffStatusByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatDiffStatusByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDiffStatusByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIDs []uuid.UUID) ([]database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
|
||||
m.queryLatencies.WithLabelValues("GetChatDiffStatusesByChatIDs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDiffStatusesByChatIDs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessageByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessageByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessageByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessagesByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesForPromptByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessagesForPromptByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesForPromptByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatModelConfigByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatModelConfigByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatModelConfigByProviderAndModel(ctx context.Context, arg database.GetChatModelConfigByProviderAndModelParams) (database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatModelConfigByProviderAndModel(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChatModelConfigByProviderAndModel").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigByProviderAndModel").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatModelConfigs(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatModelConfigs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviderByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatProviderByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviderByProvider(ctx, provider)
|
||||
m.queryLatencies.WithLabelValues("GetChatProviderByProvider").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByProvider").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviders(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatProviders").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviders").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatQueuedMessages(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatQueuedMessages").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatQueuedMessages").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatsByOwnerID(ctx, ownerID)
|
||||
m.queryLatencies.WithLabelValues("GetChatsByOwnerID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsByOwnerID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetConnectionLogsOffset(ctx, arg)
|
||||
@@ -1135,14 +958,6 @@ func (m queryMetricsStore) GetDERPMeshKey(ctx context.Context) (string, error) {
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetDefaultChatModelConfig(ctx context.Context) (database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetDefaultChatModelConfig(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetDefaultChatModelConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetDefaultChatModelConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetDefaultOrganization(ctx context.Context) (database.Organization, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetDefaultOrganization(ctx)
|
||||
@@ -1207,22 +1022,6 @@ func (m queryMetricsStore) GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetEnabledChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetEnabledChatModelConfigs(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetEnabledChatModelConfigs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledChatModelConfigs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetEnabledChatProviders(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetEnabledChatProviders").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledChatProviders").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetExternalAuthLink(ctx, arg)
|
||||
@@ -1919,14 +1718,6 @@ func (m queryMetricsStore) GetRuntimeConfig(ctx context.Context, key string) (st
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetStaleChats(ctx, staleThreshold)
|
||||
m.queryLatencies.WithLabelValues("GetStaleChats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetStaleChats").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]database.TailnetPeer, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetTailnetPeers(ctx, id)
|
||||
@@ -1999,14 +1790,6 @@ func (m queryMetricsStore) GetTelemetryItems(ctx context.Context) ([]database.Te
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetTelemetryTaskEvents(ctx context.Context, createdAfter database.GetTelemetryTaskEventsParams) ([]database.GetTelemetryTaskEventsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetTelemetryTaskEvents(ctx, createdAfter)
|
||||
m.queryLatencies.WithLabelValues("GetTelemetryTaskEvents").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTelemetryTaskEvents").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetTemplateAppInsights(ctx, arg)
|
||||
@@ -2647,14 +2430,6 @@ func (m queryMetricsStore) GetWorkspaceBuildParametersByBuildIDs(ctx context.Con
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuildID)
|
||||
m.queryLatencies.WithLabelValues("GetWorkspaceBuildProvisionerStateByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetWorkspaceBuildProvisionerStateByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetWorkspaceBuildStatsByTemplates(ctx, since)
|
||||
@@ -2919,46 +2694,6 @@ func (m queryMetricsStore) InsertAuditLog(ctx context.Context, arg database.Inse
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChat(ctx context.Context, arg database.InsertChatParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChat(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChat").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatMessage(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatMessage").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatMessage").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatModelConfig(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatModelConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatModelConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatProvider(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatProvider").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatProvider").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatQueuedMessage(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatQueuedMessage").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatQueuedMessage").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertCryptoKey(ctx, arg)
|
||||
@@ -3463,14 +3198,6 @@ func (m queryMetricsStore) ListAIBridgeInterceptionsTelemetrySummaries(ctx conte
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams) ([]string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeModels(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("ListAIBridgeModels").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAIBridgeModels").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIds)
|
||||
@@ -3495,22 +3222,6 @@ func (m queryMetricsStore) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListChatsByRootID(ctx, rootChatID)
|
||||
m.queryLatencies.WithLabelValues("ListChatsByRootID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatsByRootID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListChildChatsByParentID(ctx, parentChatID)
|
||||
m.queryLatencies.WithLabelValues("ListChildChatsByParentID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChildChatsByParentID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListProvisionerKeysByOrganization(ctx, organizationID)
|
||||
@@ -3591,14 +3302,6 @@ func (m queryMetricsStore) PaginatedOrganizationMembers(ctx context.Context, arg
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (database.ChatQueuedMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.PopNextQueuedMessage(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("PopNextQueuedMessage").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "PopNextQueuedMessage").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx, templateID)
|
||||
@@ -3671,14 +3374,6 @@ func (m queryMetricsStore) UnfavoriteWorkspace(ctx context.Context, id uuid.UUID
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UnsetDefaultChatModelConfigs(ctx context.Context) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UnsetDefaultChatModelConfigs(ctx)
|
||||
m.queryLatencies.WithLabelValues("UnsetDefaultChatModelConfigs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UnsetDefaultChatModelConfigs").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateAIBridgeInterceptionEnded(ctx, arg)
|
||||
@@ -3695,54 +3390,6 @@ func (m queryMetricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.Up
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatByID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatHeartbeat(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatHeartbeat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeat").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatModelConfig(ctx context.Context, arg database.UpdateChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatModelConfig(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatModelConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatModelConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatProvider(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatProvider").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatProvider").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatStatus(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatStatus").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatStatus").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatWorkspace(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatWorkspace").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatWorkspace").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateCryptoKeyDeletesAt(ctx, arg)
|
||||
@@ -4454,22 +4101,6 @@ func (m queryMetricsStore) UpsertBoundaryUsageStats(ctx context.Context, arg dat
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatDiffStatus(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatDiffStatus").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDiffStatus").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDiffStatusReference(ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatDiffStatusReference(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatDiffStatusReference").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDiffStatusReference").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertConnectionLog(ctx, arg)
|
||||
@@ -4781,11 +4412,3 @@ func (m queryMetricsStore) CountAuthorizedAIBridgeInterceptions(ctx context.Cont
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountAuthorizedAIBridgeInterceptions").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAuthorizedAIBridgeModels(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("ListAuthorizedAIBridgeModels").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeModels").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user