Compare commits
48 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 50fb554051 | |||
| 59571e8288 | |||
| 226f826dfa | |||
| 14c42fbb20 | |||
| 69a3dc1577 | |||
| 88cc140af5 | |||
| dbe4d2be5a | |||
| 953bc815b3 | |||
| 4284d0107a | |||
| c2ab1395b6 | |||
| aaf2381b21 | |||
| 96c32b27bd | |||
| 12d4d9f964 | |||
| 9ace84c1c4 | |||
| 79b27445af | |||
| 31525da014 | |||
| 8b7c5d23af | |||
| b12966c0db | |||
| 90401524b6 | |||
| ab36d88627 | |||
| 8bfb79abc9 | |||
| 6912180aa6 | |||
| fae93db5c3 | |||
| 7385c6720c | |||
| 2accb72e09 | |||
| 018dd9e650 | |||
| 446c84b708 | |||
| 90d8e49e87 | |||
| 07f9fcd83e | |||
| ebde198a5c | |||
| 975c9b8d93 | |||
| 14cb11aedc | |||
| 664b74b627 | |||
| cdccaeadb9 | |||
| ba04ddae6f | |||
| 75b6b438e4 | |||
| bbe24e097b | |||
| 03422cbf71 | |||
| 5403231011 | |||
| 2af39a5878 | |||
| 736ee901a3 | |||
| 12f2c73648 | |||
| eaee2cded6 | |||
| 83672a0a3e | |||
| be9dc1c1a9 | |||
| ba677af7df | |||
| c95452aa46 | |||
| c51f50301f |
@@ -852,7 +852,7 @@ enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
|
||||
# -C sets the directory for the go run command
|
||||
go run -C ./scripts/apitypings main.go > $@
|
||||
(cd site/ && pnpm exec biome format --write src/api/typesGenerated.ts)
|
||||
./scripts/biome_format.sh src/api/typesGenerated.ts
|
||||
touch "$@"
|
||||
|
||||
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
|
||||
@@ -861,7 +861,7 @@ site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/prot
|
||||
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*)
|
||||
go run ./scripts/gensite/ -icons "$@"
|
||||
(cd site/ && pnpm exec biome format --write src/theme/icons.json)
|
||||
./scripts/biome_format.sh src/theme/icons.json
|
||||
touch "$@"
|
||||
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates)
|
||||
@@ -899,12 +899,12 @@ codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scope
|
||||
|
||||
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
|
||||
go run scripts/typegen/main.go rbac typescript > "$@"
|
||||
(cd site/ && pnpm exec biome format --write src/api/rbacresourcesGenerated.ts)
|
||||
./scripts/biome_format.sh src/api/rbacresourcesGenerated.ts
|
||||
touch "$@"
|
||||
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go
|
||||
go run scripts/typegen/main.go countries > "$@"
|
||||
(cd site/ && pnpm exec biome format --write src/api/countriesGenerated.ts)
|
||||
./scripts/biome_format.sh src/api/countriesGenerated.ts
|
||||
touch "$@"
|
||||
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics
|
||||
@@ -944,11 +944,11 @@ coderd/apidoc/.gen: \
|
||||
touch "$@"
|
||||
|
||||
docs/manifest.json: site/node_modules/.installed coderd/apidoc/.gen docs/reference/cli/index.md
|
||||
(cd site/ && pnpm exec biome format --write ../docs/manifest.json)
|
||||
./scripts/biome_format.sh ../docs/manifest.json
|
||||
touch "$@"
|
||||
|
||||
coderd/apidoc/swagger.json: site/node_modules/.installed coderd/apidoc/.gen
|
||||
(cd site/ && pnpm exec biome format --write ../coderd/apidoc/swagger.json)
|
||||
./scripts/biome_format.sh ../coderd/apidoc/swagger.json
|
||||
touch "$@"
|
||||
|
||||
update-golden-files:
|
||||
@@ -993,11 +993,19 @@ enterprise/tailnet/testdata/.gen-golden: $(wildcard enterprise/tailnet/testdata/
|
||||
touch "$@"
|
||||
|
||||
helm/coder/tests/testdata/.gen-golden: $(wildcard helm/coder/tests/testdata/*.yaml) $(wildcard helm/coder/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/coder/tests/*_test.go)
|
||||
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
|
||||
if command -v helm >/dev/null 2>&1; then
|
||||
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
|
||||
else
|
||||
echo "WARNING: helm not found; skipping helm/coder golden generation" >&2
|
||||
fi
|
||||
touch "$@"
|
||||
|
||||
helm/provisioner/tests/testdata/.gen-golden: $(wildcard helm/provisioner/tests/testdata/*.yaml) $(wildcard helm/provisioner/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/provisioner/tests/*_test.go)
|
||||
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
|
||||
if command -v helm >/dev/null 2>&1; then
|
||||
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
|
||||
else
|
||||
echo "WARNING: helm not found; skipping helm/provisioner golden generation" >&2
|
||||
fi
|
||||
touch "$@"
|
||||
|
||||
coderd/.gen-golden: $(wildcard coderd/testdata/*/*.golden) $(GO_SRC_FILES) $(wildcard coderd/*_test.go)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
|
||||
// protoc-gen-go-drpc version: v0.0.34
|
||||
// protoc-gen-go-drpc version: (devel)
|
||||
// source: agent/agentsocket/proto/agentsocket.proto
|
||||
|
||||
package proto
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
|
||||
// protoc-gen-go-drpc version: v0.0.34
|
||||
// protoc-gen-go-drpc version: (devel)
|
||||
// source: agent/proto/agent.proto
|
||||
|
||||
package proto
|
||||
|
||||
+101
-3
@@ -4,6 +4,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
@@ -16,6 +20,41 @@ import (
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
const (
|
||||
chatAgentEnvVar = "CODER_CHAT_AGENT"
|
||||
)
|
||||
|
||||
type gitAuthRequiredMarker struct {
|
||||
ProviderID string `json:"provider_id"`
|
||||
ProviderType string `json:"provider_type,omitempty"`
|
||||
ProviderDisplayName string `json:"provider_display_name,omitempty"`
|
||||
AuthenticateURL string `json:"authenticate_url"`
|
||||
Host string `json:"host,omitempty"`
|
||||
}
|
||||
|
||||
// detectGitRef attempts to resolve the current git branch and remote
|
||||
// origin URL from the given working directory. These are sent to the
|
||||
// control plane so it can look up PR/diff status via the GitHub API
|
||||
// without SSHing into the workspace. Failures are silently ignored
|
||||
// since this is best-effort.
|
||||
func detectGitRef(workingDirectory string) (branch string, remoteOrigin string) {
|
||||
run := func(args ...string) string {
|
||||
//nolint:gosec
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
if workingDirectory != "" {
|
||||
cmd.Dir = workingDirectory
|
||||
}
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
branch = run("git", "rev-parse", "--abbrev-ref", "HEAD")
|
||||
remoteOrigin = run("git", "config", "--get", "remote.origin.url")
|
||||
return branch, remoteOrigin
|
||||
}
|
||||
|
||||
// gitAskpass is used by the Coder agent to automatically authenticate
|
||||
// with Git providers based on a hostname.
|
||||
func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
@@ -38,8 +77,21 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
|
||||
workingDirectory, err := os.Getwd()
|
||||
if err != nil {
|
||||
workingDirectory = ""
|
||||
}
|
||||
|
||||
// Detect the current git branch and remote origin so
|
||||
// the control plane can resolve diffs without needing
|
||||
// to SSH back into the workspace.
|
||||
gitBranch, gitRemoteOrigin := detectGitRef(workingDirectory)
|
||||
|
||||
token, err := client.ExternalAuth(ctx, agentsdk.ExternalAuthRequest{
|
||||
Match: host,
|
||||
Match: host,
|
||||
Workdir: workingDirectory,
|
||||
GitBranch: gitBranch,
|
||||
GitRemoteOrigin: gitRemoteOrigin,
|
||||
})
|
||||
if err != nil {
|
||||
var apiError *codersdk.Error
|
||||
@@ -58,6 +110,12 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
return xerrors.Errorf("get git token: %w", err)
|
||||
}
|
||||
if token.URL != "" {
|
||||
// This is to help the agent authenticate with Git.
|
||||
if inv.Environ.Get("CODER_CHAT_AGENT") == "true" {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, `You must use the "wait_for_external_auth" tool to authenticate with Git.\n\nThe URL is: %s\n`, token.URL)
|
||||
return cliui.ErrCanceled
|
||||
}
|
||||
|
||||
if err := openURL(inv, token.URL); err == nil {
|
||||
cliui.Infof(inv.Stderr, "Your browser has been opened to authenticate with Git:\n%s", token.URL)
|
||||
} else {
|
||||
@@ -66,8 +124,9 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
|
||||
for r := retry.New(250*time.Millisecond, 10*time.Second); r.Wait(ctx); {
|
||||
token, err = client.ExternalAuth(ctx, agentsdk.ExternalAuthRequest{
|
||||
Match: host,
|
||||
Listen: true,
|
||||
Match: host,
|
||||
Listen: true,
|
||||
Workdir: workingDirectory,
|
||||
})
|
||||
if err != nil {
|
||||
continue
|
||||
@@ -93,3 +152,42 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func providerIDFromAuthenticateURL(rawURL string) string {
|
||||
parsed, err := url.Parse(strings.TrimSpace(rawURL))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
path := strings.Trim(parsed.Path, "/")
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Split(path, "/")
|
||||
for i := 0; i < len(parts)-1; i++ {
|
||||
if parts[i] == "external-auth" {
|
||||
return strings.TrimSpace(parts[i+1])
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func providerDisplayName(providerType string) string {
|
||||
switch strings.TrimSpace(providerType) {
|
||||
case codersdk.EnhancedExternalAuthProviderGitHub.String():
|
||||
return "GitHub"
|
||||
case codersdk.EnhancedExternalAuthProviderGitLab.String():
|
||||
return "GitLab"
|
||||
case codersdk.EnhancedExternalAuthProviderGitea.String():
|
||||
return "Gitea"
|
||||
case codersdk.EnhancedExternalAuthProviderAzureDevops.String():
|
||||
return "Azure DevOps"
|
||||
case codersdk.EnhancedExternalAuthProviderAzureDevopsEntra.String():
|
||||
return "Azure DevOps Entra"
|
||||
case codersdk.EnhancedExternalAuthProviderBitBucketCloud.String():
|
||||
return "Bitbucket Cloud"
|
||||
case codersdk.EnhancedExternalAuthProviderBitBucketServer.String():
|
||||
return "Bitbucket Server"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
@@ -108,4 +111,58 @@ func TestGitAskpass(t *testing.T) {
|
||||
})
|
||||
stdout.ExpectMatch("username")
|
||||
})
|
||||
|
||||
t.Run("ChatAgentAuthRequired", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
listenCalls := atomic.Int64{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Has("listen") {
|
||||
listenCalls.Add(1)
|
||||
}
|
||||
httpapi.Write(context.Background(), w, http.StatusOK, agentsdk.ExternalAuthResponse{
|
||||
URL: "https://coder.example.com/external-auth/github",
|
||||
Type: codersdk.EnhancedExternalAuthProviderGitHub.String(),
|
||||
})
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
inv, _ := clitest.New(
|
||||
t,
|
||||
"--agent-url",
|
||||
srv.URL,
|
||||
"--no-open",
|
||||
"Username for 'https://github.com':",
|
||||
)
|
||||
inv.Environ.Set("GIT_PREFIX", "/")
|
||||
inv.Environ.Set("CODER_AGENT_TOKEN", "fake-token")
|
||||
inv.Environ.Set("CODER_CHAT_AGENT", "true")
|
||||
|
||||
var stderr bytes.Buffer
|
||||
inv.Stderr = &stderr
|
||||
|
||||
err := inv.Run()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "exit code")
|
||||
require.Zero(t, listenCalls.Load())
|
||||
|
||||
output := stderr.String()
|
||||
require.Contains(t, output, "CODER_GITAUTH_REQUIRED:")
|
||||
require.NotContains(t, output, "Open the following URL to authenticate")
|
||||
require.NotContains(t, output, "Your browser has been opened")
|
||||
|
||||
_, markerRaw, found := strings.Cut(output, "CODER_GITAUTH_REQUIRED:")
|
||||
require.True(t, found)
|
||||
var marker struct {
|
||||
ProviderID string `json:"provider_id"`
|
||||
ProviderType string `json:"provider_type"`
|
||||
ProviderDisplayName string `json:"provider_display_name"`
|
||||
AuthenticateURL string `json:"authenticate_url"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimSpace(markerRaw)), &marker))
|
||||
require.Equal(t, "github", marker.ProviderID)
|
||||
require.Equal(t, codersdk.EnhancedExternalAuthProviderGitHub.String(), marker.ProviderType)
|
||||
require.Equal(t, "GitHub", marker.ProviderDisplayName)
|
||||
require.Equal(t, "https://coder.example.com/external-auth/github", marker.AuthenticateURL)
|
||||
})
|
||||
}
|
||||
|
||||
+111
-43
@@ -607,28 +607,8 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
}
|
||||
}
|
||||
|
||||
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read external auth providers from env: %w", err)
|
||||
}
|
||||
|
||||
promRegistry := prometheus.NewRegistry()
|
||||
oauthInstrument := promoauth.NewFactory(promRegistry)
|
||||
vals.ExternalAuthConfigs.Value = append(vals.ExternalAuthConfigs.Value, extAuthEnv...)
|
||||
externalAuthConfigs, err := externalauth.ConvertConfig(
|
||||
oauthInstrument,
|
||||
vals.ExternalAuthConfigs.Value,
|
||||
vals.AccessURL.Value(),
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("convert external auth config: %w", err)
|
||||
}
|
||||
for _, c := range externalAuthConfigs {
|
||||
logger.Debug(
|
||||
ctx, "loaded external auth config",
|
||||
slog.F("id", c.ID),
|
||||
)
|
||||
}
|
||||
|
||||
realIPConfig, err := httpmw.ParseRealIPConfig(vals.ProxyTrustedHeaders, vals.ProxyTrustedOrigins)
|
||||
if err != nil {
|
||||
@@ -659,7 +639,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
Pubsub: nil,
|
||||
CacheDir: cacheDir,
|
||||
GoogleTokenValidator: googleTokenValidator,
|
||||
ExternalAuthConfigs: externalAuthConfigs,
|
||||
ExternalAuthConfigs: nil,
|
||||
RealIPConfig: realIPConfig,
|
||||
SSHKeygenAlgorithm: sshKeygenAlgorithm,
|
||||
TracerProvider: tracerProvider,
|
||||
@@ -819,6 +799,40 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
return xerrors.Errorf("set deployment id: %w", err)
|
||||
}
|
||||
|
||||
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read external auth providers from env: %w", err)
|
||||
}
|
||||
mergedExternalAuthProviders := append([]codersdk.ExternalAuthConfig{}, vals.ExternalAuthConfigs.Value...)
|
||||
mergedExternalAuthProviders = append(mergedExternalAuthProviders, extAuthEnv...)
|
||||
vals.ExternalAuthConfigs.Value = mergedExternalAuthProviders
|
||||
|
||||
mergedExternalAuthProviders, err = maybeAppendDefaultGithubExternalAuthProvider(
|
||||
ctx,
|
||||
options.Logger,
|
||||
options.Database,
|
||||
vals,
|
||||
mergedExternalAuthProviders,
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("maybe append default github external auth provider: %w", err)
|
||||
}
|
||||
|
||||
options.ExternalAuthConfigs, err = externalauth.ConvertConfig(
|
||||
oauthInstrument,
|
||||
mergedExternalAuthProviders,
|
||||
vals.AccessURL.Value(),
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("convert external auth config: %w", err)
|
||||
}
|
||||
for _, c := range options.ExternalAuthConfigs {
|
||||
logger.Debug(
|
||||
ctx, "loaded external auth config",
|
||||
slog.F("id", c.ID),
|
||||
)
|
||||
}
|
||||
|
||||
// Manage push notifications.
|
||||
experiments := coderd.ReadExperiments(options.Logger, options.DeploymentValues.Experiments.Value())
|
||||
if experiments.Enabled(codersdk.ExperimentWebPush) {
|
||||
@@ -1910,6 +1924,79 @@ type githubOAuth2ConfigParams struct {
|
||||
enterpriseBaseURL string
|
||||
}
|
||||
|
||||
func isDeploymentEligibleForGithubDefaultProvider(ctx context.Context, db database.Store) (bool, error) {
|
||||
// We want to enable the default provider only for new deployments, and avoid
|
||||
// enabling it if a deployment was upgraded from an older version.
|
||||
// nolint:gocritic // Requires system privileges
|
||||
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return false, xerrors.Errorf("get github default eligible: %w", err)
|
||||
}
|
||||
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
|
||||
|
||||
if defaultEligibleNotSet {
|
||||
// nolint:gocritic // User count requires system privileges
|
||||
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
|
||||
if err != nil {
|
||||
return false, xerrors.Errorf("get user count: %w", err)
|
||||
}
|
||||
// We check if a deployment is new by checking if it has any users.
|
||||
defaultEligible = userCount == 0
|
||||
// nolint:gocritic // Requires system privileges
|
||||
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
|
||||
return false, xerrors.Errorf("upsert github default eligible: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return defaultEligible, nil
|
||||
}
|
||||
|
||||
func maybeAppendDefaultGithubExternalAuthProvider(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
db database.Store,
|
||||
vals *codersdk.DeploymentValues,
|
||||
mergedExplicitProviders []codersdk.ExternalAuthConfig,
|
||||
) ([]codersdk.ExternalAuthConfig, error) {
|
||||
if !vals.ExternalAuthGithubDefaultProviderEnable.Value() {
|
||||
logger.Info(ctx, "default github external auth provider suppressed",
|
||||
slog.F("reason", "disabled by configuration"),
|
||||
slog.F("flag", "external-auth-github-default-provider-enable"),
|
||||
)
|
||||
return mergedExplicitProviders, nil
|
||||
}
|
||||
|
||||
if len(mergedExplicitProviders) > 0 {
|
||||
logger.Info(ctx, "default github external auth provider suppressed",
|
||||
slog.F("reason", "explicit external auth providers configured"),
|
||||
slog.F("provider_count", len(mergedExplicitProviders)),
|
||||
)
|
||||
return mergedExplicitProviders, nil
|
||||
}
|
||||
|
||||
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !defaultEligible {
|
||||
logger.Info(ctx, "default github external auth provider suppressed",
|
||||
slog.F("reason", "deployment is not eligible"),
|
||||
)
|
||||
return mergedExplicitProviders, nil
|
||||
}
|
||||
|
||||
logger.Info(ctx, "injecting default github external auth provider",
|
||||
slog.F("type", codersdk.EnhancedExternalAuthProviderGitHub.String()),
|
||||
slog.F("client_id", GithubOAuth2DefaultProviderClientID),
|
||||
slog.F("device_flow", GithubOAuth2DefaultProviderDeviceFlow),
|
||||
)
|
||||
return append(mergedExplicitProviders, codersdk.ExternalAuthConfig{
|
||||
Type: codersdk.EnhancedExternalAuthProviderGitHub.String(),
|
||||
ClientID: GithubOAuth2DefaultProviderClientID,
|
||||
DeviceFlow: GithubOAuth2DefaultProviderDeviceFlow,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *codersdk.DeploymentValues) (*githubOAuth2ConfigParams, error) {
|
||||
params := githubOAuth2ConfigParams{
|
||||
accessURL: vals.AccessURL.Value(),
|
||||
@@ -1934,28 +2021,9 @@ func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *c
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
// Check if the deployment is eligible for the default GitHub OAuth2 provider.
|
||||
// We want to enable it only for new deployments, and avoid enabling it
|
||||
// if a deployment was upgraded from an older version.
|
||||
// nolint:gocritic // Requires system privileges
|
||||
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get github default eligible: %w", err)
|
||||
}
|
||||
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
|
||||
|
||||
if defaultEligibleNotSet {
|
||||
// nolint:gocritic // User count requires system privileges
|
||||
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get user count: %w", err)
|
||||
}
|
||||
// We check if a deployment is new by checking if it has any users.
|
||||
defaultEligible = userCount == 0
|
||||
// nolint:gocritic // Requires system privileges
|
||||
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
|
||||
return nil, xerrors.Errorf("upsert github default eligible: %w", err)
|
||||
}
|
||||
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !defaultEligible {
|
||||
|
||||
@@ -53,6 +53,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/migrations"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
"github.com/coder/coder/v2/coderd/userpassword"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/cryptorand"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
@@ -1793,6 +1794,152 @@ func TestServer(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest // This test sets environment variables.
|
||||
func TestServer_ExternalAuthGitHubDefaultProvider(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
args []string
|
||||
env map[string]string
|
||||
createUserPreStart bool
|
||||
expectedProviders []string
|
||||
}
|
||||
|
||||
run := func(t *testing.T, tc testCase) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
unsetPrefixedEnv := func(prefix string) {
|
||||
t.Helper()
|
||||
for _, envVar := range os.Environ() {
|
||||
key, _, found := strings.Cut(envVar, "=")
|
||||
if !found || !strings.HasPrefix(key, prefix) {
|
||||
continue
|
||||
}
|
||||
value, had := os.LookupEnv(key)
|
||||
require.True(t, had)
|
||||
require.NoError(t, os.Unsetenv(key))
|
||||
keyCopy := key
|
||||
valueCopy := value
|
||||
t.Cleanup(func() {
|
||||
_ = os.Setenv(keyCopy, valueCopy)
|
||||
})
|
||||
}
|
||||
}
|
||||
unsetPrefixedEnv("CODER_EXTERNAL_AUTH_")
|
||||
unsetPrefixedEnv("CODER_GITAUTH_")
|
||||
|
||||
dbURL, err := dbtestutil.Open(t)
|
||||
require.NoError(t, err)
|
||||
db, _ := dbtestutil.NewDB(t, dbtestutil.WithURL(dbURL))
|
||||
|
||||
const (
|
||||
existingUserEmail = "existing-user@coder.com"
|
||||
existingUserUsername = "existing-user"
|
||||
existingUserPassword = "SomeSecurePassword!"
|
||||
)
|
||||
if tc.createUserPreStart {
|
||||
hashedPassword, err := userpassword.Hash(existingUserPassword)
|
||||
require.NoError(t, err)
|
||||
_ = dbgen.User(t, db, database.User{
|
||||
Email: existingUserEmail,
|
||||
Username: existingUserUsername,
|
||||
HashedPassword: []byte(hashedPassword),
|
||||
})
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"server",
|
||||
"--postgres-url", dbURL,
|
||||
"--http-address", ":0",
|
||||
"--access-url", "https://example.com",
|
||||
}
|
||||
args = append(args, tc.args...)
|
||||
|
||||
inv, cfg := clitest.New(t, args...)
|
||||
for key, value := range tc.env {
|
||||
t.Setenv(key, value)
|
||||
}
|
||||
clitest.Start(t, inv)
|
||||
|
||||
accessURL := waitAccessURL(t, cfg)
|
||||
client := codersdk.New(accessURL)
|
||||
|
||||
if tc.createUserPreStart {
|
||||
loginResp, err := client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{
|
||||
Email: existingUserEmail,
|
||||
Password: existingUserPassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
client.SetSessionToken(loginResp.SessionToken)
|
||||
} else {
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
}
|
||||
|
||||
externalAuthResp, err := client.ListExternalAuths(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
gotProviders := map[string]codersdk.ExternalAuthLinkProvider{}
|
||||
for _, provider := range externalAuthResp.Providers {
|
||||
gotProviders[provider.ID] = provider
|
||||
}
|
||||
require.Len(t, gotProviders, len(tc.expectedProviders))
|
||||
|
||||
for _, providerID := range tc.expectedProviders {
|
||||
provider, ok := gotProviders[providerID]
|
||||
require.Truef(t, ok, "expected provider %q to be configured", providerID)
|
||||
if providerID == codersdk.EnhancedExternalAuthProviderGitHub.String() {
|
||||
require.Equal(t, codersdk.EnhancedExternalAuthProviderGitHub.String(), provider.Type)
|
||||
require.True(t, provider.Device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range []testCase{
|
||||
{
|
||||
name: "NewDeployment_NoExplicitProviders_InjectsDefaultGithub",
|
||||
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitHub.String()},
|
||||
},
|
||||
{
|
||||
name: "ExistingDeployment_DoesNotInjectDefaultGithub",
|
||||
createUserPreStart: true,
|
||||
expectedProviders: nil,
|
||||
},
|
||||
{
|
||||
name: "DefaultProviderDisabled_DoesNotInjectDefaultGithub",
|
||||
args: []string{
|
||||
"--external-auth-github-default-provider-enable=false",
|
||||
},
|
||||
expectedProviders: nil,
|
||||
},
|
||||
{
|
||||
name: "ExplicitProviderViaConfig_DoesNotInjectDefaultGithub",
|
||||
args: []string{
|
||||
`--external-auth-providers=[{"type":"gitlab","client_id":"config-client-id"}]`,
|
||||
},
|
||||
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
|
||||
},
|
||||
{
|
||||
name: "ExplicitProviderViaEnv_DoesNotInjectDefaultGithub",
|
||||
env: map[string]string{
|
||||
"CODER_EXTERNAL_AUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
|
||||
"CODER_EXTERNAL_AUTH_0_CLIENT_ID": "env-client-id",
|
||||
},
|
||||
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
|
||||
},
|
||||
{
|
||||
name: "ExplicitProviderViaLegacyEnv_DoesNotInjectDefaultGithub",
|
||||
env: map[string]string{
|
||||
"CODER_GITAUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
|
||||
"CODER_GITAUTH_0_CLIENT_ID": "legacy-env-client-id",
|
||||
},
|
||||
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
run(t, tc)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest // This test sets environment variables.
|
||||
func TestServer_Logging_NoParallel(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
+90
@@ -63,6 +63,9 @@ OPTIONS:
|
||||
Separate multiple experiments with commas, or enter '*' to opt-in to
|
||||
all available experiments.
|
||||
|
||||
--external-auth-github-default-provider-enable bool, $CODER_EXTERNAL_AUTH_GITHUB_DEFAULT_PROVIDER_ENABLE (default: true)
|
||||
Enable the default GitHub external auth provider managed by Coder.
|
||||
|
||||
--postgres-auth password|awsiamrds, $CODER_PG_AUTH (default: password)
|
||||
Type of auth to use when connecting to postgres. For AWS RDS, using
|
||||
IAM authentication (awsiamrds) is recommended.
|
||||
@@ -172,6 +175,93 @@ AI BRIDGE OPTIONS:
|
||||
Emit structured logs for AI Bridge interception records. Use this for
|
||||
exporting these records to external SIEM or observability systems.
|
||||
|
||||
--chat-anthropic-models-url string, $CODER_CHAT_ANTHROPIC_MODELS_URL
|
||||
Override URL used to list Anthropic models for the chat model catalog.
|
||||
|
||||
--chat-models-allowlist string, $CODER_CHAT_MODELS_ALLOWLIST
|
||||
Comma-separated allowlist of models for the chat model catalog.
|
||||
|
||||
--chat-models-denylist string, $CODER_CHAT_MODELS_DENYLIST
|
||||
Comma-separated denylist of models for the chat model catalog.
|
||||
|
||||
--chat-openai-models-url string, $CODER_CHAT_OPENAI_MODELS_URL
|
||||
Override URL used to list OpenAI models for the chat model catalog.
|
||||
|
||||
--chat-system-prompt string, $CODER_CHAT_SYSTEM_PROMPT (default: You are the Coder agent — an interactive chat tool that helps users with software-engineering tasks inside of the Coder product.
|
||||
Use the instructions below and the tools available to you to assist User.
|
||||
|
||||
IMPORTANT — obey every rule in this prompt before anything else.
|
||||
Do EXACTLY what the User asked, never more, never less.
|
||||
|
||||
<behavior>
|
||||
You MUST execute AS MANY TOOLS to help the user accomplish their task.
|
||||
You are COMFORTABLE with vague tasks - using your tools to collect the most relevant answer possible.
|
||||
You ALWAYS use GitHub tools for ANY query related to source code.
|
||||
If a user asks how something works, no matter how vague, you MUST use your tools to collect the most relevant answer possible.
|
||||
DO NOT ask the user for clarification - just use your tools.
|
||||
</behavior>
|
||||
|
||||
<personality>
|
||||
Analytical — You break problems into measurable steps, relying on tool output and data rather than intuition.
|
||||
Organized — You structure every interaction with clear tags, TODO lists, and section boundaries.
|
||||
Precision-Oriented — You insist on exact formatting, package-manager choice, and rule adherence.
|
||||
Efficiency-Focused — You minimize chatter, run tasks in parallel, and favor small, complete answers.
|
||||
Clarity-Seeking — You ask for missing details instead of guessing, avoiding any ambiguity.
|
||||
</personality>
|
||||
|
||||
<communication>
|
||||
Be concise, direct, and to the point.
|
||||
NO emojis unless the User explicitly asks for them.
|
||||
If a task appears incomplete or ambiguous, **pause and ask the User** rather than guessing or marking "done".
|
||||
Prefer accuracy over reassurance; confirm facts with tool calls instead of assuming the User is right.
|
||||
If you face an architectural, tooling, or package-manager choice, **ask the User's preference first**.
|
||||
Default to the project's existing package manager / tooling; never substitute without confirmation.
|
||||
You MUST avoid text before/after your response, such as "The answer is" or "Short answer:", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
|
||||
Mimic the style of the User's messages.
|
||||
Do not remind the User you are happy to help.
|
||||
Do not inherently assume the User is correct; they may be making assumptions.
|
||||
If you are not confident in your answer, DO NOT provide an answer. Use your tools to collect more information, or ask the User for help.
|
||||
Do not act with sycophantic flattery or over-the-top enthusiasm.
|
||||
|
||||
Here are examples to demonstrate appropriate communication style and level of verbosity:
|
||||
|
||||
<example>
|
||||
user: find me a good issue to work on
|
||||
assistant: Issue [#1234](https://example) indicates a bug in the frontend, which you've contributed to in the past.
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: work on this issue <url>
|
||||
...assistant does work...
|
||||
assistant: I've put up this pull request: https://github.com/example/example/pull/1824. Please let me know your thoughts!
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: what is 2+2?
|
||||
assistant: 4
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: how does X work in <popular-repository-name>?
|
||||
assistant: Let me take a look at the code...
|
||||
[tool calls to investigate the repository]
|
||||
</example>
|
||||
</communication>
|
||||
|
||||
<collaboration>
|
||||
When a user asks for help with a task or there is ambiguity on the objective, always start by asking clarifying questions to understand:
|
||||
- What specific aspect they want to focus on
|
||||
- Their goals and vision for the changes
|
||||
- Their preferences for approach or style
|
||||
- What problems they're trying to solve
|
||||
|
||||
Don't assume what needs to be done - collaborate to define the scope together.
|
||||
</collaboration>)
|
||||
Default system prompt inserted into new chats.
|
||||
|
||||
--chat-title-generation-prompt string, $CODER_CHAT_TITLE_GENERATION_PROMPT (default: Generate a concise title (max 8 words) for the user's first message. Return plain text only, with no surrounding quotes.)
|
||||
Prompt used to generate chat titles from the first user message.
|
||||
|
||||
AI BRIDGE PROXY OPTIONS:
|
||||
--aibridge-proxy-cert-file string, $CODER_AIBRIDGE_PROXY_CERT_FILE
|
||||
Path to the CA certificate file for AI Bridge Proxy.
|
||||
|
||||
+183
@@ -553,6 +553,9 @@ supportLinks: []
|
||||
# External Authentication providers.
|
||||
# (default: <unset>, type: struct[[]codersdk.ExternalAuthConfig])
|
||||
externalAuthProviders: []
|
||||
# Enable the default GitHub external auth provider managed by Coder.
|
||||
# (default: true, type: bool)
|
||||
externalAuthGithubDefaultProviderEnable: true
|
||||
# Hostname of HTTPS server that runs https://github.com/coder/wgtunnel. By
|
||||
# default, this will pick the best available wgtunnel server hosted by Coder. e.g.
|
||||
# "tunnel.example.com".
|
||||
@@ -739,15 +742,195 @@ workspace_prebuilds:
|
||||
# (default: 3, type: int)
|
||||
failure_hard_limit: 3
|
||||
aibridge:
|
||||
# Default system prompt inserted into new chats.
|
||||
# (default: You are the Coder agent — an interactive chat tool that helps users
|
||||
# with software-engineering tasks inside of the Coder product.
|
||||
# Use the instructions below and the tools available to you to assist User.
|
||||
|
||||
# IMPORTANT — obey every rule in this prompt before anything else.
|
||||
# Do EXACTLY what the User asked, never more, never less.
|
||||
|
||||
# <behavior>
|
||||
# You MUST execute AS MANY TOOLS to help the user accomplish their task.
|
||||
# You are COMFORTABLE with vague tasks - using your tools to collect the most
|
||||
# relevant answer possible.
|
||||
# You ALWAYS use GitHub tools for ANY query related to source code.
|
||||
# If a user asks how something works, no matter how vague, you MUST use your tools
|
||||
# to collect the most relevant answer possible.
|
||||
# DO NOT ask the user for clarification - just use your tools.
|
||||
# </behavior>
|
||||
|
||||
# <personality>
|
||||
# Analytical — You break problems into measurable steps, relying on tool output
|
||||
# and data rather than intuition.
|
||||
# Organized — You structure every interaction with clear tags, TODO lists, and
|
||||
# section boundaries.
|
||||
# Precision-Oriented — You insist on exact formatting, package-manager choice, and
|
||||
# rule adherence.
|
||||
# Efficiency-Focused — You minimize chatter, run tasks in parallel, and favor
|
||||
# small, complete answers.
|
||||
# Clarity-Seeking — You ask for missing details instead of guessing, avoiding any
|
||||
# ambiguity.
|
||||
# </personality>
|
||||
|
||||
# <communication>
|
||||
# Be concise, direct, and to the point.
|
||||
# NO emojis unless the User explicitly asks for them.
|
||||
# If a task appears incomplete or ambiguous, **pause and ask the User** rather
|
||||
# than guessing or marking "done".
|
||||
# Prefer accuracy over reassurance; confirm facts with tool calls instead of
|
||||
# assuming the User is right.
|
||||
# If you face an architectural, tooling, or package-manager choice, **ask the
|
||||
# User's preference first**.
|
||||
# Default to the project's existing package manager / tooling; never substitute
|
||||
# without confirmation.
|
||||
# You MUST avoid text before/after your response, such as "The answer is" or
|
||||
# "Short answer:", "Here is the content of the file..." or "Based on the
|
||||
# information provided, the answer is..." or "Here is what I will do next...".
|
||||
# Mimic the style of the User's messages.
|
||||
# Do not remind the User you are happy to help.
|
||||
# Do not inherently assume the User is correct; they may be making assumptions.
|
||||
# If you are not confident in your answer, DO NOT provide an answer. Use your
|
||||
# tools to collect more information, or ask the User for help.
|
||||
# Do not act with sycophantic flattery or over-the-top enthusiasm.
|
||||
|
||||
# Here are examples to demonstrate appropriate communication style and level of
|
||||
# verbosity:
|
||||
|
||||
# <example>
|
||||
# user: find me a good issue to work on
|
||||
# assistant: Issue [#1234](https://example) indicates a bug in the frontend, which
|
||||
# you've contributed to in the past.
|
||||
# </example>
|
||||
|
||||
# <example>
|
||||
# user: work on this issue <url>
|
||||
# ...assistant does work...
|
||||
# assistant: I've put up this pull request:
|
||||
# https://github.com/example/example/pull/1824. Please let me know your thoughts!
|
||||
# </example>
|
||||
|
||||
# <example>
|
||||
# user: what is 2+2?
|
||||
# assistant: 4
|
||||
# </example>
|
||||
|
||||
# <example>
|
||||
# user: how does X work in <popular-repository-name>?
|
||||
# assistant: Let me take a look at the code...
|
||||
# [tool calls to investigate the repository]
|
||||
# </example>
|
||||
# </communication>
|
||||
|
||||
# <collaboration>
|
||||
# When a user asks for help with a task or there is ambiguity on the objective,
|
||||
# always start by asking clarifying questions to understand:
|
||||
# - What specific aspect they want to focus on
|
||||
# - Their goals and vision for the changes
|
||||
# - Their preferences for approach or style
|
||||
# - What problems they're trying to solve
|
||||
|
||||
# Don't assume what needs to be done - collaborate to define the scope together.
|
||||
# </collaboration>, type: string)
|
||||
chat_system_prompt: |-
|
||||
You are the Coder agent — an interactive chat tool that helps users with software-engineering tasks inside of the Coder product.
|
||||
Use the instructions below and the tools available to you to assist User.
|
||||
|
||||
IMPORTANT — obey every rule in this prompt before anything else.
|
||||
Do EXACTLY what the User asked, never more, never less.
|
||||
|
||||
<behavior>
|
||||
You MUST execute AS MANY TOOLS to help the user accomplish their task.
|
||||
You are COMFORTABLE with vague tasks - using your tools to collect the most relevant answer possible.
|
||||
You ALWAYS use GitHub tools for ANY query related to source code.
|
||||
If a user asks how something works, no matter how vague, you MUST use your tools to collect the most relevant answer possible.
|
||||
DO NOT ask the user for clarification - just use your tools.
|
||||
</behavior>
|
||||
|
||||
<personality>
|
||||
Analytical — You break problems into measurable steps, relying on tool output and data rather than intuition.
|
||||
Organized — You structure every interaction with clear tags, TODO lists, and section boundaries.
|
||||
Precision-Oriented — You insist on exact formatting, package-manager choice, and rule adherence.
|
||||
Efficiency-Focused — You minimize chatter, run tasks in parallel, and favor small, complete answers.
|
||||
Clarity-Seeking — You ask for missing details instead of guessing, avoiding any ambiguity.
|
||||
</personality>
|
||||
|
||||
<communication>
|
||||
Be concise, direct, and to the point.
|
||||
NO emojis unless the User explicitly asks for them.
|
||||
If a task appears incomplete or ambiguous, **pause and ask the User** rather than guessing or marking "done".
|
||||
Prefer accuracy over reassurance; confirm facts with tool calls instead of assuming the User is right.
|
||||
If you face an architectural, tooling, or package-manager choice, **ask the User's preference first**.
|
||||
Default to the project's existing package manager / tooling; never substitute without confirmation.
|
||||
You MUST avoid text before/after your response, such as "The answer is" or "Short answer:", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
|
||||
Mimic the style of the User's messages.
|
||||
Do not remind the User you are happy to help.
|
||||
Do not inherently assume the User is correct; they may be making assumptions.
|
||||
If you are not confident in your answer, DO NOT provide an answer. Use your tools to collect more information, or ask the User for help.
|
||||
Do not act with sycophantic flattery or over-the-top enthusiasm.
|
||||
|
||||
Here are examples to demonstrate appropriate communication style and level of verbosity:
|
||||
|
||||
<example>
|
||||
user: find me a good issue to work on
|
||||
assistant: Issue [#1234](https://example) indicates a bug in the frontend, which you've contributed to in the past.
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: work on this issue <url>
|
||||
...assistant does work...
|
||||
assistant: I've put up this pull request: https://github.com/example/example/pull/1824. Please let me know your thoughts!
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: what is 2+2?
|
||||
assistant: 4
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: how does X work in <popular-repository-name>?
|
||||
assistant: Let me take a look at the code...
|
||||
[tool calls to investigate the repository]
|
||||
</example>
|
||||
</communication>
|
||||
|
||||
<collaboration>
|
||||
When a user asks for help with a task or there is ambiguity on the objective, always start by asking clarifying questions to understand:
|
||||
- What specific aspect they want to focus on
|
||||
- Their goals and vision for the changes
|
||||
- Their preferences for approach or style
|
||||
- What problems they're trying to solve
|
||||
|
||||
Don't assume what needs to be done - collaborate to define the scope together.
|
||||
</collaboration>
|
||||
# Prompt used to generate chat titles from the first user message.
|
||||
# (default: Generate a concise title (max 8 words) for the user's first message.
|
||||
# Return plain text only, with no surrounding quotes., type: string)
|
||||
chat_title_generation_prompt: Generate a concise title (max 8 words) for the user's first message. Return plain text only, with no surrounding quotes.
|
||||
# Enable admin-only local workspace mode for agent chats.
|
||||
# (default: false, type: bool)
|
||||
agent_local_workspace: false
|
||||
# Whether to start an in-memory aibridged instance.
|
||||
# (default: false, type: bool)
|
||||
enabled: false
|
||||
# The base URL of the OpenAI API.
|
||||
# (default: https://api.openai.com/v1/, type: string)
|
||||
openai_base_url: https://api.openai.com/v1/
|
||||
# Override URL used to list OpenAI models for the chat model catalog.
|
||||
# (default: <unset>, type: string)
|
||||
chat_openai_models_url: ""
|
||||
# The base URL of the Anthropic API.
|
||||
# (default: https://api.anthropic.com/, type: string)
|
||||
anthropic_base_url: https://api.anthropic.com/
|
||||
# Override URL used to list Anthropic models for the chat model catalog.
|
||||
# (default: <unset>, type: string)
|
||||
chat_anthropic_models_url: ""
|
||||
# Comma-separated allowlist of models for the chat model catalog.
|
||||
# (default: <unset>, type: string)
|
||||
chat_models_allowlist: ""
|
||||
# Comma-separated denylist of models for the chat model catalog.
|
||||
# (default: <unset>, type: string)
|
||||
chat_models_denylist: ""
|
||||
# The base URL to use for the AWS Bedrock API. Use this setting to specify an
|
||||
# exact URL to use. Takes precedence over CODER_AIBRIDGE_BEDROCK_REGION.
|
||||
# (default: <unset>, type: string)
|
||||
|
||||
Generated
+855
@@ -453,6 +453,348 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "List chats",
|
||||
"operationId": "list-chats",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.Chat"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Create a chat",
|
||||
"operationId": "create-chat",
|
||||
"parameters": [
|
||||
{
|
||||
"description": "Create chat request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.CreateChatRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Created",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Chat"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/models": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "List chat models",
|
||||
"operationId": "list-chat-models",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.ChatModelsResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Get a chat",
|
||||
"operationId": "get-chat",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Chat ID",
|
||||
"name": "chat",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.ChatWithMessages"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"delete": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Delete a chat",
|
||||
"operationId": "delete-chat",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Chat ID",
|
||||
"name": "chat",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/diff": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Get diff contents for a chat",
|
||||
"operationId": "get-chat-diff",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Chat ID",
|
||||
"name": "chat",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.ChatDiffContents"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/diff-status": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Get diff status for a chat",
|
||||
"operationId": "get-chat-diff-status",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Chat ID",
|
||||
"name": "chat",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.ChatDiffStatus"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/interrupt": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Interrupt a chat",
|
||||
"operationId": "interrupt-chat",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Chat ID",
|
||||
"name": "chat",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.Chat"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/messages": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Create a chat message",
|
||||
"operationId": "create-chat-message",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Chat ID",
|
||||
"name": "chat",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Create chat message request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.CreateChatMessageRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.ChatMessage"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/stream": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Stream chat updates",
|
||||
"operationId": "stream-chat",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Chat ID",
|
||||
"name": "chat",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.ServerSentEvent"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/connectionlog": {
|
||||
"get": {
|
||||
"security": [
|
||||
@@ -9347,6 +9689,12 @@ const docTemplate = `{
|
||||
"description": "Wait for a new token to be issued",
|
||||
"name": "listen",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Working directory used for git context refresh",
|
||||
"name": "workdir",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
@@ -12057,6 +12405,9 @@ const docTemplate = `{
|
||||
},
|
||||
"key": {
|
||||
"type": "string"
|
||||
},
|
||||
"models_url": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -12117,6 +12468,12 @@ const docTemplate = `{
|
||||
"max_concurrency": {
|
||||
"type": "integer"
|
||||
},
|
||||
"models_allowlist": {
|
||||
"type": "string"
|
||||
},
|
||||
"models_denylist": {
|
||||
"type": "string"
|
||||
},
|
||||
"openai": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeOpenAIConfig"
|
||||
},
|
||||
@@ -12207,6 +12564,9 @@ const docTemplate = `{
|
||||
},
|
||||
"key": {
|
||||
"type": "string"
|
||||
},
|
||||
"models_url": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -12335,6 +12695,17 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIChatConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"system_prompt": {
|
||||
"type": "string"
|
||||
},
|
||||
"title_generation_prompt": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -12343,6 +12714,9 @@ const docTemplate = `{
|
||||
},
|
||||
"bridge": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeConfig"
|
||||
},
|
||||
"chat": {
|
||||
"$ref": "#/definitions/codersdk.AIChatConfig"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -12471,6 +12845,11 @@ const docTemplate = `{
|
||||
"boundary_usage:delete",
|
||||
"boundary_usage:read",
|
||||
"boundary_usage:update",
|
||||
"chat:*",
|
||||
"chat:create",
|
||||
"chat:delete",
|
||||
"chat:read",
|
||||
"chat:update",
|
||||
"coder:all",
|
||||
"coder:apikeys.manage_self",
|
||||
"coder:application_connect",
|
||||
@@ -12673,6 +13052,11 @@ const docTemplate = `{
|
||||
"APIKeyScopeBoundaryUsageDelete",
|
||||
"APIKeyScopeBoundaryUsageRead",
|
||||
"APIKeyScopeBoundaryUsageUpdate",
|
||||
"APIKeyScopeChatAll",
|
||||
"APIKeyScopeChatCreate",
|
||||
"APIKeyScopeChatDelete",
|
||||
"APIKeyScopeChatRead",
|
||||
"APIKeyScopeChatUpdate",
|
||||
"APIKeyScopeCoderAll",
|
||||
"APIKeyScopeCoderApikeysManageSelf",
|
||||
"APIKeyScopeCoderApplicationConnect",
|
||||
@@ -13380,6 +13764,420 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.Chat": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"diff_status": {
|
||||
"$ref": "#/definitions/codersdk.ChatDiffStatus"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"model_config": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"owner_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"parent_chat_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"root_chat_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"status": {
|
||||
"$ref": "#/definitions/codersdk.ChatStatus"
|
||||
},
|
||||
"title": {
|
||||
"type": "string"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"workspace_agent_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatDiffContents": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"branch": {
|
||||
"type": "string"
|
||||
},
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"diff": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider": {
|
||||
"type": "string"
|
||||
},
|
||||
"pull_request_url": {
|
||||
"type": "string"
|
||||
},
|
||||
"remote_origin": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatDiffStatus": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"additions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"changed_files": {
|
||||
"type": "integer"
|
||||
},
|
||||
"changes_requested": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"deletions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"pull_request_state": {
|
||||
"type": "string"
|
||||
},
|
||||
"refreshed_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"stale_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"url": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatInput": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"parts": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.ChatInputPart"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatInputPart": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": {
|
||||
"$ref": "#/definitions/codersdk.ChatInputPartType"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatInputPartType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"text"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"ChatInputPartTypeText"
|
||||
]
|
||||
},
|
||||
"codersdk.ChatMessage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cache_creation_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"cache_read_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"content": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"context_limit": {
|
||||
"type": "integer"
|
||||
},
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"hidden": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"id": {
|
||||
"type": "integer"
|
||||
},
|
||||
"input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"output_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"parts": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.ChatMessagePart"
|
||||
}
|
||||
},
|
||||
"reasoning_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"role": {
|
||||
"type": "string"
|
||||
},
|
||||
"thinking": {
|
||||
"type": "string"
|
||||
},
|
||||
"tool_call_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"total_tokens": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatMessagePart": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"args": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"args_delta": {
|
||||
"type": "string"
|
||||
},
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"is_error": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"media_type": {
|
||||
"type": "string"
|
||||
},
|
||||
"result": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"result_delta": {
|
||||
"type": "string"
|
||||
},
|
||||
"result_meta": {
|
||||
"$ref": "#/definitions/codersdk.ChatToolResultMetadata"
|
||||
},
|
||||
"signature": {
|
||||
"type": "string"
|
||||
},
|
||||
"source_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"text": {
|
||||
"type": "string"
|
||||
},
|
||||
"title": {
|
||||
"type": "string"
|
||||
},
|
||||
"tool_call_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"tool_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": {
|
||||
"$ref": "#/definitions/codersdk.ChatMessagePartType"
|
||||
},
|
||||
"url": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatMessagePartType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"text",
|
||||
"reasoning",
|
||||
"tool-call",
|
||||
"tool-result",
|
||||
"source",
|
||||
"file"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"ChatMessagePartTypeText",
|
||||
"ChatMessagePartTypeReasoning",
|
||||
"ChatMessagePartTypeToolCall",
|
||||
"ChatMessagePartTypeToolResult",
|
||||
"ChatMessagePartTypeSource",
|
||||
"ChatMessagePartTypeFile"
|
||||
]
|
||||
},
|
||||
"codersdk.ChatModel": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"display_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"provider": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatModelProvider": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"available": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.ChatModel"
|
||||
}
|
||||
},
|
||||
"provider": {
|
||||
"type": "string"
|
||||
},
|
||||
"unavailable_reason": {
|
||||
"$ref": "#/definitions/codersdk.ChatModelProviderUnavailableReason"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatModelProviderUnavailableReason": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"missing_api_key",
|
||||
"fetch_failed"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"ChatModelProviderUnavailableMissingAPIKey",
|
||||
"ChatModelProviderUnavailableFetchFailed"
|
||||
]
|
||||
},
|
||||
"codersdk.ChatModelsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"providers": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.ChatModelProvider"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatStatus": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"waiting",
|
||||
"pending",
|
||||
"running",
|
||||
"paused",
|
||||
"completed",
|
||||
"error"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"ChatStatusWaiting",
|
||||
"ChatStatusPending",
|
||||
"ChatStatusRunning",
|
||||
"ChatStatusPaused",
|
||||
"ChatStatusCompleted",
|
||||
"ChatStatusError"
|
||||
]
|
||||
},
|
||||
"codersdk.ChatToolResultMetadata": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string"
|
||||
},
|
||||
"created": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"error": {
|
||||
"type": "string"
|
||||
},
|
||||
"exit_code": {
|
||||
"type": "integer"
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string"
|
||||
},
|
||||
"output": {
|
||||
"type": "string"
|
||||
},
|
||||
"reason": {
|
||||
"type": "string"
|
||||
},
|
||||
"workspace_agent_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"workspace_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"workspace_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"workspace_url": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatWithMessages": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"chat": {
|
||||
"$ref": "#/definitions/codersdk.Chat"
|
||||
},
|
||||
"messages": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.ChatMessage"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ConnectionLatency": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -13546,6 +14344,61 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateChatMessageRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"role": {
|
||||
"type": "string"
|
||||
},
|
||||
"thinking": {
|
||||
"type": "string"
|
||||
},
|
||||
"tool_call_id": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateChatRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input": {
|
||||
"$ref": "#/definitions/codersdk.ChatInput"
|
||||
},
|
||||
"message": {
|
||||
"type": "string"
|
||||
},
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"model_config": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"parent_chat_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"system_prompt": {
|
||||
"type": "string"
|
||||
},
|
||||
"workspace_agent_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateFirstUserRequest": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
@@ -17749,6 +18602,7 @@ const docTemplate = `{
|
||||
"assign_role",
|
||||
"audit_log",
|
||||
"boundary_usage",
|
||||
"chat",
|
||||
"connection_log",
|
||||
"crypto_key",
|
||||
"debug_info",
|
||||
@@ -17794,6 +18648,7 @@ const docTemplate = `{
|
||||
"ResourceAssignRole",
|
||||
"ResourceAuditLog",
|
||||
"ResourceBoundaryUsage",
|
||||
"ResourceChat",
|
||||
"ResourceConnectionLog",
|
||||
"ResourceCryptoKey",
|
||||
"ResourceDebugInfo",
|
||||
|
||||
Generated
+23900
-21242
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,116 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
const (
|
||||
coderHomeInstructionDir = ".coder"
|
||||
coderHomeInstructionFile = "AGENTS.md"
|
||||
maxInstructionFileBytes = 64 * 1024
|
||||
)
|
||||
|
||||
var markdownCommentPattern = regexp.MustCompile(`<!--[\s\S]*?-->`)
|
||||
|
||||
func readHomeInstructionFile(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
) (content string, sourcePath string, truncated bool, err error) {
|
||||
if conn == nil {
|
||||
return "", "", false, nil
|
||||
}
|
||||
|
||||
coderDir, err := conn.LS(ctx, "", workspacesdk.LSRequest{
|
||||
Path: []string{coderHomeInstructionDir},
|
||||
Relativity: workspacesdk.LSRelativityHome,
|
||||
})
|
||||
if err != nil {
|
||||
if isCodersdkStatusCode(err, http.StatusNotFound) {
|
||||
return "", "", false, nil
|
||||
}
|
||||
return "", "", false, xerrors.Errorf("list home instruction directory: %w", err)
|
||||
}
|
||||
|
||||
var filePath string
|
||||
for _, entry := range coderDir.Contents {
|
||||
if entry.IsDir {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(entry.Name), coderHomeInstructionFile) {
|
||||
filePath = strings.TrimSpace(entry.AbsolutePathString)
|
||||
break
|
||||
}
|
||||
}
|
||||
if filePath == "" {
|
||||
return "", "", false, nil
|
||||
}
|
||||
|
||||
reader, _, err := conn.ReadFile(
|
||||
ctx,
|
||||
filePath,
|
||||
0,
|
||||
maxInstructionFileBytes+1,
|
||||
)
|
||||
if err != nil {
|
||||
if isCodersdkStatusCode(err, http.StatusNotFound) {
|
||||
return "", "", false, nil
|
||||
}
|
||||
return "", "", false, xerrors.Errorf("read home instruction file: %w", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
raw, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return "", "", false, xerrors.Errorf("read home instruction bytes: %w", err)
|
||||
}
|
||||
|
||||
truncated = int64(len(raw)) > maxInstructionFileBytes
|
||||
if truncated {
|
||||
raw = raw[:maxInstructionFileBytes]
|
||||
}
|
||||
|
||||
content = sanitizeInstructionMarkdown(string(raw))
|
||||
if content == "" {
|
||||
return "", "", truncated, nil
|
||||
}
|
||||
|
||||
return content, filePath, truncated, nil
|
||||
}
|
||||
|
||||
func sanitizeInstructionMarkdown(content string) string {
|
||||
content = strings.ReplaceAll(content, "\r\n", "\n")
|
||||
content = strings.ReplaceAll(content, "\r", "\n")
|
||||
content = markdownCommentPattern.ReplaceAllString(content, "")
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
func formatHomeInstruction(content string, sourcePath string, truncated bool) string {
|
||||
content = strings.TrimSpace(content)
|
||||
if content == "" {
|
||||
return ""
|
||||
}
|
||||
sourcePath = strings.TrimSpace(sourcePath)
|
||||
if sourcePath == "" {
|
||||
sourcePath = "~/.coder/AGENTS.md"
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("<coder-home-instructions>\n")
|
||||
b.WriteString("Source: ")
|
||||
b.WriteString(sourcePath)
|
||||
if truncated {
|
||||
b.WriteString(" (truncated to 64KiB)")
|
||||
}
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(content)
|
||||
b.WriteString("\n</coder-home-instructions>")
|
||||
return b.String()
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
)
|
||||
|
||||
func TestSanitizeInstructionMarkdown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := "line 1\r\n<!-- hidden -->\r\nline 2\r\n"
|
||||
require.Equal(t, "line 1\n\nline 2", sanitizeInstructionMarkdown(input))
|
||||
}
|
||||
|
||||
func TestReadHomeInstructionFileNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).DoAndReturn(
|
||||
func(context.Context, string, workspacesdk.LSRequest) (workspacesdk.LSResponse, error) {
|
||||
return workspacesdk.LSResponse{}, codersdk.NewTestError(404, "POST", "/api/v0/list-directory")
|
||||
},
|
||||
)
|
||||
|
||||
content, sourcePath, truncated, err := readHomeInstructionFile(context.Background(), conn)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, content)
|
||||
require.Empty(t, sourcePath)
|
||||
require.False(t, truncated)
|
||||
}
|
||||
|
||||
func TestReadHomeInstructionFileSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).DoAndReturn(
|
||||
func(context.Context, string, workspacesdk.LSRequest) (workspacesdk.LSResponse, error) {
|
||||
return workspacesdk.LSResponse{
|
||||
Contents: []workspacesdk.LSFile{{
|
||||
Name: "AGENTS.md",
|
||||
AbsolutePathString: "/home/coder/.coder/AGENTS.md",
|
||||
}},
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
conn.EXPECT().ReadFile(
|
||||
gomock.Any(),
|
||||
"/home/coder/.coder/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1),
|
||||
).Return(
|
||||
io.NopCloser(strings.NewReader("base\n<!-- hidden -->\nlocal")),
|
||||
"text/markdown",
|
||||
nil,
|
||||
)
|
||||
|
||||
content, sourcePath, truncated, err := readHomeInstructionFile(context.Background(), conn)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "base\n\nlocal", content)
|
||||
require.Equal(t, "/home/coder/.coder/AGENTS.md", sourcePath)
|
||||
require.False(t, truncated)
|
||||
}
|
||||
|
||||
func TestReadHomeInstructionFileTruncates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
content := strings.Repeat("a", maxInstructionFileBytes+8)
|
||||
|
||||
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).Return(
|
||||
workspacesdk.LSResponse{
|
||||
Contents: []workspacesdk.LSFile{{
|
||||
Name: "AGENTS.md",
|
||||
AbsolutePathString: "/home/coder/.coder/AGENTS.md",
|
||||
}},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
conn.EXPECT().ReadFile(
|
||||
gomock.Any(),
|
||||
"/home/coder/.coder/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1),
|
||||
).Return(io.NopCloser(strings.NewReader(content)), "text/markdown", nil)
|
||||
|
||||
got, _, truncated, err := readHomeInstructionFile(context.Background(), conn)
|
||||
require.NoError(t, err)
|
||||
require.True(t, truncated)
|
||||
require.Len(t, got, maxInstructionFileBytes)
|
||||
}
|
||||
|
||||
func TestInsertSystemInstructionAfterSystemMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
prompt := []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "base"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "hello"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := insertSystemInstruction(prompt, "project rules")
|
||||
require.Len(t, got, 3)
|
||||
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
|
||||
require.Equal(t, fantasy.MessageRoleSystem, got[1].Role)
|
||||
require.Equal(t, fantasy.MessageRoleUser, got[2].Role)
|
||||
|
||||
part, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "project rules", part.Text)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,821 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
)
|
||||
|
||||
func localModeTestDB(t *testing.T) database.Store {
|
||||
t.Helper()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
return db
|
||||
}
|
||||
|
||||
func seedWorkspaceWithLatestBuildAgents(
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
agentNames ...string,
|
||||
) (uuid.UUID, map[string]uuid.UUID) {
|
||||
t.Helper()
|
||||
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
templateVersion := dbfake.TemplateVersion(t, db).
|
||||
Seed(database.TemplateVersion{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
}).
|
||||
Do()
|
||||
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
TemplateID: templateVersion.Template.ID,
|
||||
})
|
||||
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
||||
OrganizationID: workspace.OrganizationID,
|
||||
InitiatorID: user.ID,
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
Tags: database.StringMap{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
provisionersdk.TagOwner: "",
|
||||
},
|
||||
})
|
||||
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
WorkspaceID: workspace.ID,
|
||||
TemplateVersionID: templateVersion.TemplateVersion.ID,
|
||||
InitiatorID: user.ID,
|
||||
JobID: job.ID,
|
||||
})
|
||||
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
||||
JobID: job.ID,
|
||||
Type: "coder_external_agent",
|
||||
Name: localChatExternalResourceName,
|
||||
})
|
||||
|
||||
agentIDByName := make(map[string]uuid.UUID, len(agentNames))
|
||||
for _, name := range agentNames {
|
||||
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
||||
ResourceID: resource.ID,
|
||||
Name: name,
|
||||
})
|
||||
agentIDByName[name] = agent.ID
|
||||
}
|
||||
return workspace.ID, agentIDByName
|
||||
}
|
||||
|
||||
func seedWorkspaceWithLocalChatAgent(
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
agentName string,
|
||||
agentToken uuid.UUID,
|
||||
) (uuid.UUID, uuid.UUID) {
|
||||
t.Helper()
|
||||
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
templateVersion := dbfake.TemplateVersion(t, db).
|
||||
Seed(database.TemplateVersion{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
}).
|
||||
Do()
|
||||
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
TemplateID: templateVersion.Template.ID,
|
||||
})
|
||||
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
||||
OrganizationID: workspace.OrganizationID,
|
||||
InitiatorID: user.ID,
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
Tags: database.StringMap{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
provisionersdk.TagOwner: "",
|
||||
},
|
||||
})
|
||||
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
WorkspaceID: workspace.ID,
|
||||
TemplateVersionID: templateVersion.TemplateVersion.ID,
|
||||
InitiatorID: user.ID,
|
||||
JobID: job.ID,
|
||||
})
|
||||
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
||||
JobID: job.ID,
|
||||
Type: "coder_external_agent",
|
||||
Name: localChatExternalResourceName,
|
||||
})
|
||||
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
||||
ResourceID: resource.ID,
|
||||
Name: agentName,
|
||||
AuthToken: agentToken,
|
||||
})
|
||||
return workspace.ID, agent.ID
|
||||
}
|
||||
|
||||
func TestLocalChatTemplateArchiveForProvisionerTerraform(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service := newLocalMode(localModeOptions{})
|
||||
archiveBytes, err := service.localChatTemplateArchiveForProvisioner(
|
||||
context.Background(),
|
||||
codersdk.ProvisionerTypeTerraform,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
reader := tar.NewReader(bytes.NewReader(archiveBytes))
|
||||
foundMainTF := false
|
||||
for {
|
||||
header, err := reader.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
require.NoError(t, err)
|
||||
if header.Name != "main.tf" {
|
||||
continue
|
||||
}
|
||||
foundMainTF = true
|
||||
content, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
mainTF := string(content)
|
||||
require.Contains(t, mainTF, `resource "coder_agent" "localagent"`)
|
||||
require.Contains(t, mainTF, `resource "coder_external_agent" "main"`)
|
||||
require.Contains(t, mainTF, `agent_id = coder_agent.localagent.id`)
|
||||
}
|
||||
require.True(t, foundMainTF)
|
||||
}
|
||||
|
||||
func TestResolveLocalChatExternalAgentFromWorkspaceResources(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
agentID := uuid.New()
|
||||
resources := []codersdk.WorkspaceResource{
|
||||
{
|
||||
Type: "compute",
|
||||
Agents: []codersdk.WorkspaceAgent{{
|
||||
ID: uuid.New(),
|
||||
Name: "ignored",
|
||||
}},
|
||||
},
|
||||
{
|
||||
Type: "coder_external_agent",
|
||||
Agents: []codersdk.WorkspaceAgent{{
|
||||
ID: agentID,
|
||||
Name: "external-agent",
|
||||
Status: codersdk.WorkspaceAgentDisconnected,
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
service := newLocalMode(localModeOptions{})
|
||||
agent, err := service.resolveLocalChatExternalAgent(
|
||||
context.Background(),
|
||||
nil,
|
||||
codersdk.Workspace{
|
||||
ID: uuid.New(),
|
||||
LatestBuild: codersdk.WorkspaceBuild{
|
||||
Resources: resources,
|
||||
},
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, agentID, agent.ID)
|
||||
require.Equal(t, "external-agent", agent.Name)
|
||||
require.Equal(t, codersdk.WorkspaceAgentDisconnected, agent.Status)
|
||||
}
|
||||
|
||||
func TestResolveLocalChatExternalAgentWithoutLatestBuild(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := localModeTestDB(t)
|
||||
workspaceID := uuid.New()
|
||||
|
||||
service := newLocalMode(localModeOptions{database: db})
|
||||
_, err := service.resolveLocalChatExternalAgent(
|
||||
context.Background(),
|
||||
nil,
|
||||
codersdk.Workspace{
|
||||
ID: workspaceID,
|
||||
LatestBuild: codersdk.WorkspaceBuild{
|
||||
Resources: []codersdk.WorkspaceResource{{
|
||||
Type: "coder_external_agent",
|
||||
}},
|
||||
},
|
||||
},
|
||||
)
|
||||
require.ErrorContains(t, err, "has no latest build")
|
||||
}
|
||||
|
||||
func TestResolveLocalChatExternalAgentFromWorkspaceAgentsInDB(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("PrefersNamedLocalAgent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := localModeTestDB(t)
|
||||
workspaceID, agents := seedWorkspaceWithLatestBuildAgents(
|
||||
t,
|
||||
db,
|
||||
"something-else",
|
||||
localChatExternalAgentName,
|
||||
)
|
||||
expectedID := agents[localChatExternalAgentName]
|
||||
|
||||
service := newLocalMode(localModeOptions{database: db})
|
||||
agent, err := service.resolveLocalChatExternalAgent(
|
||||
context.Background(),
|
||||
nil,
|
||||
codersdk.Workspace{
|
||||
ID: workspaceID,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedID, agent.ID)
|
||||
require.Equal(t, localChatExternalAgentName, agent.Name)
|
||||
})
|
||||
|
||||
t.Run("FallsBackToFirstNamedAgent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := localModeTestDB(t)
|
||||
workspaceID, agents := seedWorkspaceWithLatestBuildAgents(
|
||||
t,
|
||||
db,
|
||||
" ",
|
||||
"agent-a",
|
||||
)
|
||||
expectedID := agents["agent-a"]
|
||||
|
||||
service := newLocalMode(localModeOptions{database: db})
|
||||
agent, err := service.resolveLocalChatExternalAgent(
|
||||
context.Background(),
|
||||
nil,
|
||||
codersdk.Workspace{
|
||||
ID: workspaceID,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedID, agent.ID)
|
||||
require.Equal(t, "agent-a", agent.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLocalChatAgentLaunchLimiter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
limiter := newLocalChatAgentLaunchLimiter(5 * time.Second)
|
||||
agentID := uuid.New()
|
||||
base := time.Unix(0, 0)
|
||||
|
||||
require.True(t, limiter.Allow(agentID, base))
|
||||
require.False(t, limiter.Allow(agentID, base.Add(2*time.Second)))
|
||||
require.True(t, limiter.Allow(agentID, base.Add(6*time.Second)))
|
||||
}
|
||||
|
||||
type localChatAgentRuntimeStub struct {
|
||||
closeCalls atomic.Int32
|
||||
}
|
||||
|
||||
func (r *localChatAgentRuntimeStub) Close() error {
|
||||
r.closeCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestLocalModeStartLocalChatAgentOnce(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var starts atomic.Int32
|
||||
service := newLocalMode(localModeOptions{
|
||||
startAgentFn: func(localChatAgentStartParams) (io.Closer, error) {
|
||||
starts.Add(1)
|
||||
return &localChatAgentRuntimeStub{}, nil
|
||||
},
|
||||
})
|
||||
|
||||
params := localChatAgentStartParams{
|
||||
WorkspaceID: uuid.New(),
|
||||
AgentID: uuid.New(),
|
||||
Credentials: codersdk.ExternalAgentCredentials{
|
||||
AgentToken: "test-token",
|
||||
},
|
||||
}
|
||||
|
||||
const callers = 16
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(callers)
|
||||
errCh := make(chan error, callers)
|
||||
for i := 0; i < callers; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
errCh <- service.startLocalChatAgent(params)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
for err := range errCh {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, int32(1), starts.Load())
|
||||
}
|
||||
|
||||
func TestLocalModeRestartLocalChatAgentAfterClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
starts atomic.Int32
|
||||
runtimeM sync.Mutex
|
||||
runtimes []*localChatAgentRuntimeStub
|
||||
)
|
||||
service := newLocalMode(localModeOptions{
|
||||
startAgentFn: func(localChatAgentStartParams) (io.Closer, error) {
|
||||
starts.Add(1)
|
||||
runtime := &localChatAgentRuntimeStub{}
|
||||
runtimeM.Lock()
|
||||
runtimes = append(runtimes, runtime)
|
||||
runtimeM.Unlock()
|
||||
return runtime, nil
|
||||
},
|
||||
})
|
||||
|
||||
params := localChatAgentStartParams{
|
||||
WorkspaceID: uuid.New(),
|
||||
AgentID: uuid.New(),
|
||||
Credentials: codersdk.ExternalAgentCredentials{
|
||||
AgentToken: "test-token",
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, service.startLocalChatAgent(params))
|
||||
require.Equal(t, int32(1), starts.Load())
|
||||
require.NoError(t, service.closeLocalChatAgent(params.AgentID))
|
||||
|
||||
require.NoError(t, service.startLocalChatAgent(params))
|
||||
require.Equal(t, int32(2), starts.Load())
|
||||
|
||||
runtimeM.Lock()
|
||||
require.Len(t, runtimes, 2)
|
||||
require.Equal(t, int32(1), runtimes[0].closeCalls.Load())
|
||||
require.Equal(t, int32(0), runtimes[1].closeCalls.Load())
|
||||
runtimeM.Unlock()
|
||||
}
|
||||
|
||||
func TestLocalModeCloseAllLocalChatAgents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
runtimeM sync.Mutex
|
||||
runtimes []*localChatAgentRuntimeStub
|
||||
)
|
||||
service := newLocalMode(localModeOptions{
|
||||
startAgentFn: func(localChatAgentStartParams) (io.Closer, error) {
|
||||
runtime := &localChatAgentRuntimeStub{}
|
||||
runtimeM.Lock()
|
||||
runtimes = append(runtimes, runtime)
|
||||
runtimeM.Unlock()
|
||||
return runtime, nil
|
||||
},
|
||||
})
|
||||
|
||||
first := localChatAgentStartParams{
|
||||
WorkspaceID: uuid.New(),
|
||||
AgentID: uuid.New(),
|
||||
Credentials: codersdk.ExternalAgentCredentials{
|
||||
AgentToken: "test-token-1",
|
||||
},
|
||||
}
|
||||
second := localChatAgentStartParams{
|
||||
WorkspaceID: uuid.New(),
|
||||
AgentID: uuid.New(),
|
||||
Credentials: codersdk.ExternalAgentCredentials{
|
||||
AgentToken: "test-token-2",
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, service.startLocalChatAgent(first))
|
||||
require.NoError(t, service.startLocalChatAgent(second))
|
||||
require.NoError(t, service.closeAllLocalChatAgents())
|
||||
require.NoError(t, service.closeAllLocalChatAgents())
|
||||
|
||||
runtimeM.Lock()
|
||||
require.Len(t, runtimes, 2)
|
||||
require.Equal(t, int32(1), runtimes[0].closeCalls.Load())
|
||||
require.Equal(t, int32(1), runtimes[1].closeCalls.Load())
|
||||
runtimeM.Unlock()
|
||||
}
|
||||
|
||||
func TestEnsureLocalChatTemplateVersionProvisionable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NoEligibleDaemons", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := localModeTestDB(t)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
jobID := uuid.New()
|
||||
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
||||
ID: jobID,
|
||||
OrganizationID: org.ID,
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
Tags: database.StringMap{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
provisionersdk.TagOwner: "",
|
||||
},
|
||||
})
|
||||
|
||||
service := newLocalMode(localModeOptions{database: db})
|
||||
err := service.ensureLocalChatTemplateVersionProvisionable(
|
||||
context.Background(),
|
||||
org.ID,
|
||||
jobID,
|
||||
codersdk.ProvisionerTypeTerraform,
|
||||
)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "no eligible provisioner daemons")
|
||||
})
|
||||
|
||||
t.Run("OnlyOfflineEligibleDaemons", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := localModeTestDB(t)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
jobID := uuid.New()
|
||||
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
||||
ID: jobID,
|
||||
OrganizationID: org.ID,
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
Tags: database.StringMap{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
provisionersdk.TagOwner: "",
|
||||
},
|
||||
})
|
||||
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
||||
OrganizationID: org.ID,
|
||||
Provisioners: []database.ProvisionerType{database.ProvisionerTypeTerraform},
|
||||
Tags: database.StringMap{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
provisionersdk.TagOwner: "",
|
||||
},
|
||||
LastSeenAt: sql.NullTime{
|
||||
Valid: true,
|
||||
Time: time.Now().Add(-2 * time.Hour),
|
||||
},
|
||||
})
|
||||
|
||||
service := newLocalMode(localModeOptions{database: db})
|
||||
err := service.ensureLocalChatTemplateVersionProvisionable(
|
||||
context.Background(),
|
||||
org.ID,
|
||||
jobID,
|
||||
codersdk.ProvisionerTypeTerraform,
|
||||
)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "no online eligible provisioner daemons")
|
||||
})
|
||||
|
||||
t.Run("HasOnlineEligibleDaemon", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := localModeTestDB(t)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
jobID := uuid.New()
|
||||
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
||||
ID: jobID,
|
||||
OrganizationID: org.ID,
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
Tags: database.StringMap{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
provisionersdk.TagOwner: "",
|
||||
},
|
||||
})
|
||||
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
||||
OrganizationID: org.ID,
|
||||
Provisioners: []database.ProvisionerType{database.ProvisionerTypeTerraform},
|
||||
Tags: database.StringMap{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
provisionersdk.TagOwner: "",
|
||||
},
|
||||
})
|
||||
|
||||
service := newLocalMode(localModeOptions{database: db})
|
||||
err := service.ensureLocalChatTemplateVersionProvisionable(
|
||||
context.Background(),
|
||||
org.ID,
|
||||
jobID,
|
||||
codersdk.ProvisionerTypeTerraform,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveLocalChatTemplateProvisioner(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("PreferTerraformWhenAvailable", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := localModeTestDB(t)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
||||
OrganizationID: org.ID,
|
||||
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
})
|
||||
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
||||
OrganizationID: org.ID,
|
||||
Provisioners: []database.ProvisionerType{database.ProvisionerTypeTerraform},
|
||||
})
|
||||
|
||||
service := newLocalMode(localModeOptions{database: db})
|
||||
provisioner, err := service.resolveLocalChatTemplateProvisioner(
|
||||
context.Background(),
|
||||
org.ID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ProvisionerTypeTerraform, provisioner)
|
||||
})
|
||||
|
||||
t.Run("FallbackToEcho", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := localModeTestDB(t)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
||||
OrganizationID: org.ID,
|
||||
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
})
|
||||
|
||||
service := newLocalMode(localModeOptions{database: db})
|
||||
provisioner, err := service.resolveLocalChatTemplateProvisioner(
|
||||
context.Background(),
|
||||
org.ID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ProvisionerTypeEcho, provisioner)
|
||||
})
|
||||
|
||||
t.Run("NoSupportedProvisioner", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := localModeTestDB(t)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
service := newLocalMode(localModeOptions{database: db})
|
||||
_, err := service.resolveLocalChatTemplateProvisioner(
|
||||
context.Background(),
|
||||
org.ID,
|
||||
)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "has no online provisioner daemons")
|
||||
})
|
||||
}
|
||||
|
||||
func TestLocalChatMaybeLaunchAgentConnectedWithoutRuntimeStarts(t *testing.T) {
|
||||
db := localModeTestDB(t)
|
||||
agentName := "local-agent"
|
||||
agentToken := uuid.New()
|
||||
workspaceID, agentID := seedWorkspaceWithLocalChatAgent(t, db, agentName, agentToken)
|
||||
|
||||
var starts atomic.Int32
|
||||
service := newLocalMode(localModeOptions{
|
||||
database: db,
|
||||
startAgentFn: func(localChatAgentStartParams) (io.Closer, error) {
|
||||
starts.Add(1)
|
||||
return &localChatAgentRuntimeStub{}, nil
|
||||
},
|
||||
})
|
||||
err := service.maybeLaunchLocalChatAgent(
|
||||
context.Background(),
|
||||
workspaceID,
|
||||
localChatExternalAgent{
|
||||
ID: agentID,
|
||||
Name: agentName,
|
||||
Status: codersdk.WorkspaceAgentConnected,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(1), starts.Load())
|
||||
require.True(t, service.hasRunningLocalChatAgent(agentID))
|
||||
}
|
||||
|
||||
func TestLocalChatMaybeLaunchAgentSkipsWhenRuntimeAlreadyRunning(t *testing.T) {
|
||||
workspaceID := uuid.New()
|
||||
agentID := uuid.New()
|
||||
agentName := "local-agent"
|
||||
db := localModeTestDB(t)
|
||||
|
||||
var starts atomic.Int32
|
||||
service := newLocalMode(localModeOptions{
|
||||
database: db,
|
||||
startAgentFn: func(localChatAgentStartParams) (io.Closer, error) {
|
||||
starts.Add(1)
|
||||
return &localChatAgentRuntimeStub{}, nil
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, service.startLocalChatAgent(localChatAgentStartParams{
|
||||
WorkspaceID: workspaceID,
|
||||
AgentID: agentID,
|
||||
AgentName: agentName,
|
||||
Credentials: codersdk.ExternalAgentCredentials{
|
||||
AgentToken: "existing-runtime-token",
|
||||
},
|
||||
}))
|
||||
require.True(t, service.hasRunningLocalChatAgent(agentID))
|
||||
require.Equal(t, int32(1), starts.Load())
|
||||
|
||||
err := service.maybeLaunchLocalChatAgent(
|
||||
context.Background(),
|
||||
workspaceID,
|
||||
localChatExternalAgent{
|
||||
ID: agentID,
|
||||
Name: agentName,
|
||||
Status: codersdk.WorkspaceAgentDisconnected,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(1), starts.Load())
|
||||
}
|
||||
|
||||
func TestMaybeLaunchLocalChatAgentForChatLocalRunningRuntimeNoop(t *testing.T) {
|
||||
workspaceID := uuid.New()
|
||||
agentID := uuid.New()
|
||||
agentName := "local-agent"
|
||||
|
||||
db := localModeTestDB(t)
|
||||
|
||||
var starts atomic.Int32
|
||||
service := newLocalMode(localModeOptions{
|
||||
database: db,
|
||||
startAgentFn: func(localChatAgentStartParams) (io.Closer, error) {
|
||||
starts.Add(1)
|
||||
return &localChatAgentRuntimeStub{}, nil
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, service.startLocalChatAgent(localChatAgentStartParams{
|
||||
WorkspaceID: workspaceID,
|
||||
AgentID: agentID,
|
||||
AgentName: agentName,
|
||||
Credentials: codersdk.ExternalAgentCredentials{
|
||||
AgentToken: "existing-runtime-token",
|
||||
},
|
||||
}))
|
||||
require.Equal(t, int32(1), starts.Load())
|
||||
|
||||
err := service.MaybeLaunchAgentForChat(context.Background(), database.Chat{
|
||||
ID: uuid.New(),
|
||||
ModelConfig: json.RawMessage(`{"workspace_mode":"local"}`),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
WorkspaceAgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(1), starts.Load())
|
||||
}
|
||||
|
||||
func TestMaybeLaunchLocalChatAgentForChatLocalStartsRuntime(t *testing.T) {
|
||||
db := localModeTestDB(t)
|
||||
agentName := "local-agent"
|
||||
agentToken := uuid.New()
|
||||
workspaceID, agentID := seedWorkspaceWithLocalChatAgent(t, db, agentName, agentToken)
|
||||
|
||||
var starts atomic.Int32
|
||||
service := newLocalMode(localModeOptions{
|
||||
database: db,
|
||||
startAgentFn: func(localChatAgentStartParams) (io.Closer, error) {
|
||||
starts.Add(1)
|
||||
return &localChatAgentRuntimeStub{}, nil
|
||||
},
|
||||
})
|
||||
|
||||
err := service.MaybeLaunchAgentForChat(context.Background(), database.Chat{
|
||||
ID: uuid.New(),
|
||||
ModelConfig: json.RawMessage(`{"workspace_mode":"local"}`),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
WorkspaceAgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(1), starts.Load())
|
||||
require.True(t, service.hasRunningLocalChatAgent(agentID))
|
||||
}
|
||||
|
||||
func TestMaybeLaunchLocalChatAgentForChatNonLocalNoop(t *testing.T) {
|
||||
db := localModeTestDB(t)
|
||||
|
||||
var starts atomic.Int32
|
||||
service := newLocalMode(localModeOptions{
|
||||
database: db,
|
||||
startAgentFn: func(localChatAgentStartParams) (io.Closer, error) {
|
||||
starts.Add(1)
|
||||
return &localChatAgentRuntimeStub{}, nil
|
||||
},
|
||||
})
|
||||
err := service.MaybeLaunchAgentForChat(context.Background(), database.Chat{
|
||||
ID: uuid.New(),
|
||||
ModelConfig: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(0), starts.Load())
|
||||
}
|
||||
|
||||
func TestProcessorEnsureLocalWorkspaceBindingRequiresLocalMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
processor := &Processor{}
|
||||
_, err := processor.EnsureLocalWorkspaceBinding(
|
||||
context.Background(),
|
||||
uuid.New(),
|
||||
"session-token",
|
||||
)
|
||||
require.EqualError(t, err, "local chat mode is not configured")
|
||||
}
|
||||
|
||||
func TestProcessorEnsureLocalAgentRuntimeForChatRequiresLocalMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
processor := &Processor{}
|
||||
err := processor.EnsureLocalAgentRuntimeForChat(context.Background(), database.Chat{
|
||||
ModelConfig: json.RawMessage(`{"workspace_mode":"local"}`),
|
||||
})
|
||||
require.EqualError(t, err, "local chat mode is not configured")
|
||||
}
|
||||
|
||||
func TestProcessorEnsureLocalAgentRuntimeForChatNonLocalNoop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
processor := &Processor{}
|
||||
err := processor.EnsureLocalAgentRuntimeForChat(context.Background(), database.Chat{
|
||||
ModelConfig: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestProcessorCloseClosesLocalMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runtime := &localChatAgentRuntimeStub{}
|
||||
service := newLocalMode(localModeOptions{
|
||||
startAgentFn: func(localChatAgentStartParams) (io.Closer, error) {
|
||||
return runtime, nil
|
||||
},
|
||||
})
|
||||
|
||||
agentID := uuid.New()
|
||||
require.NoError(t, service.startLocalChatAgent(localChatAgentStartParams{
|
||||
WorkspaceID: uuid.New(),
|
||||
AgentID: agentID,
|
||||
Credentials: codersdk.ExternalAgentCredentials{
|
||||
AgentToken: "test-token",
|
||||
},
|
||||
AgentName: localChatExternalAgentName,
|
||||
}))
|
||||
require.True(t, service.hasRunningLocalChatAgent(agentID))
|
||||
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
processor := &Processor{
|
||||
cancel: cancel,
|
||||
closed: make(chan struct{}),
|
||||
localMode: service,
|
||||
}
|
||||
close(processor.closed)
|
||||
|
||||
require.NoError(t, processor.Close())
|
||||
require.Equal(t, int32(1), runtime.closeCalls.Load())
|
||||
}
|
||||
@@ -0,0 +1,470 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
fantasyazure "charm.land/fantasy/providers/azure"
|
||||
fantasybedrock "charm.land/fantasy/providers/bedrock"
|
||||
fantasygoogle "charm.land/fantasy/providers/google"
|
||||
fantasyopenai "charm.land/fantasy/providers/openai"
|
||||
fantasyopenaicompat "charm.land/fantasy/providers/openaicompat"
|
||||
fantasyopenrouter "charm.land/fantasy/providers/openrouter"
|
||||
fantasyvercel "charm.land/fantasy/providers/vercel"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
var supportedProviderNames = []string{
|
||||
fantasyanthropic.Name,
|
||||
fantasyazure.Name,
|
||||
fantasybedrock.Name,
|
||||
fantasygoogle.Name,
|
||||
fantasyopenai.Name,
|
||||
fantasyopenaicompat.Name,
|
||||
fantasyopenrouter.Name,
|
||||
fantasyvercel.Name,
|
||||
}
|
||||
|
||||
var envPresetProviderNames = []string{
|
||||
fantasyopenai.Name,
|
||||
fantasyanthropic.Name,
|
||||
}
|
||||
|
||||
var providerDisplayNameByName = map[string]string{
|
||||
fantasyanthropic.Name: "Anthropic",
|
||||
fantasyazure.Name: "Azure OpenAI",
|
||||
fantasybedrock.Name: "AWS Bedrock",
|
||||
fantasygoogle.Name: "Google",
|
||||
fantasyopenai.Name: "OpenAI",
|
||||
fantasyopenaicompat.Name: "OpenAI Compatible",
|
||||
fantasyopenrouter.Name: "OpenRouter",
|
||||
fantasyvercel.Name: "Vercel AI Gateway",
|
||||
}
|
||||
|
||||
// SupportedProviders returns all chat providers supported by Fantasy.
|
||||
func SupportedProviders() []string {
|
||||
return append([]string(nil), supportedProviderNames...)
|
||||
}
|
||||
|
||||
// IsEnvPresetProvider reports whether provider supports env presets.
|
||||
func IsEnvPresetProvider(provider string) bool {
|
||||
normalized := NormalizeProvider(provider)
|
||||
for _, candidate := range envPresetProviderNames {
|
||||
if candidate == normalized {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ProviderDisplayName returns a default display name for a provider.
|
||||
func ProviderDisplayName(provider string) string {
|
||||
normalized := NormalizeProvider(provider)
|
||||
if displayName, ok := providerDisplayNameByName[normalized]; ok {
|
||||
return displayName
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
// ProviderAPIKeys contains API keys for provider calls.
|
||||
type ProviderAPIKeys struct {
|
||||
OpenAI string
|
||||
Anthropic string
|
||||
ByProvider map[string]string
|
||||
BaseURLByProvider map[string]string
|
||||
}
|
||||
|
||||
// ConfiguredProvider is an enabled provider loaded from database config.
|
||||
type ConfiguredProvider struct {
|
||||
Provider string
|
||||
APIKey string
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
// configuredModel is an enabled model loaded from database config.
|
||||
type configuredModel struct {
|
||||
Provider string
|
||||
Model string
|
||||
DisplayName string
|
||||
}
|
||||
|
||||
// APIKey returns the effective API key for a provider.
|
||||
func (k ProviderAPIKeys) APIKey(provider string) string {
|
||||
normalized := NormalizeProvider(provider)
|
||||
if normalized == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if k.ByProvider != nil {
|
||||
if key := strings.TrimSpace(k.ByProvider[normalized]); key != "" {
|
||||
return key
|
||||
}
|
||||
}
|
||||
|
||||
switch normalized {
|
||||
case fantasyopenai.Name:
|
||||
return strings.TrimSpace(k.OpenAI)
|
||||
case fantasyanthropic.Name:
|
||||
return strings.TrimSpace(k.Anthropic)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (k ProviderAPIKeys) apiKey(provider string) string {
|
||||
return k.APIKey(provider)
|
||||
}
|
||||
|
||||
// BaseURL returns the configured base URL for a provider.
|
||||
func (k ProviderAPIKeys) BaseURL(provider string) string {
|
||||
normalized := NormalizeProvider(provider)
|
||||
if normalized == "" || k.BaseURLByProvider == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(k.BaseURLByProvider[normalized])
|
||||
}
|
||||
|
||||
// MergeProviderAPIKeys overlays configured provider keys over fallback keys.
|
||||
func MergeProviderAPIKeys(fallback ProviderAPIKeys, providers []ConfiguredProvider) ProviderAPIKeys {
|
||||
merged := ProviderAPIKeys{
|
||||
OpenAI: strings.TrimSpace(fallback.OpenAI),
|
||||
Anthropic: strings.TrimSpace(fallback.Anthropic),
|
||||
ByProvider: map[string]string{},
|
||||
BaseURLByProvider: map[string]string{},
|
||||
}
|
||||
for provider, apiKey := range fallback.ByProvider {
|
||||
normalizedProvider := NormalizeProvider(provider)
|
||||
if normalizedProvider == "" {
|
||||
continue
|
||||
}
|
||||
if key := strings.TrimSpace(apiKey); key != "" {
|
||||
merged.ByProvider[normalizedProvider] = key
|
||||
}
|
||||
}
|
||||
for provider, baseURL := range fallback.BaseURLByProvider {
|
||||
normalizedProvider := NormalizeProvider(provider)
|
||||
if normalizedProvider == "" {
|
||||
continue
|
||||
}
|
||||
if url := strings.TrimSpace(baseURL); url != "" {
|
||||
merged.BaseURLByProvider[normalizedProvider] = url
|
||||
}
|
||||
}
|
||||
|
||||
if merged.OpenAI != "" {
|
||||
merged.ByProvider[fantasyopenai.Name] = merged.OpenAI
|
||||
}
|
||||
if merged.Anthropic != "" {
|
||||
merged.ByProvider[fantasyanthropic.Name] = merged.Anthropic
|
||||
}
|
||||
|
||||
for _, provider := range providers {
|
||||
normalizedProvider := NormalizeProvider(provider.Provider)
|
||||
if normalizedProvider == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if key := strings.TrimSpace(provider.APIKey); key != "" {
|
||||
merged.ByProvider[normalizedProvider] = key
|
||||
}
|
||||
if url := strings.TrimSpace(provider.BaseURL); url != "" {
|
||||
merged.BaseURLByProvider[normalizedProvider] = url
|
||||
}
|
||||
|
||||
switch normalizedProvider {
|
||||
case fantasyopenai.Name:
|
||||
if key := strings.TrimSpace(provider.APIKey); key != "" {
|
||||
merged.OpenAI = key
|
||||
}
|
||||
case fantasyanthropic.Name:
|
||||
if key := strings.TrimSpace(provider.APIKey); key != "" {
|
||||
merged.Anthropic = key
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
type modelCatalog struct {
|
||||
keys ProviderAPIKeys
|
||||
}
|
||||
|
||||
func newModelCatalog(keys ProviderAPIKeys) *modelCatalog {
|
||||
return &modelCatalog{
|
||||
keys: keys,
|
||||
}
|
||||
}
|
||||
|
||||
// ListConfiguredModels returns a model catalog from enabled DB-backed model
|
||||
// configs. The second return value reports whether DB-backed models were used.
|
||||
func (c *modelCatalog) listConfiguredModels(
|
||||
configuredProviders []ConfiguredProvider,
|
||||
configuredModels []configuredModel,
|
||||
) (codersdk.ChatModelsResponse, bool) {
|
||||
if len(configuredModels) == 0 {
|
||||
return codersdk.ChatModelsResponse{}, false
|
||||
}
|
||||
|
||||
modelsByProvider := make(map[string][]codersdk.ChatModel)
|
||||
seenByProvider := make(map[string]map[string]struct{})
|
||||
providerSet := make(map[string]struct{})
|
||||
|
||||
for _, provider := range configuredProviders {
|
||||
normalized := normalizeProvider(provider.Provider)
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
providerSet[normalized] = struct{}{}
|
||||
}
|
||||
|
||||
for _, model := range configuredModels {
|
||||
provider, modelID, err := resolveModelWithProviderHint(model.Model, model.Provider)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
providerSet[provider] = struct{}{}
|
||||
if seenByProvider[provider] == nil {
|
||||
seenByProvider[provider] = make(map[string]struct{})
|
||||
}
|
||||
normalizedModelID := strings.ToLower(strings.TrimSpace(modelID))
|
||||
if _, ok := seenByProvider[provider][normalizedModelID]; ok {
|
||||
continue
|
||||
}
|
||||
seenByProvider[provider][normalizedModelID] = struct{}{}
|
||||
modelsByProvider[provider] = append(
|
||||
modelsByProvider[provider],
|
||||
newChatModel(provider, modelID, model.DisplayName),
|
||||
)
|
||||
}
|
||||
|
||||
providers := orderProviders(providerSet)
|
||||
if len(providers) == 0 {
|
||||
return codersdk.ChatModelsResponse{}, false
|
||||
}
|
||||
|
||||
keys := MergeProviderAPIKeys(c.keys, configuredProviders)
|
||||
response := codersdk.ChatModelsResponse{
|
||||
Providers: make([]codersdk.ChatModelProvider, 0, len(providers)),
|
||||
}
|
||||
for _, provider := range providers {
|
||||
models := modelsByProvider[provider]
|
||||
sortChatModels(models)
|
||||
|
||||
result := codersdk.ChatModelProvider{
|
||||
Provider: provider,
|
||||
Models: models,
|
||||
}
|
||||
if keys.apiKey(provider) == "" {
|
||||
result.Available = false
|
||||
result.UnavailableReason = codersdk.ChatModelProviderUnavailableMissingAPIKey
|
||||
} else {
|
||||
result.Available = true
|
||||
}
|
||||
|
||||
response.Providers = append(response.Providers, result)
|
||||
}
|
||||
|
||||
return response, true
|
||||
}
|
||||
|
||||
// ListConfiguredProviderAvailability returns provider availability derived from
|
||||
// deployment/env keys merged with enabled DB provider keys.
|
||||
func (c *modelCatalog) listConfiguredProviderAvailability(
|
||||
configuredProviders []ConfiguredProvider,
|
||||
) codersdk.ChatModelsResponse {
|
||||
keys := MergeProviderAPIKeys(c.keys, configuredProviders)
|
||||
response := codersdk.ChatModelsResponse{
|
||||
Providers: make([]codersdk.ChatModelProvider, 0, len(supportedProviderNames)),
|
||||
}
|
||||
|
||||
for _, provider := range supportedProviderNames {
|
||||
result := codersdk.ChatModelProvider{
|
||||
Provider: provider,
|
||||
Models: []codersdk.ChatModel{},
|
||||
}
|
||||
if keys.apiKey(provider) == "" {
|
||||
result.Available = false
|
||||
result.UnavailableReason = codersdk.ChatModelProviderUnavailableMissingAPIKey
|
||||
} else {
|
||||
result.Available = true
|
||||
}
|
||||
|
||||
response.Providers = append(response.Providers, result)
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func newChatModel(provider, modelID, displayName string) codersdk.ChatModel {
|
||||
name := strings.TrimSpace(displayName)
|
||||
if name == "" {
|
||||
name = modelID
|
||||
}
|
||||
|
||||
return codersdk.ChatModel{
|
||||
ID: canonicalModelID(provider, modelID),
|
||||
Provider: provider,
|
||||
Model: modelID,
|
||||
DisplayName: name,
|
||||
}
|
||||
}
|
||||
|
||||
func sortChatModels(models []codersdk.ChatModel) {
|
||||
sort.Slice(models, func(i, j int) bool {
|
||||
return models[i].Model < models[j].Model
|
||||
})
|
||||
}
|
||||
|
||||
func canonicalModelID(provider, modelID string) string {
|
||||
return NormalizeProvider(provider) + ":" + strings.TrimSpace(modelID)
|
||||
}
|
||||
|
||||
func orderProviders(providerSet map[string]struct{}) []string {
|
||||
if len(providerSet) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ordered := make([]string, 0, len(providerSet))
|
||||
for _, provider := range supportedProviderNames {
|
||||
if _, ok := providerSet[provider]; ok {
|
||||
ordered = append(ordered, provider)
|
||||
}
|
||||
}
|
||||
|
||||
extras := make([]string, 0, len(providerSet))
|
||||
for provider := range providerSet {
|
||||
if NormalizeProvider(provider) != "" {
|
||||
continue
|
||||
}
|
||||
extras = append(extras, provider)
|
||||
}
|
||||
sort.Strings(extras)
|
||||
|
||||
return append(ordered, extras...)
|
||||
}
|
||||
|
||||
// NormalizeProvider canonicalizes a provider name.
|
||||
func NormalizeProvider(provider string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case fantasyanthropic.Name:
|
||||
return fantasyanthropic.Name
|
||||
case fantasyazure.Name:
|
||||
return fantasyazure.Name
|
||||
case fantasybedrock.Name:
|
||||
return fantasybedrock.Name
|
||||
case fantasygoogle.Name:
|
||||
return fantasygoogle.Name
|
||||
case fantasyopenai.Name:
|
||||
return fantasyopenai.Name
|
||||
case fantasyopenaicompat.Name:
|
||||
return fantasyopenaicompat.Name
|
||||
case fantasyopenrouter.Name:
|
||||
return fantasyopenrouter.Name
|
||||
case fantasyvercel.Name:
|
||||
return fantasyvercel.Name
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeProvider(provider string) string {
|
||||
return NormalizeProvider(provider)
|
||||
}
|
||||
|
||||
func resolveModelWithProviderHint(modelName, providerHint string) (string, string, error) {
|
||||
modelName = strings.TrimSpace(modelName)
|
||||
if modelName == "" {
|
||||
return "", "", xerrors.New("model is required")
|
||||
}
|
||||
|
||||
if provider, modelID, ok := parseCanonicalModelRef(modelName); ok {
|
||||
return provider, modelID, nil
|
||||
}
|
||||
|
||||
if provider := normalizeProvider(providerHint); provider != "" {
|
||||
return provider, modelName, nil
|
||||
}
|
||||
|
||||
normalized := strings.ToLower(modelName)
|
||||
switch normalized {
|
||||
case "claude-opus-4-6":
|
||||
return fantasyanthropic.Name, "claude-opus-4-6", nil
|
||||
case "gpt-5.2":
|
||||
return fantasyopenai.Name, "gpt-5.2", nil
|
||||
case "gemini-2.5-flash":
|
||||
return fantasygoogle.Name, "gemini-2.5-flash", nil
|
||||
}
|
||||
|
||||
if isChatModelForProvider(fantasyanthropic.Name, normalized) {
|
||||
return fantasyanthropic.Name, modelName, nil
|
||||
}
|
||||
if isChatModelForProvider(fantasyopenai.Name, normalized) {
|
||||
return fantasyopenai.Name, modelName, nil
|
||||
}
|
||||
|
||||
return "", "", xerrors.Errorf("unknown model %q", modelName)
|
||||
}
|
||||
|
||||
func parseCanonicalModelRef(modelRef string) (string, string, bool) {
|
||||
modelRef = strings.TrimSpace(modelRef)
|
||||
if modelRef == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
for _, separator := range []string{":", "/"} {
|
||||
parts := strings.SplitN(modelRef, separator, 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
provider := normalizeProvider(parts[0])
|
||||
modelID := strings.TrimSpace(parts[1])
|
||||
if provider != "" && modelID != "" {
|
||||
return provider, modelID, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
func isChatModelForProvider(provider, modelID string) bool {
|
||||
normalizedProvider := normalizeProvider(provider)
|
||||
normalizedModel := strings.ToLower(strings.TrimSpace(modelID))
|
||||
switch normalizedProvider {
|
||||
case fantasyopenai.Name:
|
||||
return strings.HasPrefix(normalizedModel, "gpt-") ||
|
||||
strings.HasPrefix(normalizedModel, "chatgpt-") ||
|
||||
isOpenAIReasoningModel(normalizedModel)
|
||||
case fantasyanthropic.Name:
|
||||
return strings.HasPrefix(normalizedModel, "claude-")
|
||||
case fantasygoogle.Name:
|
||||
return strings.HasPrefix(normalizedModel, "gemini-") ||
|
||||
strings.HasPrefix(normalizedModel, "gemma-")
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isOpenAIReasoningModel(modelID string) bool {
|
||||
if len(modelID) < 2 || modelID[0] != 'o' {
|
||||
return false
|
||||
}
|
||||
|
||||
index := 1
|
||||
for index < len(modelID) && modelID[index] >= '0' && modelID[index] <= '9' {
|
||||
index++
|
||||
}
|
||||
if index == 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
if index == len(modelID) {
|
||||
return true
|
||||
}
|
||||
return modelID[index] == '-' || modelID[index] == '.'
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
package chatd_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestMergeProviderAPIKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
merged := chatd.MergeProviderAPIKeys(
|
||||
chatd.ProviderAPIKeys{
|
||||
OpenAI: " deployment-openai ",
|
||||
Anthropic: "deployment-anthropic",
|
||||
ByProvider: map[string]string{
|
||||
"openrouter": " deployment-openrouter ",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
"openai": " https://openai.example.com/v1 ",
|
||||
},
|
||||
},
|
||||
[]chatd.ConfiguredProvider{
|
||||
{Provider: "openai", APIKey: " ", BaseURL: "https://db-openai.example.com/v1"},
|
||||
{Provider: "anthropic", APIKey: " provider-anthropic "},
|
||||
{Provider: "openrouter", APIKey: "provider-openrouter"},
|
||||
},
|
||||
)
|
||||
|
||||
require.Equal(t, "deployment-openai", merged.OpenAI)
|
||||
require.Equal(t, "provider-anthropic", merged.Anthropic)
|
||||
require.Equal(t, "provider-openrouter", merged.APIKey("openrouter"))
|
||||
require.Equal(t, "https://db-openai.example.com/v1", merged.BaseURL("openai"))
|
||||
}
|
||||
|
||||
func TestSupportedProvidersNormalize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, []string{
|
||||
"anthropic",
|
||||
"azure",
|
||||
"bedrock",
|
||||
"google",
|
||||
"openai",
|
||||
"openai-compat",
|
||||
"openrouter",
|
||||
"vercel",
|
||||
}, chatd.SupportedProviders())
|
||||
|
||||
for _, provider := range chatd.SupportedProviders() {
|
||||
require.Equal(t, provider, chatd.NormalizeProvider(provider))
|
||||
require.Equal(t, provider, chatd.NormalizeProvider(strings.ToUpper(provider)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamManagerStopStreamDropsMessageParts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
manager := chatd.NewStreamManager(testutil.Logger(t))
|
||||
_, events, cancel := manager.Subscribe(chatID)
|
||||
defer cancel()
|
||||
|
||||
manager.StartStream(chatID)
|
||||
manager.Publish(chatID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
MessagePart: &codersdk.ChatStreamMessagePart{
|
||||
Role: string(fantasy.MessageRoleAssistant),
|
||||
Part: codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "before-stop",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
select {
|
||||
case event, ok := <-events:
|
||||
require.True(t, ok)
|
||||
require.Equal(t, codersdk.ChatStreamEventTypeMessagePart, event.Type)
|
||||
require.NotNil(t, event.MessagePart)
|
||||
require.Equal(t, "before-stop", event.MessagePart.Part.Text)
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for initial stream message part")
|
||||
}
|
||||
|
||||
manager.StopStream(chatID)
|
||||
manager.Publish(chatID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
MessagePart: &codersdk.ChatStreamMessagePart{
|
||||
Role: string(fantasy.MessageRoleAssistant),
|
||||
Part: codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "after-stop",
|
||||
},
|
||||
},
|
||||
})
|
||||
manager.Publish(chatID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeStatus,
|
||||
Status: &codersdk.ChatStreamStatus{
|
||||
Status: codersdk.ChatStatusWaiting,
|
||||
},
|
||||
})
|
||||
|
||||
select {
|
||||
case event, ok := <-events:
|
||||
require.True(t, ok)
|
||||
require.Equal(t, codersdk.ChatStreamEventTypeStatus, event.Type)
|
||||
require.NotNil(t, event.Status)
|
||||
require.Equal(t, codersdk.ChatStatusWaiting, event.Status.Status)
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for status event after stream stop")
|
||||
}
|
||||
|
||||
require.Never(t, func() bool {
|
||||
select {
|
||||
case <-events:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, 100*time.Millisecond, 10*time.Millisecond)
|
||||
}
|
||||
@@ -0,0 +1,766 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
var ErrSubagentNotDescendant = xerrors.New("target chat is not a descendant of current chat")
|
||||
|
||||
const defaultSubagentAwaitTimeout = 5 * time.Minute
|
||||
const subagentAwaitPollInterval = 200 * time.Millisecond
|
||||
const subagentReportToolCallIDPrefix = "subagent_report_"
|
||||
const defaultFallbackSubagentReport = "Sub-agent completed without explicit report."
|
||||
|
||||
const (
|
||||
subagentEventRequest = "request"
|
||||
subagentEventResponse = "response"
|
||||
|
||||
subagentResponseMarkerRole = "__subagent_response_marker"
|
||||
subagentReportOnlyMarkerRole = "__subagent_report_only_marker"
|
||||
)
|
||||
|
||||
type interruptChatFn func(chatID uuid.UUID) bool
|
||||
|
||||
type SubagentAwaitResult struct {
|
||||
RequestID uuid.UUID
|
||||
Report string
|
||||
DurationMS int64
|
||||
}
|
||||
|
||||
type subagentRequestKey struct {
|
||||
chatID uuid.UUID
|
||||
requestID uuid.UUID
|
||||
}
|
||||
|
||||
// SubagentService handles delegated subagent request/response correlation and
|
||||
// in-memory waiting for subagent responses.
|
||||
type SubagentService struct {
|
||||
db database.Store
|
||||
|
||||
interruptChat interruptChatFn
|
||||
streamer *StreamManager
|
||||
|
||||
waitersMu sync.Mutex
|
||||
waiters map[subagentRequestKey][]chan SubagentAwaitResult
|
||||
results map[subagentRequestKey]SubagentAwaitResult
|
||||
}
|
||||
|
||||
func newSubagentService(db database.Store, interruptChat interruptChatFn) *SubagentService {
|
||||
return &SubagentService{
|
||||
db: db,
|
||||
interruptChat: interruptChat,
|
||||
waiters: make(map[subagentRequestKey][]chan SubagentAwaitResult),
|
||||
results: make(map[subagentRequestKey]SubagentAwaitResult),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SubagentService) setStreamManager(streamer *StreamManager) {
|
||||
s.streamer = streamer
|
||||
}
|
||||
|
||||
func (s *SubagentService) publishChildStatus(chat database.Chat, status database.ChatStatus) {
|
||||
if !chat.ParentChatID.Valid || chat.ParentChatID.UUID == uuid.Nil {
|
||||
return
|
||||
}
|
||||
if s.streamer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.streamer.Publish(chat.ParentChatID.UUID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeStatus,
|
||||
ChatID: chat.ID,
|
||||
Status: &codersdk.ChatStreamStatus{
|
||||
Status: codersdk.ChatStatus(status),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SubagentService) CreateChildSubagentChat(
|
||||
ctx context.Context,
|
||||
parent database.Chat,
|
||||
prompt string,
|
||||
title string,
|
||||
background bool,
|
||||
) (database.Chat, uuid.UUID, error) {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
return database.Chat{}, uuid.Nil, xerrors.New("prompt is required")
|
||||
}
|
||||
|
||||
title = strings.TrimSpace(title)
|
||||
if title == "" {
|
||||
title = fallbackChatTitle(prompt)
|
||||
}
|
||||
|
||||
rootChatID := parent.ID
|
||||
if parent.RootChatID.Valid {
|
||||
rootChatID = parent.RootChatID.UUID
|
||||
}
|
||||
|
||||
child, err := s.db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: parent.OwnerID,
|
||||
WorkspaceID: parent.WorkspaceID,
|
||||
WorkspaceAgentID: parent.WorkspaceAgentID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: rootChatID,
|
||||
Valid: true,
|
||||
},
|
||||
Title: title,
|
||||
ModelConfig: parent.ModelConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return database.Chat{}, uuid.Nil, xerrors.Errorf("insert child chat: %w", err)
|
||||
}
|
||||
|
||||
requestID := uuid.New()
|
||||
if err := s.insertRequestMessage(ctx, child.ID, prompt, requestID); err != nil {
|
||||
return database.Chat{}, uuid.Nil, err
|
||||
}
|
||||
|
||||
// Child subagents are always enqueued asynchronously in phase-1, regardless
|
||||
// of whether the parent awaits in the same tool call.
|
||||
_ = background
|
||||
|
||||
child, err = s.requeueChatIfNeeded(ctx, child)
|
||||
if err != nil {
|
||||
return database.Chat{}, uuid.Nil, err
|
||||
}
|
||||
|
||||
s.clearCachedResult(child.ID, requestID)
|
||||
return child, requestID, nil
|
||||
}
|
||||
|
||||
func (s *SubagentService) SendSubagentMessage(
|
||||
ctx context.Context,
|
||||
parentChatID uuid.UUID,
|
||||
targetChatID uuid.UUID,
|
||||
message string,
|
||||
) (database.Chat, uuid.UUID, error) {
|
||||
message = strings.TrimSpace(message)
|
||||
if message == "" {
|
||||
return database.Chat{}, uuid.Nil, xerrors.New("message is required")
|
||||
}
|
||||
|
||||
isDescendant, err := s.isDescendant(ctx, parentChatID, targetChatID)
|
||||
if err != nil {
|
||||
return database.Chat{}, uuid.Nil, err
|
||||
}
|
||||
if !isDescendant {
|
||||
return database.Chat{}, uuid.Nil, ErrSubagentNotDescendant
|
||||
}
|
||||
|
||||
targetChat, err := s.db.GetChatByID(ctx, targetChatID)
|
||||
if err != nil {
|
||||
return database.Chat{}, uuid.Nil, xerrors.Errorf("get target chat: %w", err)
|
||||
}
|
||||
|
||||
requestID := uuid.New()
|
||||
if err := s.insertRequestMessage(ctx, targetChatID, message, requestID); err != nil {
|
||||
return database.Chat{}, uuid.Nil, err
|
||||
}
|
||||
|
||||
targetChat, err = s.requeueChatIfNeeded(ctx, targetChat)
|
||||
if err != nil {
|
||||
return database.Chat{}, uuid.Nil, err
|
||||
}
|
||||
|
||||
s.clearCachedResult(targetChatID, requestID)
|
||||
return targetChat, requestID, nil
|
||||
}
|
||||
|
||||
func (s *SubagentService) insertRequestMessage(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
message string,
|
||||
requestID uuid.UUID,
|
||||
) error {
|
||||
userContent, err := marshalContentBlocks([]fantasy.Content{
|
||||
fantasy.TextContent{Text: message},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal subagent request message: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
||||
ChatID: chatID,
|
||||
Role: string(fantasy.MessageRoleUser),
|
||||
Content: userContent,
|
||||
Hidden: false,
|
||||
SubagentRequestID: uuid.NullUUID{
|
||||
UUID: requestID,
|
||||
Valid: true,
|
||||
},
|
||||
SubagentEvent: sql.NullString{
|
||||
String: subagentEventRequest,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert subagent request message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SubagentService) requeueChatIfNeeded(ctx context.Context, chat database.Chat) (database.Chat, error) {
|
||||
if chat.Status != database.ChatStatusWaiting && chat.Status != database.ChatStatusCompleted {
|
||||
return chat, nil
|
||||
}
|
||||
|
||||
updatedChat, err := s.db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusPending,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
})
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("requeue subagent chat: %w", err)
|
||||
}
|
||||
s.publishChildStatus(updatedChat, database.ChatStatusPending)
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
func (s *SubagentService) LatestPendingRequestID(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
) (uuid.UUID, bool, error) {
|
||||
requestID, err := s.db.GetLatestPendingSubagentRequestIDByChatID(ctx, chatID)
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
return uuid.Nil, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return uuid.Nil, false, xerrors.Errorf("get latest pending subagent request: %w", err)
|
||||
}
|
||||
if !requestID.Valid || requestID.UUID == uuid.Nil {
|
||||
return uuid.Nil, false, nil
|
||||
}
|
||||
return requestID.UUID, true, nil
|
||||
}
|
||||
|
||||
func (s *SubagentService) HasPendingRequest(ctx context.Context, chatID uuid.UUID) (bool, error) {
|
||||
_, hasPending, err := s.LatestPendingRequestID(ctx, chatID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return hasPending, nil
|
||||
}
|
||||
|
||||
func (s *SubagentService) ShouldRunReportOnlyPass(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
requestID uuid.UUID,
|
||||
) (bool, error) {
|
||||
messages, err := s.db.GetChatMessagesByChatID(ctx, chatID)
|
||||
if err != nil {
|
||||
return false, xerrors.Errorf("get chat messages: %w", err)
|
||||
}
|
||||
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
message := messages[i]
|
||||
if !message.SubagentRequestID.Valid || message.SubagentRequestID.UUID != requestID {
|
||||
continue
|
||||
}
|
||||
if message.SubagentEvent.Valid && message.SubagentEvent.String == subagentEventRequest {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *SubagentService) MarkReportOnlyPassRequested(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
requestID uuid.UUID,
|
||||
) error {
|
||||
content, err := marshalContentBlocks([]fantasy.Content{
|
||||
fantasy.TextContent{Text: "report-only pass requested"},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal report-only marker: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
||||
ChatID: chatID,
|
||||
Role: subagentReportOnlyMarkerRole,
|
||||
Content: content,
|
||||
Hidden: true,
|
||||
SubagentRequestID: uuid.NullUUID{
|
||||
UUID: requestID,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert report-only marker: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SubagentService) MarkSubagentReported(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
report string,
|
||||
explicitRequestID uuid.NullUUID,
|
||||
) (SubagentAwaitResult, error) {
|
||||
report = strings.TrimSpace(report)
|
||||
|
||||
chat, err := s.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return SubagentAwaitResult{}, xerrors.Errorf("get chat: %w", err)
|
||||
}
|
||||
|
||||
requestID, err := s.resolveReportRequestID(ctx, chatID, explicitRequestID)
|
||||
if err != nil {
|
||||
return SubagentAwaitResult{}, err
|
||||
}
|
||||
|
||||
responseContent, err := marshalContentBlocks([]fantasy.Content{
|
||||
fantasy.TextContent{Text: report},
|
||||
})
|
||||
if err != nil {
|
||||
return SubagentAwaitResult{}, xerrors.Errorf("marshal subagent response marker: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
||||
ChatID: chatID,
|
||||
Role: subagentResponseMarkerRole,
|
||||
Content: responseContent,
|
||||
Hidden: true,
|
||||
SubagentRequestID: uuid.NullUUID{
|
||||
UUID: requestID,
|
||||
Valid: true,
|
||||
},
|
||||
SubagentEvent: sql.NullString{
|
||||
String: subagentEventResponse,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return SubagentAwaitResult{}, xerrors.Errorf("insert subagent response marker: %w", err)
|
||||
}
|
||||
|
||||
result, ok, err := s.responseForRequest(ctx, chatID, requestID)
|
||||
if err != nil {
|
||||
return SubagentAwaitResult{}, err
|
||||
}
|
||||
if !ok {
|
||||
result = SubagentAwaitResult{
|
||||
RequestID: requestID,
|
||||
Report: report,
|
||||
DurationMS: 0,
|
||||
}
|
||||
}
|
||||
|
||||
s.resolveRequestWaiters(subagentRequestKey{chatID: chatID, requestID: requestID}, result)
|
||||
|
||||
if chat.ParentChatID.Valid {
|
||||
s.publishChildStatus(chat, database.ChatStatusCompleted)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *SubagentService) resolveReportRequestID(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
explicitRequestID uuid.NullUUID,
|
||||
) (uuid.UUID, error) {
|
||||
if explicitRequestID.Valid && explicitRequestID.UUID != uuid.Nil {
|
||||
return explicitRequestID.UUID, nil
|
||||
}
|
||||
|
||||
requestID, err := s.db.GetLatestPendingSubagentRequestIDByChatID(ctx, chatID)
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
return uuid.Nil, xerrors.New("no pending subagent request found")
|
||||
}
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("get latest pending subagent request: %w", err)
|
||||
}
|
||||
if !requestID.Valid || requestID.UUID == uuid.Nil {
|
||||
return uuid.Nil, xerrors.New("no pending subagent request found")
|
||||
}
|
||||
return requestID.UUID, nil
|
||||
}
|
||||
|
||||
func (s *SubagentService) SynthesizeFallbackSubagentReport(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
requestID uuid.UUID,
|
||||
) string {
|
||||
messages, err := s.db.GetChatMessagesByChatID(ctx, chatID)
|
||||
if err != nil {
|
||||
return defaultFallbackSubagentReport
|
||||
}
|
||||
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
message := messages[i]
|
||||
if message.Role != string(fantasy.MessageRoleAssistant) {
|
||||
continue
|
||||
}
|
||||
if requestID != uuid.Nil {
|
||||
if !message.SubagentRequestID.Valid || message.SubagentRequestID.UUID != requestID {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
content, parseErr := parseContentBlocks(message.Role, message.Content)
|
||||
if parseErr != nil {
|
||||
continue
|
||||
}
|
||||
report := strings.TrimSpace(contentBlocksToText(content))
|
||||
if report != "" {
|
||||
return report
|
||||
}
|
||||
}
|
||||
|
||||
return defaultFallbackSubagentReport
|
||||
}
|
||||
|
||||
func (s *SubagentService) AwaitSubagentReport(
|
||||
ctx context.Context,
|
||||
parentChatID uuid.UUID,
|
||||
targetChatID uuid.UUID,
|
||||
requestID uuid.UUID,
|
||||
timeout time.Duration,
|
||||
) (SubagentAwaitResult, error) {
|
||||
isDescendant, err := s.isDescendant(ctx, parentChatID, targetChatID)
|
||||
if err != nil {
|
||||
return SubagentAwaitResult{}, err
|
||||
}
|
||||
if !isDescendant {
|
||||
return SubagentAwaitResult{}, ErrSubagentNotDescendant
|
||||
}
|
||||
|
||||
key := subagentRequestKey{chatID: targetChatID, requestID: requestID}
|
||||
if result, ok := s.cachedResult(targetChatID, requestID); ok {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if result, ok, err := s.responseForRequest(ctx, targetChatID, requestID); err != nil {
|
||||
return SubagentAwaitResult{}, err
|
||||
} else if ok {
|
||||
s.cacheResult(targetChatID, requestID, result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
waiter := make(chan SubagentAwaitResult, 1)
|
||||
if result, ok := s.registerWaiter(key, waiter); ok {
|
||||
return result, nil
|
||||
}
|
||||
defer s.unregisterWaiter(key, waiter)
|
||||
|
||||
if timeout <= 0 {
|
||||
timeout = defaultSubagentAwaitTimeout
|
||||
}
|
||||
deadline := time.NewTimer(timeout)
|
||||
defer deadline.Stop()
|
||||
|
||||
ticker := time.NewTicker(subagentAwaitPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case result := <-waiter:
|
||||
return result, nil
|
||||
case <-ticker.C:
|
||||
result, ok, lookupErr := s.responseForRequest(ctx, targetChatID, requestID)
|
||||
if lookupErr != nil {
|
||||
return SubagentAwaitResult{}, lookupErr
|
||||
}
|
||||
if ok {
|
||||
s.resolveRequestWaiters(key, result)
|
||||
return result, nil
|
||||
}
|
||||
case <-deadline.C:
|
||||
return SubagentAwaitResult{}, xerrors.New("timed out waiting for delegated subagent report")
|
||||
case <-ctx.Done():
|
||||
return SubagentAwaitResult{}, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SubagentService) responseForRequest(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
requestID uuid.UUID,
|
||||
) (SubagentAwaitResult, bool, error) {
|
||||
message, err := s.db.GetSubagentResponseMessageByChatIDAndRequestID(ctx,
|
||||
database.GetSubagentResponseMessageByChatIDAndRequestIDParams{
|
||||
ChatID: chatID,
|
||||
SubagentRequestID: requestID,
|
||||
},
|
||||
)
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
return SubagentAwaitResult{}, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return SubagentAwaitResult{}, false, xerrors.Errorf("get subagent response marker: %w", err)
|
||||
}
|
||||
|
||||
duration, err := s.db.GetSubagentRequestDurationByChatIDAndRequestID(ctx,
|
||||
database.GetSubagentRequestDurationByChatIDAndRequestIDParams{
|
||||
ChatID: chatID,
|
||||
SubagentRequestID: requestID,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return SubagentAwaitResult{}, false, xerrors.Errorf("get subagent request duration: %w", err)
|
||||
}
|
||||
|
||||
report := ""
|
||||
content, parseErr := parseContentBlocks(message.Role, message.Content)
|
||||
if parseErr == nil {
|
||||
report = strings.TrimSpace(contentBlocksToText(content))
|
||||
}
|
||||
|
||||
return SubagentAwaitResult{
|
||||
RequestID: requestID,
|
||||
Report: report,
|
||||
DurationMS: duration,
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
func (s *SubagentService) TerminateSubagentSubtree(
|
||||
ctx context.Context,
|
||||
parentChatID uuid.UUID,
|
||||
targetChatID uuid.UUID,
|
||||
) error {
|
||||
isDescendant, err := s.isDescendant(ctx, parentChatID, targetChatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !isDescendant {
|
||||
return ErrSubagentNotDescendant
|
||||
}
|
||||
|
||||
subtree, err := s.subtree(ctx, targetChatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, chat := range subtree {
|
||||
if s.streamer != nil {
|
||||
s.streamer.StopStream(chat.ID)
|
||||
}
|
||||
|
||||
if s.interruptChat != nil {
|
||||
s.interruptChat(chat.ID)
|
||||
}
|
||||
|
||||
if chat.Status == database.ChatStatusPending {
|
||||
updatedChat, err := s.db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusWaiting,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("set pending chat waiting for termination: %w", err)
|
||||
}
|
||||
s.publishChildStatus(updatedChat, database.ChatStatusWaiting)
|
||||
}
|
||||
|
||||
for {
|
||||
requestID, hasPending, requestErr := s.LatestPendingRequestID(ctx, chat.ID)
|
||||
if requestErr != nil {
|
||||
return requestErr
|
||||
}
|
||||
if !hasPending {
|
||||
break
|
||||
}
|
||||
|
||||
_, reportErr := s.MarkSubagentReported(ctx, chat.ID, "terminated", uuid.NullUUID{
|
||||
UUID: requestID,
|
||||
Valid: true,
|
||||
})
|
||||
if reportErr != nil {
|
||||
return reportErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SubagentService) HasActiveDescendants(ctx context.Context, chatID uuid.UUID) (bool, error) {
|
||||
descendants, err := s.listDescendants(ctx, chatID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, descendant := range descendants {
|
||||
_, hasPending, requestErr := s.LatestPendingRequestID(ctx, descendant.ID)
|
||||
if requestErr != nil {
|
||||
return false, requestErr
|
||||
}
|
||||
if hasPending {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *SubagentService) isDescendant(
|
||||
ctx context.Context,
|
||||
ancestorChatID uuid.UUID,
|
||||
targetChatID uuid.UUID,
|
||||
) (bool, error) {
|
||||
if ancestorChatID == targetChatID {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
descendants, err := s.listDescendants(ctx, ancestorChatID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, descendant := range descendants {
|
||||
if descendant.ID == targetChatID {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *SubagentService) subtree(ctx context.Context, rootChatID uuid.UUID) ([]database.Chat, error) {
|
||||
rootChat, err := s.db.GetChatByID(ctx, rootChatID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get subtree root chat: %w", err)
|
||||
}
|
||||
|
||||
descendants, err := s.listDescendants(ctx, rootChatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make([]database.Chat, 0, len(descendants)+1)
|
||||
out = append(out, rootChat)
|
||||
out = append(out, descendants...)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *SubagentService) listDescendants(ctx context.Context, chatID uuid.UUID) ([]database.Chat, error) {
|
||||
queue := []uuid.UUID{chatID}
|
||||
visited := map[uuid.UUID]struct{}{
|
||||
chatID: {},
|
||||
}
|
||||
|
||||
out := make([]database.Chat, 0)
|
||||
for len(queue) > 0 {
|
||||
parentChatID := queue[0]
|
||||
queue = queue[1:]
|
||||
|
||||
children, err := s.db.ListChildChatsByParentID(ctx, parentChatID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("list child chats for %s: %w", parentChatID, err)
|
||||
}
|
||||
|
||||
for _, child := range children {
|
||||
if _, ok := visited[child.ID]; ok {
|
||||
continue
|
||||
}
|
||||
visited[child.ID] = struct{}{}
|
||||
out = append(out, child)
|
||||
queue = append(queue, child.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *SubagentService) cachedResult(chatID uuid.UUID, requestID uuid.UUID) (SubagentAwaitResult, bool) {
|
||||
s.waitersMu.Lock()
|
||||
defer s.waitersMu.Unlock()
|
||||
|
||||
result, ok := s.results[subagentRequestKey{chatID: chatID, requestID: requestID}]
|
||||
return result, ok
|
||||
}
|
||||
|
||||
func (s *SubagentService) cacheResult(chatID uuid.UUID, requestID uuid.UUID, result SubagentAwaitResult) {
|
||||
s.waitersMu.Lock()
|
||||
defer s.waitersMu.Unlock()
|
||||
|
||||
s.results[subagentRequestKey{chatID: chatID, requestID: requestID}] = result
|
||||
}
|
||||
|
||||
func (s *SubagentService) clearCachedResult(chatID uuid.UUID, requestID uuid.UUID) {
|
||||
s.waitersMu.Lock()
|
||||
defer s.waitersMu.Unlock()
|
||||
|
||||
delete(s.results, subagentRequestKey{chatID: chatID, requestID: requestID})
|
||||
}
|
||||
|
||||
func (s *SubagentService) resolveRequestWaiters(
|
||||
key subagentRequestKey,
|
||||
result SubagentAwaitResult,
|
||||
) {
|
||||
s.waitersMu.Lock()
|
||||
s.results[key] = result
|
||||
waiters := s.waiters[key]
|
||||
delete(s.waiters, key)
|
||||
s.waitersMu.Unlock()
|
||||
|
||||
for _, waiter := range waiters {
|
||||
select {
|
||||
case waiter <- result:
|
||||
default:
|
||||
}
|
||||
close(waiter)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SubagentService) registerWaiter(
|
||||
key subagentRequestKey,
|
||||
waiter chan SubagentAwaitResult,
|
||||
) (SubagentAwaitResult, bool) {
|
||||
s.waitersMu.Lock()
|
||||
defer s.waitersMu.Unlock()
|
||||
|
||||
if result, ok := s.results[key]; ok {
|
||||
return result, true
|
||||
}
|
||||
|
||||
s.waiters[key] = append(s.waiters[key], waiter)
|
||||
return SubagentAwaitResult{}, false
|
||||
}
|
||||
|
||||
func (s *SubagentService) unregisterWaiter(
|
||||
key subagentRequestKey,
|
||||
waiter chan SubagentAwaitResult,
|
||||
) {
|
||||
s.waitersMu.Lock()
|
||||
defer s.waitersMu.Unlock()
|
||||
|
||||
waiters := s.waiters[key]
|
||||
if len(waiters) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
filtered := make([]chan SubagentAwaitResult, 0, len(waiters))
|
||||
for _, current := range waiters {
|
||||
if current == waiter {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, current)
|
||||
}
|
||||
|
||||
if len(filtered) == 0 {
|
||||
delete(s.waiters, key)
|
||||
return
|
||||
}
|
||||
s.waiters[key] = filtered
|
||||
}
|
||||
@@ -0,0 +1,989 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestProcessor_SubagentToolIncludesCreatedTitle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parentID := uuid.New()
|
||||
parent := testSubagentChat(parentID, uuid.Nil)
|
||||
parent.Title = "Parent"
|
||||
|
||||
store := newSubagentServiceTestStore(parent)
|
||||
processor := &Processor{subagentService: newSubagentService(store, nil)}
|
||||
chatState := parent
|
||||
chatStateMu := &sync.Mutex{}
|
||||
tools := processor.agentTools(nil, &chatState, chatStateMu, nil)
|
||||
|
||||
tool := testFindAgentTool(t, tools, "subagent")
|
||||
input, err := json.Marshal(subagentArgs{
|
||||
Prompt: "Run delegated child work",
|
||||
Title: "Delegated child",
|
||||
Background: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
response, err := tool.Run(context.Background(), fantasy.ToolCall{
|
||||
ID: "tool-call-subagent",
|
||||
Name: "subagent",
|
||||
Input: string(input),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, response.IsError)
|
||||
|
||||
payload := parseToolPayload(t, response.Content)
|
||||
require.Equal(t, "Delegated child", payload["title"])
|
||||
require.NotEmpty(t, payload["chat_id"])
|
||||
require.NotEmpty(t, payload["request_id"])
|
||||
require.Equal(t, "pending", payload["status"])
|
||||
}
|
||||
|
||||
func TestProcessor_SubagentAwaitToolIncludesTargetTitle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parentID := uuid.New()
|
||||
childID := uuid.New()
|
||||
requestID := uuid.New()
|
||||
|
||||
parent := testSubagentChat(parentID, uuid.Nil)
|
||||
child := testSubagentChat(childID, parentID)
|
||||
child.Title = "Awaited child"
|
||||
|
||||
store := newSubagentServiceTestStore(parent, child)
|
||||
require.NoError(t, store.insertSubagentRequestMessage(childID, requestID, "work"))
|
||||
require.NoError(t, store.insertSubagentResponseMessage(childID, requestID, "done"))
|
||||
|
||||
processor := &Processor{subagentService: newSubagentService(store, nil)}
|
||||
chatState := parent
|
||||
chatStateMu := &sync.Mutex{}
|
||||
tools := processor.agentTools(nil, &chatState, chatStateMu, nil)
|
||||
|
||||
tool := testFindAgentTool(t, tools, "subagent_await")
|
||||
input, err := json.Marshal(subagentAwaitArgs{
|
||||
ChatID: childID.String(),
|
||||
RequestID: requestID.String(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
response, err := tool.Run(context.Background(), fantasy.ToolCall{
|
||||
ID: "tool-call-subagent-await",
|
||||
Name: "subagent_await",
|
||||
Input: string(input),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, response.IsError)
|
||||
|
||||
payload := parseToolPayload(t, response.Content)
|
||||
require.Equal(t, child.Title, payload["title"])
|
||||
require.Equal(t, childID.String(), payload["chat_id"])
|
||||
require.Equal(t, requestID.String(), payload["request_id"])
|
||||
require.Equal(t, "done", payload["report"])
|
||||
require.Equal(t, "completed", payload["status"])
|
||||
}
|
||||
|
||||
func TestProcessor_SubagentMessageToolIncludesTargetTitle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parentID := uuid.New()
|
||||
childID := uuid.New()
|
||||
|
||||
parent := testSubagentChat(parentID, uuid.Nil)
|
||||
child := testSubagentChat(childID, parentID)
|
||||
child.Title = "Message target"
|
||||
|
||||
store := newSubagentServiceTestStore(parent, child)
|
||||
processor := &Processor{subagentService: newSubagentService(store, nil)}
|
||||
chatState := parent
|
||||
chatStateMu := &sync.Mutex{}
|
||||
tools := processor.agentTools(nil, &chatState, chatStateMu, nil)
|
||||
|
||||
tool := testFindAgentTool(t, tools, "subagent_message")
|
||||
input, err := json.Marshal(subagentMessageArgs{
|
||||
ChatID: childID.String(),
|
||||
Message: "follow-up request",
|
||||
Await: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
response, err := tool.Run(context.Background(), fantasy.ToolCall{
|
||||
ID: "tool-call-subagent-message",
|
||||
Name: "subagent_message",
|
||||
Input: string(input),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, response.IsError)
|
||||
|
||||
payload := parseToolPayload(t, response.Content)
|
||||
require.Equal(t, child.Title, payload["title"])
|
||||
require.Equal(t, childID.String(), payload["chat_id"])
|
||||
require.NotEmpty(t, payload["request_id"])
|
||||
require.Equal(t, "pending", payload["status"])
|
||||
}
|
||||
|
||||
func testFindAgentTool(t *testing.T, tools []fantasy.AgentTool, name string) fantasy.AgentTool {
|
||||
t.Helper()
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Info().Name == name {
|
||||
return tool
|
||||
}
|
||||
}
|
||||
require.FailNow(t, "tool not found", "name=%s", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseToolPayload(t *testing.T, content string) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
payload := make(map[string]any)
|
||||
require.NoError(t, json.Unmarshal([]byte(content), &payload))
|
||||
return payload
|
||||
}
|
||||
|
||||
func TestSubagentService_AwaitSubagentReportRejectsNonDescendant(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rootID := uuid.New()
|
||||
childAID := uuid.New()
|
||||
childBID := uuid.New()
|
||||
|
||||
store := newSubagentServiceTestStore(
|
||||
testSubagentChat(rootID, uuid.Nil),
|
||||
testSubagentChat(childAID, rootID),
|
||||
testSubagentChat(childBID, rootID),
|
||||
)
|
||||
service := newSubagentService(store, nil)
|
||||
|
||||
_, err := service.AwaitSubagentReport(context.Background(), childAID, childBID, uuid.New(), time.Second)
|
||||
require.ErrorIs(t, err, ErrSubagentNotDescendant)
|
||||
}
|
||||
|
||||
func TestSubagentService_AwaitSubagentReportResolvesWaiter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rootID := uuid.New()
|
||||
childID := uuid.New()
|
||||
|
||||
store := newSubagentServiceTestStore(
|
||||
testSubagentChat(rootID, uuid.Nil),
|
||||
testSubagentChat(childID, rootID),
|
||||
)
|
||||
service := newSubagentService(store, nil)
|
||||
|
||||
_, requestID, err := service.SendSubagentMessage(context.Background(), rootID, childID, "do work")
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, uuid.Nil, requestID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
type awaitResult struct {
|
||||
result SubagentAwaitResult
|
||||
err error
|
||||
}
|
||||
awaitCh := make(chan awaitResult, 1)
|
||||
go func() {
|
||||
result, err := service.AwaitSubagentReport(ctx, rootID, childID, requestID, time.Second)
|
||||
awaitCh <- awaitResult{result: result, err: err}
|
||||
}()
|
||||
|
||||
marked, err := service.MarkSubagentReported(context.Background(), childID, "completed", uuid.NullUUID{UUID: requestID, Valid: true})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, requestID, marked.RequestID)
|
||||
require.Equal(t, "completed", marked.Report)
|
||||
|
||||
select {
|
||||
case res := <-awaitCh:
|
||||
require.NoError(t, res.err)
|
||||
require.Equal(t, requestID, res.result.RequestID)
|
||||
require.Equal(t, "completed", res.result.Report)
|
||||
require.Positive(t, res.result.DurationMS)
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("timed out waiting for awaited subagent report: %v", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubagentService_AwaitSubagentReportReturnsPersistedResponseMarker(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rootID := uuid.New()
|
||||
childID := uuid.New()
|
||||
requestID := uuid.New()
|
||||
|
||||
store := newSubagentServiceTestStore(
|
||||
testSubagentChat(rootID, uuid.Nil),
|
||||
testSubagentChat(childID, rootID),
|
||||
)
|
||||
service := newSubagentService(store, nil)
|
||||
|
||||
require.NoError(t, store.insertSubagentRequestMessage(childID, requestID, "first request"))
|
||||
require.NoError(t, store.insertSubagentResponseMessage(childID, requestID, "persisted report"))
|
||||
|
||||
result, err := service.AwaitSubagentReport(context.Background(), rootID, childID, requestID, time.Second)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, requestID, result.RequestID)
|
||||
require.Equal(t, "persisted report", result.Report)
|
||||
require.Positive(t, result.DurationMS)
|
||||
}
|
||||
|
||||
func TestSubagentService_HasActiveDescendants(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rootID := uuid.New()
|
||||
childID := uuid.New()
|
||||
|
||||
store := newSubagentServiceTestStore(
|
||||
testSubagentChat(rootID, uuid.Nil),
|
||||
testSubagentChat(childID, rootID),
|
||||
)
|
||||
service := newSubagentService(store, nil)
|
||||
|
||||
_, requestID, err := service.SendSubagentMessage(context.Background(), rootID, childID, "follow up")
|
||||
require.NoError(t, err)
|
||||
|
||||
active, err := service.HasActiveDescendants(context.Background(), rootID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, active)
|
||||
|
||||
_, err = service.MarkSubagentReported(context.Background(), childID, "done", uuid.NullUUID{UUID: requestID, Valid: true})
|
||||
require.NoError(t, err)
|
||||
|
||||
active, err = service.HasActiveDescendants(context.Background(), rootID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, active)
|
||||
}
|
||||
|
||||
func TestSubagentService_TerminateSubagentSubtreeInterruptsEntireSubtree(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rootID := uuid.New()
|
||||
childID := uuid.New()
|
||||
grandchildID := uuid.New()
|
||||
|
||||
store := newSubagentServiceTestStore(
|
||||
testSubagentChat(rootID, uuid.Nil),
|
||||
testSubagentChat(childID, rootID),
|
||||
testSubagentChat(grandchildID, childID),
|
||||
)
|
||||
var (
|
||||
interruptMu sync.Mutex
|
||||
interruptedChats []uuid.UUID
|
||||
)
|
||||
service := newSubagentService(store, func(chatID uuid.UUID) bool {
|
||||
interruptMu.Lock()
|
||||
defer interruptMu.Unlock()
|
||||
interruptedChats = append(interruptedChats, chatID)
|
||||
return true
|
||||
})
|
||||
|
||||
err := service.TerminateSubagentSubtree(context.Background(), rootID, childID)
|
||||
require.NoError(t, err)
|
||||
interruptMu.Lock()
|
||||
recorded := append([]uuid.UUID(nil), interruptedChats...)
|
||||
interruptMu.Unlock()
|
||||
require.ElementsMatch(t, []uuid.UUID{childID, grandchildID}, recorded)
|
||||
}
|
||||
|
||||
func TestSubagentService_TerminateSubagentSubtreeStopsSubtreeStreams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rootID := uuid.New()
|
||||
childID := uuid.New()
|
||||
grandchildID := uuid.New()
|
||||
|
||||
store := newSubagentServiceTestStore(
|
||||
testSubagentChat(rootID, uuid.Nil),
|
||||
testSubagentChat(childID, rootID),
|
||||
testSubagentChat(grandchildID, childID),
|
||||
)
|
||||
streamManager := NewStreamManager(testutil.Logger(t))
|
||||
service := newSubagentService(store, nil)
|
||||
service.setStreamManager(streamManager)
|
||||
|
||||
_, childEvents, cancelChild := streamManager.Subscribe(childID)
|
||||
defer cancelChild()
|
||||
_, grandchildEvents, cancelGrandchild := streamManager.Subscribe(grandchildID)
|
||||
defer cancelGrandchild()
|
||||
|
||||
streamManager.StartStream(childID)
|
||||
streamManager.StartStream(grandchildID)
|
||||
|
||||
streamManager.Publish(childID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
MessagePart: &codersdk.ChatStreamMessagePart{
|
||||
Role: string(fantasy.MessageRoleAssistant),
|
||||
Part: codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "child-before-terminate",
|
||||
},
|
||||
},
|
||||
})
|
||||
streamManager.Publish(grandchildID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
MessagePart: &codersdk.ChatStreamMessagePart{
|
||||
Role: string(fantasy.MessageRoleAssistant),
|
||||
Part: codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "grandchild-before-terminate",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
select {
|
||||
case event := <-childEvents:
|
||||
require.Equal(t, codersdk.ChatStreamEventTypeMessagePart, event.Type)
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for child stream event before termination")
|
||||
}
|
||||
select {
|
||||
case event := <-grandchildEvents:
|
||||
require.Equal(t, codersdk.ChatStreamEventTypeMessagePart, event.Type)
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for grandchild stream event before termination")
|
||||
}
|
||||
|
||||
err := service.TerminateSubagentSubtree(context.Background(), rootID, childID)
|
||||
require.NoError(t, err)
|
||||
|
||||
streamManager.Publish(childID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
MessagePart: &codersdk.ChatStreamMessagePart{
|
||||
Role: string(fantasy.MessageRoleAssistant),
|
||||
Part: codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "child-after-terminate",
|
||||
},
|
||||
},
|
||||
})
|
||||
streamManager.Publish(grandchildID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
MessagePart: &codersdk.ChatStreamMessagePart{
|
||||
Role: string(fantasy.MessageRoleAssistant),
|
||||
Part: codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "grandchild-after-terminate",
|
||||
},
|
||||
},
|
||||
})
|
||||
streamManager.Publish(childID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeStatus,
|
||||
Status: &codersdk.ChatStreamStatus{
|
||||
Status: codersdk.ChatStatusWaiting,
|
||||
},
|
||||
})
|
||||
streamManager.Publish(grandchildID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeStatus,
|
||||
Status: &codersdk.ChatStreamStatus{
|
||||
Status: codersdk.ChatStatusWaiting,
|
||||
},
|
||||
})
|
||||
|
||||
select {
|
||||
case event := <-childEvents:
|
||||
require.Equal(t, codersdk.ChatStreamEventTypeStatus, event.Type)
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for child status event after termination")
|
||||
}
|
||||
select {
|
||||
case event := <-grandchildEvents:
|
||||
require.Equal(t, codersdk.ChatStreamEventTypeStatus, event.Type)
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for grandchild status event after termination")
|
||||
}
|
||||
|
||||
require.Never(t, func() bool {
|
||||
select {
|
||||
case <-childEvents:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, 100*time.Millisecond, 10*time.Millisecond)
|
||||
require.Never(t, func() bool {
|
||||
select {
|
||||
case <-grandchildEvents:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, 100*time.Millisecond, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestSubagentService_MarkSubagentReportedDoesNotInsertParentMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parentID := uuid.New()
|
||||
childID := uuid.New()
|
||||
requestID := uuid.New()
|
||||
|
||||
store := newSubagentServiceTestStore(
|
||||
testSubagentChat(parentID, uuid.Nil),
|
||||
testSubagentChat(childID, parentID),
|
||||
)
|
||||
service := newSubagentService(store, nil)
|
||||
|
||||
require.NoError(t, store.insertSubagentRequestMessage(childID, requestID, "work"))
|
||||
_, err := service.MarkSubagentReported(context.Background(), childID, "child complete", uuid.NullUUID{UUID: requestID, Valid: true})
|
||||
require.NoError(t, err)
|
||||
|
||||
// The report should not inject any messages into the parent
|
||||
// chat. The parent receives the report via subagent_await.
|
||||
messages := store.chatMessagesByChatID(parentID)
|
||||
require.Empty(t, messages)
|
||||
}
|
||||
|
||||
func TestSubagentService_MarkSubagentReportedDoesNotWakeParent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, parentStatus := range []database.ChatStatus{
|
||||
database.ChatStatusWaiting,
|
||||
database.ChatStatusCompleted,
|
||||
} {
|
||||
parentStatus := parentStatus
|
||||
t.Run(string(parentStatus), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parentID := uuid.New()
|
||||
childID := uuid.New()
|
||||
requestID := uuid.New()
|
||||
|
||||
parent := testSubagentChat(parentID, uuid.Nil)
|
||||
parent.Status = parentStatus
|
||||
|
||||
store := newSubagentServiceTestStore(
|
||||
parent,
|
||||
testSubagentChat(childID, parentID),
|
||||
)
|
||||
service := newSubagentService(store, nil)
|
||||
|
||||
require.NoError(t, store.insertSubagentRequestMessage(childID, requestID, "run"))
|
||||
_, err := service.MarkSubagentReported(context.Background(), childID, "done", uuid.NullUUID{UUID: requestID, Valid: true})
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedParent, err := store.GetChatByID(context.Background(), parentID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, parentStatus, updatedParent.Status)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubagentService_CreateChildSubagentChatPublishesPendingStatusToParentStream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parentID := uuid.New()
|
||||
parent := testSubagentChat(parentID, uuid.Nil)
|
||||
parent.Title = "Parent"
|
||||
parent.Status = database.ChatStatusWaiting
|
||||
|
||||
store := newSubagentServiceTestStore(parent)
|
||||
streamManager := NewStreamManager(testutil.Logger(t))
|
||||
service := newSubagentService(store, nil)
|
||||
service.setStreamManager(streamManager)
|
||||
|
||||
_, events, cancel := streamManager.Subscribe(parentID)
|
||||
defer cancel()
|
||||
|
||||
child, _, err := service.CreateChildSubagentChat(
|
||||
context.Background(),
|
||||
parent,
|
||||
"do delegated work",
|
||||
"Child",
|
||||
true,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusPending, child.Status)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case event, ok := <-events:
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if event.Type != codersdk.ChatStreamEventTypeStatus || event.Status == nil {
|
||||
return false
|
||||
}
|
||||
if event.ChatID != child.ID {
|
||||
return false
|
||||
}
|
||||
return event.Status.Status == codersdk.ChatStatusPending
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitShort, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestSubagentService_PublishChildStatusSkipsMissingParent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := newSubagentServiceTestStore()
|
||||
streamManager := NewStreamManager(testutil.Logger(t))
|
||||
service := newSubagentService(store, nil)
|
||||
service.setStreamManager(streamManager)
|
||||
|
||||
rootChat := testSubagentChat(uuid.New(), uuid.Nil)
|
||||
require.NotPanics(t, func() {
|
||||
service.publishChildStatus(rootChat, database.ChatStatusRunning)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSubagentService_MarkSubagentReportedDoesNotWakeRunningParent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parentID := uuid.New()
|
||||
childID := uuid.New()
|
||||
requestID := uuid.New()
|
||||
|
||||
parent := testSubagentChat(parentID, uuid.Nil)
|
||||
parent.Status = database.ChatStatusRunning
|
||||
|
||||
store := newSubagentServiceTestStore(
|
||||
parent,
|
||||
testSubagentChat(childID, parentID),
|
||||
)
|
||||
service := newSubagentService(store, nil)
|
||||
|
||||
require.NoError(t, store.insertSubagentRequestMessage(childID, requestID, "run"))
|
||||
_, err := service.MarkSubagentReported(context.Background(), childID, "done", uuid.NullUUID{UUID: requestID, Valid: true})
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedParent, err := store.GetChatByID(context.Background(), parentID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusRunning, updatedParent.Status)
|
||||
}
|
||||
|
||||
func TestSubagentService_SendSubagentMessageRejectsNonDescendant(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rootID := uuid.New()
|
||||
childAID := uuid.New()
|
||||
childBID := uuid.New()
|
||||
|
||||
store := newSubagentServiceTestStore(
|
||||
testSubagentChat(rootID, uuid.Nil),
|
||||
testSubagentChat(childAID, rootID),
|
||||
testSubagentChat(childBID, rootID),
|
||||
)
|
||||
service := newSubagentService(store, nil)
|
||||
|
||||
_, _, err := service.SendSubagentMessage(context.Background(), childAID, childBID, "follow up")
|
||||
require.ErrorIs(t, err, ErrSubagentNotDescendant)
|
||||
}
|
||||
|
||||
func TestSubagentService_SendSubagentMessageRequeuesAndReturnsRequestID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parentID := uuid.New()
|
||||
childID := uuid.New()
|
||||
|
||||
parent := testSubagentChat(parentID, uuid.Nil)
|
||||
child := testSubagentChat(childID, parentID)
|
||||
child.Status = database.ChatStatusCompleted
|
||||
|
||||
store := newSubagentServiceTestStore(parent, child)
|
||||
service := newSubagentService(store, nil)
|
||||
|
||||
updated, requestID, err := service.SendSubagentMessage(
|
||||
context.Background(),
|
||||
parentID,
|
||||
childID,
|
||||
"continue with more detail",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, uuid.Nil, requestID)
|
||||
require.Equal(t, database.ChatStatusPending, updated.Status)
|
||||
|
||||
stored, err := store.GetChatByID(context.Background(), childID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusPending, stored.Status)
|
||||
|
||||
messages := store.chatMessagesByChatID(childID)
|
||||
require.Len(t, messages, 1)
|
||||
require.Equal(t, string(fantasy.MessageRoleUser), messages[0].Role)
|
||||
require.True(t, messages[0].SubagentRequestID.Valid)
|
||||
require.Equal(t, requestID, messages[0].SubagentRequestID.UUID)
|
||||
require.True(t, messages[0].SubagentEvent.Valid)
|
||||
require.Equal(t, subagentEventRequest, messages[0].SubagentEvent.String)
|
||||
|
||||
blocks, err := parseContentBlocks(messages[0].Role, messages[0].Content)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "continue with more detail", contentBlocksToText(blocks))
|
||||
}
|
||||
|
||||
func TestSubagentService_SynthesizeFallbackSubagentReport(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
requestID := uuid.New()
|
||||
otherRequestID := uuid.New()
|
||||
|
||||
store := newSubagentServiceTestStore(
|
||||
testSubagentChat(chatID, uuid.Nil),
|
||||
)
|
||||
service := newSubagentService(store, nil)
|
||||
|
||||
userContent, err := marshalContentBlocks([]fantasy.Content{
|
||||
fantasy.TextContent{Text: "user prompt"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = store.InsertChatMessage(context.Background(), database.InsertChatMessageParams{
|
||||
ChatID: chatID,
|
||||
Role: string(fantasy.MessageRoleUser),
|
||||
Content: userContent,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
firstAssistant, err := marshalContentBlocks([]fantasy.Content{
|
||||
fantasy.TextContent{Text: "first summary"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = store.InsertChatMessage(context.Background(), database.InsertChatMessageParams{
|
||||
ChatID: chatID,
|
||||
Role: string(fantasy.MessageRoleAssistant),
|
||||
Content: firstAssistant,
|
||||
SubagentRequestID: uuid.NullUUID{
|
||||
UUID: otherRequestID,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
secondAssistant, err := marshalContentBlocks([]fantasy.Content{
|
||||
fantasy.TextContent{Text: "latest summary"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = store.InsertChatMessage(context.Background(), database.InsertChatMessageParams{
|
||||
ChatID: chatID,
|
||||
Role: string(fantasy.MessageRoleAssistant),
|
||||
Content: secondAssistant,
|
||||
SubagentRequestID: uuid.NullUUID{
|
||||
UUID: requestID,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
report := service.SynthesizeFallbackSubagentReport(context.Background(), chatID, requestID)
|
||||
require.Equal(t, "latest summary", report)
|
||||
}
|
||||
|
||||
func TestSubagentService_SynthesizeFallbackSubagentReportDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
store := newSubagentServiceTestStore(
|
||||
testSubagentChat(chatID, uuid.Nil),
|
||||
)
|
||||
service := newSubagentService(store, nil)
|
||||
|
||||
report := service.SynthesizeFallbackSubagentReport(context.Background(), chatID, uuid.New())
|
||||
require.Equal(t, defaultFallbackSubagentReport, report)
|
||||
}
|
||||
|
||||
type subagentServiceTestStore struct {
|
||||
database.Store
|
||||
|
||||
mu sync.Mutex
|
||||
chats map[uuid.UUID]database.Chat
|
||||
messages []database.ChatMessage
|
||||
nextMessageID int64
|
||||
}
|
||||
|
||||
func newSubagentServiceTestStore(chats ...database.Chat) *subagentServiceTestStore {
|
||||
byID := make(map[uuid.UUID]database.Chat, len(chats))
|
||||
for _, chat := range chats {
|
||||
byID[chat.ID] = chat
|
||||
}
|
||||
return &subagentServiceTestStore{
|
||||
chats: byID,
|
||||
nextMessageID: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *subagentServiceTestStore) GetChatByID(_ context.Context, id uuid.UUID) (database.Chat, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
chat, ok := s.chats[id]
|
||||
if !ok {
|
||||
return database.Chat{}, sql.ErrNoRows
|
||||
}
|
||||
return chat, nil
|
||||
}
|
||||
|
||||
func (s *subagentServiceTestStore) GetChatMessagesByChatID(_ context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
return s.chatMessagesByChatID(chatID), nil
|
||||
}
|
||||
|
||||
func (s *subagentServiceTestStore) GetLatestPendingSubagentRequestIDByChatID(_ context.Context, chatID uuid.UUID) (uuid.NullUUID, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
requestedAt := make(map[uuid.UUID]time.Time)
|
||||
responded := make(map[uuid.UUID]struct{})
|
||||
for _, message := range s.messages {
|
||||
if message.ChatID != chatID || !message.SubagentRequestID.Valid {
|
||||
continue
|
||||
}
|
||||
requestID := message.SubagentRequestID.UUID
|
||||
if message.SubagentEvent.Valid {
|
||||
switch message.SubagentEvent.String {
|
||||
case subagentEventRequest:
|
||||
if current, ok := requestedAt[requestID]; !ok || message.CreatedAt.After(current) {
|
||||
requestedAt[requestID] = message.CreatedAt
|
||||
}
|
||||
case subagentEventResponse:
|
||||
responded[requestID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
latestRequestID := uuid.Nil
|
||||
var latestRequestedAt time.Time
|
||||
for requestID, requestedAtTime := range requestedAt {
|
||||
if _, ok := responded[requestID]; ok {
|
||||
continue
|
||||
}
|
||||
if latestRequestID == uuid.Nil || requestedAtTime.After(latestRequestedAt) {
|
||||
latestRequestID = requestID
|
||||
latestRequestedAt = requestedAtTime
|
||||
}
|
||||
}
|
||||
|
||||
if latestRequestID == uuid.Nil {
|
||||
return uuid.NullUUID{}, sql.ErrNoRows
|
||||
}
|
||||
return uuid.NullUUID{UUID: latestRequestID, Valid: true}, nil
|
||||
}
|
||||
|
||||
func (s *subagentServiceTestStore) GetSubagentRequestDurationByChatIDAndRequestID(
|
||||
_ context.Context,
|
||||
arg database.GetSubagentRequestDurationByChatIDAndRequestIDParams,
|
||||
) (int64, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var (
|
||||
requestAt time.Time
|
||||
responseAt time.Time
|
||||
hasRequest bool
|
||||
hasReply bool
|
||||
)
|
||||
for _, message := range s.messages {
|
||||
if message.ChatID != arg.ChatID {
|
||||
continue
|
||||
}
|
||||
if !message.SubagentRequestID.Valid || message.SubagentRequestID.UUID != arg.SubagentRequestID {
|
||||
continue
|
||||
}
|
||||
if !message.SubagentEvent.Valid {
|
||||
continue
|
||||
}
|
||||
switch message.SubagentEvent.String {
|
||||
case subagentEventRequest:
|
||||
if !hasRequest || message.CreatedAt.Before(requestAt) {
|
||||
requestAt = message.CreatedAt
|
||||
hasRequest = true
|
||||
}
|
||||
case subagentEventResponse:
|
||||
if !hasReply || message.CreatedAt.After(responseAt) {
|
||||
responseAt = message.CreatedAt
|
||||
hasReply = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasRequest || !hasReply {
|
||||
return 0, nil
|
||||
}
|
||||
return responseAt.Sub(requestAt).Milliseconds(), nil
|
||||
}
|
||||
|
||||
func (s *subagentServiceTestStore) GetSubagentResponseMessageByChatIDAndRequestID(
|
||||
_ context.Context,
|
||||
arg database.GetSubagentResponseMessageByChatIDAndRequestIDParams,
|
||||
) (database.ChatMessage, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for i := len(s.messages) - 1; i >= 0; i-- {
|
||||
message := s.messages[i]
|
||||
if message.ChatID != arg.ChatID {
|
||||
continue
|
||||
}
|
||||
if !message.SubagentRequestID.Valid || message.SubagentRequestID.UUID != arg.SubagentRequestID {
|
||||
continue
|
||||
}
|
||||
if !message.SubagentEvent.Valid || message.SubagentEvent.String != subagentEventResponse {
|
||||
continue
|
||||
}
|
||||
return message, nil
|
||||
}
|
||||
return database.ChatMessage{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (s *subagentServiceTestStore) InsertChat(_ context.Context, arg database.InsertChatParams) (database.Chat, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
OwnerID: arg.OwnerID,
|
||||
WorkspaceID: arg.WorkspaceID,
|
||||
WorkspaceAgentID: arg.WorkspaceAgentID,
|
||||
Title: arg.Title,
|
||||
Status: database.ChatStatusWaiting,
|
||||
ModelConfig: arg.ModelConfig,
|
||||
ParentChatID: arg.ParentChatID,
|
||||
RootChatID: arg.RootChatID,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
s.chats[chat.ID] = chat
|
||||
return chat, nil
|
||||
}
|
||||
|
||||
func (s *subagentServiceTestStore) InsertChatMessage(_ context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
createdAt := time.Unix(0, 0).Add(time.Duration(s.nextMessageID) * time.Second)
|
||||
message := database.ChatMessage{
|
||||
ID: s.nextMessageID,
|
||||
ChatID: arg.ChatID,
|
||||
CreatedAt: createdAt,
|
||||
Role: arg.Role,
|
||||
Content: arg.Content,
|
||||
ToolCallID: arg.ToolCallID,
|
||||
Thinking: arg.Thinking,
|
||||
Hidden: arg.Hidden,
|
||||
SubagentRequestID: arg.SubagentRequestID,
|
||||
SubagentEvent: arg.SubagentEvent,
|
||||
}
|
||||
s.nextMessageID++
|
||||
s.messages = append(s.messages, message)
|
||||
return message, nil
|
||||
}
|
||||
|
||||
func (s *subagentServiceTestStore) ListChildChatsByParentID(_ context.Context, parentChatID uuid.UUID) ([]database.Chat, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
out := make([]database.Chat, 0)
|
||||
for _, chat := range s.chats {
|
||||
if chat.ParentChatID.Valid && chat.ParentChatID.UUID == parentChatID {
|
||||
out = append(out, chat)
|
||||
}
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool {
|
||||
return out[i].CreatedAt.Before(out[j].CreatedAt)
|
||||
})
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *subagentServiceTestStore) UpdateChatStatus(_ context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
chat, ok := s.chats[arg.ID]
|
||||
if !ok {
|
||||
return database.Chat{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
chat.Status = arg.Status
|
||||
chat.WorkerID = arg.WorkerID
|
||||
chat.StartedAt = arg.StartedAt
|
||||
chat.UpdatedAt = time.Now()
|
||||
s.chats[arg.ID] = chat
|
||||
return chat, nil
|
||||
}
|
||||
|
||||
func (s *subagentServiceTestStore) chatMessagesByChatID(chatID uuid.UUID) []database.ChatMessage {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
out := make([]database.ChatMessage, 0)
|
||||
for _, message := range s.messages {
|
||||
if message.ChatID == chatID {
|
||||
out = append(out, message)
|
||||
}
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool {
|
||||
return out[i].CreatedAt.Before(out[j].CreatedAt)
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *subagentServiceTestStore) insertSubagentRequestMessage(chatID uuid.UUID, requestID uuid.UUID, text string) error {
|
||||
content, err := marshalContentBlocks([]fantasy.Content{fantasy.TextContent{Text: text}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.InsertChatMessage(context.Background(), database.InsertChatMessageParams{
|
||||
ChatID: chatID,
|
||||
Role: string(fantasy.MessageRoleUser),
|
||||
Content: content,
|
||||
SubagentRequestID: uuid.NullUUID{
|
||||
UUID: requestID,
|
||||
Valid: true,
|
||||
},
|
||||
SubagentEvent: sql.NullString{String: subagentEventRequest, Valid: true},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *subagentServiceTestStore) insertSubagentResponseMessage(chatID uuid.UUID, requestID uuid.UUID, report string) error {
|
||||
content, err := marshalContentBlocks([]fantasy.Content{fantasy.TextContent{Text: report}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.InsertChatMessage(context.Background(), database.InsertChatMessageParams{
|
||||
ChatID: chatID,
|
||||
Role: subagentResponseMarkerRole,
|
||||
Content: content,
|
||||
Hidden: true,
|
||||
SubagentRequestID: uuid.NullUUID{
|
||||
UUID: requestID,
|
||||
Valid: true,
|
||||
},
|
||||
SubagentEvent: sql.NullString{String: subagentEventResponse, Valid: true},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func testSubagentChat(id uuid.UUID, parentID uuid.UUID) database.Chat {
|
||||
parentChatID := uuid.NullUUID{}
|
||||
if parentID != uuid.Nil {
|
||||
parentChatID = uuid.NullUUID{UUID: parentID, Valid: true}
|
||||
}
|
||||
|
||||
rootChatID := uuid.NullUUID{UUID: id, Valid: true}
|
||||
if parentID != uuid.Nil {
|
||||
rootChatID = uuid.NullUUID{UUID: parentID, Valid: true}
|
||||
}
|
||||
|
||||
return database.Chat{
|
||||
ID: id,
|
||||
OwnerID: uuid.New(),
|
||||
Status: database.ChatStatusWaiting,
|
||||
ParentChatID: parentChatID,
|
||||
RootChatID: rootChatID,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,357 @@
|
||||
package chatd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub/psmock"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
type templateSelectionModel struct {
|
||||
generateCall *fantasy.Call
|
||||
generateBlocks []fantasy.Content
|
||||
}
|
||||
|
||||
func (*templateSelectionModel) Provider() string {
|
||||
return "fake"
|
||||
}
|
||||
|
||||
func (*templateSelectionModel) Model() string {
|
||||
return "fake"
|
||||
}
|
||||
|
||||
func (m *templateSelectionModel) Generate(
|
||||
_ context.Context,
|
||||
call fantasy.Call,
|
||||
) (*fantasy.Response, error) {
|
||||
captured := call
|
||||
m.generateCall = &captured
|
||||
return &fantasy.Response{Content: m.generateBlocks}, nil
|
||||
}
|
||||
|
||||
func (*templateSelectionModel) Stream(
|
||||
context.Context,
|
||||
fantasy.Call,
|
||||
) (fantasy.StreamResponse, error) {
|
||||
return nil, xerrors.New("not implemented")
|
||||
}
|
||||
|
||||
func (*templateSelectionModel) GenerateObject(
|
||||
context.Context,
|
||||
fantasy.ObjectCall,
|
||||
) (*fantasy.ObjectResponse, error) {
|
||||
return nil, xerrors.New("not implemented")
|
||||
}
|
||||
|
||||
func (*templateSelectionModel) StreamObject(
|
||||
context.Context,
|
||||
fantasy.ObjectCall,
|
||||
) (fantasy.ObjectStreamResponse, error) {
|
||||
return nil, xerrors.New("not implemented")
|
||||
}
|
||||
|
||||
func TestNewWorkspaceCreator_CreateWorkspace_MultiplePromptMatchesWithoutModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
creator := chatd.NewWorkspaceCreator(
|
||||
func(ctx context.Context, _ database.Chat) (context.Context, *http.Request, string, error) {
|
||||
return ctx, httptest.NewRequest(http.MethodPost, "/api/v2/workspaces", nil), "https://coder.example", nil
|
||||
},
|
||||
func(context.Context, *http.Request) ([]database.Template, error) {
|
||||
return []database.Template{
|
||||
{ID: uuid.New(), Name: "python-starter", DisplayName: "Python Starter"},
|
||||
{ID: uuid.New(), Name: "python-web", DisplayName: "Python Web"},
|
||||
}, nil
|
||||
},
|
||||
func(context.Context, *http.Request, uuid.UUID, codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
|
||||
return codersdk.Workspace{}, xerrors.New("unexpected create workspace call")
|
||||
},
|
||||
db,
|
||||
nil,
|
||||
testutil.Logger(t),
|
||||
)
|
||||
|
||||
result, err := creator(context.Background(), chatd.CreateWorkspaceToolRequest{
|
||||
Chat: database.Chat{OwnerID: uuid.New()},
|
||||
Prompt: "create a python workspace for web development",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.Created)
|
||||
require.Equal(
|
||||
t,
|
||||
"multiple templates matched and no model is available to disambiguate",
|
||||
result.Reason,
|
||||
)
|
||||
}
|
||||
|
||||
func TestNewWorkspaceCreator_CreateWorkspace_UsesModelToDisambiguatePromptMatches(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
templateStarterID := uuid.New()
|
||||
templateWebID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
workspaceAgentID := uuid.New()
|
||||
jobID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return(
|
||||
[]database.WorkspaceAgent{{ID: workspaceAgentID}},
|
||||
nil,
|
||||
)
|
||||
|
||||
var capturedCreateReq codersdk.CreateWorkspaceRequest
|
||||
creator := chatd.NewWorkspaceCreator(
|
||||
func(ctx context.Context, _ database.Chat) (context.Context, *http.Request, string, error) {
|
||||
return ctx, httptest.NewRequest(http.MethodPost, "/api/v2/workspaces", nil), "https://coder.example", nil
|
||||
},
|
||||
func(context.Context, *http.Request) ([]database.Template, error) {
|
||||
return []database.Template{
|
||||
{ID: templateStarterID, Name: "python-starter", DisplayName: "Python Starter"},
|
||||
{ID: templateWebID, Name: "python-web", DisplayName: "Python Web"},
|
||||
}, nil
|
||||
},
|
||||
func(_ context.Context, _ *http.Request, gotOwnerID uuid.UUID, req codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
|
||||
require.Equal(t, ownerID, gotOwnerID)
|
||||
capturedCreateReq = req
|
||||
return codersdk.Workspace{
|
||||
ID: workspaceID,
|
||||
OwnerName: "alice",
|
||||
Name: "python-web-alice",
|
||||
LatestBuild: codersdk.WorkspaceBuild{
|
||||
Job: codersdk.ProvisionerJob{ID: jobID},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
db,
|
||||
nil,
|
||||
testutil.Logger(t),
|
||||
)
|
||||
|
||||
model := &templateSelectionModel{
|
||||
generateBlocks: []fantasy.Content{
|
||||
fantasy.TextContent{
|
||||
Text: fmt.Sprintf(`{"template_id":"%s","reason":"web stack"}`, templateWebID),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := creator(context.Background(), chatd.CreateWorkspaceToolRequest{
|
||||
Chat: database.Chat{OwnerID: ownerID},
|
||||
Model: model,
|
||||
Prompt: "create a python web workspace",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Created)
|
||||
require.Equal(t, workspaceID, result.WorkspaceID)
|
||||
require.Equal(t, workspaceAgentID, result.WorkspaceAgentID)
|
||||
require.Equal(t, "alice/python-web-alice", result.WorkspaceName)
|
||||
require.Equal(t, "https://coder.example/@alice/python-web-alice", result.WorkspaceURL)
|
||||
|
||||
require.Equal(t, templateWebID, capturedCreateReq.TemplateID)
|
||||
require.Equal(t, uuid.Nil, capturedCreateReq.TemplateVersionID)
|
||||
require.NotEmpty(t, capturedCreateReq.Name)
|
||||
require.NotNil(t, model.generateCall)
|
||||
require.NotNil(t, model.generateCall.ToolChoice)
|
||||
require.Equal(t, fantasy.ToolChoiceNone, *model.generateCall.ToolChoice)
|
||||
}
|
||||
|
||||
func TestNewWorkspaceCreator_CreateWorkspace_RejectsMismatchedTemplateAndVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
templateID := uuid.New()
|
||||
templateVersionTemplateID := uuid.New()
|
||||
templateVersionID := uuid.New()
|
||||
|
||||
db.EXPECT().GetTemplateVersionByID(gomock.Any(), templateVersionID).Return(database.TemplateVersion{
|
||||
ID: templateVersionID,
|
||||
TemplateID: uuid.NullUUID{
|
||||
UUID: templateVersionTemplateID,
|
||||
Valid: true,
|
||||
},
|
||||
}, nil)
|
||||
db.EXPECT().GetTemplateByID(gomock.Any(), templateVersionTemplateID).Return(database.Template{
|
||||
ID: templateVersionTemplateID,
|
||||
Name: "python-starter",
|
||||
}, nil)
|
||||
|
||||
creator := chatd.NewWorkspaceCreator(
|
||||
func(ctx context.Context, _ database.Chat) (context.Context, *http.Request, string, error) {
|
||||
return ctx, httptest.NewRequest(http.MethodPost, "/api/v2/workspaces", nil), "https://coder.example", nil
|
||||
},
|
||||
func(context.Context, *http.Request) ([]database.Template, error) {
|
||||
return nil, nil
|
||||
},
|
||||
func(context.Context, *http.Request, uuid.UUID, codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
|
||||
return codersdk.Workspace{}, xerrors.New("unexpected create workspace call")
|
||||
},
|
||||
db,
|
||||
nil,
|
||||
testutil.Logger(t),
|
||||
)
|
||||
|
||||
result, err := creator(context.Background(), chatd.CreateWorkspaceToolRequest{
|
||||
Chat: database.Chat{
|
||||
OwnerID: uuid.New(),
|
||||
},
|
||||
Prompt: "create workspace",
|
||||
Spec: json.RawMessage(
|
||||
fmt.Sprintf(`{"name":"proj","template_id":"%s","template_version_id":"%s"}`, templateID, templateVersionID),
|
||||
),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.Created)
|
||||
require.Equal(t, "template_id does not match template_version_id", result.Reason)
|
||||
}
|
||||
|
||||
func TestNewWorkspaceCreator_CreateWorkspace_StreamsBuildLogs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
ps := psmock.NewMockPubsub(ctrl)
|
||||
|
||||
templateID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
workspaceAgentID := uuid.New()
|
||||
jobID := uuid.New()
|
||||
|
||||
running := database.ProvisionerJob{
|
||||
ID: jobID,
|
||||
JobStatus: database.ProvisionerJobStatusRunning,
|
||||
}
|
||||
initialLog := database.ProvisionerJobLog{
|
||||
ID: 1,
|
||||
Source: database.LogSourceProvisioner,
|
||||
Level: database.LogLevelInfo,
|
||||
Stage: "plan",
|
||||
Output: "planning infrastructure",
|
||||
}
|
||||
notificationLog := database.ProvisionerJobLog{
|
||||
ID: 2,
|
||||
Source: database.LogSourceProvisionerDaemon,
|
||||
Level: database.LogLevelDebug,
|
||||
Stage: "apply",
|
||||
Output: "apply complete",
|
||||
}
|
||||
|
||||
db.EXPECT().GetTemplateByID(gomock.Any(), templateID).Return(database.Template{
|
||||
ID: templateID,
|
||||
Name: "python-web",
|
||||
DisplayName: "Python Web",
|
||||
}, nil)
|
||||
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return(
|
||||
[]database.WorkspaceAgent{{ID: workspaceAgentID}},
|
||||
nil,
|
||||
)
|
||||
|
||||
msg, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{
|
||||
EndOfLogs: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
subscribeCall := ps.EXPECT().SubscribeWithErr(
|
||||
provisionersdk.ProvisionerJobLogsNotifyChannel(jobID),
|
||||
gomock.Any(),
|
||||
).DoAndReturn(func(_ string, listener pubsub.ListenerWithErr) (func(), error) {
|
||||
listener(context.Background(), msg, nil)
|
||||
return func() {}, nil
|
||||
})
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetProvisionerLogsAfterID(gomock.Any(), database.GetProvisionerLogsAfterIDParams{
|
||||
JobID: jobID,
|
||||
CreatedAfter: 0,
|
||||
}).Return([]database.ProvisionerJobLog{initialLog}, nil),
|
||||
db.EXPECT().GetProvisionerJobByID(gomock.Any(), jobID).Return(running, nil),
|
||||
subscribeCall,
|
||||
db.EXPECT().GetProvisionerLogsAfterID(gomock.Any(), database.GetProvisionerLogsAfterIDParams{
|
||||
JobID: jobID,
|
||||
CreatedAfter: 1,
|
||||
}).Return([]database.ProvisionerJobLog{}, nil),
|
||||
db.EXPECT().GetProvisionerJobByID(gomock.Any(), jobID).Return(running, nil),
|
||||
db.EXPECT().GetProvisionerLogsAfterID(gomock.Any(), database.GetProvisionerLogsAfterIDParams{
|
||||
JobID: jobID,
|
||||
CreatedAfter: 1,
|
||||
}).Return([]database.ProvisionerJobLog{notificationLog}, nil),
|
||||
)
|
||||
|
||||
creator := chatd.NewWorkspaceCreator(
|
||||
func(ctx context.Context, _ database.Chat) (context.Context, *http.Request, string, error) {
|
||||
return ctx, httptest.NewRequest(http.MethodPost, "/api/v2/workspaces", nil), "https://coder.example", nil
|
||||
},
|
||||
func(context.Context, *http.Request) ([]database.Template, error) {
|
||||
return []database.Template{{ID: templateID, Name: "python-web", DisplayName: "Python Web"}}, nil
|
||||
},
|
||||
func(context.Context, *http.Request, uuid.UUID, codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
|
||||
return codersdk.Workspace{
|
||||
ID: workspaceID,
|
||||
OwnerName: "alice",
|
||||
Name: "python-web-alice",
|
||||
LatestBuild: codersdk.WorkspaceBuild{
|
||||
Job: codersdk.ProvisionerJob{ID: jobID},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
db,
|
||||
ps,
|
||||
testutil.Logger(t),
|
||||
)
|
||||
|
||||
var emitted []chatd.CreateWorkspaceBuildLog
|
||||
result, err := creator(context.Background(), chatd.CreateWorkspaceToolRequest{
|
||||
Chat: database.Chat{
|
||||
OwnerID: uuid.New(),
|
||||
},
|
||||
Prompt: "create a python web workspace",
|
||||
Spec: json.RawMessage(fmt.Sprintf(`{"name":"proj","template_id":"%s"}`, templateID)),
|
||||
BuildLogHandler: func(log chatd.CreateWorkspaceBuildLog) {
|
||||
emitted = append(emitted, log)
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Created)
|
||||
require.Equal(t, workspaceID, result.WorkspaceID)
|
||||
require.Equal(t, workspaceAgentID, result.WorkspaceAgentID)
|
||||
|
||||
require.Equal(t, []chatd.CreateWorkspaceBuildLog{
|
||||
{
|
||||
Source: string(initialLog.Source),
|
||||
Level: string(initialLog.Level),
|
||||
Stage: initialLog.Stage,
|
||||
Output: initialLog.Output,
|
||||
},
|
||||
{
|
||||
Source: string(notificationLog.Source),
|
||||
Level: string(notificationLog.Level),
|
||||
Stage: notificationLog.Stage,
|
||||
Output: notificationLog.Output,
|
||||
},
|
||||
}, emitted)
|
||||
}
|
||||
+3936
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,304 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/httpapi/httperror"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestParseGitHubRepositoryOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
owner, repo, normalized, ok := parseGitHubRepositoryOrigin("https://github.com/coder/coder.git")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "coder", owner)
|
||||
require.Equal(t, "coder", repo)
|
||||
require.Equal(t, "https://github.com/coder/coder", normalized)
|
||||
|
||||
owner, repo, normalized, ok = parseGitHubRepositoryOrigin("git@github.com:coder/coder.git")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "coder", owner)
|
||||
require.Equal(t, "coder", repo)
|
||||
require.Equal(t, "https://github.com/coder/coder", normalized)
|
||||
|
||||
owner, repo, normalized, ok = parseGitHubRepositoryOrigin("https://gitlab.com/coder/coder")
|
||||
require.False(t, ok)
|
||||
require.Empty(t, owner)
|
||||
require.Empty(t, repo)
|
||||
require.Empty(t, normalized)
|
||||
}
|
||||
|
||||
func TestResolveExternalAuthProviderType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
api := &API{
|
||||
Options: &Options{
|
||||
ExternalAuthConfigs: []*externalauth.Config{
|
||||
{
|
||||
Type: "github",
|
||||
Regex: regexp.MustCompile(`github\.com`),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := api.resolveExternalAuthProviderType("https://github.com/coder/coder")
|
||||
require.Equal(t, "github", provider)
|
||||
|
||||
provider = api.resolveExternalAuthProviderType("https://gitlab.com/coder/coder")
|
||||
require.Empty(t, provider)
|
||||
}
|
||||
|
||||
func TestShouldRefreshChatDiffStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now().UTC()
|
||||
freshStatus := database.ChatDiffStatus{
|
||||
RefreshedAt: sql.NullTime{Time: now.Add(-time.Minute), Valid: true},
|
||||
StaleAt: now.Add(time.Minute),
|
||||
}
|
||||
staleStatus := database.ChatDiffStatus{
|
||||
RefreshedAt: sql.NullTime{Time: now.Add(-time.Minute), Valid: true},
|
||||
StaleAt: now.Add(-time.Second),
|
||||
}
|
||||
|
||||
require.False(t, shouldRefreshChatDiffStatus(freshStatus, now, false))
|
||||
require.True(t, shouldRefreshChatDiffStatus(staleStatus, now, false))
|
||||
require.True(t, shouldRefreshChatDiffStatus(freshStatus, now, true))
|
||||
require.True(t, shouldRefreshChatDiffStatus(database.ChatDiffStatus{}, now, false))
|
||||
}
|
||||
|
||||
func TestFilterChatsByWorkspaceID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
workspaceID := uuid.New()
|
||||
otherWorkspaceID := uuid.New()
|
||||
|
||||
matchingChat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
|
||||
}
|
||||
otherWorkspaceChat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{UUID: otherWorkspaceID, Valid: true},
|
||||
}
|
||||
noWorkspaceChat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{},
|
||||
}
|
||||
|
||||
filtered := filterChatsByWorkspaceID(
|
||||
[]database.Chat{matchingChat, otherWorkspaceChat, noWorkspaceChat},
|
||||
workspaceID,
|
||||
)
|
||||
|
||||
require.Len(t, filtered, 1)
|
||||
require.Equal(t, matchingChat.ID, filtered[0].ID)
|
||||
}
|
||||
|
||||
func TestChatWorkspaceAuditStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ResponderError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "invalid request",
|
||||
})
|
||||
require.Equal(t, http.StatusBadRequest, chatWorkspaceAuditStatus(err))
|
||||
})
|
||||
|
||||
t.Run("GenericError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, chatWorkspaceAuditStatus(assertionError("boom")))
|
||||
})
|
||||
}
|
||||
|
||||
func TestSynthesizeChatWorkspaceRequestPreservesMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
requestID := uuid.New()
|
||||
metadata := chatWorkspaceRequestMetadata{
|
||||
Header: http.Header{"User-Agent": []string{"coder-test-agent"}},
|
||||
RemoteAddr: "203.0.113.42:9999",
|
||||
RequestID: requestID.String(),
|
||||
}
|
||||
|
||||
req, err := synthesizeChatWorkspaceRequest(
|
||||
context.Background(),
|
||||
"http://localhost/api/v2/chats/workspace",
|
||||
metadata,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, metadata.RemoteAddr, req.RemoteAddr)
|
||||
require.Equal(t, metadata.Header.Get("User-Agent"), req.Header.Get("User-Agent"))
|
||||
require.Equal(t, requestID, httpmw.RequestID(req))
|
||||
}
|
||||
|
||||
func TestSynthesizeChatWorkspaceRequestFallsBackToGeneratedRequestID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req, err := synthesizeChatWorkspaceRequest(
|
||||
context.Background(),
|
||||
"http://localhost/api/v2/chats/workspace",
|
||||
chatWorkspaceRequestMetadata{RequestID: "not-a-uuid"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, uuid.Nil, httpmw.RequestID(req))
|
||||
}
|
||||
|
||||
func TestConvertChatMessagesSkipsWorkspaceMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
messages := []database.ChatMessage{
|
||||
{
|
||||
ID: 1,
|
||||
Role: "user",
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Role: chatWorkspaceRequestMetadataRole,
|
||||
Hidden: true,
|
||||
},
|
||||
}
|
||||
|
||||
converted := convertChatMessages(messages)
|
||||
require.Len(t, converted, 1)
|
||||
require.Equal(t, int64(1), converted[0].ID)
|
||||
}
|
||||
|
||||
func TestShouldQueueUserMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
status database.ChatStatus
|
||||
isChatActive bool
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "RunningAlwaysQueues",
|
||||
status: database.ChatStatusRunning,
|
||||
isChatActive: false,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PendingAlwaysQueues",
|
||||
status: database.ChatStatusPending,
|
||||
isChatActive: false,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "WaitingQueuesWhileWorkerActive",
|
||||
status: database.ChatStatusWaiting,
|
||||
isChatActive: true,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "WaitingDoesNotQueueWhenIdle",
|
||||
status: database.ChatStatusWaiting,
|
||||
isChatActive: false,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "CompletedDoesNotQueue",
|
||||
status: database.ChatStatusCompleted,
|
||||
isChatActive: true,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(
|
||||
t,
|
||||
tc.expected,
|
||||
shouldQueueUserMessage(tc.status, tc.isChatActive),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertChatIncludesHierarchyMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ChildChatMetadata", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parentID := uuid.New()
|
||||
rootID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
workspaceAgentID := uuid.New()
|
||||
|
||||
converted := convertChat(database.Chat{
|
||||
ID: uuid.New(),
|
||||
OwnerID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
|
||||
WorkspaceAgentID: uuid.NullUUID{UUID: workspaceAgentID, Valid: true},
|
||||
ParentChatID: uuid.NullUUID{UUID: parentID, Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: rootID, Valid: true},
|
||||
Title: "Child Chat",
|
||||
}, nil)
|
||||
|
||||
require.NotNil(t, converted.ParentChatID)
|
||||
require.Equal(t, parentID, *converted.ParentChatID)
|
||||
require.NotNil(t, converted.RootChatID)
|
||||
require.Equal(t, rootID, *converted.RootChatID)
|
||||
require.NotNil(t, converted.WorkspaceID)
|
||||
require.Equal(t, workspaceID, *converted.WorkspaceID)
|
||||
require.NotNil(t, converted.WorkspaceAgentID)
|
||||
require.Equal(t, workspaceAgentID, *converted.WorkspaceAgentID)
|
||||
})
|
||||
|
||||
t.Run("RootFallbackMetadata", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rootID := uuid.New()
|
||||
converted := convertChat(database.Chat{
|
||||
ID: rootID,
|
||||
Title: "Root Chat",
|
||||
}, nil)
|
||||
|
||||
require.Nil(t, converted.ParentChatID)
|
||||
require.NotNil(t, converted.RootChatID)
|
||||
require.Equal(t, rootID, *converted.RootChatID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatWorkspaceRequestMetadataFromRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "http://example.test/chats", nil)
|
||||
req.Header.Set("User-Agent", "coder-test-agent")
|
||||
req.RemoteAddr = "203.0.113.42:9999"
|
||||
|
||||
requestID := uuid.New()
|
||||
req = req.WithContext(httpmw.WithRequestID(context.Background(), requestID))
|
||||
|
||||
metadata := chatWorkspaceRequestMetadataFromRequest(req)
|
||||
require.Equal(t, "203.0.113.42:9999", metadata.RemoteAddr)
|
||||
require.Equal(t, requestID.String(), metadata.RequestID)
|
||||
require.Equal(t, "coder-test-agent", metadata.Header.Get("User-Agent"))
|
||||
}
|
||||
|
||||
type assertionError string
|
||||
|
||||
func (e assertionError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
+104
-2
@@ -13,6 +13,7 @@ import (
|
||||
"net/http"
|
||||
httppprof "net/http/pprof"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime/pprof"
|
||||
@@ -50,6 +51,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/awsidentity"
|
||||
"github.com/coder/coder/v2/coderd/boundaryusage"
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/connectionlog"
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -265,6 +267,8 @@ type Options struct {
|
||||
// DatabaseRolluper rolls up template usage stats from raw agent and app
|
||||
// stats. This is used to provide insights in the WebUI.
|
||||
DatabaseRolluper *dbrollup.Rolluper
|
||||
// ChatProcessor handles background processing of pending chats.
|
||||
ChatProcessor *chatd.Processor
|
||||
// WorkspaceUsageTracker tracks workspace usage by the CLI.
|
||||
WorkspaceUsageTracker *workspacestats.UsageTracker
|
||||
// BoundaryUsageTracker tracks boundary usage for telemetry.
|
||||
@@ -593,6 +597,39 @@ func New(options *Options) *API {
|
||||
var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{}
|
||||
buildUsageChecker.Store(&noopUsageChecker)
|
||||
|
||||
chatProviderAPIKeys := chatd.ProviderAPIKeys{
|
||||
OpenAI: options.DeploymentValues.AI.BridgeConfig.OpenAI.Key.Value(),
|
||||
Anthropic: options.DeploymentValues.AI.BridgeConfig.Anthropic.Key.Value(),
|
||||
BaseURLByProvider: map[string]string{
|
||||
"openai": options.DeploymentValues.AI.BridgeConfig.OpenAI.BaseURL.Value(),
|
||||
"anthropic": options.DeploymentValues.AI.BridgeConfig.Anthropic.BaseURL.Value(),
|
||||
},
|
||||
}
|
||||
if value := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")); value != "" {
|
||||
chatProviderAPIKeys.OpenAI = value
|
||||
}
|
||||
if value := strings.TrimSpace(os.Getenv("ANTHROPIC_API_KEY")); value != "" {
|
||||
chatProviderAPIKeys.Anthropic = value
|
||||
}
|
||||
|
||||
chatProviderAPIKeysResolver := func(ctx context.Context) (chatd.ProviderAPIKeys, error) {
|
||||
providers, err := options.Database.GetEnabledChatProviders(ctx)
|
||||
if err != nil {
|
||||
return chatd.ProviderAPIKeys{}, err
|
||||
}
|
||||
|
||||
configuredProviders := make([]chatd.ConfiguredProvider, 0, len(providers))
|
||||
for _, provider := range providers {
|
||||
configuredProviders = append(configuredProviders, chatd.ConfiguredProvider{
|
||||
Provider: provider.Provider,
|
||||
APIKey: provider.APIKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
})
|
||||
}
|
||||
|
||||
return chatd.MergeProviderAPIKeys(chatProviderAPIKeys, configuredProviders), nil
|
||||
}
|
||||
|
||||
api := &API{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
@@ -625,7 +662,8 @@ func New(options *Options) *API {
|
||||
options.Database,
|
||||
options.Pubsub,
|
||||
),
|
||||
dbRolluper: options.DatabaseRolluper,
|
||||
dbRolluper: options.DatabaseRolluper,
|
||||
chatProcessor: options.ChatProcessor,
|
||||
}
|
||||
api.WorkspaceAppsProvider = workspaceapps.NewDBTokenProvider(
|
||||
ctx,
|
||||
@@ -756,6 +794,30 @@ func New(options *Options) *API {
|
||||
panic("failed to setup server tailnet: " + err.Error())
|
||||
}
|
||||
api.agentProvider = stn
|
||||
var chatLocalConfig *chatd.LocalConfig
|
||||
if chatLocalWorkspaceEnabled(api) {
|
||||
chatLocalConfig = &chatd.LocalConfig{
|
||||
AccessURL: options.AccessURL,
|
||||
HTTPClient: options.HTTPClient,
|
||||
DeploymentID: depID,
|
||||
}
|
||||
}
|
||||
|
||||
if options.ChatProcessor == nil {
|
||||
options.ChatProcessor = chatd.New(chatd.Config{
|
||||
Logger: options.Logger.Named("chats"),
|
||||
Database: options.Database,
|
||||
ResolveProviderAPIKeys: chatProviderAPIKeysResolver,
|
||||
TitleGeneration: chatd.TitleGenerationConfig{
|
||||
Prompt: options.DeploymentValues.AI.Chat.TitleGenerationPrompt.Value(),
|
||||
},
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
CreateWorkspace: api.newChatWorkspaceCreator(),
|
||||
Pubsub: options.Pubsub,
|
||||
Local: chatLocalConfig,
|
||||
})
|
||||
}
|
||||
api.chatProcessor = options.ChatProcessor
|
||||
if options.DeploymentValues.Prometheus.Enable {
|
||||
options.PrometheusRegistry.MustRegister(stn)
|
||||
}
|
||||
@@ -1158,6 +1220,42 @@ func New(options *Options) *API {
|
||||
r.Get("/", api.auditLogs)
|
||||
r.Post("/testgenerate", api.generateFakeAuditLog)
|
||||
})
|
||||
r.Route("/chats", func(r chi.Router) {
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Get("/", api.listChats)
|
||||
r.Post("/", api.createChat)
|
||||
r.Get("/models", api.listChatModels)
|
||||
r.Get("/watch", api.watchChats)
|
||||
r.Route("/providers", func(r chi.Router) {
|
||||
r.Get("/", api.listChatProviders)
|
||||
r.Post("/", api.createChatProvider)
|
||||
r.Route("/{providerConfig}", func(r chi.Router) {
|
||||
r.Patch("/", api.updateChatProvider)
|
||||
r.Delete("/", api.deleteChatProvider)
|
||||
})
|
||||
})
|
||||
r.Route("/model-configs", func(r chi.Router) {
|
||||
r.Get("/", api.listChatModelConfigs)
|
||||
r.Post("/", api.createChatModelConfig)
|
||||
r.Route("/{modelConfig}", func(r chi.Router) {
|
||||
r.Patch("/", api.updateChatModelConfig)
|
||||
r.Delete("/", api.deleteChatModelConfig)
|
||||
})
|
||||
})
|
||||
r.Route("/{chat}", func(r chi.Router) {
|
||||
r.Get("/", api.getChat)
|
||||
r.Delete("/", api.deleteChat)
|
||||
r.Post("/messages", api.createChatMessage)
|
||||
r.Get("/stream", api.streamChat)
|
||||
r.Post("/interrupt", api.interruptChat)
|
||||
r.Get("/diff-status", api.getChatDiffStatus)
|
||||
r.Get("/diff", api.getChatDiffContents)
|
||||
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
|
||||
r.Delete("/", api.deleteChatQueuedMessage)
|
||||
r.Post("/promote", api.promoteChatQueuedMessage)
|
||||
})
|
||||
})
|
||||
})
|
||||
r.Route("/files", func(r chi.Router) {
|
||||
r.Use(
|
||||
apiKeyMiddleware,
|
||||
@@ -1898,6 +1996,8 @@ type API struct {
|
||||
// dbRolluper rolls up template usage stats from raw agent and app
|
||||
// stats. This is used to provide insights in the WebUI.
|
||||
dbRolluper *dbrollup.Rolluper
|
||||
// chatProcessor handles background processing of pending chats.
|
||||
chatProcessor *chatd.Processor
|
||||
}
|
||||
|
||||
// Close waits for all WebSocket connections to drain before returning.
|
||||
@@ -1926,8 +2026,10 @@ func (api *API) Close() error {
|
||||
case <-timer.C:
|
||||
api.Logger.Warn(api.ctx, "websocket shutdown timed out after 10 seconds")
|
||||
}
|
||||
|
||||
api.dbRolluper.Close()
|
||||
if err := api.chatProcessor.Close(); err != nil {
|
||||
api.Logger.Warn(api.ctx, "close chat processor", slog.Error(err))
|
||||
}
|
||||
api.metricsCache.Close()
|
||||
if api.updateChecker != nil {
|
||||
api.updateChecker.Close()
|
||||
|
||||
@@ -6,15 +6,19 @@ type CheckConstraint string
|
||||
|
||||
// CheckConstraint enums.
|
||||
const (
|
||||
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
|
||||
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
|
||||
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
|
||||
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
|
||||
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
|
||||
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
|
||||
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
|
||||
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
|
||||
CheckChatMessagesSubagentEventCheck CheckConstraint = "chat_messages_subagent_event_check" // chat_messages
|
||||
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
|
||||
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
|
||||
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
|
||||
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
|
||||
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
|
||||
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
|
||||
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
|
||||
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
|
||||
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
|
||||
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package db2sdk
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/hcl/v2"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
@@ -1036,3 +1038,404 @@ func jsonOrEmptyMap(rawMessage pqtype.NullRawMessage) map[string]any {
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
type chatToolResultBlock struct {
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Result any `json:"result"`
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
func ChatMessage(m database.ChatMessage) codersdk.ChatMessage {
|
||||
msg := codersdk.ChatMessage{
|
||||
ID: m.ID,
|
||||
ChatID: m.ChatID,
|
||||
CreatedAt: m.CreatedAt,
|
||||
Role: m.Role,
|
||||
Hidden: m.Hidden,
|
||||
InputTokens: nullInt64Ptr(m.InputTokens),
|
||||
OutputTokens: nullInt64Ptr(m.OutputTokens),
|
||||
TotalTokens: nullInt64Ptr(m.TotalTokens),
|
||||
ReasoningTokens: nullInt64Ptr(m.ReasoningTokens),
|
||||
CacheCreationTokens: nullInt64Ptr(m.CacheCreationTokens),
|
||||
CacheReadTokens: nullInt64Ptr(m.CacheReadTokens),
|
||||
ContextLimit: nullInt64Ptr(m.ContextLimit),
|
||||
}
|
||||
if m.Content.Valid {
|
||||
msg.Content = m.Content.RawMessage
|
||||
parts, err := chatMessageParts(m.Role, m.Content)
|
||||
if err == nil {
|
||||
msg.Parts = parts
|
||||
}
|
||||
}
|
||||
if m.ToolCallID.Valid {
|
||||
msg.ToolCallID = &m.ToolCallID.String
|
||||
}
|
||||
if m.Thinking.Valid {
|
||||
msg.Thinking = &m.Thinking.String
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func chatMessageParts(role string, raw pqtype.NullRawMessage) ([]codersdk.ChatMessagePart, error) {
|
||||
switch role {
|
||||
case string(fantasy.MessageRoleSystem):
|
||||
content, err := parseSystemContent(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return []codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: content,
|
||||
}}, nil
|
||||
case string(fantasy.MessageRoleUser), string(fantasy.MessageRoleAssistant):
|
||||
content, err := parseContentBlocks(role, raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rawBlocks []json.RawMessage
|
||||
if role == string(fantasy.MessageRoleAssistant) {
|
||||
_ = json.Unmarshal(raw.RawMessage, &rawBlocks)
|
||||
}
|
||||
|
||||
parts := make([]codersdk.ChatMessagePart, 0, len(content))
|
||||
for i, block := range content {
|
||||
part := contentBlockToPart(block)
|
||||
if part.Type == "" {
|
||||
continue
|
||||
}
|
||||
if part.Type == codersdk.ChatMessagePartTypeReasoning {
|
||||
part.Title = ""
|
||||
if i < len(rawBlocks) {
|
||||
part.Title = reasoningStoredTitle(rawBlocks[i])
|
||||
}
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
return parts, nil
|
||||
case string(fantasy.MessageRoleTool):
|
||||
results, err := parseToolResults(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts := make([]codersdk.ChatMessagePart, 0, len(results))
|
||||
for _, result := range results {
|
||||
part := toolResultToPart(result)
|
||||
if part.Type == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
return parts, nil
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseSystemContent(raw pqtype.NullRawMessage) (string, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
var content string
|
||||
if err := json.Unmarshal(raw.RawMessage, &content); err != nil {
|
||||
return "", xerrors.Errorf("parse system content: %w", err)
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func parseContentBlocks(role string, raw pqtype.NullRawMessage) ([]fantasy.Content, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if role == string(fantasy.MessageRoleUser) {
|
||||
var text string
|
||||
if err := json.Unmarshal(raw.RawMessage, &text); err == nil {
|
||||
return []fantasy.Content{
|
||||
fantasy.TextContent{Text: text},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
var blocks []json.RawMessage
|
||||
if err := json.Unmarshal(raw.RawMessage, &blocks); err != nil {
|
||||
return nil, xerrors.Errorf("parse content blocks: %w", err)
|
||||
}
|
||||
|
||||
content := make([]fantasy.Content, 0, len(blocks))
|
||||
for _, block := range blocks {
|
||||
decoded, err := fantasy.UnmarshalContent(block)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse content block: %w", err)
|
||||
}
|
||||
content = append(content, decoded)
|
||||
}
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func parseToolResults(raw pqtype.NullRawMessage) ([]chatToolResultBlock, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var results []chatToolResultBlock
|
||||
if err := json.Unmarshal(raw.RawMessage, &results); err != nil {
|
||||
return nil, xerrors.Errorf("parse tool results: %w", err)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func reasoningStoredTitle(raw json.RawMessage) string {
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
Data struct {
|
||||
Title string `json:"title"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &envelope); err != nil {
|
||||
return ""
|
||||
}
|
||||
if !strings.EqualFold(envelope.Type, string(fantasy.ContentTypeReasoning)) {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(envelope.Data.Title)
|
||||
}
|
||||
|
||||
func contentBlockToPart(block fantasy.Content) codersdk.ChatMessagePart {
|
||||
switch value := block.(type) {
|
||||
case fantasy.TextContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: value.Text,
|
||||
}
|
||||
case *fantasy.TextContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: value.Text,
|
||||
}
|
||||
case fantasy.ReasoningContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Text: value.Text,
|
||||
}
|
||||
case *fantasy.ReasoningContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Text: value.Text,
|
||||
}
|
||||
case fantasy.ToolCallContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: value.ToolCallID,
|
||||
ToolName: value.ToolName,
|
||||
Args: []byte(value.Input),
|
||||
}
|
||||
case *fantasy.ToolCallContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: value.ToolCallID,
|
||||
ToolName: value.ToolName,
|
||||
Args: []byte(value.Input),
|
||||
}
|
||||
case fantasy.SourceContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeSource,
|
||||
SourceID: value.ID,
|
||||
URL: value.URL,
|
||||
Title: value.Title,
|
||||
}
|
||||
case *fantasy.SourceContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeSource,
|
||||
SourceID: value.ID,
|
||||
URL: value.URL,
|
||||
Title: value.Title,
|
||||
}
|
||||
case fantasy.FileContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeFile,
|
||||
MediaType: value.MediaType,
|
||||
Data: value.Data,
|
||||
}
|
||||
case *fantasy.FileContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeFile,
|
||||
MediaType: value.MediaType,
|
||||
Data: value.Data,
|
||||
}
|
||||
case fantasy.ToolResultContent:
|
||||
return toolResultToPart(toolResultBlockFromContent(value))
|
||||
case *fantasy.ToolResultContent:
|
||||
return toolResultToPart(toolResultBlockFromContent(*value))
|
||||
default:
|
||||
return codersdk.ChatMessagePart{}
|
||||
}
|
||||
}
|
||||
|
||||
func toolResultBlockFromContent(content fantasy.ToolResultContent) chatToolResultBlock {
|
||||
result := chatToolResultBlock{
|
||||
ToolCallID: content.ToolCallID,
|
||||
ToolName: content.ToolName,
|
||||
}
|
||||
switch output := content.Result.(type) {
|
||||
case fantasy.ToolResultOutputContentError:
|
||||
result.IsError = true
|
||||
if output.Error != nil {
|
||||
result.Result = map[string]any{"error": output.Error.Error()}
|
||||
} else {
|
||||
result.Result = map[string]any{"error": ""}
|
||||
}
|
||||
case fantasy.ToolResultOutputContentText:
|
||||
decoded := map[string]any{}
|
||||
if err := json.Unmarshal([]byte(output.Text), &decoded); err == nil {
|
||||
result.Result = decoded
|
||||
} else {
|
||||
result.Result = map[string]any{"output": output.Text}
|
||||
}
|
||||
case fantasy.ToolResultOutputContentMedia:
|
||||
result.Result = map[string]any{
|
||||
"data": output.Data,
|
||||
"mime_type": output.MediaType,
|
||||
"text": output.Text,
|
||||
}
|
||||
default:
|
||||
result.Result = map[string]any{}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func toolResultToPart(result chatToolResultBlock) codersdk.ChatMessagePart {
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolResult,
|
||||
ToolCallID: result.ToolCallID,
|
||||
ToolName: result.ToolName,
|
||||
Result: toRawJSON(result.Result),
|
||||
IsError: result.IsError,
|
||||
ResultMeta: toolResultMetadata(result.Result),
|
||||
}
|
||||
}
|
||||
|
||||
func toRawJSON(value any) json.RawMessage {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func toolResultMetadata(value any) *codersdk.ChatToolResultMetadata {
|
||||
fields, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
meta := codersdk.ChatToolResultMetadata{}
|
||||
if s, ok := stringValue(fields["error"]); ok {
|
||||
meta.Error = s
|
||||
}
|
||||
if s, ok := stringValue(fields["output"]); ok {
|
||||
meta.Output = s
|
||||
}
|
||||
if n, ok := intValue(fields["exit_code"]); ok {
|
||||
meta.ExitCode = &n
|
||||
}
|
||||
if s, ok := stringValue(fields["content"]); ok {
|
||||
meta.Content = s
|
||||
}
|
||||
if s, ok := stringValue(fields["mime_type"]); ok {
|
||||
meta.MimeType = s
|
||||
}
|
||||
if b, ok := boolValue(fields["created"]); ok {
|
||||
meta.Created = &b
|
||||
}
|
||||
if s, ok := stringValue(fields["workspace_id"]); ok {
|
||||
meta.WorkspaceID = s
|
||||
}
|
||||
if s, ok := stringValue(fields["workspace_agent_id"]); ok {
|
||||
meta.WorkspaceAgentID = s
|
||||
}
|
||||
if s, ok := stringValue(fields["workspace_name"]); ok {
|
||||
meta.WorkspaceName = s
|
||||
}
|
||||
if s, ok := stringValue(fields["workspace_url"]); ok {
|
||||
meta.WorkspaceURL = s
|
||||
}
|
||||
if s, ok := stringValue(fields["reason"]); ok {
|
||||
meta.Reason = s
|
||||
}
|
||||
|
||||
if meta.Error == "" &&
|
||||
meta.Output == "" &&
|
||||
meta.ExitCode == nil &&
|
||||
meta.Content == "" &&
|
||||
meta.MimeType == "" &&
|
||||
meta.Created == nil &&
|
||||
meta.WorkspaceID == "" &&
|
||||
meta.WorkspaceAgentID == "" &&
|
||||
meta.WorkspaceName == "" &&
|
||||
meta.WorkspaceURL == "" &&
|
||||
meta.Reason == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &meta
|
||||
}
|
||||
|
||||
func stringValue(value any) (string, bool) {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return typed, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func boolValue(value any) (bool, bool) {
|
||||
switch typed := value.(type) {
|
||||
case bool:
|
||||
return typed, true
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
|
||||
func intValue(value any) (int, bool) {
|
||||
switch typed := value.(type) {
|
||||
case int:
|
||||
return typed, true
|
||||
case int8:
|
||||
return int(typed), true
|
||||
case int16:
|
||||
return int(typed), true
|
||||
case int32:
|
||||
return int(typed), true
|
||||
case int64:
|
||||
return int(typed), true
|
||||
case float64:
|
||||
return int(typed), true
|
||||
case json.Number:
|
||||
n, err := typed.Int64()
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return int(n), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func nullInt64Ptr(v sql.NullInt64) *int64 {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
value := v.Int64
|
||||
return &value
|
||||
}
|
||||
|
||||
@@ -8,7 +8,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyopenai "charm.land/fantasy/providers/openai"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -206,3 +209,79 @@ func TestTemplateVersionParameter_BadDescription(t *testing.T) {
|
||||
req.NoError(err)
|
||||
req.NotEmpty(sdk.DescriptionPlaintext, "broke the markdown parser with %v", desc)
|
||||
}
|
||||
|
||||
func TestChatMessage_ReasoningPartWithoutPersistedTitleIsEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assistantContent, err := json.Marshal([]fantasy.Content{
|
||||
fantasy.ReasoningContent{
|
||||
Text: "Plan migration",
|
||||
ProviderMetadata: fantasy.ProviderMetadata{
|
||||
fantasyopenai.Name: &fantasyopenai.ResponsesReasoningMetadata{
|
||||
ItemID: "reasoning-1",
|
||||
Summary: []string{"Plan migration"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
message := db2sdk.ChatMessage(database.ChatMessage{
|
||||
ID: 1,
|
||||
ChatID: uuid.New(),
|
||||
CreatedAt: time.Now(),
|
||||
Role: string(fantasy.MessageRoleAssistant),
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: assistantContent,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
|
||||
require.Len(t, message.Parts, 1)
|
||||
require.Equal(t, codersdk.ChatMessagePartTypeReasoning, message.Parts[0].Type)
|
||||
require.Equal(t, "Plan migration", message.Parts[0].Text)
|
||||
require.Empty(t, message.Parts[0].Title)
|
||||
}
|
||||
|
||||
func TestChatMessage_ReasoningPartPrefersPersistedTitle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reasoningContent, err := json.Marshal(fantasy.ReasoningContent{
|
||||
Text: "Verify schema updates, then apply changes in order.",
|
||||
ProviderMetadata: fantasy.ProviderMetadata{
|
||||
fantasyopenai.Name: &fantasyopenai.ResponsesReasoningMetadata{
|
||||
ItemID: "reasoning-1",
|
||||
Summary: []string{
|
||||
"**Metadata-derived title**\n\nLonger explanation.",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var envelope map[string]any
|
||||
require.NoError(t, json.Unmarshal(reasoningContent, &envelope))
|
||||
dataValue, ok := envelope["data"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
dataValue["title"] = "Persisted stream title"
|
||||
|
||||
encodedReasoning, err := json.Marshal(envelope)
|
||||
require.NoError(t, err)
|
||||
assistantContent, err := json.Marshal([]json.RawMessage{encodedReasoning})
|
||||
require.NoError(t, err)
|
||||
|
||||
message := db2sdk.ChatMessage(database.ChatMessage{
|
||||
ID: 1,
|
||||
ChatID: uuid.New(),
|
||||
CreatedAt: time.Now(),
|
||||
Role: string(fantasy.MessageRoleAssistant),
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: assistantContent,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
|
||||
require.Len(t, message.Parts, 1)
|
||||
require.Equal(t, codersdk.ChatMessagePartTypeReasoning, message.Parts[0].Type)
|
||||
require.Equal(t, "Persisted stream title", message.Parts[0].Title)
|
||||
}
|
||||
|
||||
@@ -453,6 +453,7 @@ var (
|
||||
rbac.ResourceProvisionerJobs.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreate},
|
||||
rbac.ResourceOauth2App.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
rbac.ResourceOauth2AppSecret.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
rbac.ResourceChat.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
}),
|
||||
User: []rbac.Permission{},
|
||||
ByOrgID: map[string]rbac.OrgPermissions{},
|
||||
@@ -1451,6 +1452,15 @@ func (q *querier) authorizeProvisionerJob(ctx context.Context, job database.Prov
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *querier) AcquireChat(ctx context.Context, arg database.AcquireChatParams) (database.Chat, error) {
|
||||
// AcquireChat is a system-level operation used by the chat processor.
|
||||
// Authorization is done at the system level, not per-user.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.AcquireChat(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) AcquireLock(ctx context.Context, id int64) error {
|
||||
return q.db.AcquireLock(ctx, id)
|
||||
}
|
||||
@@ -1679,6 +1689,17 @@ func (q *querier) DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) e
|
||||
return q.db.DeleteAPIKeysByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteAllChatQueuedMessages(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil {
|
||||
return err
|
||||
@@ -1710,6 +1731,54 @@ func (q *querier) DeleteBoundaryUsageStatsByReplicaID(ctx context.Context, repli
|
||||
return q.db.DeleteBoundaryUsageStatsByReplicaID(ctx, replicaID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
// Authorize delete on the parent chat.
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatMessagesByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatModelConfigByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatProviderByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatQueuedMessage(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceCryptoKey); err != nil {
|
||||
return database.CryptoKey{}, err
|
||||
@@ -2278,6 +2347,131 @@ func (q *querier) GetBoundaryUsageSummary(ctx context.Context, maxStalenessMs in
|
||||
return q.db.GetBoundaryUsageSummary(ctx, maxStalenessMs)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatByID(ctx context.Context, id uuid.UUID) (database.Chat, error) {
|
||||
return fetch(q.log, q.auth, q.db.GetChatByID)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (database.Chat, error) {
|
||||
return fetch(q.log, q.auth, q.db.GetChatByIDForUpdate)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return database.ChatDiffStatus{}, err
|
||||
}
|
||||
return q.db.GetChatDiffStatusByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIDs []uuid.UUID) ([]database.ChatDiffStatus, error) {
|
||||
if len(chatIDs) == 0 {
|
||||
return []database.ChatDiffStatus{}, nil
|
||||
}
|
||||
|
||||
actor, ok := ActorFromContext(ctx)
|
||||
if ok && actor.Type == rbac.SubjectTypeSystemRestricted {
|
||||
return q.db.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
|
||||
}
|
||||
|
||||
for _, chatID := range chatIDs {
|
||||
// Authorize read on each parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return q.db.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
|
||||
// ChatMessages are authorized through their parent Chat.
|
||||
// We need to fetch the message first to get its chat_id.
|
||||
msg, err := q.db.GetChatMessageByID(ctx, id)
|
||||
if err != nil {
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
// Authorize read on the parent chat.
|
||||
_, err = q.GetChatByID(ctx, msg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatMessagesByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatMessagesForPromptByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatModelConfig{}, err
|
||||
}
|
||||
return q.db.GetChatModelConfigByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatModelConfigByProviderAndModel(ctx context.Context, arg database.GetChatModelConfigByProviderAndModelParams) (database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatModelConfig{}, err
|
||||
}
|
||||
return q.db.GetChatModelConfigByProviderAndModel(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatModelConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.GetChatProviderByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.GetChatProviderByProvider(ctx, provider)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatProviders(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatQueuedMessages(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByOwnerID)(ctx, ownerID)
|
||||
}
|
||||
|
||||
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
// Just like with the audit logs query, shortcut if the user is an owner.
|
||||
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
|
||||
@@ -2375,6 +2569,20 @@ func (q *querier) GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx context.C
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetEligibleProvisionerDaemonsByProvisionerJobIDs)(ctx, provisionerJobIDs)
|
||||
}
|
||||
|
||||
func (q *querier) GetEnabledChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetEnabledChatModelConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetEnabledChatProviders(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
||||
return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetExternalAuthLink)(ctx, arg)
|
||||
}
|
||||
@@ -2522,6 +2730,14 @@ func (q *querier) GetLatestCryptoKeyByFeature(ctx context.Context, feature datab
|
||||
return q.db.GetLatestCryptoKeyByFeature(ctx, feature)
|
||||
}
|
||||
|
||||
func (q *querier) GetLatestPendingSubagentRequestIDByChatID(ctx context.Context, chatID uuid.UUID) (uuid.NullUUID, error) {
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return uuid.NullUUID{}, err
|
||||
}
|
||||
return q.db.GetLatestPendingSubagentRequestIDByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetLatestWorkspaceAppStatusByAppID(ctx context.Context, appID uuid.UUID) (database.WorkspaceAppStatus, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return database.WorkspaceAppStatus{}, err
|
||||
@@ -3074,6 +3290,30 @@ func (q *querier) GetRuntimeConfig(ctx context.Context, key string) (string, err
|
||||
return q.db.GetRuntimeConfig(ctx, key)
|
||||
}
|
||||
|
||||
func (q *querier) GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]database.Chat, error) {
|
||||
// GetStaleChats is a system-level operation used by the chat processor for recovery.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetStaleChats(ctx, staleThreshold)
|
||||
}
|
||||
|
||||
func (q *querier) GetSubagentRequestDurationByChatIDAndRequestID(ctx context.Context, arg database.GetSubagentRequestDurationByChatIDAndRequestIDParams) (int64, error) {
|
||||
_, err := q.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.GetSubagentRequestDurationByChatIDAndRequestID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetSubagentResponseMessageByChatIDAndRequestID(ctx context.Context, arg database.GetSubagentResponseMessageByChatIDAndRequestIDParams) (database.ChatMessage, error) {
|
||||
_, err := q.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
return q.db.GetSubagentResponseMessageByChatIDAndRequestID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]database.TailnetPeer, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
|
||||
return nil, err
|
||||
@@ -4177,6 +4417,47 @@ func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLo
|
||||
return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChat(ctx context.Context, arg database.InsertChatParams) (database.Chat, error) {
|
||||
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()), q.db.InsertChat)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
// Authorize create on the parent chat (using update permission).
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
return q.db.InsertChatMessage(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatModelConfig{}, err
|
||||
}
|
||||
return q.db.InsertChatModelConfig(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.InsertChatProvider(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatQueuedMessage{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatQueuedMessage{}, err
|
||||
}
|
||||
return q.db.InsertChatQueuedMessage(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceCryptoKey); err != nil {
|
||||
return database.CryptoKey{}, err
|
||||
@@ -4775,6 +5056,14 @@ func (q *querier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context,
|
||||
return q.db.ListAIBridgeUserPromptsByInterceptionIDs(ctx, interceptionIDs)
|
||||
}
|
||||
|
||||
func (q *querier) ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListChatsByRootID)(ctx, rootChatID)
|
||||
}
|
||||
|
||||
func (q *querier) ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListChildChatsByParentID)(ctx, parentChatID)
|
||||
}
|
||||
|
||||
func (q *querier) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListProvisionerKeysByOrganization)(ctx, organizationID)
|
||||
}
|
||||
@@ -4855,6 +5144,17 @@ func (q *querier) PaginatedOrganizationMembers(ctx context.Context, arg database
|
||||
return q.db.PaginatedOrganizationMembers(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (database.ChatQueuedMessage, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return database.ChatQueuedMessage{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatQueuedMessage{}, err
|
||||
}
|
||||
return q.db.PopNextQueuedMessage(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
|
||||
template, err := q.db.GetTemplateByID(ctx, templateID)
|
||||
if err != nil {
|
||||
@@ -4954,6 +5254,75 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe
|
||||
return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.UpdateChatByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatModelConfig(ctx context.Context, arg database.UpdateChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatModelConfig{}, err
|
||||
}
|
||||
return q.db.UpdateChatModelConfig(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatModelConfigByChatID(ctx context.Context, arg database.UpdateChatModelConfigByChatIDParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.UpdateChatModelConfigByChatID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.UpdateChatProvider(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
|
||||
// UpdateChatStatus is used by the chat processor to change chat status.
|
||||
// It should be called with system context.
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.UpdateChatStatus(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
|
||||
// UpdateChatWorkspace is manually implemented for chat tables and may not be
|
||||
// present on every wrapped store interface yet.
|
||||
chatWorkspaceUpdater, ok := q.db.(interface {
|
||||
UpdateChatWorkspace(context.Context, database.UpdateChatWorkspaceParams) (database.Chat, error)
|
||||
})
|
||||
if !ok {
|
||||
return database.Chat{}, xerrors.New("update chat workspace is not implemented by wrapped store")
|
||||
}
|
||||
return chatWorkspaceUpdater.UpdateChatWorkspace(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceCryptoKey); err != nil {
|
||||
return database.CryptoKey{}, err
|
||||
@@ -5996,6 +6365,30 @@ func (q *querier) UpsertBoundaryUsageStats(ctx context.Context, arg database.Ups
|
||||
return q.db.UpsertBoundaryUsageStats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
// Authorize update on the parent chat.
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatDiffStatus{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatDiffStatus{}, err
|
||||
}
|
||||
return q.db.UpsertChatDiffStatus(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDiffStatusReference(ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
||||
// Authorize update on the parent chat.
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatDiffStatus{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatDiffStatus{}, err
|
||||
}
|
||||
return q.db.UpsertChatDiffStatusReference(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
|
||||
return database.ConnectionLog{}, err
|
||||
@@ -6287,16 +6680,10 @@ func (q *querier) CountAuthorizedConnectionLogs(ctx context.Context, arg databas
|
||||
return q.CountConnectionLogs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, _ rbac.PreparedAuthorized) ([]database.ListAIBridgeInterceptionsRow, error) {
|
||||
// TODO: Delete this function, all ListAIBridgeInterceptions should be authorized. For now just call ListAIBridgeInterceptions on the authz querier.
|
||||
// This cannot be deleted for now because it's included in the
|
||||
// database.Store interface, so dbauthz needs to implement it.
|
||||
return q.ListAIBridgeInterceptions(ctx, arg)
|
||||
func (q *querier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeInterceptionsRow, error) {
|
||||
return q.db.ListAuthorizedAIBridgeInterceptions(ctx, arg, prepared)
|
||||
}
|
||||
|
||||
func (q *querier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams, _ rbac.PreparedAuthorized) (int64, error) {
|
||||
// TODO: Delete this function, all CountAIBridgeInterceptions should be authorized. For now just call CountAIBridgeInterceptions on the authz querier.
|
||||
// This cannot be deleted for now because it's included in the
|
||||
// database.Store interface, so dbauthz needs to implement it.
|
||||
return q.CountAIBridgeInterceptions(ctx, arg)
|
||||
func (q *querier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
return q.db.CountAuthorizedAIBridgeInterceptions(ctx, arg, prepared)
|
||||
}
|
||||
|
||||
@@ -171,6 +171,7 @@ func TestDBAuthzRecursive(t *testing.T) {
|
||||
Groups: []string{},
|
||||
Scope: rbac.ScopeAll,
|
||||
}
|
||||
preparedAuthorizedType := reflect.TypeOf((*rbac.PreparedAuthorized)(nil)).Elem()
|
||||
for i := 0; i < reflect.TypeOf(q).NumMethod(); i++ {
|
||||
var ins []reflect.Value
|
||||
ctx := dbauthz.As(context.Background(), actor)
|
||||
@@ -178,7 +179,13 @@ func TestDBAuthzRecursive(t *testing.T) {
|
||||
ins = append(ins, reflect.ValueOf(ctx))
|
||||
method := reflect.TypeOf(q).Method(i)
|
||||
for i := 2; i < method.Type.NumIn(); i++ {
|
||||
ins = append(ins, reflect.New(method.Type.In(i)).Elem())
|
||||
inType := method.Type.In(i)
|
||||
if inType.Implements(preparedAuthorizedType) {
|
||||
ins = append(ins, reflect.ValueOf(emptyPreparedAuthorized{}))
|
||||
continue
|
||||
}
|
||||
|
||||
ins = append(ins, reflect.New(inType).Elem())
|
||||
}
|
||||
if method.Name == "InTx" ||
|
||||
method.Name == "Ping" ||
|
||||
@@ -370,6 +377,300 @@ func (s *MethodTestSuite) TestConnectionLogs() {
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestChats() {
|
||||
s.Run("AcquireChat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.AcquireChatParams{
|
||||
StartedAt: dbtime.Now(),
|
||||
WorkerID: uuid.New(),
|
||||
}
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().AcquireChat(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("DeleteChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionDelete).Returns()
|
||||
}))
|
||||
s.Run("DeleteChatMessagesByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteChatMessagesByChatID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionDelete).Returns()
|
||||
}))
|
||||
s.Run("DeleteChatModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
id := uuid.New()
|
||||
dbm.EXPECT().DeleteChatModelConfigByID(gomock.Any(), id).Return(nil).AnyTimes()
|
||||
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("DeleteChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
id := uuid.New()
|
||||
dbm.EXPECT().DeleteChatProviderByID(gomock.Any(), id).Return(nil).AnyTimes()
|
||||
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(chat)
|
||||
}))
|
||||
s.Run("GetChatDiffStatusByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatDiffStatusByChatID(gomock.Any(), chat.ID).Return(diffStatus, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(diffStatus)
|
||||
}))
|
||||
s.Run("GetChatDiffStatusesByChatIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chatA := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
chatB := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
ids := []uuid.UUID{chatA.ID, chatB.ID}
|
||||
diffStatusA := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chatA.ID})
|
||||
diffStatusB := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chatB.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chatA.ID).Return(chatA, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chatB.ID).Return(chatB, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatDiffStatusesByChatIDs(gomock.Any(), ids).Return([]database.ChatDiffStatus{diffStatusA, diffStatusB}, nil).AnyTimes()
|
||||
check.Args(ids).
|
||||
Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).
|
||||
Returns([]database.ChatDiffStatus{diffStatusA, diffStatusB})
|
||||
}))
|
||||
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatMessageByID(gomock.Any(), msg.ID).Return(msg, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
check.Args(msg.ID).Asserts(chat, policy.ActionRead).Returns(msg)
|
||||
}))
|
||||
s.Run("GetChatMessagesByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), chat.ID).Return(msgs, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(msgs)
|
||||
}))
|
||||
s.Run("GetChatModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
dbm.EXPECT().GetChatModelConfigByID(gomock.Any(), config.ID).Return(config, nil).AnyTimes()
|
||||
check.Args(config.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
|
||||
}))
|
||||
s.Run("GetChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
dbm.EXPECT().GetChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB})
|
||||
}))
|
||||
s.Run("GetChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
dbm.EXPECT().GetChatProviderByID(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes()
|
||||
check.Args(provider.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
|
||||
}))
|
||||
s.Run("GetChatProviderByProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
providerName := "test-provider"
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: providerName})
|
||||
dbm.EXPECT().GetChatProviderByProvider(gomock.Any(), providerName).Return(provider, nil).AnyTimes()
|
||||
check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
|
||||
}))
|
||||
s.Run("GetChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
providerA := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
providerB := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
dbm.EXPECT().GetChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
|
||||
}))
|
||||
s.Run("GetChatsByOwnerID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
c1 := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
c2 := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatsByOwnerID(gomock.Any(), c1.OwnerID).Return([]database.Chat{c1, c2}, nil).AnyTimes()
|
||||
check.Args(c1.OwnerID).Asserts(c1, policy.ActionRead, c2, policy.ActionRead).Returns([]database.Chat{c1, c2})
|
||||
}))
|
||||
s.Run("GetEnabledChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
dbm.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB})
|
||||
}))
|
||||
s.Run("GetEnabledChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
providerA := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
providerB := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
dbm.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
|
||||
}))
|
||||
s.Run("ListChatsByRootID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
rootChatID := uuid.New()
|
||||
chatA := testutil.Fake(s.T(), faker, database.Chat{RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true}})
|
||||
chatB := testutil.Fake(s.T(), faker, database.Chat{RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true}})
|
||||
dbm.EXPECT().ListChatsByRootID(gomock.Any(), rootChatID).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
|
||||
check.Args(rootChatID).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
|
||||
}))
|
||||
s.Run("ListChildChatsByParentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
parentChatID := uuid.New()
|
||||
chatA := testutil.Fake(s.T(), faker, database.Chat{ParentChatID: uuid.NullUUID{UUID: parentChatID, Valid: true}})
|
||||
chatB := testutil.Fake(s.T(), faker, database.Chat{ParentChatID: uuid.NullUUID{UUID: parentChatID, Valid: true}})
|
||||
dbm.EXPECT().ListChildChatsByParentID(gomock.Any(), parentChatID).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
|
||||
check.Args(parentChatID).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
|
||||
}))
|
||||
s.Run("GetStaleChats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
threshold := dbtime.Now()
|
||||
chats := []database.Chat{testutil.Fake(s.T(), faker, database.Chat{})}
|
||||
dbm.EXPECT().GetStaleChats(gomock.Any(), threshold).Return(chats, nil).AnyTimes()
|
||||
check.Args(threshold).Asserts(rbac.ResourceChat, policy.ActionRead).Returns(chats)
|
||||
}))
|
||||
s.Run("InsertChat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatParams{
|
||||
ModelConfig: json.RawMessage(`{}`),
|
||||
})
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{OwnerID: arg.OwnerID})
|
||||
dbm.EXPECT().InsertChat(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionCreate).Returns(chat)
|
||||
}))
|
||||
s.Run("InsertChatMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatMessageParams{ChatID: chat.ID})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().InsertChatMessage(gomock.Any(), arg).Return(msg, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(msg)
|
||||
}))
|
||||
s.Run("InsertChatModelConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.InsertChatModelConfigParams{
|
||||
Provider: "test-provider",
|
||||
Model: "test-model",
|
||||
DisplayName: "Test Model",
|
||||
Enabled: true,
|
||||
}
|
||||
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{Provider: arg.Provider, Model: arg.Model, DisplayName: arg.DisplayName, Enabled: arg.Enabled})
|
||||
dbm.EXPECT().InsertChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
|
||||
}))
|
||||
s.Run("InsertChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.InsertChatProviderParams{
|
||||
Provider: "test-provider",
|
||||
DisplayName: "Test Provider",
|
||||
APIKey: "test-api-key",
|
||||
Enabled: true,
|
||||
}
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: arg.Provider, DisplayName: arg.DisplayName, APIKey: arg.APIKey, Enabled: arg.Enabled})
|
||||
dbm.EXPECT().InsertChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
|
||||
}))
|
||||
s.Run("UpdateChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatByIDParams{
|
||||
ID: chat.ID,
|
||||
Title: "Updated title",
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatByID(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatModelConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
arg := database.UpdateChatModelConfigParams{
|
||||
ID: config.ID,
|
||||
Provider: "updated-provider",
|
||||
Model: "updated-model",
|
||||
DisplayName: "Updated Model",
|
||||
Enabled: true,
|
||||
}
|
||||
dbm.EXPECT().UpdateChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
|
||||
}))
|
||||
s.Run("UpdateChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
arg := database.UpdateChatProviderParams{
|
||||
ID: provider.ID,
|
||||
DisplayName: "Updated Provider",
|
||||
APIKey: "updated-api-key",
|
||||
Enabled: true,
|
||||
}
|
||||
dbm.EXPECT().UpdateChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
|
||||
}))
|
||||
s.Run("UpdateChatStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatStatus(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("GetLatestPendingSubagentRequestIDByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
result := uuid.NullUUID{UUID: uuid.New(), Valid: true}
|
||||
dbm.EXPECT().GetLatestPendingSubagentRequestIDByChatID(gomock.Any(), chat.ID).Return(result, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(result)
|
||||
}))
|
||||
s.Run("GetSubagentRequestDurationByChatIDAndRequestID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.GetSubagentRequestDurationByChatIDAndRequestIDParams{
|
||||
ChatID: chat.ID,
|
||||
SubagentRequestID: uuid.New(),
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
const durationMS int64 = 250
|
||||
dbm.EXPECT().GetSubagentRequestDurationByChatIDAndRequestID(gomock.Any(), arg).Return(durationMS, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(durationMS)
|
||||
}))
|
||||
s.Run("GetSubagentResponseMessageByChatIDAndRequestID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.GetSubagentResponseMessageByChatIDAndRequestIDParams{
|
||||
ChatID: chat.ID,
|
||||
SubagentRequestID: uuid.New(),
|
||||
}
|
||||
message := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetSubagentResponseMessageByChatIDAndRequestID(gomock.Any(), arg).Return(message, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(message)
|
||||
}))
|
||||
s.Run("UpdateChatWorkspace", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatWorkspaceParams{
|
||||
ID: chat.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
WorkspaceAgentID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
}
|
||||
updatedChat := testutil.Fake(s.T(), faker, database.Chat{ID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatWorkspace(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updatedChat)
|
||||
}))
|
||||
s.Run("UpsertChatDiffStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
now := dbtime.Now()
|
||||
arg := database.UpsertChatDiffStatusParams{
|
||||
ChatID: chat.ID,
|
||||
Url: sql.NullString{String: "https://example.com/pr/123", Valid: true},
|
||||
PullRequestState: sql.NullString{String: "open", Valid: true},
|
||||
ChangesRequested: false,
|
||||
Additions: 10,
|
||||
Deletions: 5,
|
||||
ChangedFiles: 2,
|
||||
RefreshedAt: now,
|
||||
StaleAt: now.Add(time.Hour),
|
||||
}
|
||||
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpsertChatDiffStatus(gomock.Any(), arg).Return(diffStatus, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(diffStatus)
|
||||
}))
|
||||
s.Run("UpsertChatDiffStatusReference", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpsertChatDiffStatusReferenceParams{
|
||||
ChatID: chat.ID,
|
||||
Url: sql.NullString{String: "https://example.com/pr/123", Valid: true},
|
||||
GitBranch: "feature/test",
|
||||
GitRemoteOrigin: "origin",
|
||||
StaleAt: dbtime.Now().Add(time.Hour),
|
||||
}
|
||||
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), arg).Return(diffStatus, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(diffStatus)
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestFile() {
|
||||
s.Run("GetFileByHashAndCreator", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
f := testutil.Fake(s.T(), faker, database.File{})
|
||||
|
||||
@@ -104,6 +104,14 @@ func (m queryMetricsStore) DeleteOrganization(ctx context.Context, id uuid.UUID)
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) AcquireChat(ctx context.Context, arg database.AcquireChatParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.AcquireChat(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("AcquireChat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "AcquireChat").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.AcquireLock(ctx, pgAdvisoryXactLock)
|
||||
@@ -156,6 +164,7 @@ func (m queryMetricsStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context
|
||||
start := time.Now()
|
||||
r0 := m.s.BatchUpdateWorkspaceAgentMetadata(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("BatchUpdateWorkspaceAgentMetadata").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BatchUpdateWorkspaceAgentMetadata").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -311,6 +320,14 @@ func (m queryMetricsStore) DeleteAPIKeysByUserID(ctx context.Context, userID uui
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteAllChatQueuedMessages(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("DeleteAllChatQueuedMessages").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteAllChatQueuedMessages").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteAllTailnetTunnels(ctx, arg)
|
||||
@@ -343,6 +360,46 @@ func (m queryMetricsStore) DeleteBoundaryUsageStatsByReplicaID(ctx context.Conte
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatMessagesByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatMessagesByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatMessagesByChatID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatModelConfigByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatModelConfigByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatModelConfigByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatProviderByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatProviderByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatProviderByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatQueuedMessage(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatQueuedMessage").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatQueuedMessage").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteCryptoKey(ctx, arg)
|
||||
@@ -910,6 +967,126 @@ func (m queryMetricsStore) GetBoundaryUsageSummary(ctx context.Context, maxStale
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatByID(ctx context.Context, id uuid.UUID) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatByIDForUpdate(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatByIDForUpdate").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatByIDForUpdate").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDiffStatusByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatDiffStatusByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDiffStatusByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIDs []uuid.UUID) ([]database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDiffStatusesByChatIDs(ctx, chatIDs)
|
||||
m.queryLatencies.WithLabelValues("GetChatDiffStatusesByChatIDs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDiffStatusesByChatIDs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessageByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessageByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessageByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessagesByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesForPromptByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessagesForPromptByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesForPromptByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatModelConfigByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatModelConfigByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatModelConfigByProviderAndModel(ctx context.Context, arg database.GetChatModelConfigByProviderAndModelParams) (database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatModelConfigByProviderAndModel(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChatModelConfigByProviderAndModel").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigByProviderAndModel").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatModelConfigs(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatModelConfigs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviderByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatProviderByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviderByProvider(ctx, provider)
|
||||
m.queryLatencies.WithLabelValues("GetChatProviderByProvider").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByProvider").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviders(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatProviders").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviders").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatQueuedMessages(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatQueuedMessages").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatQueuedMessages").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatsByOwnerID(ctx, ownerID)
|
||||
m.queryLatencies.WithLabelValues("GetChatsByOwnerID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsByOwnerID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetConnectionLogsOffset(ctx, arg)
|
||||
@@ -1030,6 +1207,22 @@ func (m queryMetricsStore) GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetEnabledChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetEnabledChatModelConfigs(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetEnabledChatModelConfigs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledChatModelConfigs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetEnabledChatProviders(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetEnabledChatProviders").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledChatProviders").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetExternalAuthLink(ctx, arg)
|
||||
@@ -1190,6 +1383,14 @@ func (m queryMetricsStore) GetLatestCryptoKeyByFeature(ctx context.Context, feat
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetLatestPendingSubagentRequestIDByChatID(ctx context.Context, chatID uuid.UUID) (uuid.NullUUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetLatestPendingSubagentRequestIDByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetLatestPendingSubagentRequestIDByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetLatestPendingSubagentRequestIDByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetLatestWorkspaceAppStatusByAppID(ctx context.Context, appID uuid.UUID) (database.WorkspaceAppStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetLatestWorkspaceAppStatusByAppID(ctx, appID)
|
||||
@@ -1726,6 +1927,30 @@ func (m queryMetricsStore) GetRuntimeConfig(ctx context.Context, key string) (st
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetStaleChats(ctx, staleThreshold)
|
||||
m.queryLatencies.WithLabelValues("GetStaleChats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetStaleChats").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetSubagentRequestDurationByChatIDAndRequestID(ctx context.Context, arg database.GetSubagentRequestDurationByChatIDAndRequestIDParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetSubagentRequestDurationByChatIDAndRequestID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetSubagentRequestDurationByChatIDAndRequestID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetSubagentRequestDurationByChatIDAndRequestID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetSubagentResponseMessageByChatIDAndRequestID(ctx context.Context, arg database.GetSubagentResponseMessageByChatIDAndRequestIDParams) (database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetSubagentResponseMessageByChatIDAndRequestID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetSubagentResponseMessageByChatIDAndRequestID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetSubagentResponseMessageByChatIDAndRequestID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]database.TailnetPeer, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetTailnetPeers(ctx, id)
|
||||
@@ -2694,6 +2919,46 @@ func (m queryMetricsStore) InsertAuditLog(ctx context.Context, arg database.Inse
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChat(ctx context.Context, arg database.InsertChatParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChat(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChat").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatMessage(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatMessage").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatMessage").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatModelConfig(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatModelConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatModelConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatProvider(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatProvider").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatProvider").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatQueuedMessage(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatQueuedMessage").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatQueuedMessage").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertCryptoKey(ctx, arg)
|
||||
@@ -3222,6 +3487,22 @@ func (m queryMetricsStore) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListChatsByRootID(ctx, rootChatID)
|
||||
m.queryLatencies.WithLabelValues("ListChatsByRootID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatsByRootID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListChildChatsByParentID(ctx, parentChatID)
|
||||
m.queryLatencies.WithLabelValues("ListChildChatsByParentID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChildChatsByParentID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListProvisionerKeysByOrganization(ctx, organizationID)
|
||||
@@ -3302,6 +3583,14 @@ func (m queryMetricsStore) PaginatedOrganizationMembers(ctx context.Context, arg
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (database.ChatQueuedMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.PopNextQueuedMessage(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("PopNextQueuedMessage").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "PopNextQueuedMessage").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx, templateID)
|
||||
@@ -3398,6 +3687,54 @@ func (m queryMetricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.Up
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatByID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatModelConfig(ctx context.Context, arg database.UpdateChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatModelConfig(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatModelConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatModelConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatModelConfigByChatID(ctx context.Context, arg database.UpdateChatModelConfigByChatIDParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatModelConfigByChatID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatModelConfigByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatModelConfigByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatProvider(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatProvider").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatProvider").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatStatus(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatStatus").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatStatus").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatWorkspace(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatWorkspace").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatWorkspace").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateCryptoKeyDeletesAt(ctx, arg)
|
||||
@@ -4101,6 +4438,22 @@ func (m queryMetricsStore) UpsertBoundaryUsageStats(ctx context.Context, arg dat
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatDiffStatus(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatDiffStatus").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDiffStatus").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDiffStatusReference(ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatDiffStatusReference(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatDiffStatusReference").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDiffStatusReference").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertConnectionLog(ctx, arg)
|
||||
|
||||
@@ -44,6 +44,21 @@ func (m *MockStore) EXPECT() *MockStoreMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AcquireChat mocks base method.
|
||||
func (m *MockStore) AcquireChat(ctx context.Context, arg database.AcquireChatParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcquireChat", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcquireChat indicates an expected call of AcquireChat.
|
||||
func (mr *MockStoreMockRecorder) AcquireChat(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireChat", reflect.TypeOf((*MockStore)(nil).AcquireChat), ctx, arg)
|
||||
}
|
||||
|
||||
// AcquireLock mocks base method.
|
||||
func (m *MockStore) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -469,6 +484,20 @@ func (mr *MockStoreMockRecorder) DeleteAPIKeysByUserID(ctx, userID any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteAPIKeysByUserID), ctx, userID)
|
||||
}
|
||||
|
||||
// DeleteAllChatQueuedMessages mocks base method.
|
||||
func (m *MockStore) DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAllChatQueuedMessages", ctx, chatID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAllChatQueuedMessages indicates an expected call of DeleteAllChatQueuedMessages.
|
||||
func (mr *MockStoreMockRecorder) DeleteAllChatQueuedMessages(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllChatQueuedMessages", reflect.TypeOf((*MockStore)(nil).DeleteAllChatQueuedMessages), ctx, chatID)
|
||||
}
|
||||
|
||||
// DeleteAllTailnetTunnels mocks base method.
|
||||
func (m *MockStore) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -525,6 +554,76 @@ func (mr *MockStoreMockRecorder) DeleteBoundaryUsageStatsByReplicaID(ctx, replic
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteBoundaryUsageStatsByReplicaID", reflect.TypeOf((*MockStore)(nil).DeleteBoundaryUsageStatsByReplicaID), ctx, replicaID)
|
||||
}
|
||||
|
||||
// DeleteChatByID mocks base method.
|
||||
func (m *MockStore) DeleteChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatByID indicates an expected call of DeleteChatByID.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatByID", reflect.TypeOf((*MockStore)(nil).DeleteChatByID), ctx, id)
|
||||
}
|
||||
|
||||
// DeleteChatMessagesByChatID mocks base method.
|
||||
func (m *MockStore) DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatMessagesByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatMessagesByChatID indicates an expected call of DeleteChatMessagesByChatID.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatMessagesByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).DeleteChatMessagesByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// DeleteChatModelConfigByID mocks base method.
|
||||
func (m *MockStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatModelConfigByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatModelConfigByID indicates an expected call of DeleteChatModelConfigByID.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatModelConfigByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatModelConfigByID", reflect.TypeOf((*MockStore)(nil).DeleteChatModelConfigByID), ctx, id)
|
||||
}
|
||||
|
||||
// DeleteChatProviderByID mocks base method.
|
||||
func (m *MockStore) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatProviderByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatProviderByID indicates an expected call of DeleteChatProviderByID.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatProviderByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatProviderByID", reflect.TypeOf((*MockStore)(nil).DeleteChatProviderByID), ctx, id)
|
||||
}
|
||||
|
||||
// DeleteChatQueuedMessage mocks base method.
|
||||
func (m *MockStore) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatQueuedMessage", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatQueuedMessage indicates an expected call of DeleteChatQueuedMessage.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatQueuedMessage(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatQueuedMessage", reflect.TypeOf((*MockStore)(nil).DeleteChatQueuedMessage), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteCryptoKey mocks base method.
|
||||
func (m *MockStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1663,6 +1762,231 @@ func (mr *MockStoreMockRecorder) GetBoundaryUsageSummary(ctx, maxStalenessMs any
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBoundaryUsageSummary", reflect.TypeOf((*MockStore)(nil).GetBoundaryUsageSummary), ctx, maxStalenessMs)
|
||||
}
|
||||
|
||||
// GetChatByID mocks base method.
|
||||
func (m *MockStore) GetChatByID(ctx context.Context, id uuid.UUID) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatByID", ctx, id)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatByID indicates an expected call of GetChatByID.
|
||||
func (mr *MockStoreMockRecorder) GetChatByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatByID", reflect.TypeOf((*MockStore)(nil).GetChatByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatByIDForUpdate mocks base method.
|
||||
func (m *MockStore) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatByIDForUpdate", ctx, id)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatByIDForUpdate indicates an expected call of GetChatByIDForUpdate.
|
||||
func (mr *MockStoreMockRecorder) GetChatByIDForUpdate(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatByIDForUpdate", reflect.TypeOf((*MockStore)(nil).GetChatByIDForUpdate), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatDiffStatusByChatID mocks base method.
|
||||
func (m *MockStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatDiffStatusByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].(database.ChatDiffStatus)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatDiffStatusByChatID indicates an expected call of GetChatDiffStatusByChatID.
|
||||
func (mr *MockStoreMockRecorder) GetChatDiffStatusByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDiffStatusByChatID", reflect.TypeOf((*MockStore)(nil).GetChatDiffStatusByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// GetChatDiffStatusesByChatIDs mocks base method.
|
||||
func (m *MockStore) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatDiffStatusesByChatIDs", ctx, chatIds)
|
||||
ret0, _ := ret[0].([]database.ChatDiffStatus)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatDiffStatusesByChatIDs indicates an expected call of GetChatDiffStatusesByChatIDs.
|
||||
func (mr *MockStoreMockRecorder) GetChatDiffStatusesByChatIDs(ctx, chatIds any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDiffStatusesByChatIDs", reflect.TypeOf((*MockStore)(nil).GetChatDiffStatusesByChatIDs), ctx, chatIds)
|
||||
}
|
||||
|
||||
// GetChatMessageByID mocks base method.
|
||||
func (m *MockStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatMessageByID", ctx, id)
|
||||
ret0, _ := ret[0].(database.ChatMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatMessageByID indicates an expected call of GetChatMessageByID.
|
||||
func (mr *MockStoreMockRecorder) GetChatMessageByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessageByID", reflect.TypeOf((*MockStore)(nil).GetChatMessageByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatID mocks base method.
|
||||
func (m *MockStore) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatMessagesByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].([]database.ChatMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatID indicates an expected call of GetChatMessagesByChatID.
|
||||
func (mr *MockStoreMockRecorder) GetChatMessagesByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// GetChatMessagesForPromptByChatID mocks base method.
|
||||
func (m *MockStore) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatMessagesForPromptByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].([]database.ChatMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatMessagesForPromptByChatID indicates an expected call of GetChatMessagesForPromptByChatID.
|
||||
func (mr *MockStoreMockRecorder) GetChatMessagesForPromptByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesForPromptByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesForPromptByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// GetChatModelConfigByID mocks base method.
|
||||
func (m *MockStore) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatModelConfigByID", ctx, id)
|
||||
ret0, _ := ret[0].(database.ChatModelConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatModelConfigByID indicates an expected call of GetChatModelConfigByID.
|
||||
func (mr *MockStoreMockRecorder) GetChatModelConfigByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigByID", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatModelConfigByProviderAndModel mocks base method.
|
||||
func (m *MockStore) GetChatModelConfigByProviderAndModel(ctx context.Context, arg database.GetChatModelConfigByProviderAndModelParams) (database.ChatModelConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatModelConfigByProviderAndModel", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatModelConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatModelConfigByProviderAndModel indicates an expected call of GetChatModelConfigByProviderAndModel.
|
||||
func (mr *MockStoreMockRecorder) GetChatModelConfigByProviderAndModel(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigByProviderAndModel", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigByProviderAndModel), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatModelConfigs mocks base method.
|
||||
func (m *MockStore) GetChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatModelConfigs", ctx)
|
||||
ret0, _ := ret[0].([]database.ChatModelConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatModelConfigs indicates an expected call of GetChatModelConfigs.
|
||||
func (mr *MockStoreMockRecorder) GetChatModelConfigs(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigs", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigs), ctx)
|
||||
}
|
||||
|
||||
// GetChatProviderByID mocks base method.
|
||||
func (m *MockStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatProviderByID", ctx, id)
|
||||
ret0, _ := ret[0].(database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatProviderByID indicates an expected call of GetChatProviderByID.
|
||||
func (mr *MockStoreMockRecorder) GetChatProviderByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByID", reflect.TypeOf((*MockStore)(nil).GetChatProviderByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatProviderByProvider mocks base method.
|
||||
func (m *MockStore) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatProviderByProvider", ctx, provider)
|
||||
ret0, _ := ret[0].(database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatProviderByProvider indicates an expected call of GetChatProviderByProvider.
|
||||
func (mr *MockStoreMockRecorder) GetChatProviderByProvider(ctx, provider any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByProvider", reflect.TypeOf((*MockStore)(nil).GetChatProviderByProvider), ctx, provider)
|
||||
}
|
||||
|
||||
// GetChatProviders mocks base method.
|
||||
func (m *MockStore) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatProviders", ctx)
|
||||
ret0, _ := ret[0].([]database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatProviders indicates an expected call of GetChatProviders.
|
||||
func (mr *MockStoreMockRecorder) GetChatProviders(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviders", reflect.TypeOf((*MockStore)(nil).GetChatProviders), ctx)
|
||||
}
|
||||
|
||||
// GetChatQueuedMessages mocks base method.
|
||||
func (m *MockStore) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatQueuedMessages", ctx, chatID)
|
||||
ret0, _ := ret[0].([]database.ChatQueuedMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatQueuedMessages indicates an expected call of GetChatQueuedMessages.
|
||||
func (mr *MockStoreMockRecorder) GetChatQueuedMessages(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatQueuedMessages", reflect.TypeOf((*MockStore)(nil).GetChatQueuedMessages), ctx, chatID)
|
||||
}
|
||||
|
||||
// GetChatsByOwnerID mocks base method.
|
||||
func (m *MockStore) GetChatsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatsByOwnerID", ctx, ownerID)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatsByOwnerID indicates an expected call of GetChatsByOwnerID.
|
||||
func (mr *MockStoreMockRecorder) GetChatsByOwnerID(ctx, ownerID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByOwnerID", reflect.TypeOf((*MockStore)(nil).GetChatsByOwnerID), ctx, ownerID)
|
||||
}
|
||||
|
||||
// GetConnectionLogsOffset mocks base method.
|
||||
func (m *MockStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1888,6 +2212,36 @@ func (mr *MockStoreMockRecorder) GetEligibleProvisionerDaemonsByProvisionerJobID
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEligibleProvisionerDaemonsByProvisionerJobIDs", reflect.TypeOf((*MockStore)(nil).GetEligibleProvisionerDaemonsByProvisionerJobIDs), ctx, provisionerJobIds)
|
||||
}
|
||||
|
||||
// GetEnabledChatModelConfigs mocks base method.
|
||||
func (m *MockStore) GetEnabledChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetEnabledChatModelConfigs", ctx)
|
||||
ret0, _ := ret[0].([]database.ChatModelConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetEnabledChatModelConfigs indicates an expected call of GetEnabledChatModelConfigs.
|
||||
func (mr *MockStoreMockRecorder) GetEnabledChatModelConfigs(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatModelConfigs", reflect.TypeOf((*MockStore)(nil).GetEnabledChatModelConfigs), ctx)
|
||||
}
|
||||
|
||||
// GetEnabledChatProviders mocks base method.
|
||||
func (m *MockStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetEnabledChatProviders", ctx)
|
||||
ret0, _ := ret[0].([]database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetEnabledChatProviders indicates an expected call of GetEnabledChatProviders.
|
||||
func (mr *MockStoreMockRecorder) GetEnabledChatProviders(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatProviders", reflect.TypeOf((*MockStore)(nil).GetEnabledChatProviders), ctx)
|
||||
}
|
||||
|
||||
// GetExternalAuthLink mocks base method.
|
||||
func (m *MockStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2188,6 +2542,21 @@ func (mr *MockStoreMockRecorder) GetLatestCryptoKeyByFeature(ctx, feature any) *
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestCryptoKeyByFeature", reflect.TypeOf((*MockStore)(nil).GetLatestCryptoKeyByFeature), ctx, feature)
|
||||
}
|
||||
|
||||
// GetLatestPendingSubagentRequestIDByChatID mocks base method.
|
||||
func (m *MockStore) GetLatestPendingSubagentRequestIDByChatID(ctx context.Context, chatID uuid.UUID) (uuid.NullUUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetLatestPendingSubagentRequestIDByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].(uuid.NullUUID)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetLatestPendingSubagentRequestIDByChatID indicates an expected call of GetLatestPendingSubagentRequestIDByChatID.
|
||||
func (mr *MockStoreMockRecorder) GetLatestPendingSubagentRequestIDByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestPendingSubagentRequestIDByChatID", reflect.TypeOf((*MockStore)(nil).GetLatestPendingSubagentRequestIDByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// GetLatestWorkspaceAppStatusByAppID mocks base method.
|
||||
func (m *MockStore) GetLatestWorkspaceAppStatusByAppID(ctx context.Context, appID uuid.UUID) (database.WorkspaceAppStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -3193,6 +3562,51 @@ func (mr *MockStoreMockRecorder) GetRuntimeConfig(ctx, key any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRuntimeConfig", reflect.TypeOf((*MockStore)(nil).GetRuntimeConfig), ctx, key)
|
||||
}
|
||||
|
||||
// GetStaleChats mocks base method.
|
||||
func (m *MockStore) GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetStaleChats", ctx, staleThreshold)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetStaleChats indicates an expected call of GetStaleChats.
|
||||
func (mr *MockStoreMockRecorder) GetStaleChats(ctx, staleThreshold any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStaleChats", reflect.TypeOf((*MockStore)(nil).GetStaleChats), ctx, staleThreshold)
|
||||
}
|
||||
|
||||
// GetSubagentRequestDurationByChatIDAndRequestID mocks base method.
|
||||
func (m *MockStore) GetSubagentRequestDurationByChatIDAndRequestID(ctx context.Context, arg database.GetSubagentRequestDurationByChatIDAndRequestIDParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetSubagentRequestDurationByChatIDAndRequestID", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetSubagentRequestDurationByChatIDAndRequestID indicates an expected call of GetSubagentRequestDurationByChatIDAndRequestID.
|
||||
func (mr *MockStoreMockRecorder) GetSubagentRequestDurationByChatIDAndRequestID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubagentRequestDurationByChatIDAndRequestID", reflect.TypeOf((*MockStore)(nil).GetSubagentRequestDurationByChatIDAndRequestID), ctx, arg)
|
||||
}
|
||||
|
||||
// GetSubagentResponseMessageByChatIDAndRequestID mocks base method.
|
||||
func (m *MockStore) GetSubagentResponseMessageByChatIDAndRequestID(ctx context.Context, arg database.GetSubagentResponseMessageByChatIDAndRequestIDParams) (database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetSubagentResponseMessageByChatIDAndRequestID", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetSubagentResponseMessageByChatIDAndRequestID indicates an expected call of GetSubagentResponseMessageByChatIDAndRequestID.
|
||||
func (mr *MockStoreMockRecorder) GetSubagentResponseMessageByChatIDAndRequestID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubagentResponseMessageByChatIDAndRequestID", reflect.TypeOf((*MockStore)(nil).GetSubagentResponseMessageByChatIDAndRequestID), ctx, arg)
|
||||
}
|
||||
|
||||
// GetTailnetPeers mocks base method.
|
||||
func (m *MockStore) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]database.TailnetPeer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -5052,6 +5466,81 @@ func (mr *MockStoreMockRecorder) InsertAuditLog(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAuditLog", reflect.TypeOf((*MockStore)(nil).InsertAuditLog), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChat mocks base method.
|
||||
func (m *MockStore) InsertChat(ctx context.Context, arg database.InsertChatParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChat", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChat indicates an expected call of InsertChat.
|
||||
func (mr *MockStoreMockRecorder) InsertChat(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChat", reflect.TypeOf((*MockStore)(nil).InsertChat), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatMessage mocks base method.
|
||||
func (m *MockStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChatMessage", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChatMessage indicates an expected call of InsertChatMessage.
|
||||
func (mr *MockStoreMockRecorder) InsertChatMessage(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatMessage", reflect.TypeOf((*MockStore)(nil).InsertChatMessage), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatModelConfig mocks base method.
|
||||
func (m *MockStore) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChatModelConfig", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatModelConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChatModelConfig indicates an expected call of InsertChatModelConfig.
|
||||
func (mr *MockStoreMockRecorder) InsertChatModelConfig(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatModelConfig", reflect.TypeOf((*MockStore)(nil).InsertChatModelConfig), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatProvider mocks base method.
|
||||
func (m *MockStore) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChatProvider", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChatProvider indicates an expected call of InsertChatProvider.
|
||||
func (mr *MockStoreMockRecorder) InsertChatProvider(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatProvider", reflect.TypeOf((*MockStore)(nil).InsertChatProvider), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatQueuedMessage mocks base method.
|
||||
func (m *MockStore) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChatQueuedMessage", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatQueuedMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChatQueuedMessage indicates an expected call of InsertChatQueuedMessage.
|
||||
func (mr *MockStoreMockRecorder) InsertChatQueuedMessage(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatQueuedMessage", reflect.TypeOf((*MockStore)(nil).InsertChatQueuedMessage), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertCryptoKey mocks base method.
|
||||
func (m *MockStore) InsertCryptoKey(ctx context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6041,6 +6530,36 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeInterceptions(ctx, arg, p
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeInterceptions), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// ListChatsByRootID mocks base method.
|
||||
func (m *MockStore) ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListChatsByRootID", ctx, rootChatID)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListChatsByRootID indicates an expected call of ListChatsByRootID.
|
||||
func (mr *MockStoreMockRecorder) ListChatsByRootID(ctx, rootChatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChatsByRootID", reflect.TypeOf((*MockStore)(nil).ListChatsByRootID), ctx, rootChatID)
|
||||
}
|
||||
|
||||
// ListChildChatsByParentID mocks base method.
|
||||
func (m *MockStore) ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListChildChatsByParentID", ctx, parentChatID)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListChildChatsByParentID indicates an expected call of ListChildChatsByParentID.
|
||||
func (mr *MockStoreMockRecorder) ListChildChatsByParentID(ctx, parentChatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChildChatsByParentID", reflect.TypeOf((*MockStore)(nil).ListChildChatsByParentID), ctx, parentChatID)
|
||||
}
|
||||
|
||||
// ListProvisionerKeysByOrganization mocks base method.
|
||||
func (m *MockStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6220,6 +6739,21 @@ func (mr *MockStoreMockRecorder) Ping(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockStore)(nil).Ping), ctx)
|
||||
}
|
||||
|
||||
// PopNextQueuedMessage mocks base method.
|
||||
func (m *MockStore) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (database.ChatQueuedMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PopNextQueuedMessage", ctx, chatID)
|
||||
ret0, _ := ret[0].(database.ChatQueuedMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// PopNextQueuedMessage indicates an expected call of PopNextQueuedMessage.
|
||||
func (mr *MockStoreMockRecorder) PopNextQueuedMessage(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopNextQueuedMessage", reflect.TypeOf((*MockStore)(nil).PopNextQueuedMessage), ctx, chatID)
|
||||
}
|
||||
|
||||
// ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate mocks base method.
|
||||
func (m *MockStore) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6393,6 +6927,96 @@ func (mr *MockStoreMockRecorder) UpdateAPIKeyByID(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAPIKeyByID", reflect.TypeOf((*MockStore)(nil).UpdateAPIKeyByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatByID mocks base method.
|
||||
func (m *MockStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatByID", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatByID indicates an expected call of UpdateChatByID.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatByID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatByID", reflect.TypeOf((*MockStore)(nil).UpdateChatByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatModelConfig mocks base method.
|
||||
func (m *MockStore) UpdateChatModelConfig(ctx context.Context, arg database.UpdateChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatModelConfig", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatModelConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatModelConfig indicates an expected call of UpdateChatModelConfig.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatModelConfig(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatModelConfig", reflect.TypeOf((*MockStore)(nil).UpdateChatModelConfig), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatModelConfigByChatID mocks base method.
|
||||
func (m *MockStore) UpdateChatModelConfigByChatID(ctx context.Context, arg database.UpdateChatModelConfigByChatIDParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatModelConfigByChatID", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatModelConfigByChatID indicates an expected call of UpdateChatModelConfigByChatID.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatModelConfigByChatID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatModelConfigByChatID", reflect.TypeOf((*MockStore)(nil).UpdateChatModelConfigByChatID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatProvider mocks base method.
|
||||
func (m *MockStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatProvider", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatProvider indicates an expected call of UpdateChatProvider.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatProvider(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatProvider", reflect.TypeOf((*MockStore)(nil).UpdateChatProvider), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatStatus mocks base method.
|
||||
func (m *MockStore) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatStatus", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatStatus indicates an expected call of UpdateChatStatus.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatStatus(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatStatus", reflect.TypeOf((*MockStore)(nil).UpdateChatStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatWorkspace mocks base method.
|
||||
func (m *MockStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatWorkspace", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatWorkspace indicates an expected call of UpdateChatWorkspace.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatWorkspace(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatWorkspace", reflect.TypeOf((*MockStore)(nil).UpdateChatWorkspace), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateCryptoKeyDeletesAt mocks base method.
|
||||
func (m *MockStore) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7661,6 +8285,36 @@ func (mr *MockStoreMockRecorder) UpsertBoundaryUsageStats(ctx, arg any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertBoundaryUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertBoundaryUsageStats), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatDiffStatus mocks base method.
|
||||
func (m *MockStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatDiffStatus", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatDiffStatus)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertChatDiffStatus indicates an expected call of UpsertChatDiffStatus.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatDiffStatus(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDiffStatus", reflect.TypeOf((*MockStore)(nil).UpsertChatDiffStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatDiffStatusReference mocks base method.
|
||||
func (m *MockStore) UpsertChatDiffStatusReference(ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatDiffStatusReference", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatDiffStatus)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertChatDiffStatusReference indicates an expected call of UpsertChatDiffStatusReference.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatDiffStatusReference(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDiffStatusReference", reflect.TypeOf((*MockStore)(nil).UpsertChatDiffStatusReference), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertConnectionLog mocks base method.
|
||||
func (m *MockStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+203
@@ -258,6 +258,15 @@ CREATE TYPE build_reason AS ENUM (
|
||||
'task_resume'
|
||||
);
|
||||
|
||||
CREATE TYPE chat_status AS ENUM (
|
||||
'waiting',
|
||||
'pending',
|
||||
'running',
|
||||
'paused',
|
||||
'completed',
|
||||
'error'
|
||||
);
|
||||
|
||||
CREATE TYPE connection_status AS ENUM (
|
||||
'connected',
|
||||
'disconnected'
|
||||
@@ -1141,6 +1150,115 @@ COMMENT ON COLUMN boundary_usage_stats.window_start IS 'Start of the time window
|
||||
|
||||
COMMENT ON COLUMN boundary_usage_stats.updated_at IS 'Timestamp of the last update to this row.';
|
||||
|
||||
CREATE TABLE chat_diff_statuses (
|
||||
chat_id uuid NOT NULL,
|
||||
url text,
|
||||
pull_request_state text,
|
||||
changes_requested boolean DEFAULT false NOT NULL,
|
||||
additions integer DEFAULT 0 NOT NULL,
|
||||
deletions integer DEFAULT 0 NOT NULL,
|
||||
changed_files integer DEFAULT 0 NOT NULL,
|
||||
refreshed_at timestamp with time zone,
|
||||
stale_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
git_branch text DEFAULT ''::text NOT NULL,
|
||||
git_remote_origin text DEFAULT ''::text NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE chat_messages (
|
||||
id bigint NOT NULL,
|
||||
chat_id uuid NOT NULL,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
role text NOT NULL,
|
||||
content jsonb,
|
||||
tool_call_id text,
|
||||
thinking text,
|
||||
hidden boolean DEFAULT false NOT NULL,
|
||||
subagent_request_id uuid,
|
||||
subagent_event text,
|
||||
input_tokens bigint,
|
||||
output_tokens bigint,
|
||||
total_tokens bigint,
|
||||
reasoning_tokens bigint,
|
||||
cache_creation_tokens bigint,
|
||||
cache_read_tokens bigint,
|
||||
context_limit bigint,
|
||||
compressed boolean DEFAULT false NOT NULL,
|
||||
CONSTRAINT chat_messages_subagent_event_check CHECK (((subagent_event IS NULL) OR (subagent_event = ANY (ARRAY['request'::text, 'response'::text]))))
|
||||
);
|
||||
|
||||
CREATE SEQUENCE chat_messages_id_seq
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1;
|
||||
|
||||
ALTER SEQUENCE chat_messages_id_seq OWNED BY chat_messages.id;
|
||||
|
||||
CREATE TABLE chat_model_configs (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
provider text NOT NULL,
|
||||
model text NOT NULL,
|
||||
display_name text DEFAULT ''::text NOT NULL,
|
||||
enabled boolean DEFAULT true NOT NULL,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
context_limit bigint NOT NULL,
|
||||
compression_threshold integer NOT NULL,
|
||||
model_config jsonb DEFAULT '{}'::jsonb NOT NULL,
|
||||
CONSTRAINT chat_model_configs_compression_threshold_check CHECK (((compression_threshold >= 0) AND (compression_threshold <= 100))),
|
||||
CONSTRAINT chat_model_configs_context_limit_check CHECK ((context_limit > 0))
|
||||
);
|
||||
|
||||
CREATE TABLE chat_providers (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
provider text NOT NULL,
|
||||
display_name text DEFAULT ''::text NOT NULL,
|
||||
api_key text DEFAULT ''::text NOT NULL,
|
||||
api_key_key_id text,
|
||||
enabled boolean DEFAULT true NOT NULL,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
base_url text DEFAULT ''::text NOT NULL,
|
||||
CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text])))
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN chat_providers.api_key_key_id IS 'The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted';
|
||||
|
||||
CREATE TABLE chat_queued_messages (
|
||||
id bigint NOT NULL,
|
||||
chat_id uuid NOT NULL,
|
||||
content jsonb NOT NULL,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL
|
||||
);
|
||||
|
||||
CREATE SEQUENCE chat_queued_messages_id_seq
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1;
|
||||
|
||||
ALTER SEQUENCE chat_queued_messages_id_seq OWNED BY chat_queued_messages.id;
|
||||
|
||||
CREATE TABLE chats (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
owner_id uuid NOT NULL,
|
||||
workspace_id uuid,
|
||||
workspace_agent_id uuid,
|
||||
title text DEFAULT 'New Chat'::text NOT NULL,
|
||||
status chat_status DEFAULT 'waiting'::chat_status NOT NULL,
|
||||
model_config jsonb DEFAULT '{}'::jsonb NOT NULL,
|
||||
worker_id uuid,
|
||||
started_at timestamp with time zone,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
parent_chat_id uuid,
|
||||
root_chat_id uuid
|
||||
);
|
||||
|
||||
CREATE TABLE connection_logs (
|
||||
id uuid NOT NULL,
|
||||
connect_time timestamp with time zone NOT NULL,
|
||||
@@ -2939,6 +3057,10 @@ CREATE VIEW workspaces_expanded AS
|
||||
|
||||
COMMENT ON VIEW workspaces_expanded IS 'Joins in the display name information such as username, avatar, and organization name.';
|
||||
|
||||
ALTER TABLE ONLY chat_messages ALTER COLUMN id SET DEFAULT nextval('chat_messages_id_seq'::regclass);
|
||||
|
||||
ALTER TABLE ONLY chat_queued_messages ALTER COLUMN id SET DEFAULT nextval('chat_queued_messages_id_seq'::regclass);
|
||||
|
||||
ALTER TABLE ONLY licenses ALTER COLUMN id SET DEFAULT nextval('licenses_id_seq'::regclass);
|
||||
|
||||
ALTER TABLE ONLY provisioner_job_logs ALTER COLUMN id SET DEFAULT nextval('provisioner_job_logs_id_seq'::regclass);
|
||||
@@ -2975,6 +3097,27 @@ ALTER TABLE ONLY audit_logs
|
||||
ALTER TABLE ONLY boundary_usage_stats
|
||||
ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
|
||||
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
|
||||
ALTER TABLE ONLY chat_messages
|
||||
ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_model_configs
|
||||
ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_providers
|
||||
ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_providers
|
||||
ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider);
|
||||
|
||||
ALTER TABLE ONLY chat_queued_messages
|
||||
ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY connection_logs
|
||||
ADD CONSTRAINT connection_logs_pkey PRIMARY KEY (id);
|
||||
|
||||
@@ -3300,6 +3443,36 @@ CREATE INDEX idx_audit_log_user_id ON audit_logs USING btree (user_id);
|
||||
|
||||
CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC);
|
||||
|
||||
CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses USING btree (stale_at);
|
||||
|
||||
CREATE INDEX idx_chat_messages_chat ON chat_messages USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chat_messages_chat_created ON chat_messages USING btree (chat_id, created_at);
|
||||
|
||||
CREATE INDEX idx_chat_messages_compressed_summary_boundary ON chat_messages USING btree (chat_id, created_at DESC, id DESC) WHERE ((compressed = true) AND (role = 'system'::text) AND (hidden = true));
|
||||
|
||||
CREATE INDEX idx_chat_messages_subagent_request ON chat_messages USING btree (chat_id, subagent_request_id, created_at) WHERE (subagent_request_id IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_chat_model_configs_enabled ON chat_model_configs USING btree (enabled);
|
||||
|
||||
CREATE INDEX idx_chat_model_configs_provider ON chat_model_configs USING btree (provider);
|
||||
|
||||
CREATE INDEX idx_chat_model_configs_provider_model ON chat_model_configs USING btree (provider, model);
|
||||
|
||||
CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled);
|
||||
|
||||
CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_owner ON chats USING btree (owner_id);
|
||||
|
||||
CREATE INDEX idx_chats_parent_chat_id ON chats USING btree (parent_chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_pending ON chats USING btree (status) WHERE (status = 'pending'::chat_status);
|
||||
|
||||
CREATE INDEX idx_chats_root_chat_id ON chats USING btree (root_chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_workspace ON chats USING btree (workspace_id);
|
||||
|
||||
CREATE INDEX idx_connection_logs_connect_time_desc ON connection_logs USING btree (connect_time DESC);
|
||||
|
||||
CREATE UNIQUE INDEX idx_connection_logs_connection_id_workspace_id_agent_name ON connection_logs USING btree (connection_id, workspace_id, agent_name);
|
||||
@@ -3546,6 +3719,36 @@ ALTER TABLE ONLY aibridge_interceptions
|
||||
ALTER TABLE ONLY api_keys
|
||||
ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_messages
|
||||
ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_model_configs
|
||||
ADD CONSTRAINT chat_model_configs_provider_fkey FOREIGN KEY (provider) REFERENCES chat_providers(provider) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_providers
|
||||
ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY chat_queued_messages
|
||||
ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_parent_chat_id_fkey FOREIGN KEY (parent_chat_id) REFERENCES chats(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_root_chat_id_fkey FOREIGN KEY (root_chat_id) REFERENCES chats(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_workspace_agent_id_fkey FOREIGN KEY (workspace_agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY connection_logs
|
||||
ADD CONSTRAINT connection_logs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
|
||||
|
||||
@@ -8,6 +8,16 @@ type ForeignKeyConstraint string
|
||||
const (
|
||||
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
|
||||
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatMessagesChatID ForeignKeyConstraint = "chat_messages_chat_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatModelConfigsProvider ForeignKeyConstraint = "chat_model_configs_provider_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_provider_fkey FOREIGN KEY (provider) REFERENCES chat_providers(provider) ON DELETE CASCADE;
|
||||
ForeignKeyChatProvidersAPIKeyKeyID ForeignKeyConstraint = "chat_providers_api_key_key_id_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyChatQueuedMessagesChatID ForeignKeyConstraint = "chat_queued_messages_chat_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatsOwnerID ForeignKeyConstraint = "chats_owner_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatsParentChatID ForeignKeyConstraint = "chats_parent_chat_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_parent_chat_id_fkey FOREIGN KEY (parent_chat_id) REFERENCES chats(id) ON DELETE SET NULL;
|
||||
ForeignKeyChatsRootChatID ForeignKeyConstraint = "chats_root_chat_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_root_chat_id_fkey FOREIGN KEY (root_chat_id) REFERENCES chats(id) ON DELETE SET NULL;
|
||||
ForeignKeyChatsWorkspaceAgentID ForeignKeyConstraint = "chats_workspace_agent_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_workspace_agent_id_fkey FOREIGN KEY (workspace_agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
|
||||
ForeignKeyChatsWorkspaceID ForeignKeyConstraint = "chats_workspace_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE SET NULL;
|
||||
ForeignKeyConnectionLogsOrganizationID ForeignKeyConstraint = "connection_logs_organization_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
ForeignKeyConnectionLogsWorkspaceID ForeignKeyConstraint = "connection_logs_workspace_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
|
||||
ForeignKeyConnectionLogsWorkspaceOwnerID ForeignKeyConstraint = "connection_logs_workspace_owner_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_owner_id_fkey FOREIGN KEY (workspace_owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
DROP TABLE IF EXISTS chat_messages;
|
||||
DROP TABLE IF EXISTS chats;
|
||||
DROP TYPE IF EXISTS chat_status;
|
||||
@@ -0,0 +1,42 @@
|
||||
CREATE TYPE chat_status AS ENUM (
|
||||
'waiting', -- Waiting for user input or workspace
|
||||
'pending', -- Queued, waiting for a coderd replica to pick up
|
||||
'running', -- Being processed by a coderd replica
|
||||
'paused', -- Manually paused by user
|
||||
'completed', -- Finished (no pending work)
|
||||
'error' -- Failed, needs user intervention
|
||||
);
|
||||
|
||||
CREATE TABLE chats (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
owner_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
workspace_id UUID REFERENCES workspaces(id) ON DELETE SET NULL,
|
||||
workspace_agent_id UUID REFERENCES workspace_agents(id) ON DELETE SET NULL,
|
||||
title TEXT NOT NULL DEFAULT 'New Chat',
|
||||
status chat_status NOT NULL DEFAULT 'waiting',
|
||||
model_config JSONB NOT NULL DEFAULT '{}',
|
||||
-- Locking fields for multi-replica safety
|
||||
worker_id UUID,
|
||||
started_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chats_owner ON chats(owner_id);
|
||||
CREATE INDEX idx_chats_workspace ON chats(workspace_id);
|
||||
CREATE INDEX idx_chats_pending ON chats(status) WHERE status = 'pending';
|
||||
|
||||
CREATE TABLE chat_messages (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
role TEXT NOT NULL, -- 'user', 'assistant', 'system', 'tool'
|
||||
content JSONB, -- Text content or structured data
|
||||
tool_calls JSONB, -- For assistant messages with tool calls
|
||||
tool_call_id TEXT, -- For tool result messages
|
||||
thinking TEXT, -- Extended thinking content (if any)
|
||||
hidden BOOLEAN NOT NULL DEFAULT FALSE -- For system/hidden messages
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chat_messages_chat ON chat_messages(chat_id);
|
||||
CREATE INDEX idx_chat_messages_chat_created ON chat_messages(chat_id, created_at);
|
||||
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS chat_git_changes;
|
||||
@@ -0,0 +1,13 @@
|
||||
CREATE TABLE chat_git_changes (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
|
||||
file_path TEXT NOT NULL,
|
||||
change_type TEXT NOT NULL, -- 'added', 'modified', 'deleted', 'renamed'
|
||||
old_path TEXT, -- For renames
|
||||
diff_summary TEXT, -- Optional: lines added/removed summary
|
||||
detected_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
|
||||
UNIQUE(chat_id, file_path)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chat_git_changes_chat ON chat_git_changes(chat_id);
|
||||
@@ -0,0 +1,2 @@
|
||||
DROP TABLE IF EXISTS chat_model_configs;
|
||||
DROP TABLE IF EXISTS chat_providers;
|
||||
@@ -0,0 +1,29 @@
|
||||
CREATE TABLE chat_providers (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
provider TEXT NOT NULL UNIQUE CHECK (provider IN ('openai', 'anthropic')),
|
||||
display_name TEXT NOT NULL DEFAULT '',
|
||||
api_key TEXT NOT NULL DEFAULT '',
|
||||
api_key_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
|
||||
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN chat_providers.api_key_key_id IS 'The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted';
|
||||
|
||||
CREATE INDEX idx_chat_providers_enabled ON chat_providers(enabled);
|
||||
|
||||
CREATE TABLE chat_model_configs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
provider TEXT NOT NULL REFERENCES chat_providers(provider) ON DELETE CASCADE,
|
||||
model TEXT NOT NULL,
|
||||
display_name TEXT NOT NULL DEFAULT '',
|
||||
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
|
||||
UNIQUE(provider, model)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chat_model_configs_enabled ON chat_model_configs(enabled);
|
||||
CREATE INDEX idx_chat_model_configs_provider ON chat_model_configs(provider);
|
||||
@@ -0,0 +1,16 @@
|
||||
DROP TABLE IF EXISTS chat_messages;
|
||||
|
||||
CREATE TABLE chat_messages (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
role TEXT NOT NULL, -- 'user', 'assistant', 'system', 'tool'
|
||||
content JSONB, -- Text content or structured data
|
||||
tool_calls JSONB, -- For assistant messages with tool calls
|
||||
tool_call_id TEXT, -- For tool result messages
|
||||
thinking TEXT, -- Extended thinking content (if any)
|
||||
hidden BOOLEAN NOT NULL DEFAULT FALSE -- For system/hidden messages
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chat_messages_chat ON chat_messages(chat_id);
|
||||
CREATE INDEX idx_chat_messages_chat_created ON chat_messages(chat_id, created_at);
|
||||
@@ -0,0 +1,16 @@
|
||||
-- This migration intentionally recreates chat_messages in the new shape.
|
||||
DROP TABLE IF EXISTS chat_messages;
|
||||
|
||||
CREATE TABLE chat_messages (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
role TEXT NOT NULL, -- 'user', 'assistant', 'system', 'tool'
|
||||
content JSONB, -- Text content or structured data
|
||||
tool_call_id TEXT, -- For tool result messages
|
||||
thinking TEXT, -- Extended thinking content (if any)
|
||||
hidden BOOLEAN NOT NULL DEFAULT FALSE -- For system/hidden messages
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chat_messages_chat ON chat_messages(chat_id);
|
||||
CREATE INDEX idx_chat_messages_chat_created ON chat_messages(chat_id, created_at);
|
||||
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS chat_diff_statuses;
|
||||
@@ -0,0 +1,16 @@
|
||||
CREATE TABLE chat_diff_statuses (
|
||||
chat_id UUID PRIMARY KEY REFERENCES chats(id) ON DELETE CASCADE,
|
||||
github_pr_url TEXT,
|
||||
pull_request_state TEXT NOT NULL DEFAULT '',
|
||||
pull_request_open BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
changes_requested BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
additions INTEGER NOT NULL DEFAULT 0,
|
||||
deletions INTEGER NOT NULL DEFAULT 0,
|
||||
changed_files INTEGER NOT NULL DEFAULT 0,
|
||||
refreshed_at TIMESTAMPTZ,
|
||||
stale_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses(stale_at);
|
||||
@@ -0,0 +1,13 @@
|
||||
CREATE TABLE chat_git_changes (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
|
||||
file_path TEXT NOT NULL,
|
||||
change_type TEXT NOT NULL, -- 'added', 'modified', 'deleted', 'renamed'
|
||||
old_path TEXT, -- For renames
|
||||
diff_summary TEXT, -- Optional: lines added/removed summary
|
||||
detected_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
|
||||
UNIQUE(chat_id, file_path)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chat_git_changes_chat ON chat_git_changes(chat_id);
|
||||
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS chat_git_changes;
|
||||
@@ -0,0 +1,11 @@
|
||||
ALTER TABLE chat_diff_statuses
|
||||
ALTER COLUMN pull_request_state SET NOT NULL,
|
||||
ALTER COLUMN pull_request_state SET DEFAULT '';
|
||||
|
||||
ALTER TABLE chat_diff_statuses
|
||||
RENAME COLUMN url TO github_pr_url;
|
||||
|
||||
ALTER TABLE chat_diff_statuses
|
||||
ADD COLUMN pull_request_open BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
DROP COLUMN IF EXISTS git_branch,
|
||||
DROP COLUMN IF EXISTS git_remote_origin;
|
||||
@@ -0,0 +1,12 @@
|
||||
ALTER TABLE chat_diff_statuses
|
||||
ADD COLUMN git_branch TEXT NOT NULL DEFAULT '',
|
||||
ADD COLUMN git_remote_origin TEXT NOT NULL DEFAULT '',
|
||||
DROP COLUMN pull_request_open;
|
||||
|
||||
ALTER TABLE chat_diff_statuses
|
||||
RENAME COLUMN github_pr_url TO url;
|
||||
|
||||
ALTER TABLE chat_diff_statuses
|
||||
ALTER COLUMN pull_request_state DROP NOT NULL,
|
||||
ALTER COLUMN pull_request_state DROP DEFAULT,
|
||||
ALTER COLUMN pull_request_state TYPE TEXT;
|
||||
@@ -0,0 +1,12 @@
|
||||
ALTER TABLE chat_providers
|
||||
DROP CONSTRAINT IF EXISTS chat_providers_provider_check;
|
||||
|
||||
ALTER TABLE chat_providers
|
||||
ADD CONSTRAINT chat_providers_provider_check CHECK (
|
||||
provider = ANY (
|
||||
ARRAY[
|
||||
'openai'::text,
|
||||
'anthropic'::text
|
||||
]
|
||||
)
|
||||
);
|
||||
@@ -0,0 +1,18 @@
|
||||
ALTER TABLE chat_providers
|
||||
DROP CONSTRAINT IF EXISTS chat_providers_provider_check;
|
||||
|
||||
ALTER TABLE chat_providers
|
||||
ADD CONSTRAINT chat_providers_provider_check CHECK (
|
||||
provider = ANY (
|
||||
ARRAY[
|
||||
'anthropic'::text,
|
||||
'azure'::text,
|
||||
'bedrock'::text,
|
||||
'google'::text,
|
||||
'openai'::text,
|
||||
'openai-compat'::text,
|
||||
'openrouter'::text,
|
||||
'vercel'::text
|
||||
]
|
||||
)
|
||||
);
|
||||
@@ -0,0 +1,6 @@
|
||||
DROP INDEX IF EXISTS idx_chats_root_chat_id;
|
||||
DROP INDEX IF EXISTS idx_chats_parent_chat_id;
|
||||
|
||||
ALTER TABLE chats
|
||||
DROP COLUMN IF EXISTS root_chat_id,
|
||||
DROP COLUMN IF EXISTS parent_chat_id;
|
||||
@@ -0,0 +1,13 @@
|
||||
ALTER TABLE chats
|
||||
ADD COLUMN parent_chat_id UUID REFERENCES chats(id) ON DELETE SET NULL,
|
||||
ADD COLUMN root_chat_id UUID REFERENCES chats(id) ON DELETE SET NULL;
|
||||
|
||||
-- Existing chats are non-delegated; make them their own root.
|
||||
UPDATE chats
|
||||
SET
|
||||
root_chat_id = id
|
||||
WHERE
|
||||
root_chat_id IS NULL;
|
||||
|
||||
CREATE INDEX idx_chats_parent_chat_id ON chats(parent_chat_id);
|
||||
CREATE INDEX idx_chats_root_chat_id ON chats(root_chat_id);
|
||||
@@ -0,0 +1,8 @@
|
||||
DROP INDEX IF EXISTS idx_chat_messages_subagent_request;
|
||||
|
||||
ALTER TABLE chat_messages
|
||||
DROP CONSTRAINT IF EXISTS chat_messages_subagent_event_check;
|
||||
|
||||
ALTER TABLE chat_messages
|
||||
DROP COLUMN IF EXISTS subagent_event,
|
||||
DROP COLUMN IF EXISTS subagent_request_id;
|
||||
@@ -0,0 +1,11 @@
|
||||
ALTER TABLE chat_messages
|
||||
ADD COLUMN subagent_request_id UUID,
|
||||
ADD COLUMN subagent_event TEXT;
|
||||
|
||||
ALTER TABLE chat_messages
|
||||
ADD CONSTRAINT chat_messages_subagent_event_check
|
||||
CHECK (subagent_event IS NULL OR subagent_event IN ('request', 'response'));
|
||||
|
||||
CREATE INDEX idx_chat_messages_subagent_request
|
||||
ON chat_messages(chat_id, subagent_request_id, created_at)
|
||||
WHERE subagent_request_id IS NOT NULL;
|
||||
@@ -0,0 +1,10 @@
|
||||
ALTER TABLE chat_messages
|
||||
DROP COLUMN IF EXISTS context_limit,
|
||||
DROP COLUMN IF EXISTS cache_read_tokens,
|
||||
DROP COLUMN IF EXISTS cache_creation_tokens,
|
||||
DROP COLUMN IF EXISTS reasoning_tokens,
|
||||
DROP COLUMN IF EXISTS total_tokens,
|
||||
DROP COLUMN IF EXISTS output_tokens,
|
||||
DROP COLUMN IF EXISTS input_tokens,
|
||||
DROP COLUMN IF EXISTS cached_output_tokens,
|
||||
DROP COLUMN IF EXISTS cached_input_tokens;
|
||||
@@ -0,0 +1,39 @@
|
||||
ALTER TABLE chat_messages
|
||||
ADD COLUMN IF NOT EXISTS input_tokens BIGINT,
|
||||
ADD COLUMN IF NOT EXISTS output_tokens BIGINT,
|
||||
ADD COLUMN IF NOT EXISTS total_tokens BIGINT,
|
||||
ADD COLUMN IF NOT EXISTS reasoning_tokens BIGINT,
|
||||
ADD COLUMN IF NOT EXISTS cache_creation_tokens BIGINT,
|
||||
ADD COLUMN IF NOT EXISTS cache_read_tokens BIGINT,
|
||||
ADD COLUMN IF NOT EXISTS context_limit BIGINT;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = current_schema()
|
||||
AND table_name = 'chat_messages'
|
||||
AND column_name = 'cached_output_tokens'
|
||||
) THEN
|
||||
UPDATE chat_messages
|
||||
SET cache_creation_tokens = COALESCE(cache_creation_tokens, cached_output_tokens)
|
||||
WHERE cached_output_tokens IS NOT NULL;
|
||||
END IF;
|
||||
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = current_schema()
|
||||
AND table_name = 'chat_messages'
|
||||
AND column_name = 'cached_input_tokens'
|
||||
) THEN
|
||||
UPDATE chat_messages
|
||||
SET cache_read_tokens = COALESCE(cache_read_tokens, cached_input_tokens)
|
||||
WHERE cached_input_tokens IS NOT NULL;
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
ALTER TABLE chat_messages
|
||||
DROP COLUMN IF EXISTS cached_output_tokens,
|
||||
DROP COLUMN IF EXISTS cached_input_tokens;
|
||||
@@ -0,0 +1,12 @@
|
||||
DROP INDEX IF EXISTS idx_chat_messages_compressed_summary_boundary;
|
||||
|
||||
ALTER TABLE chat_messages
|
||||
DROP COLUMN IF EXISTS compressed;
|
||||
|
||||
ALTER TABLE chat_model_configs
|
||||
DROP CONSTRAINT IF EXISTS chat_model_configs_compression_threshold_check,
|
||||
DROP CONSTRAINT IF EXISTS chat_model_configs_context_limit_check;
|
||||
|
||||
ALTER TABLE chat_model_configs
|
||||
DROP COLUMN IF EXISTS compression_threshold,
|
||||
DROP COLUMN IF EXISTS context_limit;
|
||||
@@ -0,0 +1,28 @@
|
||||
ALTER TABLE chat_model_configs
|
||||
ADD COLUMN IF NOT EXISTS context_limit BIGINT,
|
||||
ADD COLUMN IF NOT EXISTS compression_threshold INTEGER;
|
||||
|
||||
-- Backfill existing rows so context compression can operate safely by default.
|
||||
UPDATE chat_model_configs
|
||||
SET
|
||||
context_limit = COALESCE(context_limit, 200000),
|
||||
compression_threshold = COALESCE(compression_threshold, 70);
|
||||
|
||||
ALTER TABLE chat_model_configs
|
||||
ALTER COLUMN context_limit SET NOT NULL,
|
||||
ALTER COLUMN compression_threshold SET NOT NULL;
|
||||
|
||||
ALTER TABLE chat_model_configs
|
||||
ADD CONSTRAINT chat_model_configs_context_limit_check
|
||||
CHECK (context_limit > 0),
|
||||
ADD CONSTRAINT chat_model_configs_compression_threshold_check
|
||||
CHECK (compression_threshold >= 0 AND compression_threshold <= 100);
|
||||
|
||||
ALTER TABLE chat_messages
|
||||
ADD COLUMN IF NOT EXISTS compressed BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_messages_compressed_summary_boundary
|
||||
ON chat_messages(chat_id, created_at DESC, id DESC)
|
||||
WHERE compressed = TRUE
|
||||
AND role = 'system'
|
||||
AND hidden = TRUE;
|
||||
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS chat_queued_messages;
|
||||
@@ -0,0 +1,8 @@
|
||||
CREATE TABLE chat_queued_messages (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
|
||||
content jsonb NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages(chat_id);
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE chat_providers
|
||||
DROP COLUMN base_url;
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE chat_providers
|
||||
ADD COLUMN base_url TEXT NOT NULL DEFAULT '';
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE chat_model_configs
|
||||
DROP COLUMN IF EXISTS model_config;
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE chat_model_configs
|
||||
ADD COLUMN IF NOT EXISTS model_config JSONB NOT NULL DEFAULT '{}'::jsonb;
|
||||
+19
@@ -0,0 +1,19 @@
|
||||
DROP INDEX IF EXISTS idx_chat_model_configs_provider_model;
|
||||
|
||||
WITH ranked AS (
|
||||
SELECT
|
||||
id,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY provider, model
|
||||
ORDER BY updated_at DESC, created_at DESC, id DESC
|
||||
) AS rownum
|
||||
FROM chat_model_configs
|
||||
)
|
||||
DELETE FROM chat_model_configs AS cmc
|
||||
USING ranked
|
||||
WHERE
|
||||
cmc.id = ranked.id
|
||||
AND ranked.rownum > 1;
|
||||
|
||||
ALTER TABLE chat_model_configs
|
||||
ADD CONSTRAINT chat_model_configs_provider_model_key UNIQUE (provider, model);
|
||||
@@ -0,0 +1,5 @@
|
||||
ALTER TABLE chat_model_configs
|
||||
DROP CONSTRAINT IF EXISTS chat_model_configs_provider_model_key;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_model_configs_provider_model
|
||||
ON chat_model_configs(provider, model);
|
||||
+37
@@ -0,0 +1,37 @@
|
||||
INSERT INTO chat_providers (
|
||||
id,
|
||||
provider,
|
||||
display_name,
|
||||
api_key,
|
||||
api_key_key_id,
|
||||
enabled,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
'0a8b2f84-b5a8-4c44-8c9f-e58c44a534a7',
|
||||
'openai',
|
||||
'OpenAI',
|
||||
'',
|
||||
NULL,
|
||||
TRUE,
|
||||
'2024-01-01 00:00:00+00',
|
||||
'2024-01-01 00:00:00+00'
|
||||
);
|
||||
|
||||
INSERT INTO chat_model_configs (
|
||||
id,
|
||||
provider,
|
||||
model,
|
||||
display_name,
|
||||
enabled,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
'9af5f8d5-6a57-4505-8a69-3d6c787b95fd',
|
||||
'openai',
|
||||
'gpt-5.2',
|
||||
'GPT 5.2',
|
||||
TRUE,
|
||||
'2024-01-01 00:00:00+00',
|
||||
'2024-01-01 00:00:00+00'
|
||||
);
|
||||
@@ -165,6 +165,10 @@ func (t TaskTable) RBACObject() rbac.Object {
|
||||
InOrg(t.OrganizationID)
|
||||
}
|
||||
|
||||
func (c Chat) RBACObject() rbac.Object {
|
||||
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String())
|
||||
}
|
||||
|
||||
func (s APIKeyScope) ToRBAC() rbac.ScopeName {
|
||||
switch s {
|
||||
case ApiKeyScopeCoderAll:
|
||||
|
||||
@@ -268,7 +268,7 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
|
||||
pq.Array(arg.TemplateIDs),
|
||||
pq.Array(arg.WorkspaceIds),
|
||||
arg.Name,
|
||||
arg.HasAgent,
|
||||
pq.Array(arg.HasAgentStatuses),
|
||||
arg.AgentInactiveDisconnectTimeoutSeconds,
|
||||
arg.Dormant,
|
||||
arg.LastUsedBefore,
|
||||
|
||||
@@ -1028,6 +1028,76 @@ func AllBuildReasonValues() []BuildReason {
|
||||
}
|
||||
}
|
||||
|
||||
type ChatStatus string
|
||||
|
||||
const (
|
||||
ChatStatusWaiting ChatStatus = "waiting"
|
||||
ChatStatusPending ChatStatus = "pending"
|
||||
ChatStatusRunning ChatStatus = "running"
|
||||
ChatStatusPaused ChatStatus = "paused"
|
||||
ChatStatusCompleted ChatStatus = "completed"
|
||||
ChatStatusError ChatStatus = "error"
|
||||
)
|
||||
|
||||
func (e *ChatStatus) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = ChatStatus(s)
|
||||
case string:
|
||||
*e = ChatStatus(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for ChatStatus: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type NullChatStatus struct {
|
||||
ChatStatus ChatStatus `json:"chat_status"`
|
||||
Valid bool `json:"valid"` // Valid is true if ChatStatus is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (ns *NullChatStatus) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
ns.ChatStatus, ns.Valid = "", false
|
||||
return nil
|
||||
}
|
||||
ns.Valid = true
|
||||
return ns.ChatStatus.Scan(value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (ns NullChatStatus) Value() (driver.Value, error) {
|
||||
if !ns.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return string(ns.ChatStatus), nil
|
||||
}
|
||||
|
||||
func (e ChatStatus) Valid() bool {
|
||||
switch e {
|
||||
case ChatStatusWaiting,
|
||||
ChatStatusPending,
|
||||
ChatStatusRunning,
|
||||
ChatStatusPaused,
|
||||
ChatStatusCompleted,
|
||||
ChatStatusError:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AllChatStatusValues() []ChatStatus {
|
||||
return []ChatStatus{
|
||||
ChatStatusWaiting,
|
||||
ChatStatusPending,
|
||||
ChatStatusRunning,
|
||||
ChatStatusPaused,
|
||||
ChatStatusCompleted,
|
||||
ChatStatusError,
|
||||
}
|
||||
}
|
||||
|
||||
type ConnectionStatus string
|
||||
|
||||
const (
|
||||
@@ -3732,6 +3802,92 @@ type BoundaryUsageStat struct {
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
type Chat struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
|
||||
WorkspaceAgentID uuid.NullUUID `db:"workspace_agent_id" json:"workspace_agent_id"`
|
||||
Title string `db:"title" json:"title"`
|
||||
Status ChatStatus `db:"status" json:"status"`
|
||||
ModelConfig json.RawMessage `db:"model_config" json:"model_config"`
|
||||
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
|
||||
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
|
||||
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
|
||||
}
|
||||
|
||||
type ChatDiffStatus struct {
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
Url sql.NullString `db:"url" json:"url"`
|
||||
PullRequestState sql.NullString `db:"pull_request_state" json:"pull_request_state"`
|
||||
ChangesRequested bool `db:"changes_requested" json:"changes_requested"`
|
||||
Additions int32 `db:"additions" json:"additions"`
|
||||
Deletions int32 `db:"deletions" json:"deletions"`
|
||||
ChangedFiles int32 `db:"changed_files" json:"changed_files"`
|
||||
RefreshedAt sql.NullTime `db:"refreshed_at" json:"refreshed_at"`
|
||||
StaleAt time.Time `db:"stale_at" json:"stale_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
GitBranch string `db:"git_branch" json:"git_branch"`
|
||||
GitRemoteOrigin string `db:"git_remote_origin" json:"git_remote_origin"`
|
||||
}
|
||||
|
||||
type ChatMessage struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
Role string `db:"role" json:"role"`
|
||||
Content pqtype.NullRawMessage `db:"content" json:"content"`
|
||||
ToolCallID sql.NullString `db:"tool_call_id" json:"tool_call_id"`
|
||||
Thinking sql.NullString `db:"thinking" json:"thinking"`
|
||||
Hidden bool `db:"hidden" json:"hidden"`
|
||||
SubagentRequestID uuid.NullUUID `db:"subagent_request_id" json:"subagent_request_id"`
|
||||
SubagentEvent sql.NullString `db:"subagent_event" json:"subagent_event"`
|
||||
InputTokens sql.NullInt64 `db:"input_tokens" json:"input_tokens"`
|
||||
OutputTokens sql.NullInt64 `db:"output_tokens" json:"output_tokens"`
|
||||
TotalTokens sql.NullInt64 `db:"total_tokens" json:"total_tokens"`
|
||||
ReasoningTokens sql.NullInt64 `db:"reasoning_tokens" json:"reasoning_tokens"`
|
||||
CacheCreationTokens sql.NullInt64 `db:"cache_creation_tokens" json:"cache_creation_tokens"`
|
||||
CacheReadTokens sql.NullInt64 `db:"cache_read_tokens" json:"cache_read_tokens"`
|
||||
ContextLimit sql.NullInt64 `db:"context_limit" json:"context_limit"`
|
||||
Compressed bool `db:"compressed" json:"compressed"`
|
||||
}
|
||||
|
||||
type ChatModelConfig struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
Model string `db:"model" json:"model"`
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ContextLimit int64 `db:"context_limit" json:"context_limit"`
|
||||
CompressionThreshold int32 `db:"compression_threshold" json:"compression_threshold"`
|
||||
ModelConfig json.RawMessage `db:"model_config" json:"model_config"`
|
||||
}
|
||||
|
||||
type ChatProvider struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
// The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
BaseUrl string `db:"base_url" json:"base_url"`
|
||||
}
|
||||
|
||||
type ChatQueuedMessage struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
Content json.RawMessage `db:"content" json:"content"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
type ConnectionLog struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
ConnectTime time.Time `db:"connect_time" json:"connect_time"`
|
||||
|
||||
@@ -12,6 +12,9 @@ import (
|
||||
)
|
||||
|
||||
type sqlcQuerier interface {
|
||||
// Acquires a pending chat for processing. Uses SKIP LOCKED to prevent
|
||||
// multiple replicas from acquiring the same chat.
|
||||
AcquireChat(ctx context.Context, arg AcquireChatParams) (Chat, error)
|
||||
// Blocks until the lock is acquired.
|
||||
//
|
||||
// This must be called from within a transaction. The lock will be automatically
|
||||
@@ -81,6 +84,7 @@ type sqlcQuerier interface {
|
||||
CustomRoles(ctx context.Context, arg CustomRolesParams) ([]CustomRole, error)
|
||||
DeleteAPIKeyByID(ctx context.Context, id string) error
|
||||
DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
|
||||
DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.UUID) error
|
||||
DeleteAllTailnetTunnels(ctx context.Context, arg DeleteAllTailnetTunnelsParams) error
|
||||
// Deletes all existing webpush subscriptions.
|
||||
// This should be called when the VAPID keypair is regenerated, as the old
|
||||
@@ -90,6 +94,11 @@ type sqlcQuerier interface {
|
||||
DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
|
||||
// Deletes boundary usage statistics for a specific replica.
|
||||
DeleteBoundaryUsageStatsByReplicaID(ctx context.Context, replicaID uuid.UUID) error
|
||||
DeleteChatByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error
|
||||
DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteChatQueuedMessage(ctx context.Context, arg DeleteChatQueuedMessageParams) error
|
||||
DeleteCryptoKey(ctx context.Context, arg DeleteCryptoKeyParams) (CryptoKey, error)
|
||||
DeleteCustomRole(ctx context.Context, arg DeleteCustomRoleParams) error
|
||||
DeleteExpiredAPIKeys(ctx context.Context, arg DeleteExpiredAPIKeysParams) (int64, error)
|
||||
@@ -200,6 +209,21 @@ type sqlcQuerier interface {
|
||||
// include data where window_start is within the given interval to exclude
|
||||
// stale data.
|
||||
GetBoundaryUsageSummary(ctx context.Context, maxStalenessMs int64) (GetBoundaryUsageSummaryRow, error)
|
||||
GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error)
|
||||
GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error)
|
||||
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
|
||||
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
|
||||
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
|
||||
GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
|
||||
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
|
||||
GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
|
||||
GetChatModelConfigByProviderAndModel(ctx context.Context, arg GetChatModelConfigByProviderAndModelParams) (ChatModelConfig, error)
|
||||
GetChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
|
||||
GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error)
|
||||
GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error)
|
||||
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
|
||||
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
|
||||
GetChatsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]Chat, error)
|
||||
GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error)
|
||||
GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error)
|
||||
GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error)
|
||||
@@ -215,6 +239,8 @@ type sqlcQuerier interface {
|
||||
GetDeploymentWorkspaceAgentUsageStats(ctx context.Context, createdAt time.Time) (GetDeploymentWorkspaceAgentUsageStatsRow, error)
|
||||
GetDeploymentWorkspaceStats(ctx context.Context) (GetDeploymentWorkspaceStatsRow, error)
|
||||
GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx context.Context, provisionerJobIds []uuid.UUID) ([]GetEligibleProvisionerDaemonsByProvisionerJobIDsRow, error)
|
||||
GetEnabledChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
|
||||
GetEnabledChatProviders(ctx context.Context) ([]ChatProvider, error)
|
||||
GetExternalAuthLink(ctx context.Context, arg GetExternalAuthLinkParams) (ExternalAuthLink, error)
|
||||
GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]ExternalAuthLink, error)
|
||||
GetFailedWorkspaceBuildsByTemplateID(ctx context.Context, arg GetFailedWorkspaceBuildsByTemplateIDParams) ([]GetFailedWorkspaceBuildsByTemplateIDRow, error)
|
||||
@@ -251,6 +277,7 @@ type sqlcQuerier interface {
|
||||
GetInboxNotificationsByUserID(ctx context.Context, arg GetInboxNotificationsByUserIDParams) ([]InboxNotification, error)
|
||||
GetLastUpdateCheck(ctx context.Context) (string, error)
|
||||
GetLatestCryptoKeyByFeature(ctx context.Context, feature CryptoKeyFeature) (CryptoKey, error)
|
||||
GetLatestPendingSubagentRequestIDByChatID(ctx context.Context, chatID uuid.UUID) (uuid.NullUUID, error)
|
||||
GetLatestWorkspaceAppStatusByAppID(ctx context.Context, appID uuid.UUID) (WorkspaceAppStatus, error)
|
||||
GetLatestWorkspaceAppStatusesByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAppStatus, error)
|
||||
GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error)
|
||||
@@ -353,6 +380,11 @@ type sqlcQuerier interface {
|
||||
GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error)
|
||||
GetRunningPrebuiltWorkspaces(ctx context.Context) ([]GetRunningPrebuiltWorkspacesRow, error)
|
||||
GetRuntimeConfig(ctx context.Context, key string) (string, error)
|
||||
// Find chats that appear stuck (running but no heartbeat).
|
||||
// Used for recovery after coderd crashes.
|
||||
GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error)
|
||||
GetSubagentRequestDurationByChatIDAndRequestID(ctx context.Context, arg GetSubagentRequestDurationByChatIDAndRequestIDParams) (int64, error)
|
||||
GetSubagentResponseMessageByChatIDAndRequestID(ctx context.Context, arg GetSubagentResponseMessageByChatIDAndRequestIDParams) (ChatMessage, error)
|
||||
GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]TailnetPeer, error)
|
||||
GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerBindingsRow, error)
|
||||
GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerIDsRow, error)
|
||||
@@ -550,6 +582,11 @@ type sqlcQuerier interface {
|
||||
// every member of the org.
|
||||
InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error)
|
||||
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
|
||||
InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error)
|
||||
InsertChatMessage(ctx context.Context, arg InsertChatMessageParams) (ChatMessage, error)
|
||||
InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error)
|
||||
InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error)
|
||||
InsertChatQueuedMessage(ctx context.Context, arg InsertChatQueuedMessageParams) (ChatQueuedMessage, error)
|
||||
InsertCryptoKey(ctx context.Context, arg InsertCryptoKeyParams) (CryptoKey, error)
|
||||
InsertCustomRole(ctx context.Context, arg InsertCustomRoleParams) (CustomRole, error)
|
||||
InsertDBCryptKey(ctx context.Context, arg InsertDBCryptKeyParams) error
|
||||
@@ -632,6 +669,8 @@ type sqlcQuerier interface {
|
||||
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
|
||||
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
|
||||
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
|
||||
ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]Chat, error)
|
||||
ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]Chat, error)
|
||||
ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
|
||||
@@ -648,6 +687,7 @@ type sqlcQuerier interface {
|
||||
// - Use both to get a specific org member row
|
||||
OrganizationMembers(ctx context.Context, arg OrganizationMembersParams) ([]OrganizationMembersRow, error)
|
||||
PaginatedOrganizationMembers(ctx context.Context, arg PaginatedOrganizationMembersParams) ([]PaginatedOrganizationMembersRow, error)
|
||||
PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (ChatQueuedMessage, error)
|
||||
ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error
|
||||
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
|
||||
RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error
|
||||
@@ -671,6 +711,12 @@ type sqlcQuerier interface {
|
||||
UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error
|
||||
UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error)
|
||||
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
|
||||
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
|
||||
UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error)
|
||||
UpdateChatModelConfigByChatID(ctx context.Context, arg UpdateChatModelConfigByChatIDParams) (Chat, error)
|
||||
UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error)
|
||||
UpdateChatStatus(ctx context.Context, arg UpdateChatStatusParams) (Chat, error)
|
||||
UpdateChatWorkspace(ctx context.Context, arg UpdateChatWorkspaceParams) (Chat, error)
|
||||
UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error)
|
||||
UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error)
|
||||
UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error)
|
||||
@@ -766,6 +812,8 @@ type sqlcQuerier interface {
|
||||
// the current in-memory state. Returns true if this was an insert (new period),
|
||||
// false if update.
|
||||
UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBoundaryUsageStatsParams) (bool, error)
|
||||
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
|
||||
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
|
||||
UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error)
|
||||
UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error
|
||||
// The default proxy is implied and not actually stored in the database.
|
||||
|
||||
@@ -7395,6 +7395,47 @@ func TestGetTaskByWorkspaceID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteTaskDeletesTaskSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
template := dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
})
|
||||
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
||||
TemplateID: uuid.NullUUID{UUID: template.ID, Valid: true},
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
})
|
||||
task := dbgen.Task(t, db, database.TaskTable{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
TemplateVersionID: templateVersion.ID,
|
||||
Prompt: "Test prompt",
|
||||
})
|
||||
|
||||
err := db.UpsertTaskSnapshot(ctx, database.UpsertTaskSnapshotParams{
|
||||
TaskID: task.ID,
|
||||
LogSnapshot: json.RawMessage(`{"messages":[]}`),
|
||||
LogSnapshotCreatedAt: dbtime.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.DeleteTask(ctx, database.DeleteTaskParams{
|
||||
ID: task.ID,
|
||||
DeletedAt: dbtime.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.GetTaskSnapshot(ctx, task.ID)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
}
|
||||
|
||||
func TestTaskNameUniqueness(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
+1886
-25
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,92 @@
|
||||
-- name: GetChatModelConfigByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_model_configs
|
||||
WHERE
|
||||
id = @id::uuid;
|
||||
|
||||
-- name: GetChatModelConfigByProviderAndModel :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_model_configs
|
||||
WHERE
|
||||
provider = @provider::text
|
||||
AND model = @model::text
|
||||
ORDER BY
|
||||
updated_at DESC,
|
||||
created_at DESC,
|
||||
id DESC
|
||||
LIMIT 1;
|
||||
|
||||
-- name: GetChatModelConfigs :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_model_configs
|
||||
ORDER BY
|
||||
provider ASC,
|
||||
model ASC,
|
||||
updated_at DESC,
|
||||
id DESC;
|
||||
|
||||
-- name: GetEnabledChatModelConfigs :many
|
||||
SELECT
|
||||
cmc.*
|
||||
FROM
|
||||
chat_model_configs cmc
|
||||
JOIN
|
||||
chat_providers cp ON cp.provider = cmc.provider
|
||||
WHERE
|
||||
cmc.enabled = TRUE
|
||||
AND cp.enabled = TRUE
|
||||
ORDER BY
|
||||
cmc.provider ASC,
|
||||
cmc.model ASC,
|
||||
cmc.updated_at DESC,
|
||||
cmc.id DESC;
|
||||
|
||||
-- name: InsertChatModelConfig :one
|
||||
INSERT INTO chat_model_configs (
|
||||
provider,
|
||||
model,
|
||||
display_name,
|
||||
enabled,
|
||||
context_limit,
|
||||
compression_threshold,
|
||||
model_config
|
||||
) VALUES (
|
||||
@provider::text,
|
||||
@model::text,
|
||||
@display_name::text,
|
||||
@enabled::boolean,
|
||||
@context_limit::bigint,
|
||||
@compression_threshold::integer,
|
||||
@model_config::jsonb
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatModelConfig :one
|
||||
UPDATE
|
||||
chat_model_configs
|
||||
SET
|
||||
provider = @provider::text,
|
||||
model = @model::text,
|
||||
display_name = @display_name::text,
|
||||
enabled = @enabled::boolean,
|
||||
context_limit = @context_limit::bigint,
|
||||
compression_threshold = @compression_threshold::integer,
|
||||
model_config = @model_config::jsonb,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: DeleteChatModelConfigByID :exec
|
||||
DELETE FROM
|
||||
chat_model_configs
|
||||
WHERE
|
||||
id = @id::uuid;
|
||||
@@ -0,0 +1,73 @@
|
||||
-- name: GetChatProviderByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
id = @id::uuid;
|
||||
|
||||
-- name: GetChatProviderByProvider :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
provider = @provider::text;
|
||||
|
||||
-- name: GetChatProviders :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_providers
|
||||
ORDER BY
|
||||
provider ASC;
|
||||
|
||||
-- name: GetEnabledChatProviders :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
enabled = TRUE
|
||||
ORDER BY
|
||||
provider ASC;
|
||||
|
||||
-- name: InsertChatProvider :one
|
||||
INSERT INTO chat_providers (
|
||||
provider,
|
||||
display_name,
|
||||
api_key,
|
||||
base_url,
|
||||
api_key_key_id,
|
||||
enabled
|
||||
) VALUES (
|
||||
@provider::text,
|
||||
@display_name::text,
|
||||
@api_key::text,
|
||||
@base_url::text,
|
||||
sqlc.narg('api_key_key_id')::text,
|
||||
@enabled::boolean
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatProvider :one
|
||||
UPDATE
|
||||
chat_providers
|
||||
SET
|
||||
display_name = @display_name::text,
|
||||
api_key = @api_key::text,
|
||||
base_url = @base_url::text,
|
||||
api_key_key_id = sqlc.narg('api_key_key_id')::text,
|
||||
enabled = @enabled::boolean,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: DeleteChatProviderByID :exec
|
||||
DELETE FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
id = @id::uuid;
|
||||
@@ -0,0 +1,465 @@
|
||||
-- name: DeleteChatByID :exec
|
||||
DELETE FROM
|
||||
chats
|
||||
WHERE
|
||||
id = @id::uuid;
|
||||
|
||||
-- name: DeleteChatMessagesByChatID :exec
|
||||
DELETE FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid;
|
||||
|
||||
-- name: GetChatByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
id = @id::uuid;
|
||||
|
||||
-- name: GetChatMessageByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
id = @id::bigint;
|
||||
|
||||
-- name: GetChatMessagesByChatID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid
|
||||
ORDER BY
|
||||
created_at ASC;
|
||||
|
||||
-- name: GetChatMessagesForPromptByChatID :many
|
||||
WITH latest_compressed_summary AS (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid
|
||||
AND role = 'system'
|
||||
AND hidden = TRUE
|
||||
AND compressed = TRUE
|
||||
ORDER BY
|
||||
created_at DESC,
|
||||
id DESC
|
||||
LIMIT
|
||||
1
|
||||
)
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid
|
||||
AND (
|
||||
(
|
||||
role = 'system'
|
||||
AND hidden = TRUE
|
||||
AND compressed = FALSE
|
||||
)
|
||||
OR (
|
||||
compressed = FALSE
|
||||
AND (
|
||||
NOT EXISTS (
|
||||
SELECT
|
||||
1
|
||||
FROM
|
||||
latest_compressed_summary
|
||||
)
|
||||
OR id > (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
latest_compressed_summary
|
||||
)
|
||||
)
|
||||
)
|
||||
OR id = (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
latest_compressed_summary
|
||||
)
|
||||
)
|
||||
ORDER BY
|
||||
created_at ASC,
|
||||
id ASC;
|
||||
|
||||
-- name: GetChatsByOwnerID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
owner_id = @owner_id::uuid
|
||||
ORDER BY
|
||||
updated_at DESC;
|
||||
|
||||
-- name: ListChildChatsByParentID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
parent_chat_id = @parent_chat_id::uuid
|
||||
ORDER BY
|
||||
created_at ASC;
|
||||
|
||||
-- name: ListChatsByRootID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
root_chat_id = @root_chat_id::uuid
|
||||
ORDER BY
|
||||
created_at ASC;
|
||||
|
||||
-- name: InsertChat :one
|
||||
INSERT INTO chats (
|
||||
owner_id,
|
||||
workspace_id,
|
||||
workspace_agent_id,
|
||||
parent_chat_id,
|
||||
root_chat_id,
|
||||
title,
|
||||
model_config
|
||||
) VALUES (
|
||||
@owner_id::uuid,
|
||||
sqlc.narg('workspace_id')::uuid,
|
||||
sqlc.narg('workspace_agent_id')::uuid,
|
||||
sqlc.narg('parent_chat_id')::uuid,
|
||||
sqlc.narg('root_chat_id')::uuid,
|
||||
@title::text,
|
||||
@model_config::jsonb
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: InsertChatMessage :one
|
||||
INSERT INTO chat_messages (
|
||||
chat_id,
|
||||
role,
|
||||
content,
|
||||
tool_call_id,
|
||||
thinking,
|
||||
hidden,
|
||||
subagent_request_id,
|
||||
subagent_event,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
total_tokens,
|
||||
reasoning_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
context_limit,
|
||||
compressed
|
||||
) VALUES (
|
||||
@chat_id::uuid,
|
||||
@role::text,
|
||||
sqlc.narg('content')::jsonb,
|
||||
sqlc.narg('tool_call_id')::text,
|
||||
sqlc.narg('thinking')::text,
|
||||
@hidden::boolean,
|
||||
sqlc.narg('subagent_request_id')::uuid,
|
||||
sqlc.narg('subagent_event')::text,
|
||||
sqlc.narg('input_tokens')::bigint,
|
||||
sqlc.narg('output_tokens')::bigint,
|
||||
sqlc.narg('total_tokens')::bigint,
|
||||
sqlc.narg('reasoning_tokens')::bigint,
|
||||
sqlc.narg('cache_creation_tokens')::bigint,
|
||||
sqlc.narg('cache_read_tokens')::bigint,
|
||||
sqlc.narg('context_limit')::bigint,
|
||||
COALESCE(sqlc.narg('compressed')::boolean, FALSE)
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: GetLatestPendingSubagentRequestIDByChatID :one
|
||||
WITH requests AS (
|
||||
SELECT
|
||||
subagent_request_id,
|
||||
MAX(created_at) AS requested_at
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid
|
||||
AND subagent_request_id IS NOT NULL
|
||||
AND subagent_event = 'request'
|
||||
GROUP BY
|
||||
subagent_request_id
|
||||
)
|
||||
SELECT
|
||||
COALESCE(
|
||||
requests.subagent_request_id,
|
||||
'00000000-0000-0000-0000-000000000000'::uuid
|
||||
) AS subagent_request_id
|
||||
FROM
|
||||
requests
|
||||
WHERE
|
||||
NOT EXISTS (
|
||||
SELECT
|
||||
1
|
||||
FROM
|
||||
chat_messages responses
|
||||
WHERE
|
||||
responses.chat_id = @chat_id::uuid
|
||||
AND responses.subagent_request_id = requests.subagent_request_id
|
||||
AND responses.subagent_event = 'response'
|
||||
)
|
||||
ORDER BY
|
||||
requests.requested_at DESC
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetSubagentResponseMessageByChatIDAndRequestID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid
|
||||
AND subagent_request_id = @subagent_request_id::uuid
|
||||
AND subagent_event = 'response'
|
||||
ORDER BY
|
||||
created_at DESC
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetSubagentRequestDurationByChatIDAndRequestID :one
|
||||
WITH request AS (
|
||||
SELECT
|
||||
MIN(created_at) AS created_at
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid
|
||||
AND subagent_request_id = @subagent_request_id::uuid
|
||||
AND subagent_event = 'request'
|
||||
),
|
||||
response AS (
|
||||
SELECT
|
||||
MAX(created_at) AS created_at
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid
|
||||
AND subagent_request_id = @subagent_request_id::uuid
|
||||
AND subagent_event = 'response'
|
||||
)
|
||||
SELECT
|
||||
COALESCE(
|
||||
CAST(EXTRACT(EPOCH FROM (response.created_at - request.created_at)) * 1000 AS BIGINT),
|
||||
0::BIGINT
|
||||
)::BIGINT AS duration_ms
|
||||
FROM
|
||||
request,
|
||||
response;
|
||||
|
||||
-- name: UpdateChatByID :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
title = @title::text,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatWorkspace :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
workspace_id = sqlc.narg('workspace_id')::uuid,
|
||||
workspace_agent_id = sqlc.narg('workspace_agent_id')::uuid,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatModelConfigByChatID :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
model_config = @model_config::jsonb,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: AcquireChat :one
|
||||
-- Acquires a pending chat for processing. Uses SKIP LOCKED to prevent
|
||||
-- multiple replicas from acquiring the same chat.
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
status = 'running'::chat_status,
|
||||
started_at = @started_at::timestamptz,
|
||||
updated_at = @started_at::timestamptz,
|
||||
worker_id = @worker_id::uuid
|
||||
WHERE
|
||||
id = (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
status = 'pending'::chat_status
|
||||
ORDER BY
|
||||
updated_at ASC
|
||||
FOR UPDATE
|
||||
SKIP LOCKED
|
||||
LIMIT
|
||||
1
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatStatus :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
status = @status::chat_status,
|
||||
worker_id = sqlc.narg('worker_id')::uuid,
|
||||
started_at = sqlc.narg('started_at')::timestamptz,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: GetStaleChats :many
|
||||
-- Find chats that appear stuck (running but no heartbeat).
|
||||
-- Used for recovery after coderd crashes.
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
status = 'running'::chat_status
|
||||
AND started_at < @stale_threshold::timestamptz;
|
||||
|
||||
-- name: GetChatDiffStatusByChatID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_diff_statuses
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid;
|
||||
|
||||
-- name: GetChatDiffStatusesByChatIDs :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_diff_statuses
|
||||
WHERE
|
||||
chat_id = ANY(@chat_ids::uuid[]);
|
||||
|
||||
-- name: UpsertChatDiffStatusReference :one
|
||||
INSERT INTO chat_diff_statuses (
|
||||
chat_id,
|
||||
url,
|
||||
git_branch,
|
||||
git_remote_origin,
|
||||
stale_at
|
||||
) VALUES (
|
||||
@chat_id::uuid,
|
||||
sqlc.narg('url')::text,
|
||||
@git_branch::text,
|
||||
@git_remote_origin::text,
|
||||
@stale_at::timestamptz
|
||||
)
|
||||
ON CONFLICT (chat_id) DO UPDATE
|
||||
SET
|
||||
url = CASE
|
||||
WHEN EXCLUDED.url IS NOT NULL THEN EXCLUDED.url
|
||||
ELSE chat_diff_statuses.url
|
||||
END,
|
||||
git_branch = CASE
|
||||
WHEN EXCLUDED.git_branch != '' THEN EXCLUDED.git_branch
|
||||
ELSE chat_diff_statuses.git_branch
|
||||
END,
|
||||
git_remote_origin = CASE
|
||||
WHEN EXCLUDED.git_remote_origin != '' THEN EXCLUDED.git_remote_origin
|
||||
ELSE chat_diff_statuses.git_remote_origin
|
||||
END,
|
||||
stale_at = EXCLUDED.stale_at,
|
||||
updated_at = NOW()
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpsertChatDiffStatus :one
|
||||
INSERT INTO chat_diff_statuses (
|
||||
chat_id,
|
||||
url,
|
||||
pull_request_state,
|
||||
changes_requested,
|
||||
additions,
|
||||
deletions,
|
||||
changed_files,
|
||||
refreshed_at,
|
||||
stale_at
|
||||
) VALUES (
|
||||
@chat_id::uuid,
|
||||
sqlc.narg('url')::text,
|
||||
sqlc.narg('pull_request_state')::text,
|
||||
@changes_requested::boolean,
|
||||
@additions::integer,
|
||||
@deletions::integer,
|
||||
@changed_files::integer,
|
||||
@refreshed_at::timestamptz,
|
||||
@stale_at::timestamptz
|
||||
)
|
||||
ON CONFLICT (chat_id) DO UPDATE
|
||||
SET
|
||||
url = EXCLUDED.url,
|
||||
pull_request_state = EXCLUDED.pull_request_state,
|
||||
changes_requested = EXCLUDED.changes_requested,
|
||||
additions = EXCLUDED.additions,
|
||||
deletions = EXCLUDED.deletions,
|
||||
changed_files = EXCLUDED.changed_files,
|
||||
refreshed_at = EXCLUDED.refreshed_at,
|
||||
stale_at = EXCLUDED.stale_at,
|
||||
updated_at = NOW()
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: InsertChatQueuedMessage :one
|
||||
INSERT INTO chat_queued_messages (chat_id, content)
|
||||
VALUES (@chat_id, @content)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetChatQueuedMessages :many
|
||||
SELECT * FROM chat_queued_messages
|
||||
WHERE chat_id = @chat_id
|
||||
ORDER BY id ASC;
|
||||
|
||||
-- name: DeleteChatQueuedMessage :exec
|
||||
DELETE FROM chat_queued_messages WHERE id = @id AND chat_id = @chat_id;
|
||||
|
||||
-- name: DeleteAllChatQueuedMessages :exec
|
||||
DELETE FROM chat_queued_messages WHERE chat_id = @chat_id;
|
||||
|
||||
-- name: PopNextQueuedMessage :one
|
||||
DELETE FROM chat_queued_messages
|
||||
WHERE id = (
|
||||
SELECT cqm.id FROM chat_queued_messages cqm
|
||||
WHERE cqm.chat_id = @chat_id
|
||||
ORDER BY cqm.id ASC
|
||||
LIMIT 1
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetChatByIDForUpdate :one
|
||||
SELECT * FROM chats WHERE id = @id::uuid FOR UPDATE;
|
||||
@@ -57,13 +57,31 @@ AND CASE WHEN @status::text != '' THEN tws.status = @status::task_status ELSE TR
|
||||
ORDER BY tws.created_at DESC;
|
||||
|
||||
-- name: DeleteTask :one
|
||||
UPDATE tasks
|
||||
SET
|
||||
deleted_at = @deleted_at::timestamptz
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
AND deleted_at IS NULL
|
||||
RETURNING *;
|
||||
WITH deleted_task AS (
|
||||
UPDATE
|
||||
tasks
|
||||
SET
|
||||
deleted_at = @deleted_at::timestamptz
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
AND deleted_at IS NULL
|
||||
RETURNING id
|
||||
), deleted_task_snapshot AS (
|
||||
DELETE FROM
|
||||
task_snapshots
|
||||
USING
|
||||
deleted_task
|
||||
WHERE
|
||||
task_snapshots.task_id = deleted_task.id
|
||||
)
|
||||
SELECT
|
||||
tasks.*
|
||||
FROM
|
||||
tasks
|
||||
JOIN
|
||||
deleted_task
|
||||
ON
|
||||
tasks.id = deleted_task.id;
|
||||
|
||||
|
||||
-- name: UpdateTaskPrompt :one
|
||||
|
||||
@@ -292,7 +292,7 @@ WHERE
|
||||
-- Filter by agent status
|
||||
-- has-agent: is only applicable for workspaces in "start" transition. Stopped and deleted workspaces don't have agents.
|
||||
AND CASE
|
||||
WHEN @has_agent :: text != '' THEN
|
||||
WHEN array_length(@has_agent_statuses :: text[], 1) > 0 THEN
|
||||
(
|
||||
SELECT COUNT(*)
|
||||
FROM
|
||||
@@ -303,14 +303,14 @@ WHERE
|
||||
workspace_agents.resource_id = workspace_resources.id
|
||||
WHERE
|
||||
workspace_resources.job_id = latest_build.provisioner_job_id AND
|
||||
latest_build.transition = 'start'::workspace_transition AND
|
||||
-- Filter out deleted sub agents.
|
||||
workspace_agents.deleted = FALSE AND
|
||||
@has_agent = (
|
||||
CASE
|
||||
WHEN workspace_agents.first_connected_at IS NULL THEN
|
||||
CASE
|
||||
WHEN workspace_agents.connection_timeout_seconds > 0 AND NOW() - workspace_agents.created_at > workspace_agents.connection_timeout_seconds * INTERVAL '1 second' THEN
|
||||
latest_build.transition = 'start'::workspace_transition AND
|
||||
-- Filter out deleted sub agents.
|
||||
workspace_agents.deleted = FALSE AND
|
||||
(
|
||||
CASE
|
||||
WHEN workspace_agents.first_connected_at IS NULL THEN
|
||||
CASE
|
||||
WHEN workspace_agents.connection_timeout_seconds > 0 AND NOW() - workspace_agents.created_at > workspace_agents.connection_timeout_seconds * INTERVAL '1 second' THEN
|
||||
'timeout'
|
||||
ELSE
|
||||
'connecting'
|
||||
@@ -321,12 +321,12 @@ WHERE
|
||||
'disconnected'
|
||||
WHEN workspace_agents.last_connected_at IS NOT NULL THEN
|
||||
'connected'
|
||||
ELSE
|
||||
NULL
|
||||
END
|
||||
)
|
||||
) > 0
|
||||
ELSE true
|
||||
ELSE
|
||||
NULL
|
||||
END
|
||||
) = ANY(@has_agent_statuses :: text[])
|
||||
) > 0
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by dormant workspaces.
|
||||
AND CASE
|
||||
@@ -398,7 +398,7 @@ WHERE
|
||||
filtered_workspaces fw
|
||||
ORDER BY
|
||||
-- To ensure that 'favorite' workspaces show up first in the list only for their owner.
|
||||
CASE WHEN owner_id = @requester_id AND favorite THEN 0 ELSE 1 END ASC,
|
||||
CASE WHEN favorite AND owner_username = (SELECT users.username FROM users WHERE users.id = @requester_id) THEN 0 ELSE 1 END ASC,
|
||||
(latest_build_completed_at IS NOT NULL AND
|
||||
latest_build_canceled_at IS NULL AND
|
||||
latest_build_error IS NULL AND
|
||||
|
||||
@@ -14,6 +14,13 @@ const (
|
||||
UniqueAPIKeysPkey UniqueConstraint = "api_keys_pkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_pkey PRIMARY KEY (id);
|
||||
UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
|
||||
UniqueBoundaryUsageStatsPkey UniqueConstraint = "boundary_usage_stats_pkey" // ALTER TABLE ONLY boundary_usage_stats ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
|
||||
UniqueChatDiffStatusesPkey UniqueConstraint = "chat_diff_statuses_pkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
UniqueChatMessagesPkey UniqueConstraint = "chat_messages_pkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
|
||||
UniqueChatModelConfigsPkey UniqueConstraint = "chat_model_configs_pkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id);
|
||||
UniqueChatProvidersPkey UniqueConstraint = "chat_providers_pkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id);
|
||||
UniqueChatProvidersProviderKey UniqueConstraint = "chat_providers_provider_key" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider);
|
||||
UniqueChatQueuedMessagesPkey UniqueConstraint = "chat_queued_messages_pkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
|
||||
UniqueChatsPkey UniqueConstraint = "chats_pkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_pkey PRIMARY KEY (id);
|
||||
UniqueConnectionLogsPkey UniqueConstraint = "connection_logs_pkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_pkey PRIMARY KEY (id);
|
||||
UniqueCryptoKeysPkey UniqueConstraint = "crypto_keys_pkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_pkey PRIMARY KEY (feature, sequence);
|
||||
UniqueCustomRolesUniqueKey UniqueConstraint = "custom_roles_unique_key" // ALTER TABLE ONLY custom_roles ADD CONSTRAINT custom_roles_unique_key UNIQUE (name, organization_id);
|
||||
|
||||
@@ -15,13 +15,24 @@ type requestIDContextKey struct{}
|
||||
|
||||
// RequestID returns the ID of the request.
|
||||
func RequestID(r *http.Request) uuid.UUID {
|
||||
rid, ok := r.Context().Value(requestIDContextKey{}).(uuid.UUID)
|
||||
rid, ok := RequestIDOptional(r)
|
||||
if !ok {
|
||||
panic("developer error: request id middleware not provided")
|
||||
}
|
||||
return rid
|
||||
}
|
||||
|
||||
// RequestIDOptional returns the request ID when present.
|
||||
func RequestIDOptional(r *http.Request) (uuid.UUID, bool) {
|
||||
rid, ok := r.Context().Value(requestIDContextKey{}).(uuid.UUID)
|
||||
return rid, ok
|
||||
}
|
||||
|
||||
// WithRequestID stores a request ID in the context.
|
||||
func WithRequestID(ctx context.Context, rid uuid.UUID) context.Context {
|
||||
return context.WithValue(ctx, requestIDContextKey{}, rid)
|
||||
}
|
||||
|
||||
// AttachRequestID adds a request ID to each HTTP request.
|
||||
func AttachRequestID(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
@@ -31,3 +33,16 @@ func TestRequestID(t *testing.T) {
|
||||
require.NotEmpty(t, res.Header.Get("X-Coder-Request-ID"))
|
||||
require.NotEmpty(t, rw.Body.Bytes())
|
||||
}
|
||||
|
||||
func TestRequestIDHelpers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
requestID := uuid.New()
|
||||
ctx := httpmw.WithRequestID(context.Background(), requestID)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
|
||||
|
||||
gotRequestID, ok := httpmw.RequestIDOptional(req)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, requestID, gotRequestID)
|
||||
require.Equal(t, requestID, httpmw.RequestID(req))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func ChatEventChannel(ownerID uuid.UUID) string {
|
||||
return fmt.Sprintf("chat:owner:%s", ownerID)
|
||||
}
|
||||
|
||||
func HandleChatEvent(cb func(ctx context.Context, payload ChatEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
cb(ctx, ChatEvent{}, xerrors.Errorf("chat event pubsub: %w", err))
|
||||
return
|
||||
}
|
||||
var payload ChatEvent
|
||||
if err := json.Unmarshal(message, &payload); err != nil {
|
||||
cb(ctx, ChatEvent{}, xerrors.Errorf("unmarshal chat event"))
|
||||
return
|
||||
}
|
||||
|
||||
cb(ctx, payload, err)
|
||||
}
|
||||
}
|
||||
|
||||
type ChatEvent struct {
|
||||
Kind ChatEventKind `json:"kind"`
|
||||
Chat codersdk.Chat `json:"chat"`
|
||||
}
|
||||
|
||||
type ChatEventKind string
|
||||
|
||||
const (
|
||||
ChatEventKindStatusChange ChatEventKind = "status_change"
|
||||
ChatEventKindTitleChange ChatEventKind = "title_change"
|
||||
ChatEventKindCreated ChatEventKind = "created"
|
||||
ChatEventKindDeleted ChatEventKind = "deleted"
|
||||
)
|
||||
@@ -72,6 +72,16 @@ var (
|
||||
Type: "boundary_usage",
|
||||
}
|
||||
|
||||
// ResourceChat
|
||||
// Valid Actions
|
||||
// - "ActionCreate" :: create a new chat
|
||||
// - "ActionDelete" :: delete a chat
|
||||
// - "ActionRead" :: read chat messages and metadata
|
||||
// - "ActionUpdate" :: update chat title or settings
|
||||
ResourceChat = Object{
|
||||
Type: "chat",
|
||||
}
|
||||
|
||||
// ResourceConnectionLog
|
||||
// Valid Actions
|
||||
// - "ActionRead" :: read connection logs
|
||||
@@ -427,6 +437,7 @@ func AllResources() []Objecter {
|
||||
ResourceAssignRole,
|
||||
ResourceAuditLog,
|
||||
ResourceBoundaryUsage,
|
||||
ResourceChat,
|
||||
ResourceConnectionLog,
|
||||
ResourceCryptoKey,
|
||||
ResourceDebugInfo,
|
||||
|
||||
@@ -75,6 +75,13 @@ var taskActions = map[Action]ActionDefinition{
|
||||
ActionDelete: "delete task",
|
||||
}
|
||||
|
||||
var chatActions = map[Action]ActionDefinition{
|
||||
ActionCreate: "create a new chat",
|
||||
ActionRead: "read chat messages and metadata",
|
||||
ActionUpdate: "update chat title or settings",
|
||||
ActionDelete: "delete a chat",
|
||||
}
|
||||
|
||||
// RBACPermissions is indexed by the type
|
||||
var RBACPermissions = map[string]PermissionDefinition{
|
||||
// Wildcard is every object, and the action "*" provides all actions.
|
||||
@@ -101,6 +108,9 @@ var RBACPermissions = map[string]PermissionDefinition{
|
||||
"task": {
|
||||
Actions: taskActions,
|
||||
},
|
||||
"chat": {
|
||||
Actions: chatActions,
|
||||
},
|
||||
// Dormant workspaces have the same perms as workspaces.
|
||||
"workspace_dormant": {
|
||||
Actions: workspaceActions,
|
||||
|
||||
@@ -28,6 +28,10 @@ const (
|
||||
ScopeBoundaryUsageDelete ScopeName = "boundary_usage:delete"
|
||||
ScopeBoundaryUsageRead ScopeName = "boundary_usage:read"
|
||||
ScopeBoundaryUsageUpdate ScopeName = "boundary_usage:update"
|
||||
ScopeChatCreate ScopeName = "chat:create"
|
||||
ScopeChatDelete ScopeName = "chat:delete"
|
||||
ScopeChatRead ScopeName = "chat:read"
|
||||
ScopeChatUpdate ScopeName = "chat:update"
|
||||
ScopeConnectionLogRead ScopeName = "connection_log:read"
|
||||
ScopeConnectionLogUpdate ScopeName = "connection_log:update"
|
||||
ScopeCryptoKeyCreate ScopeName = "crypto_key:create"
|
||||
@@ -186,6 +190,10 @@ func (e ScopeName) Valid() bool {
|
||||
ScopeBoundaryUsageDelete,
|
||||
ScopeBoundaryUsageRead,
|
||||
ScopeBoundaryUsageUpdate,
|
||||
ScopeChatCreate,
|
||||
ScopeChatDelete,
|
||||
ScopeChatRead,
|
||||
ScopeChatUpdate,
|
||||
ScopeConnectionLogRead,
|
||||
ScopeConnectionLogUpdate,
|
||||
ScopeCryptoKeyCreate,
|
||||
@@ -345,6 +353,10 @@ func AllScopeNameValues() []ScopeName {
|
||||
ScopeBoundaryUsageDelete,
|
||||
ScopeBoundaryUsageRead,
|
||||
ScopeBoundaryUsageUpdate,
|
||||
ScopeChatCreate,
|
||||
ScopeChatDelete,
|
||||
ScopeChatRead,
|
||||
ScopeChatUpdate,
|
||||
ScopeConnectionLogRead,
|
||||
ScopeConnectionLogUpdate,
|
||||
ScopeCryptoKeyCreate,
|
||||
|
||||
@@ -253,7 +253,7 @@ func Workspaces(ctx context.Context, db database.Store, query string, page coder
|
||||
filter.TemplateName = parser.String(values, "", "template")
|
||||
filter.Name = parser.String(values, "", "name")
|
||||
filter.Status = string(httpapi.ParseCustom(parser, values, "", "status", httpapi.ParseEnum[database.WorkspaceStatus]))
|
||||
filter.HasAgent = parser.String(values, "", "has-agent")
|
||||
filter.HasAgentStatuses = parser.Strings(values, nil, "has-agent")
|
||||
filter.Dormant = parser.Boolean(values, false, "dormant")
|
||||
filter.LastUsedAfter = parser.Time3339Nano(values, time.Time{}, "last_used_after")
|
||||
filter.LastUsedBefore = parser.Time3339Nano(values, time.Time{}, "last_used_before")
|
||||
@@ -272,6 +272,15 @@ func Workspaces(ctx context.Context, db database.Store, query string, page coder
|
||||
// TODO: support "me" by passing in the actorID
|
||||
filter.SharedWithUserID = parseUser(ctx, db, parser, values, "shared_with_user", uuid.Nil)
|
||||
filter.SharedWithGroupID = parseGroup(ctx, db, parser, values, "shared_with_group")
|
||||
// Translate healthy filter to has-agent statuses.
|
||||
// healthy:true = connected, healthy:false = disconnected or timeout.
|
||||
if healthy := parser.NullableBoolean(values, sql.NullBool{}, "healthy"); healthy.Valid {
|
||||
if healthy.Bool {
|
||||
filter.HasAgentStatuses = append(filter.HasAgentStatuses, "connected")
|
||||
} else {
|
||||
filter.HasAgentStatuses = append(filter.HasAgentStatuses, "disconnected", "timeout")
|
||||
}
|
||||
}
|
||||
|
||||
type paramMatch struct {
|
||||
name string
|
||||
|
||||
@@ -141,7 +141,6 @@ const AgentAPIVersionREST = "1.0"
|
||||
func (api *API) patchWorkspaceAgentLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
|
||||
var req agentsdk.PatchLogs
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
@@ -1948,11 +1947,16 @@ func convertWorkspaceAgentMetadata(db []database.WorkspaceAgentMetadatum) []code
|
||||
// @Param match query string true "Match"
|
||||
// @Param id query string true "Provider ID"
|
||||
// @Param listen query bool false "Wait for a new token to be issued"
|
||||
// @Param workdir query string false "Working directory used for git context refresh"
|
||||
// @Success 200 {object} agentsdk.ExternalAuthResponse
|
||||
// @Router /workspaceagents/me/external-auth [get]
|
||||
func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
query := r.URL.Query()
|
||||
gitRef := chatGitRef{
|
||||
Branch: strings.TrimSpace(query.Get("git_branch")),
|
||||
RemoteOrigin: strings.TrimSpace(query.Get("git_remote_origin")),
|
||||
}
|
||||
// Either match or configID must be provided!
|
||||
match := query.Get("match")
|
||||
if match == "" {
|
||||
@@ -1975,7 +1979,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
|
||||
// listen determines if the request will wait for a
|
||||
// new token to be issued!
|
||||
listen := r.URL.Query().Has("listen")
|
||||
listen := query.Has("listen")
|
||||
|
||||
var externalAuthConfig *externalauth.Config
|
||||
for _, extAuth := range api.ExternalAuthConfigs {
|
||||
@@ -2046,6 +2050,12 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
// Persist git refs as soon as the agent requests external auth so branch
|
||||
// context is retained even if the flow requires an out-of-band login.
|
||||
if gitRef.Branch != "" || gitRef.RemoteOrigin != "" {
|
||||
api.storeChatGitRef(dbauthz.AsSystemRestricted(ctx), workspace.ID, workspace.OwnerID, gitRef)
|
||||
}
|
||||
|
||||
var previousToken *database.ExternalAuthLink
|
||||
// handleRetrying will attempt to continually check for a new token
|
||||
// if listen is true. This is useful if an error is encountered in the
|
||||
@@ -2059,7 +2069,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace)
|
||||
api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace, gitRef)
|
||||
}
|
||||
|
||||
// This is the URL that will redirect the user with a state token.
|
||||
@@ -2117,10 +2127,11 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
})
|
||||
return
|
||||
}
|
||||
api.triggerWorkspaceChatDiffStatusRefresh(workspace, gitRef)
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace) {
|
||||
func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace, gitRef chatGitRef) {
|
||||
// Since we're ticking frequently and this sign-in operation is rare,
|
||||
// we are OK with polling to avoid the complexity of pubsub.
|
||||
ticker, done := api.NewTicker(time.Second)
|
||||
@@ -2190,6 +2201,7 @@ func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.R
|
||||
})
|
||||
return
|
||||
}
|
||||
api.triggerWorkspaceChatDiffStatusRefresh(workspace, gitRef)
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2765,6 +2765,52 @@ func TestWorkspaceAgentExternalAuthListen(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkspaceAgentExternalAuthStoresGitRef(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const providerID = "github"
|
||||
|
||||
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||
ExternalAuthConfigs: []*externalauth.Config{
|
||||
{
|
||||
ID: providerID,
|
||||
Type: codersdk.EnhancedExternalAuthProviderGitHub.String(),
|
||||
},
|
||||
},
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
|
||||
agents[0].Directory = tmpDir
|
||||
return agents
|
||||
}).Do()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
workspaceID := r.Workspace.ID
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Message: "Track branch status from external auth.",
|
||||
WorkspaceID: &workspaceID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
|
||||
_, err = agentClient.ExternalAuth(ctx, agentsdk.ExternalAuthRequest{
|
||||
ID: providerID,
|
||||
GitBranch: "feature/cache-git-ref",
|
||||
GitRemoteOrigin: "https://github.com/coder/coder.git",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := api.Database.GetChatDiffStatusByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "feature/cache-git-ref", status.GitBranch)
|
||||
require.Equal(t, "https://github.com/coder/coder.git", status.GitRemoteOrigin)
|
||||
}
|
||||
|
||||
func TestOwnedWorkspacesCoordinate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -605,6 +605,17 @@ type ExternalAuthRequest struct {
|
||||
ID string
|
||||
// Match is an arbitrary string matched against the regex of the provider.
|
||||
Match string
|
||||
// Workdir is an optional working directory used for follow-up workspace
|
||||
// context refreshes.
|
||||
Workdir string
|
||||
// GitBranch is the current git branch in the working directory.
|
||||
// Sent by the agent so the control plane can resolve diffs
|
||||
// without SSHing into the workspace.
|
||||
GitBranch string
|
||||
// GitRemoteOrigin is the remote origin URL of the git repository.
|
||||
// Sent by the agent so the control plane can resolve diffs
|
||||
// without SSHing into the workspace.
|
||||
GitRemoteOrigin string
|
||||
// Listen indicates that the request should be long-lived and listen for
|
||||
// a new token to be requested.
|
||||
Listen bool
|
||||
@@ -620,6 +631,15 @@ func (c *Client) ExternalAuth(ctx context.Context, req ExternalAuthRequest) (Ext
|
||||
if req.Listen {
|
||||
q.Set("listen", "true")
|
||||
}
|
||||
if req.Workdir != "" {
|
||||
q.Set("workdir", req.Workdir)
|
||||
}
|
||||
if req.GitBranch != "" {
|
||||
q.Set("git_branch", req.GitBranch)
|
||||
}
|
||||
if req.GitRemoteOrigin != "" {
|
||||
q.Set("git_remote_origin", req.GitRemoteOrigin)
|
||||
}
|
||||
reqURL := "/api/v2/workspaceagents/me/external-auth?" + q.Encode()
|
||||
res, err := c.SDK.Request(ctx, http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -153,3 +153,52 @@ func TestRewriteDERPMap(t *testing.T) {
|
||||
require.Equal(t, "coconuts.org", node.HostName)
|
||||
require.Equal(t, 44558, node.DERPPort)
|
||||
}
|
||||
|
||||
func TestExternalAuthRequestWorkdirQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("IncludesWorkdirWhenSet", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const expectedWorkdir = "/tmp/repo with spaces"
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/api/v2/workspaceagents/me/external-auth", r.URL.Path)
|
||||
require.Equal(t, "true", r.URL.Query().Get("listen"))
|
||||
require.Equal(t, expectedWorkdir, r.URL.Query().Get("workdir"))
|
||||
_, _ = w.Write([]byte(`{"type":"github","access_token":"token"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
parsedURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := agentsdk.New(parsedURL, agentsdk.WithFixedToken("token"))
|
||||
_, err = client.ExternalAuth(testutil.Context(t, testutil.WaitShort), agentsdk.ExternalAuthRequest{
|
||||
Match: "github.com",
|
||||
Listen: true,
|
||||
Workdir: expectedWorkdir,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("OmitsWorkdirWhenNotSet", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/api/v2/workspaceagents/me/external-auth", r.URL.Path)
|
||||
require.Equal(t, "", r.URL.Query().Get("workdir"))
|
||||
require.False(t, r.URL.Query().Has("workdir"))
|
||||
_, _ = w.Write([]byte(`{"type":"github","access_token":"token"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
parsedURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := agentsdk.New(parsedURL, agentsdk.WithFixedToken("token"))
|
||||
_, err = client.ExternalAuth(testutil.Context(t, testutil.WaitShort), agentsdk.ExternalAuthRequest{
|
||||
Match: "github.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -33,6 +33,11 @@ const (
|
||||
APIKeyScopeBoundaryUsageDelete APIKeyScope = "boundary_usage:delete"
|
||||
APIKeyScopeBoundaryUsageRead APIKeyScope = "boundary_usage:read"
|
||||
APIKeyScopeBoundaryUsageUpdate APIKeyScope = "boundary_usage:update"
|
||||
APIKeyScopeChatAll APIKeyScope = "chat:*"
|
||||
APIKeyScopeChatCreate APIKeyScope = "chat:create"
|
||||
APIKeyScopeChatDelete APIKeyScope = "chat:delete"
|
||||
APIKeyScopeChatRead APIKeyScope = "chat:read"
|
||||
APIKeyScopeChatUpdate APIKeyScope = "chat:update"
|
||||
APIKeyScopeCoderAll APIKeyScope = "coder:all"
|
||||
APIKeyScopeCoderApikeysManageSelf APIKeyScope = "coder:apikeys.manage_self"
|
||||
APIKeyScopeCoderApplicationConnect APIKeyScope = "coder:application_connect"
|
||||
|
||||
@@ -0,0 +1,760 @@
|
||||
package codersdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ChatStatus represents the status of a chat.
|
||||
type ChatStatus string
|
||||
|
||||
const (
|
||||
ChatStatusWaiting ChatStatus = "waiting"
|
||||
ChatStatusPending ChatStatus = "pending"
|
||||
ChatStatusRunning ChatStatus = "running"
|
||||
ChatStatusPaused ChatStatus = "paused"
|
||||
ChatStatusCompleted ChatStatus = "completed"
|
||||
ChatStatusError ChatStatus = "error"
|
||||
)
|
||||
|
||||
// ChatWorkspaceMode represents how chat tools access files and commands.
|
||||
type ChatWorkspaceMode string
|
||||
|
||||
const (
|
||||
ChatWorkspaceModeWorkspace ChatWorkspaceMode = "workspace"
|
||||
ChatWorkspaceModeLocal ChatWorkspaceMode = "local"
|
||||
)
|
||||
|
||||
// Chat represents a chat session with an AI agent.
|
||||
type Chat struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
OwnerID uuid.UUID `json:"owner_id" format:"uuid"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"`
|
||||
WorkspaceAgentID *uuid.UUID `json:"workspace_agent_id,omitempty" format:"uuid"`
|
||||
WorkspaceMode ChatWorkspaceMode `json:"workspace_mode,omitempty"`
|
||||
ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"`
|
||||
RootChatID *uuid.UUID `json:"root_chat_id,omitempty" format:"uuid"`
|
||||
Title string `json:"title"`
|
||||
Status ChatStatus `json:"status"`
|
||||
DiffStatus *ChatDiffStatus `json:"diff_status,omitempty"`
|
||||
ModelConfig json.RawMessage `json:"model_config,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
|
||||
}
|
||||
|
||||
// ChatMessage represents a single message in a chat.
|
||||
type ChatMessage struct {
|
||||
ID int64 `json:"id"`
|
||||
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
Parts []ChatMessagePart `json:"parts,omitempty"`
|
||||
ToolCallID *string `json:"tool_call_id,omitempty"`
|
||||
Thinking *string `json:"thinking,omitempty"`
|
||||
Hidden bool `json:"hidden"`
|
||||
InputTokens *int64 `json:"input_tokens,omitempty"`
|
||||
OutputTokens *int64 `json:"output_tokens,omitempty"`
|
||||
TotalTokens *int64 `json:"total_tokens,omitempty"`
|
||||
ReasoningTokens *int64 `json:"reasoning_tokens,omitempty"`
|
||||
CacheCreationTokens *int64 `json:"cache_creation_tokens,omitempty"`
|
||||
CacheReadTokens *int64 `json:"cache_read_tokens,omitempty"`
|
||||
ContextLimit *int64 `json:"context_limit,omitempty"`
|
||||
}
|
||||
|
||||
// ChatMessagePartType represents a structured message part type.
|
||||
type ChatMessagePartType string
|
||||
|
||||
const (
|
||||
ChatMessagePartTypeText ChatMessagePartType = "text"
|
||||
ChatMessagePartTypeReasoning ChatMessagePartType = "reasoning"
|
||||
ChatMessagePartTypeToolCall ChatMessagePartType = "tool-call"
|
||||
ChatMessagePartTypeToolResult ChatMessagePartType = "tool-result"
|
||||
ChatMessagePartTypeSource ChatMessagePartType = "source"
|
||||
ChatMessagePartTypeFile ChatMessagePartType = "file"
|
||||
)
|
||||
|
||||
// ChatToolResultMetadata exposes commonly used tool-result fields for rendering.
|
||||
type ChatToolResultMetadata struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
Output string `json:"output,omitempty"`
|
||||
ExitCode *int `json:"exit_code,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
MimeType string `json:"mime_type,omitempty"`
|
||||
Created *bool `json:"created,omitempty"`
|
||||
WorkspaceID string `json:"workspace_id,omitempty"`
|
||||
WorkspaceAgentID string `json:"workspace_agent_id,omitempty"`
|
||||
WorkspaceName string `json:"workspace_name,omitempty"`
|
||||
WorkspaceURL string `json:"workspace_url,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// ChatMessagePart is a structured chunk of a chat message.
|
||||
type ChatMessagePart struct {
|
||||
Type ChatMessagePartType `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
Args json.RawMessage `json:"args,omitempty"`
|
||||
ArgsDelta string `json:"args_delta,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
ResultDelta string `json:"result_delta,omitempty"`
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
ResultMeta *ChatToolResultMetadata `json:"result_meta,omitempty"`
|
||||
SourceID string `json:"source_id,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Title string `json:"title,omitempty"`
|
||||
MediaType string `json:"media_type,omitempty"`
|
||||
Data []byte `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// ChatInputPartType represents an input part type for user chat input.
|
||||
type ChatInputPartType string
|
||||
|
||||
const (
|
||||
ChatInputPartTypeText ChatInputPartType = "text"
|
||||
)
|
||||
|
||||
// ChatInputPart is a single user input part for creating a chat.
|
||||
type ChatInputPart struct {
|
||||
Type ChatInputPartType `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// ChatInput is the structured user input payload for chat creation.
|
||||
type ChatInput struct {
|
||||
Parts []ChatInputPart `json:"parts"`
|
||||
}
|
||||
|
||||
// CreateChatRequest is the request to create a new chat.
|
||||
type CreateChatRequest struct {
|
||||
Input *ChatInput `json:"input,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
SystemPrompt string `json:"system_prompt,omitempty"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"`
|
||||
WorkspaceAgentID *uuid.UUID `json:"workspace_agent_id,omitempty" format:"uuid"`
|
||||
WorkspaceMode ChatWorkspaceMode `json:"workspace_mode,omitempty"`
|
||||
ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"`
|
||||
Model string `json:"model,omitempty"`
|
||||
ModelConfig json.RawMessage `json:"model_config,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateChatRequest is the request to update a chat.
|
||||
type UpdateChatRequest struct {
|
||||
Title string `json:"title"`
|
||||
}
|
||||
|
||||
// CreateChatMessageRequest is the request to add a message to a chat.
|
||||
type CreateChatMessageRequest struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
ToolCallID *string `json:"tool_call_id,omitempty"`
|
||||
Thinking *string `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// CreateChatMessageResponse is the response from adding a message to a chat.
|
||||
type CreateChatMessageResponse struct {
|
||||
Messages []ChatMessage `json:"messages,omitempty"`
|
||||
QueuedMessage *ChatQueuedMessage `json:"queued_message,omitempty"`
|
||||
Queued bool `json:"queued"`
|
||||
}
|
||||
|
||||
// ChatWithMessages is a chat along with its messages.
|
||||
type ChatWithMessages struct {
|
||||
Chat Chat `json:"chat"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
QueuedMessages []ChatQueuedMessage `json:"queued_messages"`
|
||||
}
|
||||
|
||||
// ChatModelProviderUnavailableReason explains why a provider cannot be used.
|
||||
type ChatModelProviderUnavailableReason string
|
||||
|
||||
const (
|
||||
ChatModelProviderUnavailableMissingAPIKey ChatModelProviderUnavailableReason = "missing_api_key"
|
||||
ChatModelProviderUnavailableFetchFailed ChatModelProviderUnavailableReason = "fetch_failed"
|
||||
)
|
||||
|
||||
// ChatModel represents a model in the chat model catalog.
|
||||
type ChatModel struct {
|
||||
ID string `json:"id"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
|
||||
// ChatModelProvider represents provider availability and model results.
|
||||
type ChatModelProvider struct {
|
||||
Provider string `json:"provider"`
|
||||
Available bool `json:"available"`
|
||||
UnavailableReason ChatModelProviderUnavailableReason `json:"unavailable_reason,omitempty"`
|
||||
Models []ChatModel `json:"models"`
|
||||
}
|
||||
|
||||
// ChatModelsResponse is the catalog returned from chat model discovery.
|
||||
type ChatModelsResponse struct {
|
||||
Providers []ChatModelProvider `json:"providers"`
|
||||
}
|
||||
|
||||
// ChatProviderConfigSource describes how a provider entry is sourced.
|
||||
type ChatProviderConfigSource string
|
||||
|
||||
const (
|
||||
ChatProviderConfigSourceDatabase ChatProviderConfigSource = "database"
|
||||
ChatProviderConfigSourceEnvPreset ChatProviderConfigSource = "env_preset"
|
||||
ChatProviderConfigSourceSupported ChatProviderConfigSource = "supported"
|
||||
)
|
||||
|
||||
// ChatProviderConfig is an admin-managed provider configuration.
|
||||
type ChatProviderConfig struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
Provider string `json:"provider"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
HasAPIKey bool `json:"has_api_key"`
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
Source ChatProviderConfigSource `json:"source"`
|
||||
CreatedAt time.Time `json:"created_at,omitempty" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty" format:"date-time"`
|
||||
}
|
||||
|
||||
// CreateChatProviderConfigRequest creates a chat provider config.
|
||||
type CreateChatProviderConfigRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateChatProviderConfigRequest updates a chat provider config.
|
||||
type UpdateChatProviderConfigRequest struct {
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
APIKey *string `json:"api_key,omitempty"`
|
||||
BaseURL *string `json:"base_url,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelConfig is an admin-managed model configuration.
|
||||
type ChatModelConfig struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
ContextLimit int64 `json:"context_limit"`
|
||||
CompressionThreshold int32 `json:"compression_threshold"`
|
||||
ModelConfig *ChatModelCallConfig `json:"model_config,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
|
||||
}
|
||||
|
||||
// ChatModelProviderOptions contains typed provider-specific options.
|
||||
//
|
||||
// Note: Azure models use the `openai` options shape.
|
||||
// Note: Bedrock models use the `anthropic` options shape.
|
||||
type ChatModelProviderOptions struct {
|
||||
OpenAI *ChatModelOpenAIProviderOptions `json:"openai,omitempty"`
|
||||
Anthropic *ChatModelAnthropicProviderOptions `json:"anthropic,omitempty"`
|
||||
Google *ChatModelGoogleProviderOptions `json:"google,omitempty"`
|
||||
OpenAICompat *ChatModelOpenAICompatProviderOptions `json:"openaicompat,omitempty"`
|
||||
OpenRouter *ChatModelOpenRouterProviderOptions `json:"openrouter,omitempty"`
|
||||
Vercel *ChatModelVercelProviderOptions `json:"vercel,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelOpenAIProviderOptions configures OpenAI provider behavior.
|
||||
type ChatModelOpenAIProviderOptions struct {
|
||||
Include []string `json:"include,omitempty"`
|
||||
Instructions *string `json:"instructions,omitempty"`
|
||||
LogitBias map[string]int64 `json:"logit_bias,omitempty"`
|
||||
LogProbs *bool `json:"log_probs,omitempty"`
|
||||
TopLogProbs *int64 `json:"top_log_probs,omitempty"`
|
||||
MaxToolCalls *int64 `json:"max_tool_calls,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
User *string `json:"user,omitempty"`
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
ReasoningSummary *string `json:"reasoning_summary,omitempty"`
|
||||
MaxCompletionTokens *int64 `json:"max_completion_tokens,omitempty"`
|
||||
TextVerbosity *string `json:"text_verbosity,omitempty"`
|
||||
Prediction map[string]any `json:"prediction,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
PromptCacheKey *string `json:"prompt_cache_key,omitempty"`
|
||||
SafetyIdentifier *string `json:"safety_identifier,omitempty"`
|
||||
ServiceTier *string `json:"service_tier,omitempty"`
|
||||
StructuredOutputs *bool `json:"structured_outputs,omitempty"`
|
||||
StrictJSONSchema *bool `json:"strict_json_schema,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelAnthropicThinkingOptions configures Anthropic thinking budget.
|
||||
type ChatModelAnthropicThinkingOptions struct {
|
||||
BudgetTokens *int64 `json:"budget_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelAnthropicProviderOptions configures Anthropic provider behavior.
|
||||
type ChatModelAnthropicProviderOptions struct {
|
||||
SendReasoning *bool `json:"send_reasoning,omitempty"`
|
||||
Thinking *ChatModelAnthropicThinkingOptions `json:"thinking,omitempty"`
|
||||
Effort *string `json:"effort,omitempty"`
|
||||
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelGoogleThinkingConfig configures Google thinking behavior.
|
||||
type ChatModelGoogleThinkingConfig struct {
|
||||
ThinkingBudget *int64 `json:"thinking_budget,omitempty"`
|
||||
IncludeThoughts *bool `json:"include_thoughts,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelGoogleSafetySetting configures Google safety filtering.
|
||||
type ChatModelGoogleSafetySetting struct {
|
||||
Category string `json:"category,omitempty"`
|
||||
Threshold string `json:"threshold,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelGoogleProviderOptions configures Google provider behavior.
|
||||
type ChatModelGoogleProviderOptions struct {
|
||||
ThinkingConfig *ChatModelGoogleThinkingConfig `json:"thinking_config,omitempty"`
|
||||
CachedContent string `json:"cached_content,omitempty"`
|
||||
SafetySettings []ChatModelGoogleSafetySetting `json:"safety_settings,omitempty"`
|
||||
Threshold string `json:"threshold,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelOpenAICompatProviderOptions configures OpenAI-compatible behavior.
|
||||
type ChatModelOpenAICompatProviderOptions struct {
|
||||
User *string `json:"user,omitempty"`
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelOpenRouterReasoningOptions configures OpenRouter reasoning behavior.
|
||||
type ChatModelOpenRouterReasoningOptions struct {
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
Exclude *bool `json:"exclude,omitempty"`
|
||||
MaxTokens *int64 `json:"max_tokens,omitempty"`
|
||||
Effort *string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelOpenRouterProvider configures OpenRouter routing preferences.
|
||||
type ChatModelOpenRouterProvider struct {
|
||||
Order []string `json:"order,omitempty"`
|
||||
AllowFallbacks *bool `json:"allow_fallbacks,omitempty"`
|
||||
RequireParameters *bool `json:"require_parameters,omitempty"`
|
||||
DataCollection *string `json:"data_collection,omitempty"`
|
||||
Only []string `json:"only,omitempty"`
|
||||
Ignore []string `json:"ignore,omitempty"`
|
||||
Quantizations []string `json:"quantizations,omitempty"`
|
||||
Sort *string `json:"sort,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelOpenRouterProviderOptions configures OpenRouter provider behavior.
|
||||
type ChatModelOpenRouterProviderOptions struct {
|
||||
Reasoning *ChatModelOpenRouterReasoningOptions `json:"reasoning,omitempty"`
|
||||
ExtraBody map[string]any `json:"extra_body,omitempty"`
|
||||
IncludeUsage *bool `json:"include_usage,omitempty"`
|
||||
LogitBias map[string]int64 `json:"logit_bias,omitempty"`
|
||||
LogProbs *bool `json:"log_probs,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
User *string `json:"user,omitempty"`
|
||||
Provider *ChatModelOpenRouterProvider `json:"provider,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelVercelReasoningOptions configures Vercel reasoning behavior.
|
||||
type ChatModelVercelReasoningOptions struct {
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
MaxTokens *int64 `json:"max_tokens,omitempty"`
|
||||
Effort *string `json:"effort,omitempty"`
|
||||
Exclude *bool `json:"exclude,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelVercelGatewayProviderOptions configures Vercel routing behavior.
|
||||
type ChatModelVercelGatewayProviderOptions struct {
|
||||
Order []string `json:"order,omitempty"`
|
||||
Models []string `json:"models,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelVercelProviderOptions configures Vercel provider behavior.
|
||||
type ChatModelVercelProviderOptions struct {
|
||||
Reasoning *ChatModelVercelReasoningOptions `json:"reasoning,omitempty"`
|
||||
ProviderOptions *ChatModelVercelGatewayProviderOptions `json:"providerOptions,omitempty"`
|
||||
User *string `json:"user,omitempty"`
|
||||
LogitBias map[string]int64 `json:"logit_bias,omitempty"`
|
||||
LogProbs *bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs *int64 `json:"top_logprobs,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
ExtraBody map[string]any `json:"extra_body,omitempty"`
|
||||
}
|
||||
|
||||
// ChatModelCallConfig configures per-call model behavior defaults.
|
||||
type ChatModelCallConfig struct {
|
||||
MaxOutputTokens *int64 `json:"max_output_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int64 `json:"top_k,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
ProviderOptions *ChatModelProviderOptions `json:"provider_options,omitempty"`
|
||||
}
|
||||
|
||||
// CreateChatModelConfigRequest creates a chat model config.
|
||||
type CreateChatModelConfigRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
ContextLimit *int64 `json:"context_limit,omitempty"`
|
||||
CompressionThreshold *int32 `json:"compression_threshold,omitempty"`
|
||||
ModelConfig *ChatModelCallConfig `json:"model_config,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateChatModelConfigRequest updates a chat model config.
|
||||
type UpdateChatModelConfigRequest struct {
|
||||
Provider string `json:"provider,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
ContextLimit *int64 `json:"context_limit,omitempty"`
|
||||
CompressionThreshold *int32 `json:"compression_threshold,omitempty"`
|
||||
ModelConfig *ChatModelCallConfig `json:"model_config,omitempty"`
|
||||
}
|
||||
|
||||
// ChatGitChange represents a git file change detected during a chat session.
|
||||
type ChatGitChange struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
|
||||
FilePath string `json:"file_path"`
|
||||
ChangeType string `json:"change_type"` // added, modified, deleted, renamed
|
||||
OldPath *string `json:"old_path,omitempty"`
|
||||
DiffSummary *string `json:"diff_summary,omitempty"`
|
||||
DetectedAt time.Time `json:"detected_at" format:"date-time"`
|
||||
}
|
||||
|
||||
// ChatDiffStatus represents cached diff status for a chat. The URL
|
||||
// may point to a pull request or a branch page depending on whether
|
||||
// a PR has been opened.
|
||||
type ChatDiffStatus struct {
|
||||
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
|
||||
URL *string `json:"url,omitempty"`
|
||||
PullRequestState *string `json:"pull_request_state,omitempty"`
|
||||
ChangesRequested bool `json:"changes_requested"`
|
||||
Additions int32 `json:"additions"`
|
||||
Deletions int32 `json:"deletions"`
|
||||
ChangedFiles int32 `json:"changed_files"`
|
||||
RefreshedAt *time.Time `json:"refreshed_at,omitempty" format:"date-time"`
|
||||
StaleAt *time.Time `json:"stale_at,omitempty" format:"date-time"`
|
||||
}
|
||||
|
||||
// ChatDiffContents represents the resolved diff text for a chat.
|
||||
type ChatDiffContents struct {
|
||||
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
|
||||
Provider *string `json:"provider,omitempty"`
|
||||
RemoteOrigin *string `json:"remote_origin,omitempty"`
|
||||
Branch *string `json:"branch,omitempty"`
|
||||
PullRequestURL *string `json:"pull_request_url,omitempty"`
|
||||
Diff string `json:"diff,omitempty"`
|
||||
}
|
||||
|
||||
// ChatStreamEventType represents the kind of chat stream update.
|
||||
type ChatStreamEventType string
|
||||
|
||||
const (
|
||||
ChatStreamEventTypeMessagePart ChatStreamEventType = "message_part"
|
||||
ChatStreamEventTypeMessage ChatStreamEventType = "message"
|
||||
ChatStreamEventTypeStatus ChatStreamEventType = "status"
|
||||
ChatStreamEventTypeError ChatStreamEventType = "error"
|
||||
ChatStreamEventTypeQueueUpdate ChatStreamEventType = "queue_update"
|
||||
)
|
||||
|
||||
// ChatQueuedMessage represents a queued message waiting to be processed.
|
||||
type ChatQueuedMessage struct {
|
||||
ID int64 `json:"id"`
|
||||
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
}
|
||||
|
||||
// ChatStreamMessagePart is a streamed message part update.
|
||||
type ChatStreamMessagePart struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Part ChatMessagePart `json:"part"`
|
||||
}
|
||||
|
||||
// ChatStreamStatus represents an updated chat status.
|
||||
type ChatStreamStatus struct {
|
||||
Status ChatStatus `json:"status"`
|
||||
}
|
||||
|
||||
// ChatStreamError represents an error event in the stream.
|
||||
type ChatStreamError struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ChatStreamEvent represents a real-time update for chat streaming.
|
||||
type ChatStreamEvent struct {
|
||||
Type ChatStreamEventType `json:"type"`
|
||||
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
|
||||
Message *ChatMessage `json:"message,omitempty"`
|
||||
MessagePart *ChatStreamMessagePart `json:"message_part,omitempty"`
|
||||
Status *ChatStreamStatus `json:"status,omitempty"`
|
||||
Error *ChatStreamError `json:"error,omitempty"`
|
||||
QueuedMessages []ChatQueuedMessage `json:"queued_messages,omitempty"`
|
||||
}
|
||||
|
||||
// ListChats returns all chats for the authenticated user.
|
||||
func (c *Client) ListChats(ctx context.Context) ([]Chat, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/v2/chats", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, ReadBodyAsError(res)
|
||||
}
|
||||
var chats []Chat
|
||||
return chats, json.NewDecoder(res.Body).Decode(&chats)
|
||||
}
|
||||
|
||||
// ListChatModels returns the available chat model catalog.
|
||||
func (c *Client) ListChatModels(ctx context.Context) (ChatModelsResponse, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/v2/chats/models", nil)
|
||||
if err != nil {
|
||||
return ChatModelsResponse{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatModelsResponse{}, ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var catalog ChatModelsResponse
|
||||
return catalog, json.NewDecoder(res.Body).Decode(&catalog)
|
||||
}
|
||||
|
||||
// ListChatProviders returns admin-managed chat provider configs.
|
||||
func (c *Client) ListChatProviders(ctx context.Context) ([]ChatProviderConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/v2/chats/providers", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var providers []ChatProviderConfig
|
||||
return providers, json.NewDecoder(res.Body).Decode(&providers)
|
||||
}
|
||||
|
||||
// CreateChatProvider creates an admin-managed chat provider config.
|
||||
func (c *Client) CreateChatProvider(ctx context.Context, req CreateChatProviderConfigRequest) (ChatProviderConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodPost, "/api/v2/chats/providers", req)
|
||||
if err != nil {
|
||||
return ChatProviderConfig{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusCreated {
|
||||
return ChatProviderConfig{}, ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var provider ChatProviderConfig
|
||||
return provider, json.NewDecoder(res.Body).Decode(&provider)
|
||||
}
|
||||
|
||||
// UpdateChatProvider updates an admin-managed chat provider config.
|
||||
func (c *Client) UpdateChatProvider(ctx context.Context, providerID uuid.UUID, req UpdateChatProviderConfigRequest) (ChatProviderConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodPatch, fmt.Sprintf("/api/v2/chats/providers/%s", providerID), req)
|
||||
if err != nil {
|
||||
return ChatProviderConfig{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatProviderConfig{}, ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var provider ChatProviderConfig
|
||||
return provider, json.NewDecoder(res.Body).Decode(&provider)
|
||||
}
|
||||
|
||||
// DeleteChatProvider deletes an admin-managed chat provider config.
|
||||
func (c *Client) DeleteChatProvider(ctx context.Context, providerID uuid.UUID) error {
|
||||
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/v2/chats/providers/%s", providerID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
return ReadBodyAsError(res)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListChatModelConfigs returns admin-managed chat model configs.
|
||||
func (c *Client) ListChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/v2/chats/model-configs", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var configs []ChatModelConfig
|
||||
return configs, json.NewDecoder(res.Body).Decode(&configs)
|
||||
}
|
||||
|
||||
// CreateChatModelConfig creates an admin-managed chat model config.
|
||||
func (c *Client) CreateChatModelConfig(ctx context.Context, req CreateChatModelConfigRequest) (ChatModelConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodPost, "/api/v2/chats/model-configs", req)
|
||||
if err != nil {
|
||||
return ChatModelConfig{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusCreated {
|
||||
return ChatModelConfig{}, ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var config ChatModelConfig
|
||||
return config, json.NewDecoder(res.Body).Decode(&config)
|
||||
}
|
||||
|
||||
// UpdateChatModelConfig updates an admin-managed chat model config.
|
||||
func (c *Client) UpdateChatModelConfig(ctx context.Context, modelConfigID uuid.UUID, req UpdateChatModelConfigRequest) (ChatModelConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodPatch, fmt.Sprintf("/api/v2/chats/model-configs/%s", modelConfigID), req)
|
||||
if err != nil {
|
||||
return ChatModelConfig{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatModelConfig{}, ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var config ChatModelConfig
|
||||
return config, json.NewDecoder(res.Body).Decode(&config)
|
||||
}
|
||||
|
||||
// DeleteChatModelConfig deletes an admin-managed chat model config.
|
||||
func (c *Client) DeleteChatModelConfig(ctx context.Context, modelConfigID uuid.UUID) error {
|
||||
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/v2/chats/model-configs/%s", modelConfigID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
return ReadBodyAsError(res)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateChat creates a new chat.
|
||||
func (c *Client) CreateChat(ctx context.Context, req CreateChatRequest) (Chat, error) {
|
||||
res, err := c.Request(ctx, http.MethodPost, "/api/v2/chats", req)
|
||||
if err != nil {
|
||||
return Chat{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusCreated {
|
||||
return Chat{}, ReadBodyAsError(res)
|
||||
}
|
||||
var chat Chat
|
||||
return chat, json.NewDecoder(res.Body).Decode(&chat)
|
||||
}
|
||||
|
||||
// GetChat returns a chat by ID, including its messages.
|
||||
func (c *Client) GetChat(ctx context.Context, chatID uuid.UUID) (ChatWithMessages, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/chats/%s", chatID), nil)
|
||||
if err != nil {
|
||||
return ChatWithMessages{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatWithMessages{}, ReadBodyAsError(res)
|
||||
}
|
||||
var chat ChatWithMessages
|
||||
return chat, json.NewDecoder(res.Body).Decode(&chat)
|
||||
}
|
||||
|
||||
// DeleteChat deletes a chat by ID.
|
||||
func (c *Client) DeleteChat(ctx context.Context, chatID uuid.UUID) error {
|
||||
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/v2/chats/%s", chatID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
return ReadBodyAsError(res)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateChatMessage adds a message to a chat.
|
||||
func (c *Client) CreateChatMessage(ctx context.Context, chatID uuid.UUID, req CreateChatMessageRequest) (CreateChatMessageResponse, error) {
|
||||
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/chats/%s/messages", chatID), req)
|
||||
if err != nil {
|
||||
return CreateChatMessageResponse{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return CreateChatMessageResponse{}, ReadBodyAsError(res)
|
||||
}
|
||||
var resp CreateChatMessageResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// InterruptChat cancels an in-flight chat run and leaves it waiting.
|
||||
func (c *Client) InterruptChat(ctx context.Context, chatID uuid.UUID) (Chat, error) {
|
||||
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/chats/%s/interrupt", chatID), nil)
|
||||
if err != nil {
|
||||
return Chat{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return Chat{}, ReadBodyAsError(res)
|
||||
}
|
||||
var chat Chat
|
||||
return chat, json.NewDecoder(res.Body).Decode(&chat)
|
||||
}
|
||||
|
||||
// GetChatGitChanges returns git changes for a chat.
|
||||
func (c *Client) GetChatGitChanges(ctx context.Context, chatID uuid.UUID) ([]ChatGitChange, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/chats/%s/git-changes", chatID), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, ReadBodyAsError(res)
|
||||
}
|
||||
var changes []ChatGitChange
|
||||
return changes, json.NewDecoder(res.Body).Decode(&changes)
|
||||
}
|
||||
|
||||
// GetChatDiffStatus returns cached GitHub pull request diff status for a chat.
|
||||
func (c *Client) GetChatDiffStatus(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/chats/%s/diff-status", chatID), nil)
|
||||
if err != nil {
|
||||
return ChatDiffStatus{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatDiffStatus{}, ReadBodyAsError(res)
|
||||
}
|
||||
var status ChatDiffStatus
|
||||
return status, json.NewDecoder(res.Body).Decode(&status)
|
||||
}
|
||||
|
||||
// GetChatDiffContents returns resolved diff contents for a chat.
|
||||
func (c *Client) GetChatDiffContents(ctx context.Context, chatID uuid.UUID) (ChatDiffContents, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/chats/%s/diff", chatID), nil)
|
||||
if err != nil {
|
||||
return ChatDiffContents{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatDiffContents{}, ReadBodyAsError(res)
|
||||
}
|
||||
var diff ChatDiffContents
|
||||
return diff, json.NewDecoder(res.Body).Decode(&diff)
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package codersdk_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestChatModelProviderOptions_MarshalJSON_UsesPlainProviderPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sendReasoning := true
|
||||
effort := "high"
|
||||
|
||||
raw, err := json.Marshal(codersdk.ChatModelProviderOptions{
|
||||
Anthropic: &codersdk.ChatModelAnthropicProviderOptions{
|
||||
SendReasoning: &sendReasoning,
|
||||
Effort: &effort,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(raw), `"type":"anthropic.options"`)
|
||||
require.NotContains(t, string(raw), `"data":`)
|
||||
require.Contains(t, string(raw), `"send_reasoning":true`)
|
||||
require.Contains(t, string(raw), `"effort":"high"`)
|
||||
}
|
||||
|
||||
func TestChatModelProviderOptions_UnmarshalJSON_ParsesPlainProviderPayloads(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
raw := []byte(`{
|
||||
"anthropic": {
|
||||
"send_reasoning": true,
|
||||
"effort": "high"
|
||||
}
|
||||
}`)
|
||||
|
||||
var decoded codersdk.ChatModelProviderOptions
|
||||
err := json.Unmarshal(raw, &decoded)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, decoded.Anthropic)
|
||||
require.NotNil(t, decoded.Anthropic.SendReasoning)
|
||||
require.True(t, *decoded.Anthropic.SendReasoning)
|
||||
require.NotNil(t, decoded.Anthropic.Effort)
|
||||
require.Equal(
|
||||
t,
|
||||
"high",
|
||||
*decoded.Anthropic.Effort,
|
||||
)
|
||||
}
|
||||
+227
-66
@@ -583,68 +583,69 @@ type DeploymentValues struct {
|
||||
DocsURL serpent.URL `json:"docs_url,omitempty"`
|
||||
RedirectToAccessURL serpent.Bool `json:"redirect_to_access_url,omitempty"`
|
||||
// HTTPAddress is a string because it may be set to zero to disable.
|
||||
HTTPAddress serpent.String `json:"http_address,omitempty" typescript:",notnull"`
|
||||
AutobuildPollInterval serpent.Duration `json:"autobuild_poll_interval,omitempty"`
|
||||
JobReaperDetectorInterval serpent.Duration `json:"job_hang_detector_interval,omitempty"`
|
||||
DERP DERP `json:"derp,omitempty" typescript:",notnull"`
|
||||
Prometheus PrometheusConfig `json:"prometheus,omitempty" typescript:",notnull"`
|
||||
Pprof PprofConfig `json:"pprof,omitempty" typescript:",notnull"`
|
||||
ProxyTrustedHeaders serpent.StringArray `json:"proxy_trusted_headers,omitempty" typescript:",notnull"`
|
||||
ProxyTrustedOrigins serpent.StringArray `json:"proxy_trusted_origins,omitempty" typescript:",notnull"`
|
||||
CacheDir serpent.String `json:"cache_directory,omitempty" typescript:",notnull"`
|
||||
EphemeralDeployment serpent.Bool `json:"ephemeral_deployment,omitempty" typescript:",notnull"`
|
||||
PostgresURL serpent.String `json:"pg_connection_url,omitempty" typescript:",notnull"`
|
||||
PostgresAuth string `json:"pg_auth,omitempty" typescript:",notnull"`
|
||||
PostgresConnMaxOpen serpent.Int64 `json:"pg_conn_max_open,omitempty" typescript:",notnull"`
|
||||
PostgresConnMaxIdle serpent.String `json:"pg_conn_max_idle,omitempty" typescript:",notnull"`
|
||||
OAuth2 OAuth2Config `json:"oauth2,omitempty" typescript:",notnull"`
|
||||
OIDC OIDCConfig `json:"oidc,omitempty" typescript:",notnull"`
|
||||
Telemetry TelemetryConfig `json:"telemetry,omitempty" typescript:",notnull"`
|
||||
TLS TLSConfig `json:"tls,omitempty" typescript:",notnull"`
|
||||
Trace TraceConfig `json:"trace,omitempty" typescript:",notnull"`
|
||||
HTTPCookies HTTPCookieConfig `json:"http_cookies,omitempty" typescript:",notnull"`
|
||||
StrictTransportSecurity serpent.Int64 `json:"strict_transport_security,omitempty" typescript:",notnull"`
|
||||
StrictTransportSecurityOptions serpent.StringArray `json:"strict_transport_security_options,omitempty" typescript:",notnull"`
|
||||
SSHKeygenAlgorithm serpent.String `json:"ssh_keygen_algorithm,omitempty" typescript:",notnull"`
|
||||
MetricsCacheRefreshInterval serpent.Duration `json:"metrics_cache_refresh_interval,omitempty" typescript:",notnull"`
|
||||
AgentStatRefreshInterval serpent.Duration `json:"agent_stat_refresh_interval,omitempty" typescript:",notnull"`
|
||||
AgentFallbackTroubleshootingURL serpent.URL `json:"agent_fallback_troubleshooting_url,omitempty" typescript:",notnull"`
|
||||
BrowserOnly serpent.Bool `json:"browser_only,omitempty" typescript:",notnull"`
|
||||
SCIMAPIKey serpent.String `json:"scim_api_key,omitempty" typescript:",notnull"`
|
||||
ExternalTokenEncryptionKeys serpent.StringArray `json:"external_token_encryption_keys,omitempty" typescript:",notnull"`
|
||||
Provisioner ProvisionerConfig `json:"provisioner,omitempty" typescript:",notnull"`
|
||||
RateLimit RateLimitConfig `json:"rate_limit,omitempty" typescript:",notnull"`
|
||||
Experiments serpent.StringArray `json:"experiments,omitempty" typescript:",notnull"`
|
||||
UpdateCheck serpent.Bool `json:"update_check,omitempty" typescript:",notnull"`
|
||||
Swagger SwaggerConfig `json:"swagger,omitempty" typescript:",notnull"`
|
||||
Logging LoggingConfig `json:"logging,omitempty" typescript:",notnull"`
|
||||
Dangerous DangerousConfig `json:"dangerous,omitempty" typescript:",notnull"`
|
||||
DisablePathApps serpent.Bool `json:"disable_path_apps,omitempty" typescript:",notnull"`
|
||||
Sessions SessionLifetime `json:"session_lifetime,omitempty" typescript:",notnull"`
|
||||
DisablePasswordAuth serpent.Bool `json:"disable_password_auth,omitempty" typescript:",notnull"`
|
||||
Support SupportConfig `json:"support,omitempty" typescript:",notnull"`
|
||||
EnableAuthzRecording serpent.Bool `json:"enable_authz_recording,omitempty" typescript:",notnull"`
|
||||
ExternalAuthConfigs serpent.Struct[[]ExternalAuthConfig] `json:"external_auth,omitempty" typescript:",notnull"`
|
||||
SSHConfig SSHConfig `json:"config_ssh,omitempty" typescript:",notnull"`
|
||||
WgtunnelHost serpent.String `json:"wgtunnel_host,omitempty" typescript:",notnull"`
|
||||
DisableOwnerWorkspaceExec serpent.Bool `json:"disable_owner_workspace_exec,omitempty" typescript:",notnull"`
|
||||
DisableWorkspaceSharing serpent.Bool `json:"disable_workspace_sharing,omitempty" typescript:",notnull"`
|
||||
ProxyHealthStatusInterval serpent.Duration `json:"proxy_health_status_interval,omitempty" typescript:",notnull"`
|
||||
EnableTerraformDebugMode serpent.Bool `json:"enable_terraform_debug_mode,omitempty" typescript:",notnull"`
|
||||
UserQuietHoursSchedule UserQuietHoursScheduleConfig `json:"user_quiet_hours_schedule,omitempty" typescript:",notnull"`
|
||||
WebTerminalRenderer serpent.String `json:"web_terminal_renderer,omitempty" typescript:",notnull"`
|
||||
AllowWorkspaceRenames serpent.Bool `json:"allow_workspace_renames,omitempty" typescript:",notnull"`
|
||||
Healthcheck HealthcheckConfig `json:"healthcheck,omitempty" typescript:",notnull"`
|
||||
Retention RetentionConfig `json:"retention,omitempty" typescript:",notnull"`
|
||||
CLIUpgradeMessage serpent.String `json:"cli_upgrade_message,omitempty" typescript:",notnull"`
|
||||
TermsOfServiceURL serpent.String `json:"terms_of_service_url,omitempty" typescript:",notnull"`
|
||||
Notifications NotificationsConfig `json:"notifications,omitempty" typescript:",notnull"`
|
||||
AdditionalCSPPolicy serpent.StringArray `json:"additional_csp_policy,omitempty" typescript:",notnull"`
|
||||
WorkspaceHostnameSuffix serpent.String `json:"workspace_hostname_suffix,omitempty" typescript:",notnull"`
|
||||
Prebuilds PrebuildsConfig `json:"workspace_prebuilds,omitempty" typescript:",notnull"`
|
||||
HideAITasks serpent.Bool `json:"hide_ai_tasks,omitempty" typescript:",notnull"`
|
||||
AI AIConfig `json:"ai,omitempty"`
|
||||
StatsCollection StatsCollectionConfig `json:"stats_collection,omitempty" typescript:",notnull"`
|
||||
HTTPAddress serpent.String `json:"http_address,omitempty" typescript:",notnull"`
|
||||
AutobuildPollInterval serpent.Duration `json:"autobuild_poll_interval,omitempty"`
|
||||
JobReaperDetectorInterval serpent.Duration `json:"job_hang_detector_interval,omitempty"`
|
||||
DERP DERP `json:"derp,omitempty" typescript:",notnull"`
|
||||
Prometheus PrometheusConfig `json:"prometheus,omitempty" typescript:",notnull"`
|
||||
Pprof PprofConfig `json:"pprof,omitempty" typescript:",notnull"`
|
||||
ProxyTrustedHeaders serpent.StringArray `json:"proxy_trusted_headers,omitempty" typescript:",notnull"`
|
||||
ProxyTrustedOrigins serpent.StringArray `json:"proxy_trusted_origins,omitempty" typescript:",notnull"`
|
||||
CacheDir serpent.String `json:"cache_directory,omitempty" typescript:",notnull"`
|
||||
EphemeralDeployment serpent.Bool `json:"ephemeral_deployment,omitempty" typescript:",notnull"`
|
||||
PostgresURL serpent.String `json:"pg_connection_url,omitempty" typescript:",notnull"`
|
||||
PostgresAuth string `json:"pg_auth,omitempty" typescript:",notnull"`
|
||||
PostgresConnMaxOpen serpent.Int64 `json:"pg_conn_max_open,omitempty" typescript:",notnull"`
|
||||
PostgresConnMaxIdle serpent.String `json:"pg_conn_max_idle,omitempty" typescript:",notnull"`
|
||||
OAuth2 OAuth2Config `json:"oauth2,omitempty" typescript:",notnull"`
|
||||
OIDC OIDCConfig `json:"oidc,omitempty" typescript:",notnull"`
|
||||
Telemetry TelemetryConfig `json:"telemetry,omitempty" typescript:",notnull"`
|
||||
TLS TLSConfig `json:"tls,omitempty" typescript:",notnull"`
|
||||
Trace TraceConfig `json:"trace,omitempty" typescript:",notnull"`
|
||||
HTTPCookies HTTPCookieConfig `json:"http_cookies,omitempty" typescript:",notnull"`
|
||||
StrictTransportSecurity serpent.Int64 `json:"strict_transport_security,omitempty" typescript:",notnull"`
|
||||
StrictTransportSecurityOptions serpent.StringArray `json:"strict_transport_security_options,omitempty" typescript:",notnull"`
|
||||
SSHKeygenAlgorithm serpent.String `json:"ssh_keygen_algorithm,omitempty" typescript:",notnull"`
|
||||
MetricsCacheRefreshInterval serpent.Duration `json:"metrics_cache_refresh_interval,omitempty" typescript:",notnull"`
|
||||
AgentStatRefreshInterval serpent.Duration `json:"agent_stat_refresh_interval,omitempty" typescript:",notnull"`
|
||||
AgentFallbackTroubleshootingURL serpent.URL `json:"agent_fallback_troubleshooting_url,omitempty" typescript:",notnull"`
|
||||
BrowserOnly serpent.Bool `json:"browser_only,omitempty" typescript:",notnull"`
|
||||
SCIMAPIKey serpent.String `json:"scim_api_key,omitempty" typescript:",notnull"`
|
||||
ExternalTokenEncryptionKeys serpent.StringArray `json:"external_token_encryption_keys,omitempty" typescript:",notnull"`
|
||||
Provisioner ProvisionerConfig `json:"provisioner,omitempty" typescript:",notnull"`
|
||||
RateLimit RateLimitConfig `json:"rate_limit,omitempty" typescript:",notnull"`
|
||||
Experiments serpent.StringArray `json:"experiments,omitempty" typescript:",notnull"`
|
||||
UpdateCheck serpent.Bool `json:"update_check,omitempty" typescript:",notnull"`
|
||||
Swagger SwaggerConfig `json:"swagger,omitempty" typescript:",notnull"`
|
||||
Logging LoggingConfig `json:"logging,omitempty" typescript:",notnull"`
|
||||
Dangerous DangerousConfig `json:"dangerous,omitempty" typescript:",notnull"`
|
||||
DisablePathApps serpent.Bool `json:"disable_path_apps,omitempty" typescript:",notnull"`
|
||||
Sessions SessionLifetime `json:"session_lifetime,omitempty" typescript:",notnull"`
|
||||
DisablePasswordAuth serpent.Bool `json:"disable_password_auth,omitempty" typescript:",notnull"`
|
||||
Support SupportConfig `json:"support,omitempty" typescript:",notnull"`
|
||||
EnableAuthzRecording serpent.Bool `json:"enable_authz_recording,omitempty" typescript:",notnull"`
|
||||
ExternalAuthConfigs serpent.Struct[[]ExternalAuthConfig] `json:"external_auth,omitempty" typescript:",notnull"`
|
||||
ExternalAuthGithubDefaultProviderEnable serpent.Bool `json:"external_auth_github_default_provider_enable,omitempty" typescript:",notnull"`
|
||||
SSHConfig SSHConfig `json:"config_ssh,omitempty" typescript:",notnull"`
|
||||
WgtunnelHost serpent.String `json:"wgtunnel_host,omitempty" typescript:",notnull"`
|
||||
DisableOwnerWorkspaceExec serpent.Bool `json:"disable_owner_workspace_exec,omitempty" typescript:",notnull"`
|
||||
DisableWorkspaceSharing serpent.Bool `json:"disable_workspace_sharing,omitempty" typescript:",notnull"`
|
||||
ProxyHealthStatusInterval serpent.Duration `json:"proxy_health_status_interval,omitempty" typescript:",notnull"`
|
||||
EnableTerraformDebugMode serpent.Bool `json:"enable_terraform_debug_mode,omitempty" typescript:",notnull"`
|
||||
UserQuietHoursSchedule UserQuietHoursScheduleConfig `json:"user_quiet_hours_schedule,omitempty" typescript:",notnull"`
|
||||
WebTerminalRenderer serpent.String `json:"web_terminal_renderer,omitempty" typescript:",notnull"`
|
||||
AllowWorkspaceRenames serpent.Bool `json:"allow_workspace_renames,omitempty" typescript:",notnull"`
|
||||
Healthcheck HealthcheckConfig `json:"healthcheck,omitempty" typescript:",notnull"`
|
||||
Retention RetentionConfig `json:"retention,omitempty" typescript:",notnull"`
|
||||
CLIUpgradeMessage serpent.String `json:"cli_upgrade_message,omitempty" typescript:",notnull"`
|
||||
TermsOfServiceURL serpent.String `json:"terms_of_service_url,omitempty" typescript:",notnull"`
|
||||
Notifications NotificationsConfig `json:"notifications,omitempty" typescript:",notnull"`
|
||||
AdditionalCSPPolicy serpent.StringArray `json:"additional_csp_policy,omitempty" typescript:",notnull"`
|
||||
WorkspaceHostnameSuffix serpent.String `json:"workspace_hostname_suffix,omitempty" typescript:",notnull"`
|
||||
Prebuilds PrebuildsConfig `json:"workspace_prebuilds,omitempty" typescript:",notnull"`
|
||||
HideAITasks serpent.Bool `json:"hide_ai_tasks,omitempty" typescript:",notnull"`
|
||||
AI AIConfig `json:"ai,omitempty"`
|
||||
StatsCollection StatsCollectionConfig `json:"stats_collection,omitempty" typescript:",notnull"`
|
||||
|
||||
Config serpent.YAMLConfigPath `json:"config,omitempty" typescript:",notnull"`
|
||||
WriteConfig serpent.Bool `json:"write_config,omitempty" typescript:",notnull"`
|
||||
@@ -3043,6 +3044,15 @@ Write out the current server config as YAML to stdout.`,
|
||||
Value: &c.ExternalAuthConfigs,
|
||||
Hidden: true,
|
||||
},
|
||||
{
|
||||
Name: "External Auth GitHub Default Provider Enable",
|
||||
Description: "Enable the default GitHub external auth provider managed by Coder.",
|
||||
Flag: "external-auth-github-default-provider-enable",
|
||||
Env: "CODER_EXTERNAL_AUTH_GITHUB_DEFAULT_PROVIDER_ENABLE",
|
||||
YAML: "externalAuthGithubDefaultProviderEnable",
|
||||
Value: &c.ExternalAuthGithubDefaultProviderEnable,
|
||||
Default: "true",
|
||||
},
|
||||
{
|
||||
Name: "Custom wgtunnel Host",
|
||||
Description: `Hostname of HTTPS server that runs https://github.com/coder/wgtunnel. By default, this will pick the best available wgtunnel server hosted by Coder. e.g. "tunnel.example.com".`,
|
||||
@@ -3473,6 +3483,106 @@ Write out the current server config as YAML to stdout.`,
|
||||
Group: &deploymentGroupClient,
|
||||
YAML: "hideAITasks",
|
||||
},
|
||||
{
|
||||
Name: "Chat System Prompt",
|
||||
Description: "Default system prompt inserted into new chats.",
|
||||
Flag: "chat-system-prompt",
|
||||
Env: "CODER_CHAT_SYSTEM_PROMPT",
|
||||
Value: &c.AI.Chat.SystemPrompt,
|
||||
Default: `You are the Coder agent — an interactive chat tool that helps users with software-engineering tasks inside of the Coder product.
|
||||
Use the instructions below and the tools available to you to assist User.
|
||||
|
||||
IMPORTANT — obey every rule in this prompt before anything else.
|
||||
Do EXACTLY what the User asked, never more, never less.
|
||||
|
||||
<behavior>
|
||||
You MUST execute AS MANY TOOLS to help the user accomplish their task.
|
||||
You are COMFORTABLE with vague tasks - using your tools to collect the most relevant answer possible.
|
||||
You ALWAYS use GitHub tools for ANY query related to source code.
|
||||
If a user asks how something works, no matter how vague, you MUST use your tools to collect the most relevant answer possible.
|
||||
DO NOT ask the user for clarification - just use your tools.
|
||||
</behavior>
|
||||
|
||||
<personality>
|
||||
Analytical — You break problems into measurable steps, relying on tool output and data rather than intuition.
|
||||
Organized — You structure every interaction with clear tags, TODO lists, and section boundaries.
|
||||
Precision-Oriented — You insist on exact formatting, package-manager choice, and rule adherence.
|
||||
Efficiency-Focused — You minimize chatter, run tasks in parallel, and favor small, complete answers.
|
||||
Clarity-Seeking — You ask for missing details instead of guessing, avoiding any ambiguity.
|
||||
</personality>
|
||||
|
||||
<communication>
|
||||
Be concise, direct, and to the point.
|
||||
NO emojis unless the User explicitly asks for them.
|
||||
If a task appears incomplete or ambiguous, **pause and ask the User** rather than guessing or marking "done".
|
||||
Prefer accuracy over reassurance; confirm facts with tool calls instead of assuming the User is right.
|
||||
If you face an architectural, tooling, or package-manager choice, **ask the User's preference first**.
|
||||
Default to the project's existing package manager / tooling; never substitute without confirmation.
|
||||
You MUST avoid text before/after your response, such as "The answer is" or "Short answer:", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
|
||||
Mimic the style of the User's messages.
|
||||
Do not remind the User you are happy to help.
|
||||
Do not inherently assume the User is correct; they may be making assumptions.
|
||||
If you are not confident in your answer, DO NOT provide an answer. Use your tools to collect more information, or ask the User for help.
|
||||
Do not act with sycophantic flattery or over-the-top enthusiasm.
|
||||
|
||||
Here are examples to demonstrate appropriate communication style and level of verbosity:
|
||||
|
||||
<example>
|
||||
user: find me a good issue to work on
|
||||
assistant: Issue [#1234](https://example) indicates a bug in the frontend, which you've contributed to in the past.
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: work on this issue <url>
|
||||
...assistant does work...
|
||||
assistant: I've put up this pull request: https://github.com/example/example/pull/1824. Please let me know your thoughts!
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: what is 2+2?
|
||||
assistant: 4
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: how does X work in <popular-repository-name>?
|
||||
assistant: Let me take a look at the code...
|
||||
[tool calls to investigate the repository]
|
||||
</example>
|
||||
</communication>
|
||||
|
||||
<collaboration>
|
||||
When a user asks for help with a task or there is ambiguity on the objective, always start by asking clarifying questions to understand:
|
||||
- What specific aspect they want to focus on
|
||||
- Their goals and vision for the changes
|
||||
- Their preferences for approach or style
|
||||
- What problems they're trying to solve
|
||||
|
||||
Don't assume what needs to be done - collaborate to define the scope together.
|
||||
</collaboration>`,
|
||||
Group: &deploymentGroupAIBridge,
|
||||
YAML: "chat_system_prompt",
|
||||
},
|
||||
{
|
||||
Name: "Chat Title Generation Prompt",
|
||||
Description: "Prompt used to generate chat titles from the first user message.",
|
||||
Flag: "chat-title-generation-prompt",
|
||||
Env: "CODER_CHAT_TITLE_GENERATION_PROMPT",
|
||||
Value: &c.AI.Chat.TitleGenerationPrompt,
|
||||
Default: "Generate a concise title (max 8 words) for the user's first message. " +
|
||||
"Return plain text only, with no surrounding quotes.",
|
||||
Group: &deploymentGroupAIBridge,
|
||||
YAML: "chat_title_generation_prompt",
|
||||
},
|
||||
{
|
||||
Name: "Agent Local Workspace",
|
||||
Description: "Enable admin-only local workspace mode for agent chats.",
|
||||
Flag: "agent-local-workspace",
|
||||
Env: "CODER_AGENT_LOCAL_WORKSPACE",
|
||||
Value: &c.AI.Chat.LocalWorkspace,
|
||||
Default: "false",
|
||||
Group: &deploymentGroupAIBridge,
|
||||
YAML: "agent_local_workspace",
|
||||
},
|
||||
|
||||
// AI Bridge Options
|
||||
{
|
||||
@@ -3495,6 +3605,16 @@ Write out the current server config as YAML to stdout.`,
|
||||
Group: &deploymentGroupAIBridge,
|
||||
YAML: "openai_base_url",
|
||||
},
|
||||
{
|
||||
Name: "Chat OpenAI Models URL",
|
||||
Description: "Override URL used to list OpenAI models for the chat model catalog.",
|
||||
Flag: "chat-openai-models-url",
|
||||
Env: "CODER_CHAT_OPENAI_MODELS_URL",
|
||||
Value: &c.AI.BridgeConfig.OpenAI.ModelsURL,
|
||||
Default: "",
|
||||
Group: &deploymentGroupAIBridge,
|
||||
YAML: "chat_openai_models_url",
|
||||
},
|
||||
{
|
||||
Name: "AI Bridge OpenAI Key",
|
||||
Description: "The key to authenticate against the OpenAI API.",
|
||||
@@ -3515,6 +3635,16 @@ Write out the current server config as YAML to stdout.`,
|
||||
Group: &deploymentGroupAIBridge,
|
||||
YAML: "anthropic_base_url",
|
||||
},
|
||||
{
|
||||
Name: "Chat Anthropic Models URL",
|
||||
Description: "Override URL used to list Anthropic models for the chat model catalog.",
|
||||
Flag: "chat-anthropic-models-url",
|
||||
Env: "CODER_CHAT_ANTHROPIC_MODELS_URL",
|
||||
Value: &c.AI.BridgeConfig.Anthropic.ModelsURL,
|
||||
Default: "",
|
||||
Group: &deploymentGroupAIBridge,
|
||||
YAML: "chat_anthropic_models_url",
|
||||
},
|
||||
{
|
||||
Name: "AI Bridge Anthropic Key",
|
||||
Description: "The key to authenticate against the Anthropic API.",
|
||||
@@ -3525,6 +3655,26 @@ Write out the current server config as YAML to stdout.`,
|
||||
Group: &deploymentGroupAIBridge,
|
||||
Annotations: serpent.Annotations{}.Mark(annotationSecretKey, "true"),
|
||||
},
|
||||
{
|
||||
Name: "Chat Models Allowlist",
|
||||
Description: "Comma-separated allowlist of models for the chat model catalog.",
|
||||
Flag: "chat-models-allowlist",
|
||||
Env: "CODER_CHAT_MODELS_ALLOWLIST",
|
||||
Value: &c.AI.BridgeConfig.ModelsAllowlist,
|
||||
Default: "",
|
||||
Group: &deploymentGroupAIBridge,
|
||||
YAML: "chat_models_allowlist",
|
||||
},
|
||||
{
|
||||
Name: "Chat Models Denylist",
|
||||
Description: "Comma-separated denylist of models for the chat model catalog.",
|
||||
Flag: "chat-models-denylist",
|
||||
Env: "CODER_CHAT_MODELS_DENYLIST",
|
||||
Value: &c.AI.BridgeConfig.ModelsDenylist,
|
||||
Default: "",
|
||||
Group: &deploymentGroupAIBridge,
|
||||
YAML: "chat_models_denylist",
|
||||
},
|
||||
{
|
||||
Name: "AI Bridge Bedrock Base URL",
|
||||
Description: "The base URL to use for the AWS Bedrock API. Use this setting to specify an exact URL to use. Takes precedence " +
|
||||
@@ -3857,6 +4007,8 @@ type AIBridgeConfig struct {
|
||||
Enabled serpent.Bool `json:"enabled" typescript:",notnull"`
|
||||
OpenAI AIBridgeOpenAIConfig `json:"openai" typescript:",notnull"`
|
||||
Anthropic AIBridgeAnthropicConfig `json:"anthropic" typescript:",notnull"`
|
||||
ModelsAllowlist serpent.String `json:"models_allowlist" typescript:",notnull"`
|
||||
ModelsDenylist serpent.String `json:"models_denylist" typescript:",notnull"`
|
||||
Bedrock AIBridgeBedrockConfig `json:"bedrock" typescript:",notnull"`
|
||||
InjectCoderMCPTools serpent.Bool `json:"inject_coder_mcp_tools" typescript:",notnull"`
|
||||
Retention serpent.Duration `json:"retention" typescript:",notnull"`
|
||||
@@ -3874,13 +4026,15 @@ type AIBridgeConfig struct {
|
||||
}
|
||||
|
||||
type AIBridgeOpenAIConfig struct {
|
||||
BaseURL serpent.String `json:"base_url" typescript:",notnull"`
|
||||
Key serpent.String `json:"key" typescript:",notnull"`
|
||||
BaseURL serpent.String `json:"base_url" typescript:",notnull"`
|
||||
ModelsURL serpent.String `json:"models_url" typescript:",notnull"`
|
||||
Key serpent.String `json:"key" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type AIBridgeAnthropicConfig struct {
|
||||
BaseURL serpent.String `json:"base_url" typescript:",notnull"`
|
||||
Key serpent.String `json:"key" typescript:",notnull"`
|
||||
BaseURL serpent.String `json:"base_url" typescript:",notnull"`
|
||||
ModelsURL serpent.String `json:"models_url" typescript:",notnull"`
|
||||
Key serpent.String `json:"key" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type AIBridgeBedrockConfig struct {
|
||||
@@ -3902,9 +4056,16 @@ type AIBridgeProxyConfig struct {
|
||||
UpstreamProxyCA serpent.String `json:"upstream_proxy_ca" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type AIChatConfig struct {
|
||||
SystemPrompt serpent.String `json:"system_prompt" typescript:",notnull"`
|
||||
TitleGenerationPrompt serpent.String `json:"title_generation_prompt" typescript:",notnull"`
|
||||
LocalWorkspace serpent.Bool `json:"local_workspace" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type AIConfig struct {
|
||||
BridgeConfig AIBridgeConfig `json:"bridge,omitempty"`
|
||||
BridgeProxyConfig AIBridgeProxyConfig `json:"aibridge_proxy,omitempty"`
|
||||
Chat AIChatConfig `json:"chat,omitempty"`
|
||||
}
|
||||
|
||||
type SupportConfig struct {
|
||||
|
||||
@@ -11,6 +11,7 @@ const (
|
||||
ResourceAssignRole RBACResource = "assign_role"
|
||||
ResourceAuditLog RBACResource = "audit_log"
|
||||
ResourceBoundaryUsage RBACResource = "boundary_usage"
|
||||
ResourceChat RBACResource = "chat"
|
||||
ResourceConnectionLog RBACResource = "connection_log"
|
||||
ResourceCryptoKey RBACResource = "crypto_key"
|
||||
ResourceDebugInfo RBACResource = "debug_info"
|
||||
@@ -81,6 +82,7 @@ var RBACResourceActions = map[RBACResource][]RBACAction{
|
||||
ResourceAssignRole: {ActionAssign, ActionRead, ActionUnassign},
|
||||
ResourceAuditLog: {ActionCreate, ActionRead},
|
||||
ResourceBoundaryUsage: {ActionDelete, ActionRead, ActionUpdate},
|
||||
ResourceChat: {ActionCreate, ActionDelete, ActionRead, ActionUpdate},
|
||||
ResourceConnectionLog: {ActionRead, ActionUpdate},
|
||||
ResourceCryptoKey: {ActionCreate, ActionDelete, ActionRead, ActionUpdate},
|
||||
ResourceDebugInfo: {ActionRead},
|
||||
|
||||
+2172
-2115
File diff suppressed because it is too large
Load Diff
Generated
+6
-5
@@ -253,11 +253,12 @@ curl -X GET http://coder-server:8080/api/v2/workspaceagents/me/external-auth?mat
|
||||
|
||||
### Parameters
|
||||
|
||||
| Name | In | Type | Required | Description |
|
||||
|----------|-------|---------|----------|-----------------------------------|
|
||||
| `match` | query | string | true | Match |
|
||||
| `id` | query | string | true | Provider ID |
|
||||
| `listen` | query | boolean | false | Wait for a new token to be issued |
|
||||
| Name | In | Type | Required | Description |
|
||||
|-----------|-------|---------|----------|------------------------------------------------|
|
||||
| `match` | query | string | true | Match |
|
||||
| `id` | query | string | true | Provider ID |
|
||||
| `listen` | query | boolean | false | Wait for a new token to be issued |
|
||||
| `workdir` | query | string | false | Working directory used for git context refresh |
|
||||
|
||||
### Example responses
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user