Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 801b467b75 | |||
| 192b25e30a | |||
| 8c7111fe4a |
@@ -98,15 +98,6 @@ linters-settings:
|
||||
# stdlib port.
|
||||
checks: ["all", "-SA1019"]
|
||||
|
||||
tagalign:
|
||||
align: true
|
||||
sort: true
|
||||
strict: true
|
||||
order:
|
||||
- json
|
||||
- db
|
||||
- table
|
||||
|
||||
goimports:
|
||||
local-prefixes: coder.com,cdr.dev,go.coder.com,github.com/cdr,github.com/coder
|
||||
|
||||
@@ -271,7 +262,6 @@ linters:
|
||||
# - wastedassign
|
||||
|
||||
- staticcheck
|
||||
- tagalign
|
||||
# In Go, it's possible for a package to test it's internal functionality
|
||||
# without testing any exported functions. This is enabled to promote
|
||||
# decomposing a package before testing it's internals. A function caller
|
||||
|
||||
@@ -136,10 +136,18 @@ endif
|
||||
# the search path so that these exclusions match.
|
||||
FIND_EXCLUSIONS= \
|
||||
-not \( \( -path '*/.git/*' -o -path './build/*' -o -path './vendor/*' -o -path './.coderv2/*' -o -path '*/node_modules/*' -o -path '*/out/*' -o -path './coderd/apidoc/*' -o -path '*/.next/*' -o -path '*/.terraform/*' -o -path './_gen/*' \) -prune \)
|
||||
|
||||
# Source files used for make targets, evaluated on use.
|
||||
GO_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.go' -not -name '*_test.go')
|
||||
|
||||
# Same as GO_SRC_FILES but excluding certain files that have problematic
|
||||
# Makefile dependencies (e.g. pnpm).
|
||||
MOST_GO_SRC_FILES := $(shell \
|
||||
find . \
|
||||
$(FIND_EXCLUSIONS) \
|
||||
-type f \
|
||||
-name '*.go' \
|
||||
-not -name '*_test.go' \
|
||||
-not -wholename './agent/agentcontainers/dcspec/dcspec_gen.go' \
|
||||
)
|
||||
# All the shell files in the repo, excluding ignored files.
|
||||
SHELL_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.sh')
|
||||
|
||||
@@ -506,10 +514,7 @@ install: build/coder_$(VERSION)_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
|
||||
cp "$<" "$$output_file"
|
||||
.PHONY: install
|
||||
|
||||
# Only wildcard the go files in the develop directory to avoid rebuilds
|
||||
# when project files are changd. Technically changes to some imports may
|
||||
# not be detected, but it's unlikely to cause any issues.
|
||||
build/.bin/develop: go.mod go.sum $(wildcard scripts/develop/*.go)
|
||||
build/.bin/develop: go.mod go.sum $(GO_SRC_FILES)
|
||||
CGO_ENABLED=0 go build -o $@ ./scripts/develop
|
||||
|
||||
BOLD := $(shell tput bold 2>/dev/null)
|
||||
|
||||
+1
-1
@@ -389,7 +389,7 @@ func (a *agent) init() {
|
||||
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
|
||||
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
|
||||
desktop := agentdesktop.NewPortableDesktop(
|
||||
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(),
|
||||
a.logger.Named("desktop"), a.execer, a.scriptDataDir,
|
||||
)
|
||||
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
|
||||
a.reconnectingPTYServer = reconnectingpty.NewServer(
|
||||
|
||||
@@ -2,9 +2,13 @@ package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -20,6 +24,28 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
const (
|
||||
portableDesktopVersion = "v0.0.4"
|
||||
downloadRetries = 3
|
||||
downloadRetryDelay = time.Second
|
||||
)
|
||||
|
||||
// platformBinaries maps GOARCH to download URL and expected SHA-256
|
||||
// digest for each supported platform.
|
||||
var platformBinaries = map[string]struct {
|
||||
URL string
|
||||
SHA256 string
|
||||
}{
|
||||
"amd64": {
|
||||
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-x64",
|
||||
SHA256: "a04e05e6c7d6f2e6b3acbf1729a7b21271276300b4fee321f4ffee6136538317",
|
||||
},
|
||||
"arm64": {
|
||||
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-arm64",
|
||||
SHA256: "b8cb9142dc32d46a608f25229cbe8168ff2a3aadc54253c74ff54cd347e16ca6",
|
||||
},
|
||||
}
|
||||
|
||||
// portableDesktopOutput is the JSON output from
|
||||
// `portabledesktop up --json`.
|
||||
type portableDesktopOutput struct {
|
||||
@@ -52,31 +78,43 @@ type screenshotOutput struct {
|
||||
// portableDesktop implements Desktop by shelling out to the
|
||||
// portabledesktop CLI via agentexec.Execer.
|
||||
type portableDesktop struct {
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
scriptBinDir string // coder script bin directory
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
dataDir string // agent's ScriptDataDir, used for binary caching
|
||||
|
||||
mu sync.Mutex
|
||||
session *desktopSession // nil until started
|
||||
binPath string // resolved path to binary, cached
|
||||
closed bool
|
||||
|
||||
// httpClient is used for downloading the binary. If nil,
|
||||
// http.DefaultClient is used.
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewPortableDesktop creates a Desktop backed by the portabledesktop
|
||||
// CLI binary, using execer to spawn child processes. scriptBinDir is
|
||||
// the coder script bin directory checked for the binary.
|
||||
// CLI binary, using execer to spawn child processes. dataDir is used
|
||||
// to cache the downloaded binary.
|
||||
func NewPortableDesktop(
|
||||
logger slog.Logger,
|
||||
execer agentexec.Execer,
|
||||
scriptBinDir string,
|
||||
dataDir string,
|
||||
) Desktop {
|
||||
return &portableDesktop{
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
scriptBinDir: scriptBinDir,
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
dataDir: dataDir,
|
||||
}
|
||||
}
|
||||
|
||||
// httpDo returns the HTTP client to use for downloads.
|
||||
func (p *portableDesktop) httpDo() *http.Client {
|
||||
if p.httpClient != nil {
|
||||
return p.httpClient
|
||||
}
|
||||
return http.DefaultClient
|
||||
}
|
||||
|
||||
// Start launches the desktop session (idempotent).
|
||||
func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
|
||||
p.mu.Lock()
|
||||
@@ -361,8 +399,8 @@ func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, e
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
// ensureBinary resolves the portabledesktop binary from PATH or the
|
||||
// coder script bin directory. It must be called while p.mu is held.
|
||||
// ensureBinary resolves or downloads the portabledesktop binary. It
|
||||
// must be called while p.mu is held.
|
||||
func (p *portableDesktop) ensureBinary(ctx context.Context) error {
|
||||
if p.binPath != "" {
|
||||
return nil
|
||||
@@ -377,23 +415,130 @@ func (p *portableDesktop) ensureBinary(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 2. Check the coder script bin directory.
|
||||
scriptBinPath := filepath.Join(p.scriptBinDir, "portabledesktop")
|
||||
if info, err := os.Stat(scriptBinPath); err == nil && !info.IsDir() {
|
||||
// On Windows, permission bits don't indicate executability,
|
||||
// so accept any regular file.
|
||||
if runtime.GOOS == "windows" || info.Mode()&0o111 != 0 {
|
||||
p.logger.Info(ctx, "found portabledesktop in script bin directory",
|
||||
slog.F("path", scriptBinPath),
|
||||
// 2. Platform checks.
|
||||
if runtime.GOOS != "linux" {
|
||||
return xerrors.New("portabledesktop is only supported on Linux")
|
||||
}
|
||||
bin, ok := platformBinaries[runtime.GOARCH]
|
||||
if !ok {
|
||||
return xerrors.Errorf("unsupported architecture for portabledesktop: %s", runtime.GOARCH)
|
||||
}
|
||||
|
||||
// 3. Check cache.
|
||||
cacheDir := filepath.Join(p.dataDir, "portabledesktop", bin.SHA256)
|
||||
cachedPath := filepath.Join(cacheDir, "portabledesktop")
|
||||
|
||||
if info, err := os.Stat(cachedPath); err == nil && !info.IsDir() {
|
||||
// Verify it is executable.
|
||||
if info.Mode()&0o100 != 0 {
|
||||
p.logger.Info(ctx, "using cached portabledesktop binary",
|
||||
slog.F("path", cachedPath),
|
||||
)
|
||||
p.binPath = scriptBinPath
|
||||
p.binPath = cachedPath
|
||||
return nil
|
||||
}
|
||||
p.logger.Warn(ctx, "portabledesktop found in script bin directory but not executable",
|
||||
slog.F("path", scriptBinPath),
|
||||
slog.F("mode", info.Mode().String()),
|
||||
}
|
||||
|
||||
// 4. Download with retry.
|
||||
p.logger.Info(ctx, "downloading portabledesktop binary",
|
||||
slog.F("url", bin.URL),
|
||||
slog.F("version", portableDesktopVersion),
|
||||
slog.F("arch", runtime.GOARCH),
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
for attempt := range downloadRetries {
|
||||
if err := downloadBinary(ctx, p.httpDo(), bin.URL, bin.SHA256, cachedPath); err != nil {
|
||||
lastErr = err
|
||||
p.logger.Warn(ctx, "download attempt failed",
|
||||
slog.F("attempt", attempt+1),
|
||||
slog.F("max_attempts", downloadRetries),
|
||||
slog.Error(err),
|
||||
)
|
||||
if attempt < downloadRetries-1 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(downloadRetryDelay):
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
p.binPath = cachedPath
|
||||
p.logger.Info(ctx, "downloaded portabledesktop binary",
|
||||
slog.F("path", cachedPath),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
return xerrors.Errorf("download portabledesktop after %d attempts: %w", downloadRetries, lastErr)
|
||||
}
|
||||
|
||||
// downloadBinary fetches a binary from url, verifies its SHA-256
|
||||
// digest matches expectedSHA256, and atomically writes it to destPath.
|
||||
func downloadBinary(ctx context.Context, client *http.Client, url, expectedSHA256, destPath string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(destPath), 0o700); err != nil {
|
||||
return xerrors.Errorf("create cache directory: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create HTTP request: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("HTTP GET %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return xerrors.Errorf("HTTP GET %s: status %d", url, resp.StatusCode)
|
||||
}
|
||||
|
||||
// Write to a temp file in the same directory so the final rename
|
||||
// is atomic on the same filesystem.
|
||||
tmpFile, err := os.CreateTemp(filepath.Dir(destPath), "portabledesktop-download-*")
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
// Clean up the temp file on any error path.
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
}()
|
||||
|
||||
// Stream the response body while computing SHA-256.
|
||||
hasher := sha256.New()
|
||||
if _, err := io.Copy(tmpFile, io.TeeReader(resp.Body, hasher)); err != nil {
|
||||
return xerrors.Errorf("download body: %w", err)
|
||||
}
|
||||
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return xerrors.Errorf("close temp file: %w", err)
|
||||
}
|
||||
|
||||
// Verify digest.
|
||||
actualSHA256 := hex.EncodeToString(hasher.Sum(nil))
|
||||
if actualSHA256 != expectedSHA256 {
|
||||
return xerrors.Errorf(
|
||||
"SHA-256 mismatch: expected %s, got %s",
|
||||
expectedSHA256, actualSHA256,
|
||||
)
|
||||
}
|
||||
|
||||
return xerrors.New("portabledesktop binary not found in PATH or script bin directory")
|
||||
if err := os.Chmod(tmpPath, 0o700); err != nil {
|
||||
return xerrors.Errorf("chmod: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, destPath); err != nil {
|
||||
return xerrors.Errorf("rename to final path: %w", err)
|
||||
}
|
||||
|
||||
success = true
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,11 @@ package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -72,6 +77,7 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
// The "up" script prints the JSON line then sleeps until
|
||||
// the context is canceled (simulating a long-running process).
|
||||
@@ -82,13 +88,13 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
ctx := context.Background()
|
||||
cfg, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -105,6 +111,7 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
@@ -113,13 +120,13 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
ctx := context.Background()
|
||||
cfg1, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -147,6 +154,7 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
@@ -155,13 +163,13 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
ctx := context.Background()
|
||||
result, err := pd.Screenshot(ctx, ScreenshotOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -172,6 +180,7 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
@@ -180,13 +189,13 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
ctx := context.Background()
|
||||
_, err := pd.Screenshot(ctx, ScreenshotOptions{
|
||||
TargetWidth: 800,
|
||||
TargetHeight: 600,
|
||||
@@ -278,13 +287,13 @@ func TestPortableDesktop_MouseMethods(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
err := tt.invoke(t.Context(), pd)
|
||||
err := tt.invoke(context.Background(), pd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmds := rec.allCommands()
|
||||
@@ -363,13 +372,13 @@ func TestPortableDesktop_KeyboardMethods(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
err := tt.invoke(t.Context(), pd)
|
||||
err := tt.invoke(context.Background(), pd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmds := rec.allCommands()
|
||||
@@ -395,13 +404,13 @@ func TestPortableDesktop_CursorPosition(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
x, y, err := pd.CursorPosition(t.Context())
|
||||
x, y, err := pd.CursorPosition(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, x)
|
||||
assert.Equal(t, 200, y)
|
||||
@@ -419,13 +428,13 @@ func TestPortableDesktop_Close(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
ctx := context.Background()
|
||||
_, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -448,6 +457,81 @@ func TestPortableDesktop_Close(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "desktop is closed")
|
||||
}
|
||||
|
||||
// --- downloadBinary tests ---
|
||||
|
||||
func TestDownloadBinary_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
binaryContent := []byte("#!/bin/sh\necho portable\n")
|
||||
hash := sha256.Sum256(binaryContent)
|
||||
expectedSHA := hex.EncodeToString(hash[:])
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(binaryContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
err := downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the file exists and has correct content.
|
||||
got, err := os.ReadFile(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binaryContent, got)
|
||||
|
||||
// Verify executable permissions.
|
||||
info, err := os.Stat(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, info.Mode()&0o700, "binary should be executable")
|
||||
}
|
||||
|
||||
func TestDownloadBinary_ChecksumMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("real binary content"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
wrongSHA := "0000000000000000000000000000000000000000000000000000000000000000"
|
||||
err := downloadBinary(context.Background(), srv.Client(), srv.URL, wrongSHA, destPath)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "SHA-256 mismatch")
|
||||
|
||||
// The destination file should not exist (temp file cleaned up).
|
||||
_, statErr := os.Stat(destPath)
|
||||
assert.True(t, os.IsNotExist(statErr), "dest file should not exist after checksum failure")
|
||||
|
||||
// No leftover temp files in the directory.
|
||||
entries, err := os.ReadDir(destDir)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, entries, "no leftover temp files should remain")
|
||||
}
|
||||
|
||||
func TestDownloadBinary_HTTPError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
err := downloadBinary(context.Background(), srv.Client(), srv.URL, "irrelevant", destPath)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "status 404")
|
||||
}
|
||||
|
||||
// --- ensureBinary tests ---
|
||||
|
||||
func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
|
||||
@@ -457,89 +541,173 @@ func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
|
||||
// immediately without doing any work.
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "/already/set",
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "/already/set",
|
||||
}
|
||||
|
||||
err := pd.ensureBinary(t.Context())
|
||||
err := pd.ensureBinary(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/already/set", pd.binPath)
|
||||
}
|
||||
|
||||
func TestEnsureBinary_UsesScriptBinDir(t *testing.T) {
|
||||
func TestEnsureBinary_UsesCachedBinary(t *testing.T) {
|
||||
// Cannot use t.Parallel because t.Setenv modifies the process
|
||||
// environment.
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("portabledesktop is only supported on Linux")
|
||||
}
|
||||
|
||||
scriptBinDir := t.TempDir()
|
||||
binPath := filepath.Join(scriptBinDir, "portabledesktop")
|
||||
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
|
||||
require.NoError(t, os.Chmod(binPath, 0o755))
|
||||
bin, ok := platformBinaries[runtime.GOARCH]
|
||||
if !ok {
|
||||
t.Skipf("no platformBinary entry for %s", runtime.GOARCH)
|
||||
}
|
||||
|
||||
dataDir := t.TempDir()
|
||||
cacheDir := filepath.Join(dataDir, "portabledesktop", bin.SHA256)
|
||||
require.NoError(t, os.MkdirAll(cacheDir, 0o700))
|
||||
|
||||
cachedPath := filepath.Join(cacheDir, "portabledesktop")
|
||||
require.NoError(t, os.WriteFile(cachedPath, []byte("#!/bin/sh\n"), 0o600))
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
scriptBinDir: scriptBinDir,
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
dataDir: dataDir,
|
||||
}
|
||||
|
||||
// Clear PATH so LookPath won't find a real binary.
|
||||
t.Setenv("PATH", "")
|
||||
|
||||
err := pd.ensureBinary(t.Context())
|
||||
err := pd.ensureBinary(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binPath, pd.binPath)
|
||||
assert.Equal(t, cachedPath, pd.binPath)
|
||||
}
|
||||
|
||||
func TestEnsureBinary_ScriptBinDirNotExecutable(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Windows does not support Unix permission bits")
|
||||
}
|
||||
func TestEnsureBinary_Downloads(t *testing.T) {
|
||||
// Cannot use t.Parallel because t.Setenv modifies the process
|
||||
// environment.
|
||||
// environment and we override the package-level platformBinaries.
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("portabledesktop is only supported on Linux")
|
||||
}
|
||||
|
||||
scriptBinDir := t.TempDir()
|
||||
binPath := filepath.Join(scriptBinDir, "portabledesktop")
|
||||
// Write without execute permission.
|
||||
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
|
||||
_ = binPath
|
||||
binaryContent := []byte("#!/bin/sh\necho downloaded\n")
|
||||
hash := sha256.Sum256(binaryContent)
|
||||
expectedSHA := hex.EncodeToString(hash[:])
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(binaryContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Save and restore platformBinaries for this test.
|
||||
origBinaries := platformBinaries
|
||||
platformBinaries = map[string]struct {
|
||||
URL string
|
||||
SHA256 string
|
||||
}{
|
||||
runtime.GOARCH: {
|
||||
URL: srv.URL + "/portabledesktop",
|
||||
SHA256: expectedSHA,
|
||||
},
|
||||
}
|
||||
t.Cleanup(func() { platformBinaries = origBinaries })
|
||||
|
||||
dataDir := t.TempDir()
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
scriptBinDir: scriptBinDir,
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
dataDir: dataDir,
|
||||
httpClient: srv.Client(),
|
||||
}
|
||||
|
||||
// Clear PATH so LookPath won't find a real binary.
|
||||
// Ensure PATH doesn't contain a real portabledesktop binary.
|
||||
t.Setenv("PATH", "")
|
||||
|
||||
err := pd.ensureBinary(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
err := pd.ensureBinary(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPath := filepath.Join(dataDir, "portabledesktop", expectedSHA, "portabledesktop")
|
||||
assert.Equal(t, expectedPath, pd.binPath)
|
||||
|
||||
// Verify the downloaded file has correct content.
|
||||
got, err := os.ReadFile(expectedPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binaryContent, got)
|
||||
}
|
||||
|
||||
func TestEnsureBinary_NotFound(t *testing.T) {
|
||||
// Cannot use t.Parallel because t.Setenv modifies the process
|
||||
// environment.
|
||||
func TestEnsureBinary_RetriesOnFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
scriptBinDir: t.TempDir(), // empty directory
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("portabledesktop is only supported on Linux")
|
||||
}
|
||||
|
||||
// Clear PATH so LookPath won't find a real binary.
|
||||
t.Setenv("PATH", "")
|
||||
binaryContent := []byte("#!/bin/sh\necho retried\n")
|
||||
hash := sha256.Sum256(binaryContent)
|
||||
expectedSHA := hex.EncodeToString(hash[:])
|
||||
|
||||
err := pd.ensureBinary(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
var mu sync.Mutex
|
||||
attempt := 0
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
mu.Lock()
|
||||
current := attempt
|
||||
attempt++
|
||||
mu.Unlock()
|
||||
|
||||
// Fail the first 2 attempts, succeed on the third.
|
||||
if current < 2 {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(binaryContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Test downloadBinary directly to avoid time.Sleep in
|
||||
// ensureBinary's retry loop. We call it 3 times to simulate
|
||||
// what ensureBinary would do.
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
var lastErr error
|
||||
for i := range 3 {
|
||||
lastErr = downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
|
||||
if lastErr == nil {
|
||||
break
|
||||
}
|
||||
if i < 2 {
|
||||
// In the real code, ensureBinary sleeps here.
|
||||
// We skip the sleep in tests.
|
||||
continue
|
||||
}
|
||||
}
|
||||
require.NoError(t, lastErr, "download should succeed on the third attempt")
|
||||
|
||||
got, err := os.ReadFile(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binaryContent, got)
|
||||
|
||||
mu.Lock()
|
||||
assert.Equal(t, 3, attempt, "server should have been hit 3 times")
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// Ensure that portableDesktop satisfies the Desktop interface at
|
||||
// compile time. This uses the unexported type so it lives in the
|
||||
// internal test package.
|
||||
var _ Desktop = (*portableDesktop)(nil)
|
||||
|
||||
// Silence the linter about unused imports — agentexec.DefaultExecer
|
||||
// is used in TestEnsureBinary_UsesCachedBinPath and others, and
|
||||
// fmt.Sscanf is used indirectly via the implementation.
|
||||
var (
|
||||
_ = agentexec.DefaultExecer
|
||||
_ = fmt.Sprintf
|
||||
)
|
||||
|
||||
@@ -140,16 +140,16 @@ func (c *Client) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateAppS
|
||||
|
||||
// SyncStatusResponse contains the status information for a unit.
|
||||
type SyncStatusResponse struct {
|
||||
UnitName unit.ID `json:"unit_name" table:"unit,default_sort"`
|
||||
Status unit.Status `json:"status" table:"status"`
|
||||
IsReady bool `json:"is_ready" table:"ready"`
|
||||
Dependencies []DependencyInfo `json:"dependencies" table:"dependencies"`
|
||||
UnitName unit.ID `table:"unit,default_sort" json:"unit_name"`
|
||||
Status unit.Status `table:"status" json:"status"`
|
||||
IsReady bool `table:"ready" json:"is_ready"`
|
||||
Dependencies []DependencyInfo `table:"dependencies" json:"dependencies"`
|
||||
}
|
||||
|
||||
// DependencyInfo contains information about a unit dependency.
|
||||
type DependencyInfo struct {
|
||||
DependsOn unit.ID `json:"depends_on" table:"depends on,default_sort"`
|
||||
RequiredStatus unit.Status `json:"required_status" table:"required status"`
|
||||
CurrentStatus unit.Status `json:"current_status" table:"current status"`
|
||||
IsSatisfied bool `json:"is_satisfied" table:"satisfied"`
|
||||
DependsOn unit.ID `table:"depends on,default_sort" json:"depends_on"`
|
||||
RequiredStatus unit.Status `table:"required status" json:"required_status"`
|
||||
CurrentStatus unit.Status `table:"current status" json:"current_status"`
|
||||
IsSatisfied bool `table:"satisfied" json:"is_satisfied"`
|
||||
}
|
||||
|
||||
@@ -17,8 +17,8 @@ import (
|
||||
func NewLicenseFormatter() *cliui.OutputFormatter {
|
||||
type tableLicense struct {
|
||||
ID int32 `table:"id,default_sort"`
|
||||
UUID uuid.UUID `table:"uuid" format:"uuid"`
|
||||
UploadedAt time.Time `table:"uploaded at" format:"date-time"`
|
||||
UUID uuid.UUID `table:"uuid" format:"uuid"`
|
||||
UploadedAt time.Time `table:"uploaded at" format:"date-time"`
|
||||
// Features is the formatted string for the license claims.
|
||||
// Used for the table view.
|
||||
Features string `table:"features"`
|
||||
|
||||
@@ -214,7 +214,7 @@ func (r *RootCmd) createOrganizationRole(orgContext *OrganizationContext) *serpe
|
||||
} else {
|
||||
updated, err = client.CreateOrganizationRole(ctx, customRole)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create role: %w", err)
|
||||
return xerrors.Errorf("patch role: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -524,7 +524,7 @@ type roleTableRow struct {
|
||||
Name string `table:"name,default_sort"`
|
||||
DisplayName string `table:"display name"`
|
||||
OrganizationID string `table:"organization id"`
|
||||
SitePermissions string `table:"site permissions"`
|
||||
SitePermissions string ` table:"site permissions"`
|
||||
// map[<org_id>] -> Permissions
|
||||
OrganizationPermissions string `table:"organization permissions"`
|
||||
UserPermissions string `table:"user permissions"`
|
||||
|
||||
@@ -32,9 +32,9 @@ func (r *RootCmd) provisionerJobs() *serpent.Command {
|
||||
|
||||
func (r *RootCmd) provisionerJobsList() *serpent.Command {
|
||||
type provisionerJobRow struct {
|
||||
codersdk.ProvisionerJob ` table:"provisioner_job,recursive_inline,nosort"`
|
||||
codersdk.ProvisionerJob `table:"provisioner_job,recursive_inline,nosort"`
|
||||
OrganizationName string `json:"organization_name" table:"organization"`
|
||||
Queue string `json:"-" table:"queue"`
|
||||
Queue string `json:"-" table:"queue"`
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
+1
-1
@@ -31,7 +31,7 @@ func (r *RootCmd) Provisioners() *serpent.Command {
|
||||
|
||||
func (r *RootCmd) provisionerList() *serpent.Command {
|
||||
type provisionerDaemonRow struct {
|
||||
codersdk.ProvisionerDaemon ` table:"provisioner_daemon,recursive_inline"`
|
||||
codersdk.ProvisionerDaemon `table:"provisioner_daemon,recursive_inline"`
|
||||
OrganizationName string `json:"organization_name" table:"organization"`
|
||||
}
|
||||
var (
|
||||
|
||||
+3
-3
@@ -350,11 +350,11 @@ func displaySchedule(ws codersdk.Workspace, out io.Writer) error {
|
||||
// scheduleListRow is a row in the schedule list.
|
||||
// this is required for proper JSON output.
|
||||
type scheduleListRow struct {
|
||||
WorkspaceName string `json:"workspace" table:"workspace,default_sort"`
|
||||
StartsAt string `json:"starts_at" table:"starts at"`
|
||||
WorkspaceName string `json:"workspace" table:"workspace,default_sort"`
|
||||
StartsAt string `json:"starts_at" table:"starts at"`
|
||||
StartsNext string `json:"starts_next" table:"starts next"`
|
||||
StopsAfter string `json:"stops_after" table:"stops after"`
|
||||
StopsNext string `json:"stops_next" table:"stops next"`
|
||||
StopsNext string `json:"stops_next" table:"stops next"`
|
||||
}
|
||||
|
||||
func scheduleListRowFromWorkspace(now time.Time, workspace codersdk.Workspace) scheduleListRow {
|
||||
|
||||
+4
-4
@@ -284,9 +284,9 @@ func (*RootCmd) statDisk(fs afero.Fs) *serpent.Command {
|
||||
}
|
||||
|
||||
type statsRow struct {
|
||||
HostCPU *clistat.Result `json:"host_cpu" table:"host cpu,default_sort"`
|
||||
HostMemory *clistat.Result `json:"host_memory" table:"host memory"`
|
||||
Disk *clistat.Result `json:"home_disk" table:"home disk"`
|
||||
ContainerCPU *clistat.Result `json:"container_cpu" table:"container cpu"`
|
||||
HostCPU *clistat.Result `json:"host_cpu" table:"host cpu,default_sort"`
|
||||
HostMemory *clistat.Result `json:"host_memory" table:"host memory"`
|
||||
Disk *clistat.Result `json:"home_disk" table:"home disk"`
|
||||
ContainerCPU *clistat.Result `json:"container_cpu" table:"container cpu"`
|
||||
ContainerMemory *clistat.Result `json:"container_memory" table:"container memory"`
|
||||
}
|
||||
|
||||
+1
-1
@@ -155,7 +155,7 @@ func taskWatchIsEnded(task codersdk.Task) bool {
|
||||
}
|
||||
|
||||
type taskStatusRow struct {
|
||||
codersdk.Task ` table:"r,recursive_inline"`
|
||||
codersdk.Task `table:"r,recursive_inline"`
|
||||
ChangedAgo string `json:"-" table:"state changed"`
|
||||
Healthy bool `json:"-" table:"healthy"`
|
||||
}
|
||||
|
||||
@@ -138,7 +138,7 @@ func (r *RootCmd) templateVersionsList() *serpent.Command {
|
||||
|
||||
type templateVersionRow struct {
|
||||
// For json format:
|
||||
TemplateVersion codersdk.TemplateVersion ` table:"-"`
|
||||
TemplateVersion codersdk.TemplateVersion `table:"-"`
|
||||
ActiveJSON bool `json:"active" table:"-"`
|
||||
|
||||
// For table format:
|
||||
|
||||
@@ -24,10 +24,6 @@ OPTIONS:
|
||||
-p, --password string
|
||||
Specifies a password for the new user.
|
||||
|
||||
--service-account bool
|
||||
Create a user account intended to be used by a service or as an
|
||||
intermediary rather than by a human.
|
||||
|
||||
-u, --username string
|
||||
Specifies a username for the new user.
|
||||
|
||||
|
||||
+7
-7
@@ -161,14 +161,14 @@ type tokenListRow struct {
|
||||
codersdk.APIKey `table:"-"`
|
||||
|
||||
// For table format:
|
||||
ID string `json:"-" table:"id,default_sort"`
|
||||
ID string `json:"-" table:"id,default_sort"`
|
||||
TokenName string `json:"token_name" table:"name"`
|
||||
Scopes string `json:"-" table:"scopes"`
|
||||
Allow string `json:"-" table:"allow list"`
|
||||
LastUsed time.Time `json:"-" table:"last used"`
|
||||
ExpiresAt time.Time `json:"-" table:"expires at"`
|
||||
CreatedAt time.Time `json:"-" table:"created at"`
|
||||
Owner string `json:"-" table:"owner"`
|
||||
Scopes string `json:"-" table:"scopes"`
|
||||
Allow string `json:"-" table:"allow list"`
|
||||
LastUsed time.Time `json:"-" table:"last used"`
|
||||
ExpiresAt time.Time `json:"-" table:"expires at"`
|
||||
CreatedAt time.Time `json:"-" table:"created at"`
|
||||
Owner string `json:"-" table:"owner"`
|
||||
}
|
||||
|
||||
func tokenListRowFromToken(token codersdk.APIKeyWithOwner) tokenListRow {
|
||||
|
||||
+12
-37
@@ -17,14 +17,13 @@ import (
|
||||
|
||||
func (r *RootCmd) userCreate() *serpent.Command {
|
||||
var (
|
||||
email string
|
||||
username string
|
||||
name string
|
||||
password string
|
||||
disableLogin bool
|
||||
loginType string
|
||||
serviceAccount bool
|
||||
orgContext = NewOrganizationContext()
|
||||
email string
|
||||
username string
|
||||
name string
|
||||
password string
|
||||
disableLogin bool
|
||||
loginType string
|
||||
orgContext = NewOrganizationContext()
|
||||
)
|
||||
cmd := &serpent.Command{
|
||||
Use: "create",
|
||||
@@ -33,23 +32,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
serpent.RequireNArgs(0),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
if serviceAccount {
|
||||
switch {
|
||||
case loginType != "":
|
||||
return xerrors.New("You cannot use --login-type with --service-account")
|
||||
case password != "":
|
||||
return xerrors.New("You cannot use --password with --service-account")
|
||||
case email != "":
|
||||
return xerrors.New("You cannot use --email with --service-account")
|
||||
case disableLogin:
|
||||
return xerrors.New("You cannot use --disable-login with --service-account")
|
||||
}
|
||||
}
|
||||
|
||||
if disableLogin && loginType != "" {
|
||||
return xerrors.New("You cannot specify both --disable-login and --login-type")
|
||||
}
|
||||
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -77,7 +59,7 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if email == "" && !serviceAccount {
|
||||
if email == "" {
|
||||
email, err = cliui.Prompt(inv, cliui.PromptOptions{
|
||||
Text: "Email:",
|
||||
Validate: func(s string) error {
|
||||
@@ -105,7 +87,10 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
}
|
||||
}
|
||||
userLoginType := codersdk.LoginTypePassword
|
||||
if disableLogin || serviceAccount {
|
||||
if disableLogin && loginType != "" {
|
||||
return xerrors.New("You cannot specify both --disable-login and --login-type")
|
||||
}
|
||||
if disableLogin {
|
||||
userLoginType = codersdk.LoginTypeNone
|
||||
} else if loginType != "" {
|
||||
userLoginType = codersdk.LoginType(loginType)
|
||||
@@ -126,7 +111,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
Password: password,
|
||||
OrganizationIDs: []uuid.UUID{organization.ID},
|
||||
UserLoginType: userLoginType,
|
||||
ServiceAccount: serviceAccount,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -143,10 +127,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
case codersdk.LoginTypeOIDC:
|
||||
authenticationMethod = `Login is authenticated through the configured OIDC provider.`
|
||||
}
|
||||
if serviceAccount {
|
||||
email = "n/a"
|
||||
authenticationMethod = "Service accounts must authenticate with a token and cannot log in."
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, `A new user has been created!
|
||||
Share the instructions below to get them started.
|
||||
@@ -214,11 +194,6 @@ Create a workspace `+pretty.Sprint(cliui.DefaultStyles.Code, "coder create")+`!
|
||||
)),
|
||||
Value: serpent.StringOf(&loginType),
|
||||
},
|
||||
{
|
||||
Flag: "service-account",
|
||||
Description: "Create a user account intended to be used by a service or as an intermediary rather than by a human.",
|
||||
Value: serpent.BoolOf(&serviceAccount),
|
||||
},
|
||||
}
|
||||
|
||||
orgContext.AttachOptions(cmd)
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
@@ -125,56 +124,4 @@ func TestUserCreate(t *testing.T) {
|
||||
assert.Equal(t, args[5], created.Username)
|
||||
assert.Empty(t, created.Name)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
name: "ServiceAccount",
|
||||
args: []string{"--service-account", "-u", "dean"},
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountLoginType",
|
||||
args: []string{"--service-account", "-u", "dean", "--login-type", "none"},
|
||||
err: "You cannot use --login-type with --service-account",
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountDisableLogin",
|
||||
args: []string{"--service-account", "-u", "dean", "--disable-login"},
|
||||
err: "You cannot use --disable-login with --service-account",
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountEmail",
|
||||
args: []string{"--service-account", "-u", "dean", "--email", "dean@coder.com"},
|
||||
err: "You cannot use --email with --service-account",
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountPassword",
|
||||
args: []string{"--service-account", "-u", "dean", "--password", "1n5ecureP4ssw0rd!"},
|
||||
err: "You cannot use --password with --service-account",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
inv, root := clitest.New(t, append([]string{"users", "create"}, tt.args...)...)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
err := inv.Run()
|
||||
if tt.err == "" {
|
||||
require.NoError(t, err)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created, err := client.User(ctx, "dean")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, codersdk.LoginTypeNone, created.LoginType)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, tt.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+6
-6
@@ -11,13 +11,13 @@ import (
|
||||
)
|
||||
|
||||
type whoamiRow struct {
|
||||
URL string `json:"url" table:"URL,default_sort"`
|
||||
Username string `json:"username" table:"Username"`
|
||||
UserID string `json:"user_id" table:"ID"`
|
||||
OrganizationIDs string `json:"-" table:"Orgs"`
|
||||
URL string `json:"url" table:"URL,default_sort"`
|
||||
Username string `json:"username" table:"Username"`
|
||||
UserID string `json:"user_id" table:"ID"`
|
||||
OrganizationIDs string `json:"-" table:"Orgs"`
|
||||
OrganizationIDsJSON []string `json:"organization_ids" table:"-"`
|
||||
Roles string `json:"-" table:"Roles"`
|
||||
RolesJSON map[string][]string `json:"roles" table:"-"`
|
||||
Roles string `json:"-" table:"Roles"`
|
||||
RolesJSON map[string][]string `json:"roles" table:"-"`
|
||||
}
|
||||
|
||||
func (r whoamiRow) String() string {
|
||||
|
||||
Generated
+2
-52
@@ -4819,7 +4819,7 @@ const docTemplate = `{
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateWorkspaceSharingSettingsRequest"
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
}
|
||||
}
|
||||
],
|
||||
@@ -4827,7 +4827,7 @@ const docTemplate = `{
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
"$ref": "#/definitions/codersdk.UpdateWorkspaceSharingSettingsRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18310,9 +18310,6 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"is_service_account": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"last_seen_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -18723,19 +18720,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ShareableWorkspaceOwners": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"none",
|
||||
"everyone",
|
||||
"service_accounts"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"ShareableWorkspaceOwnersNone",
|
||||
"ShareableWorkspaceOwnersEveryone",
|
||||
"ShareableWorkspaceOwnersServiceAccounts"
|
||||
]
|
||||
},
|
||||
"codersdk.SharedWorkspaceActor": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -19634,9 +19618,6 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"is_service_account": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"last_seen_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -20347,21 +20328,7 @@ const docTemplate = `{
|
||||
"codersdk.UpdateWorkspaceSharingSettingsRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"shareable_workspace_owners": {
|
||||
"description": "ShareableWorkspaceOwners controls whose workspaces can be shared\nwithin the organization.",
|
||||
"enum": [
|
||||
"none",
|
||||
"everyone",
|
||||
"service_accounts"
|
||||
],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.ShareableWorkspaceOwners"
|
||||
}
|
||||
]
|
||||
},
|
||||
"sharing_disabled": {
|
||||
"description": "SharingDisabled is deprecated and left for backward compatibility\npurposes.\nDeprecated: use ` + "`" + `ShareableWorkspaceOwners` + "`" + ` instead",
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
@@ -20483,9 +20450,6 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"is_service_account": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"last_seen_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -22205,21 +22169,7 @@ const docTemplate = `{
|
||||
"codersdk.WorkspaceSharingSettings": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"shareable_workspace_owners": {
|
||||
"description": "ShareableWorkspaceOwners controls whose workspaces can be shared\nwithin the organization.",
|
||||
"enum": [
|
||||
"none",
|
||||
"everyone",
|
||||
"service_accounts"
|
||||
],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.ShareableWorkspaceOwners"
|
||||
}
|
||||
]
|
||||
},
|
||||
"sharing_disabled": {
|
||||
"description": "SharingDisabled is deprecated and left for backward compatibility\npurposes.\nDeprecated: use ` + "`" + `ShareableWorkspaceOwners` + "`" + ` instead",
|
||||
"type": "boolean"
|
||||
},
|
||||
"sharing_globally_disabled": {
|
||||
|
||||
Generated
+2
-40
@@ -4262,7 +4262,7 @@
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateWorkspaceSharingSettingsRequest"
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
}
|
||||
}
|
||||
],
|
||||
@@ -4270,7 +4270,7 @@
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
"$ref": "#/definitions/codersdk.UpdateWorkspaceSharingSettingsRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -16708,9 +16708,6 @@
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"is_service_account": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"last_seen_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -17112,15 +17109,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ShareableWorkspaceOwners": {
|
||||
"type": "string",
|
||||
"enum": ["none", "everyone", "service_accounts"],
|
||||
"x-enum-varnames": [
|
||||
"ShareableWorkspaceOwnersNone",
|
||||
"ShareableWorkspaceOwnersEveryone",
|
||||
"ShareableWorkspaceOwnersServiceAccounts"
|
||||
]
|
||||
},
|
||||
"codersdk.SharedWorkspaceActor": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -17982,9 +17970,6 @@
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"is_service_account": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"last_seen_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -18660,17 +18645,7 @@
|
||||
"codersdk.UpdateWorkspaceSharingSettingsRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"shareable_workspace_owners": {
|
||||
"description": "ShareableWorkspaceOwners controls whose workspaces can be shared\nwithin the organization.",
|
||||
"enum": ["none", "everyone", "service_accounts"],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.ShareableWorkspaceOwners"
|
||||
}
|
||||
]
|
||||
},
|
||||
"sharing_disabled": {
|
||||
"description": "SharingDisabled is deprecated and left for backward compatibility\npurposes.\nDeprecated: use `ShareableWorkspaceOwners` instead",
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
@@ -18774,9 +18749,6 @@
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"is_service_account": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"last_seen_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -20412,17 +20384,7 @@
|
||||
"codersdk.WorkspaceSharingSettings": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"shareable_workspace_owners": {
|
||||
"description": "ShareableWorkspaceOwners controls whose workspaces can be shared\nwithin the organization.",
|
||||
"enum": ["none", "everyone", "service_accounts"],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.ShareableWorkspaceOwners"
|
||||
}
|
||||
]
|
||||
},
|
||||
"sharing_disabled": {
|
||||
"description": "SharingDisabled is deprecated and left for backward compatibility\npurposes.\nDeprecated: use `ShareableWorkspaceOwners` instead",
|
||||
"type": "boolean"
|
||||
},
|
||||
"sharing_globally_disabled": {
|
||||
|
||||
+272
-484
File diff suppressed because it is too large
Load Diff
@@ -2,20 +2,13 @@ package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
)
|
||||
|
||||
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
|
||||
@@ -91,135 +84,3 @@ func TestRefreshChatWorkspaceSnapshot_ReturnsReloadError(t *testing.T) {
|
||||
require.ErrorContains(t, err, loadErr.Error())
|
||||
require.Equal(t, chat, refreshed)
|
||||
}
|
||||
|
||||
func TestResolveInstructionsReusesTurnLocalWorkspaceAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
workspaceAgent := database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
OperatingSystem: "linux",
|
||||
Directory: "/home/coder/project",
|
||||
ExpandedDirectory: "/home/coder/project",
|
||||
}
|
||||
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
gomock.Any(),
|
||||
workspaceID,
|
||||
).Return([]database.WorkspaceAgent{workspaceAgent}, nil).Times(1)
|
||||
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
||||
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).Return(
|
||||
workspacesdk.LSResponse{},
|
||||
codersdk.NewTestError(404, "POST", "/api/v0/list-directory"),
|
||||
).Times(1)
|
||||
conn.EXPECT().ReadFile(
|
||||
gomock.Any(),
|
||||
"/home/coder/project/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1),
|
||||
).Return(
|
||||
nil,
|
||||
"",
|
||||
codersdk.NewTestError(404, "GET", "/api/v0/read-file"),
|
||||
).Times(1)
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
instructionCache: make(map[uuid.UUID]cachedInstruction),
|
||||
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
return conn, func() {}, nil
|
||||
},
|
||||
}
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: server,
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
t.Cleanup(workspaceCtx.close)
|
||||
|
||||
instruction := server.resolveInstructions(
|
||||
ctx,
|
||||
chat,
|
||||
workspaceCtx.getWorkspaceAgent,
|
||||
workspaceCtx.getWorkspaceConn,
|
||||
)
|
||||
require.Contains(t, instruction, "Operating System: linux")
|
||||
require.Contains(t, instruction, "Working Directory: /home/coder/project")
|
||||
}
|
||||
|
||||
func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
initialAgent := database.WorkspaceAgent{ID: uuid.New()}
|
||||
refreshedAgent := database.WorkspaceAgent{ID: uuid.New()}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
gomock.Any(),
|
||||
workspaceID,
|
||||
).Return([]database.WorkspaceAgent{initialAgent}, nil),
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
gomock.Any(),
|
||||
workspaceID,
|
||||
).Return([]database.WorkspaceAgent{refreshedAgent}, nil),
|
||||
)
|
||||
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
||||
|
||||
var dialed []uuid.UUID
|
||||
server := &Server{db: db}
|
||||
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
dialed = append(dialed, agentID)
|
||||
if agentID == initialAgent.ID {
|
||||
return nil, nil, xerrors.New("dial failed")
|
||||
}
|
||||
return conn, func() {}, nil
|
||||
}
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: server,
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
t.Cleanup(workspaceCtx.close)
|
||||
|
||||
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, conn, gotConn)
|
||||
require.Equal(t, []uuid.UUID{initialAgent.ID, refreshedAgent.ID}, dialed)
|
||||
}
|
||||
|
||||
+5
-580
@@ -374,72 +374,6 @@ func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
||||
require.Len(t, messages, 1)
|
||||
}
|
||||
|
||||
func TestSendMessageQueuesWhenWaitingWithQueuedBacklog(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "queue-when-waiting-with-backlog",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("older queued"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
||||
ChatID: chat.ID,
|
||||
Content: queuedContent,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusWaiting,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("newer queued")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Queued)
|
||||
require.NotNil(t, result.QueuedMessage)
|
||||
require.Equal(t, database.ChatStatusWaiting, result.Chat.Status)
|
||||
|
||||
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queued, 2)
|
||||
|
||||
olderSDK := db2sdk.ChatQueuedMessage(queued[0])
|
||||
require.Len(t, olderSDK.Content, 1)
|
||||
require.Equal(t, "older queued", olderSDK.Content[0].Text)
|
||||
|
||||
newerSDK := db2sdk.ChatQueuedMessage(queued[1])
|
||||
require.Len(t, newerSDK.Content, 1)
|
||||
require.Equal(t, "newer queued", newerSDK.Content[0].Text)
|
||||
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 1)
|
||||
}
|
||||
|
||||
func TestSendMessageInterruptBehaviorQueuesAndInterruptsWhenBusy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -591,457 +525,6 @@ func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
|
||||
require.False(t, chatFromDB.WorkerID.Valid)
|
||||
}
|
||||
|
||||
func TestCreateChatRejectsWhenUsageLimitReached(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
|
||||
Enabled: true,
|
||||
DefaultLimitMicros: 100,
|
||||
Period: string(codersdk.ChatUsageLimitPeriodDay),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
existingChat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: user.ID,
|
||||
Title: "existing-limit-chat",
|
||||
LastModelConfigID: model.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("assistant"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
||||
ChatID: existingChat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
Content: assistantContent,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{},
|
||||
OutputTokens: sql.NullInt64{},
|
||||
TotalTokens: sql.NullInt64{},
|
||||
ReasoningTokens: sql.NullInt64{},
|
||||
CacheCreationTokens: sql.NullInt64{},
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 100, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
beforeChats, err := db.GetChats(ctx, database.GetChatsParams{
|
||||
OwnerID: user.ID,
|
||||
AfterID: uuid.Nil,
|
||||
OffsetOpt: 0,
|
||||
LimitOpt: 100,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, beforeChats, 1)
|
||||
|
||||
_, err = replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "over-limit",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
var limitErr *chatd.UsageLimitExceededError
|
||||
require.ErrorAs(t, err, &limitErr)
|
||||
require.Equal(t, int64(100), limitErr.LimitMicros)
|
||||
require.Equal(t, int64(100), limitErr.ConsumedMicros)
|
||||
|
||||
afterChats, err := db.GetChats(ctx, database.GetChatsParams{
|
||||
OwnerID: user.ID,
|
||||
AfterID: uuid.Nil,
|
||||
OffsetOpt: 0,
|
||||
LimitOpt: 100,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, afterChats, len(beforeChats))
|
||||
}
|
||||
|
||||
func TestPromoteQueuedAllowsAlreadyQueuedMessageWhenUsageLimitReached(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
|
||||
Enabled: true,
|
||||
DefaultLimitMicros: 100,
|
||||
Period: string(codersdk.ChatUsageLimitPeriodDay),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "queued-limit-reached",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
queuedResult, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, queuedResult.Queued)
|
||||
require.NotNil(t, queuedResult.QueuedMessage)
|
||||
|
||||
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("assistant"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
Content: assistantContent,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{},
|
||||
OutputTokens: sql.NullInt64{},
|
||||
TotalTokens: sql.NullInt64{},
|
||||
ReasoningTokens: sql.NullInt64{},
|
||||
CacheCreationTokens: sql.NullInt64{},
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 100, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusWaiting,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := replica.PromoteQueued(ctx, chatd.PromoteQueuedOptions{
|
||||
ChatID: chat.ID,
|
||||
QueuedMessageID: queuedResult.QueuedMessage.ID,
|
||||
CreatedBy: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatMessageRoleUser, result.PromotedMessage.Role)
|
||||
|
||||
chat, err = db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusPending, chat.Status)
|
||||
|
||||
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, queued)
|
||||
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 3)
|
||||
require.Equal(t, database.ChatMessageRoleUser, messages[2].Role)
|
||||
}
|
||||
|
||||
func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
|
||||
Enabled: true,
|
||||
DefaultLimitMicros: 100,
|
||||
Period: string(codersdk.ChatUsageLimitPeriodDay),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
streamStarted := make(chan struct{})
|
||||
interrupted := make(chan struct{})
|
||||
allowFinish := make(chan struct{})
|
||||
var requestCount atomic.Int32
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("title")
|
||||
}
|
||||
if requestCount.Add(1) == 1 {
|
||||
chunks := make(chan chattest.OpenAIChunk, 1)
|
||||
go func() {
|
||||
defer close(chunks)
|
||||
chunks <- chattest.OpenAITextChunks("partial")[0]
|
||||
select {
|
||||
case <-streamStarted:
|
||||
default:
|
||||
close(streamStarted)
|
||||
}
|
||||
<-req.Context().Done()
|
||||
select {
|
||||
case <-interrupted:
|
||||
default:
|
||||
close(interrupted)
|
||||
}
|
||||
<-allowFinish
|
||||
}()
|
||||
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("done")...,
|
||||
)
|
||||
})
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := chatd.New(chatd.Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: uuid.New(),
|
||||
Pubsub: ps,
|
||||
PendingChatAcquireInterval: 10 * time.Millisecond,
|
||||
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "interrupt-autopromote-limit",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-streamStarted:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, queuedResult.Queued)
|
||||
require.NotNil(t, queuedResult.QueuedMessage)
|
||||
|
||||
// Send "later queued" immediately after "queued" while the first
|
||||
// message is still in chat_queued_messages. The existing backlog
|
||||
// (len(existingQueued) > 0) guarantees this is queued regardless
|
||||
// of chat status, avoiding a race where the auto-promoted "queued"
|
||||
// message finishes processing before we can send this.
|
||||
laterQueuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("later queued")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, laterQueuedResult.Queued)
|
||||
require.NotNil(t, laterQueuedResult.QueuedMessage)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-interrupted:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
spendChat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: user.ID,
|
||||
WorkspaceID: uuid.NullUUID{},
|
||||
ParentChatID: uuid.NullUUID{},
|
||||
RootChatID: uuid.NullUUID{},
|
||||
LastModelConfigID: model.ID,
|
||||
Title: "other-spend",
|
||||
Mode: database.NullChatMode{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("spent elsewhere"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
||||
ChatID: spendChat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
Content: assistantContent,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{},
|
||||
OutputTokens: sql.NullInt64{},
|
||||
TotalTokens: sql.NullInt64{},
|
||||
ReasoningTokens: sql.NullInt64{},
|
||||
CacheCreationTokens: sql.NullInt64{},
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 100, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
close(allowFinish)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
queued, dbErr := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
if dbErr != nil || len(queued) != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil || fromDB.Status != database.ChatStatusWaiting {
|
||||
return false
|
||||
}
|
||||
|
||||
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
userTexts := make([]string, 0, 3)
|
||||
for _, message := range messages {
|
||||
if message.Role != database.ChatMessageRoleUser {
|
||||
continue
|
||||
}
|
||||
sdkMessage := db2sdk.ChatMessage(message)
|
||||
if len(sdkMessage.Content) != 1 {
|
||||
continue
|
||||
}
|
||||
userTexts = append(userTexts, sdkMessage.Content[0].Text)
|
||||
}
|
||||
if len(userTexts) != 3 {
|
||||
return false
|
||||
}
|
||||
return userTexts[0] == "hello" && userTexts[1] == "queued" && userTexts[2] == "later queued"
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func TestEditMessageRejectsWhenUsageLimitReached(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
|
||||
Enabled: true,
|
||||
DefaultLimitMicros: 100,
|
||||
Period: string(codersdk.ChatUsageLimitPeriodDay),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "edit-limit-reached",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("original")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 1)
|
||||
editedMessageID := messages[0].ID
|
||||
|
||||
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("assistant"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
Content: assistantContent,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{},
|
||||
OutputTokens: sql.NullInt64{},
|
||||
TotalTokens: sql.NullInt64{},
|
||||
ReasoningTokens: sql.NullInt64{},
|
||||
CacheCreationTokens: sql.NullInt64{},
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 100, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
EditedMessageID: editedMessageID,
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
var limitErr *chatd.UsageLimitExceededError
|
||||
require.ErrorAs(t, err, &limitErr)
|
||||
require.Equal(t, int64(100), limitErr.LimitMicros)
|
||||
require.Equal(t, int64(100), limitErr.ConsumedMicros)
|
||||
|
||||
messages, err = db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 2)
|
||||
originalMessage := db2sdk.ChatMessage(messages[0])
|
||||
require.Len(t, originalMessage.Content, 1)
|
||||
require.Equal(t, "original", originalMessage.Content[0].Text)
|
||||
}
|
||||
|
||||
func TestEditMessageRejectsMissingMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -2239,10 +1722,12 @@ func TestSuccessfulChatSendsWebPushWithSummary(t *testing.T) {
|
||||
const assistantText = "I have completed the task successfully and all tests are passing now."
|
||||
const summaryText = "Completed task and verified all tests pass."
|
||||
|
||||
var nonStreamingRequests atomic.Int32
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
nonStreamingRequests.Add(1)
|
||||
// Non-streaming calls are used for title
|
||||
// generation and push summary generation.
|
||||
// Return the summary text for both — the title
|
||||
// result is irrelevant to this test.
|
||||
return chattest.OpenAINonStreamingResponse(summaryText)
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
@@ -2289,63 +1774,6 @@ func TestSuccessfulChatSendsWebPushWithSummary(t *testing.T) {
|
||||
"push body should be the LLM-generated summary")
|
||||
require.NotEqual(t, "Agent has finished running.", msg.Body,
|
||||
"push body should not use the default fallback text")
|
||||
require.Equal(t, int32(1), nonStreamingRequests.Load(),
|
||||
"expected exactly one non-streaming request for push summary generation")
|
||||
}
|
||||
|
||||
func TestSuccessfulChatSendsWebPushFallbackWithoutSummaryForEmptyAssistantText(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
var nonStreamingRequests atomic.Int32
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
nonStreamingRequests.Add(1)
|
||||
return chattest.OpenAINonStreamingResponse("unexpected summary request")
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks(" ")...,
|
||||
)
|
||||
})
|
||||
|
||||
mockPush := &mockWebpushDispatcher{}
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := chatd.New(chatd.Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: uuid.New(),
|
||||
Pubsub: ps,
|
||||
PendingChatAcquireInterval: 10 * time.Millisecond,
|
||||
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
||||
WebpushDispatcher: mockPush,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
||||
|
||||
_, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "empty-summary-push-test",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("do the thing")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
return mockPush.dispatchCount.Load() >= 1
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
msg := mockPush.getLastMessage()
|
||||
require.Equal(t, "Agent has finished running.", msg.Body,
|
||||
"push body should fall back when the final assistant text is empty")
|
||||
require.Equal(t, int32(0), nonStreamingRequests.Load(),
|
||||
"push summary should not be requested when final assistant text has no usable text")
|
||||
}
|
||||
|
||||
func TestComputerUseSubagentToolsAndModel(t *testing.T) {
|
||||
@@ -2517,9 +1945,6 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.UpsertChatDesktopEnabled(ctx, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Build workspace + agent records so getWorkspaceConn can
|
||||
// resolve the agent for the computer use child.
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
@@ -2681,7 +2106,7 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) {
|
||||
|
||||
// 6. Verify the child chat has Mode = computer_use in
|
||||
// the DB.
|
||||
allChats, err := db.GetChats(ctx, database.GetChatsParams{
|
||||
allChats, err := db.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
|
||||
OwnerID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -82,7 +82,7 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
|
||||
|
||||
options := &codersdk.ChatModelProviderOptions{
|
||||
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
|
||||
Reasoning: &codersdk.ChatModelReasoningOptions{
|
||||
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
|
||||
Enabled: boolPtr(true),
|
||||
},
|
||||
Provider: &codersdk.ChatModelOpenRouterProvider{
|
||||
@@ -92,7 +92,7 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
|
||||
}
|
||||
defaults := &codersdk.ChatModelProviderOptions{
|
||||
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
|
||||
Reasoning: &codersdk.ChatModelReasoningOptions{
|
||||
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
|
||||
Enabled: boolPtr(false),
|
||||
Exclude: boolPtr(true),
|
||||
MaxTokens: int64Ptr(123),
|
||||
|
||||
@@ -78,9 +78,9 @@ type ProcessToolOptions struct {
|
||||
|
||||
// ExecuteArgs are the parameters accepted by the execute tool.
|
||||
type ExecuteArgs struct {
|
||||
Command string `json:"command" description:"The shell command to execute."`
|
||||
Timeout *string `json:"timeout,omitempty" description:"Timeout duration (e.g. '30s', '5m'). Default is 10s. Only applies to foreground commands."`
|
||||
WorkDir *string `json:"workdir,omitempty" description:"Working directory for the command."`
|
||||
Command string `json:"command" description:"The shell command to execute."`
|
||||
Timeout *string `json:"timeout,omitempty" description:"Timeout duration (e.g. '30s', '5m'). Default is 10s. Only applies to foreground commands."`
|
||||
WorkDir *string `json:"workdir,omitempty" description:"Working directory for the command."`
|
||||
RunInBackground *bool `json:"run_in_background,omitempty" description:"Run this command in the background without blocking. Use for long-running processes like dev servers, file watchers, or builds that run longer than 5 seconds. Do NOT use shell & to background processes — it will not work correctly. Always use this parameter instead."`
|
||||
}
|
||||
|
||||
|
||||
@@ -62,7 +62,6 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
messages []database.ChatMessage,
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
generatedTitle *generatedChatTitle,
|
||||
logger slog.Logger,
|
||||
) {
|
||||
input, ok := titleInput(chat, messages)
|
||||
@@ -112,7 +111,6 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
return
|
||||
}
|
||||
chat.Title = title
|
||||
generatedTitle.Store(title)
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange, nil)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -84,14 +84,6 @@ func (p *Server) isAnthropicConfigured(ctx context.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Server) isDesktopEnabled(ctx context.Context) bool {
|
||||
enabled, err := p.db.GetChatDesktopEnabled(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
func (p *Server) subagentTools(ctx context.Context, currentChat func() database.Chat) []fantasy.AgentTool {
|
||||
tools := []fantasy.AgentTool{
|
||||
fantasy.NewAgentTool(
|
||||
@@ -261,8 +253,9 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
|
||||
}
|
||||
|
||||
// Only include the computer use tool when an Anthropic
|
||||
// provider is configured and desktop is enabled.
|
||||
if p.isAnthropicConfigured(ctx) && p.isDesktopEnabled(ctx) {
|
||||
// provider is configured, since it requires an Anthropic
|
||||
// model.
|
||||
if p.isAnthropicConfigured(ctx) {
|
||||
tools = append(tools, fantasy.NewAgentTool(
|
||||
"spawn_computer_use_agent",
|
||||
"Spawn a dedicated computer use agent that can see the desktop "+
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
@@ -145,20 +144,14 @@ func findToolByName(tools []fantasy.AgentTool, name string) fantasy.AgentTool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func chatdTestContext(t *testing.T) context.Context {
|
||||
t.Helper()
|
||||
return dbauthz.AsChatd(testutil.Context(t, testutil.WaitLong))
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_NoAnthropicProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
|
||||
// No Anthropic key in ProviderAPIKeys.
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
// Create a root parent chat.
|
||||
@@ -183,13 +176,12 @@ func TestSpawnComputerUseAgent_NotAvailableForChildChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
|
||||
// Provide an Anthropic key so the provider check passes.
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "test-anthropic-key",
|
||||
})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
// Create a root parent chat.
|
||||
@@ -240,42 +232,16 @@ func TestSpawnComputerUseAgent_NotAvailableForChildChats(t *testing.T) {
|
||||
assert.Contains(t, resp.Content, "delegated chats cannot create child subagents")
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_DesktopDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "test-anthropic-key",
|
||||
})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent-desktop-disabled",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
parentChat, err := db.GetChatByID(ctx, parent.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
|
||||
tool := findToolByName(tools, "spawn_computer_use_agent")
|
||||
assert.Nil(t, tool, "spawn_computer_use_agent tool must be omitted when desktop is disabled")
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
|
||||
// Provide an Anthropic key so the tool can proceed.
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "test-anthropic-key",
|
||||
})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
// The parent uses an OpenAI model.
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// ComputeUsagePeriodBounds returns the UTC-aligned start and end bounds for the
|
||||
// active usage-limit period containing now.
|
||||
func ComputeUsagePeriodBounds(now time.Time, period codersdk.ChatUsageLimitPeriod) (start, end time.Time) {
|
||||
utcNow := now.UTC()
|
||||
|
||||
switch period {
|
||||
case codersdk.ChatUsageLimitPeriodDay:
|
||||
start = time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
|
||||
end = start.AddDate(0, 0, 1)
|
||||
case codersdk.ChatUsageLimitPeriodWeek:
|
||||
// Walk backward to Monday of the current ISO week.
|
||||
// ISO 8601 weeks always start on Monday, so this never
|
||||
// crosses an ISO-week boundary.
|
||||
start = time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
|
||||
for start.Weekday() != time.Monday {
|
||||
start = start.AddDate(0, 0, -1)
|
||||
}
|
||||
end = start.AddDate(0, 0, 7)
|
||||
case codersdk.ChatUsageLimitPeriodMonth:
|
||||
start = time.Date(utcNow.Year(), utcNow.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
end = start.AddDate(0, 1, 0)
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown chat usage limit period: %q", period))
|
||||
}
|
||||
|
||||
return start, end
|
||||
}
|
||||
|
||||
// ResolveUsageLimitStatus resolves the current usage-limit status for userID.
|
||||
//
|
||||
// Note: There is a potential race condition where two concurrent messages
|
||||
// from the same user can both pass the limit check if processed in
|
||||
// parallel, allowing brief overage. This is acceptable because:
|
||||
// - Cost is only known after the LLM API returns.
|
||||
// - Overage is bounded by message cost × concurrency.
|
||||
// - Fail-open is the deliberate design choice for this feature.
|
||||
//
|
||||
// Architecture note: today this path enforces one period globally
|
||||
// (day/week/month) from config.
|
||||
// To support simultaneous periods, add nullable
|
||||
// daily/weekly/monthly_limit_micros columns on override tables, where NULL
|
||||
// means no limit for that period.
|
||||
// Then scan spend once over the widest active window with conditional SUMs
|
||||
// for each period and compare each spend/limit pair Go-side, blocking on
|
||||
// whichever period is tightest.
|
||||
func ResolveUsageLimitStatus(ctx context.Context, db database.Store, userID uuid.UUID, now time.Time) (*codersdk.ChatUsageLimitStatus, error) {
|
||||
//nolint:gocritic // AsChatd provides narrowly-scoped daemon access for
|
||||
// deployment config reads and cross-user chat spend aggregation.
|
||||
authCtx := dbauthz.AsChatd(ctx)
|
||||
|
||||
config, err := db.GetChatUsageLimitConfig(authCtx)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if !config.Enabled {
|
||||
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
|
||||
}
|
||||
|
||||
period, ok := mapDBPeriodToSDK(config.Period)
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("invalid chat usage limit period %q", config.Period)
|
||||
}
|
||||
|
||||
// Resolve effective limit in a single query:
|
||||
// individual override > group limit > global default.
|
||||
effectiveLimit, err := db.ResolveUserChatSpendLimit(authCtx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// -1 means limits are disabled (shouldn't happen since we checked above,
|
||||
// but handle gracefully).
|
||||
if effectiveLimit < 0 {
|
||||
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
|
||||
}
|
||||
|
||||
start, end := ComputeUsagePeriodBounds(now, period)
|
||||
|
||||
spendTotal, err := db.GetUserChatSpendInPeriod(authCtx, database.GetUserChatSpendInPeriodParams{
|
||||
UserID: userID,
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &codersdk.ChatUsageLimitStatus{
|
||||
IsLimited: true,
|
||||
Period: period,
|
||||
SpendLimitMicros: &effectiveLimit,
|
||||
CurrentSpend: spendTotal,
|
||||
PeriodStart: start,
|
||||
PeriodEnd: end,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func mapDBPeriodToSDK(dbPeriod string) (codersdk.ChatUsageLimitPeriod, bool) {
|
||||
switch dbPeriod {
|
||||
case string(codersdk.ChatUsageLimitPeriodDay):
|
||||
return codersdk.ChatUsageLimitPeriodDay, true
|
||||
case string(codersdk.ChatUsageLimitPeriodWeek):
|
||||
return codersdk.ChatUsageLimitPeriodWeek, true
|
||||
case string(codersdk.ChatUsageLimitPeriodMonth):
|
||||
return codersdk.ChatUsageLimitPeriodMonth, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
@@ -1,132 +0,0 @@
|
||||
package chatd //nolint:testpackage // Keeps chatd unit tests in the package.
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestComputeUsagePeriodBounds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
newYork, err := time.LoadLocation("America/New_York")
|
||||
if err != nil {
|
||||
t.Fatalf("load America/New_York: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
now time.Time
|
||||
period codersdk.ChatUsageLimitPeriod
|
||||
wantStart time.Time
|
||||
wantEnd time.Time
|
||||
}{
|
||||
{
|
||||
name: "day/mid_day",
|
||||
now: time.Date(2025, time.June, 15, 14, 30, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodDay,
|
||||
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "day/midnight_exactly",
|
||||
now: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodDay,
|
||||
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "day/end_of_day",
|
||||
now: time.Date(2025, time.June, 15, 23, 59, 59, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodDay,
|
||||
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "week/wednesday",
|
||||
now: time.Date(2025, time.June, 11, 10, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodWeek,
|
||||
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "week/monday",
|
||||
now: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodWeek,
|
||||
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "week/sunday",
|
||||
now: time.Date(2025, time.June, 15, 23, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodWeek,
|
||||
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "week/year_boundary",
|
||||
now: time.Date(2024, time.December, 31, 12, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodWeek,
|
||||
wantStart: time.Date(2024, time.December, 30, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.January, 6, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "month/mid_month",
|
||||
now: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodMonth,
|
||||
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "month/first_day",
|
||||
now: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodMonth,
|
||||
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "month/last_day",
|
||||
now: time.Date(2025, time.June, 30, 23, 59, 59, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodMonth,
|
||||
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "month/february",
|
||||
now: time.Date(2025, time.February, 15, 12, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodMonth,
|
||||
wantStart: time.Date(2025, time.February, 1, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.March, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "month/leap_year_february",
|
||||
now: time.Date(2024, time.February, 29, 12, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodMonth,
|
||||
wantStart: time.Date(2024, time.February, 1, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2024, time.March, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "day/non_utc_timezone",
|
||||
now: time.Date(2025, time.June, 15, 22, 0, 0, 0, newYork),
|
||||
period: codersdk.ChatUsageLimitPeriodDay,
|
||||
wantStart: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 17, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
start, end := ComputeUsagePeriodBounds(tc.now, tc.period)
|
||||
if !start.Equal(tc.wantStart) {
|
||||
t.Errorf("start: got %v, want %v", start, tc.wantStart)
|
||||
}
|
||||
if !end.Equal(tc.wantEnd) {
|
||||
t.Errorf("end: got %v, want %v", end, tc.wantEnd)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+60
-626
@@ -82,30 +82,6 @@ type chatDiffReference struct {
|
||||
RepositoryRef *chatRepositoryRef
|
||||
}
|
||||
|
||||
func writeChatUsageLimitExceeded(
|
||||
ctx context.Context,
|
||||
rw http.ResponseWriter,
|
||||
limitErr *chatd.UsageLimitExceededError,
|
||||
) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.ChatUsageLimitExceededResponse{
|
||||
Response: codersdk.Response{
|
||||
Message: "Chat usage limit exceeded.",
|
||||
},
|
||||
SpentMicros: limitErr.ConsumedMicros,
|
||||
LimitMicros: limitErr.LimitMicros,
|
||||
ResetsAt: limitErr.PeriodEnd,
|
||||
})
|
||||
}
|
||||
|
||||
func maybeWriteLimitErr(ctx context.Context, rw http.ResponseWriter, err error) bool {
|
||||
var limitErr *chatd.UsageLimitExceededError
|
||||
if errors.As(err, &limitErr) {
|
||||
writeChatUsageLimitExceeded(ctx, rw, limitErr)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
@@ -189,7 +165,7 @@ func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
params := database.GetChatsParams{
|
||||
params := database.GetChatsByOwnerIDParams{
|
||||
OwnerID: apiKey.UserID,
|
||||
Archived: searchParams.Archived,
|
||||
AfterID: paginationParams.AfterID,
|
||||
@@ -199,7 +175,7 @@ func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
|
||||
LimitOpt: int32(paginationParams.Limit),
|
||||
}
|
||||
|
||||
chats, err := api.Database.GetChats(ctx, params)
|
||||
chats, err := api.Database.GetChatsByOwnerID(ctx, params)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to list chats.",
|
||||
@@ -292,9 +268,6 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
InitialUserContent: contentBlocks,
|
||||
})
|
||||
if err != nil {
|
||||
if maybeWriteLimitErr(ctx, rw, err) {
|
||||
return
|
||||
}
|
||||
if database.IsForeignKeyViolation(
|
||||
err,
|
||||
database.ForeignKeyChatsLastModelConfigID,
|
||||
@@ -458,12 +431,7 @@ func (api *API) chatCostSummary(rw http.ResponseWriter, r *http.Request) {
|
||||
chatBreakdowns = append(chatBreakdowns, convertChatCostChatBreakdown(chat))
|
||||
}
|
||||
|
||||
usageStatus, err := chatd.ResolveUsageLimitStatus(ctx, api.Database, targetUser.ID, time.Now())
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to resolve usage limit status", slog.Error(err))
|
||||
}
|
||||
|
||||
response := codersdk.ChatCostSummary{
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatCostSummary{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
TotalCostMicros: summary.TotalCostMicros,
|
||||
@@ -475,12 +443,7 @@ func (api *API) chatCostSummary(rw http.ResponseWriter, r *http.Request) {
|
||||
TotalCacheCreationTokens: summary.TotalCacheCreationTokens,
|
||||
ByModel: modelBreakdowns,
|
||||
ByChat: chatBreakdowns,
|
||||
}
|
||||
if usageStatus != nil {
|
||||
response.UsageLimit = usageStatus
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, response)
|
||||
})
|
||||
}
|
||||
|
||||
func (api *API) chatCostUsers(rw http.ResponseWriter, r *http.Request) {
|
||||
@@ -584,445 +547,6 @@ func (api *API) chatCostUsers(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Get chat usage limit config
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
func (api *API) getChatUsageLimitConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
config, configErr := api.Database.GetChatUsageLimitConfig(ctx)
|
||||
if configErr != nil && !errors.Is(configErr, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get chat usage limit config.",
|
||||
Detail: configErr.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
overrideRows, err := api.Database.ListChatUsageLimitOverrides(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to list chat usage limit overrides.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
groupOverrides, err := api.Database.ListChatUsageLimitGroupOverrides(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to list group usage limit overrides.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
unpricedModelCount, err := api.Database.CountEnabledModelsWithoutPricing(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to count unpriced chat models.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response := codersdk.ChatUsageLimitConfigResponse{
|
||||
ChatUsageLimitConfig: codersdk.ChatUsageLimitConfig{},
|
||||
UnpricedModelCount: unpricedModelCount,
|
||||
Overrides: make([]codersdk.ChatUsageLimitOverride, 0, len(overrideRows)),
|
||||
GroupOverrides: make([]codersdk.ChatUsageLimitGroupOverride, 0, len(groupOverrides)),
|
||||
}
|
||||
if configErr == nil {
|
||||
response.Period = codersdk.ChatUsageLimitPeriod(config.Period)
|
||||
response.UpdatedAt = config.UpdatedAt
|
||||
if config.Enabled {
|
||||
response.SpendLimitMicros = ptr.Ref(config.DefaultLimitMicros)
|
||||
}
|
||||
}
|
||||
|
||||
for _, row := range overrideRows {
|
||||
response.Overrides = append(response.Overrides, codersdk.ChatUsageLimitOverride{
|
||||
UserID: row.UserID,
|
||||
Username: row.Username,
|
||||
Name: row.Name,
|
||||
AvatarURL: row.AvatarURL,
|
||||
SpendLimitMicros: nullInt64Ptr(row.SpendLimitMicros),
|
||||
})
|
||||
}
|
||||
|
||||
for _, glo := range groupOverrides {
|
||||
response.GroupOverrides = append(response.GroupOverrides, codersdk.ChatUsageLimitGroupOverride{
|
||||
GroupID: glo.GroupID,
|
||||
GroupName: glo.GroupName,
|
||||
GroupDisplayName: glo.GroupDisplayName,
|
||||
GroupAvatarURL: glo.GroupAvatarUrl,
|
||||
MemberCount: glo.MemberCount,
|
||||
SpendLimitMicros: nullInt64Ptr(glo.SpendLimitMicros),
|
||||
})
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// @Summary Update chat usage limit config
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) updateChatUsageLimitConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.ChatUsageLimitConfig
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
params := database.UpsertChatUsageLimitConfigParams{
|
||||
Enabled: false,
|
||||
DefaultLimitMicros: 0,
|
||||
Period: "",
|
||||
}
|
||||
if req.SpendLimitMicros == nil {
|
||||
if req.Period != "" && !req.Period.Valid() {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid chat usage limit period.",
|
||||
Detail: "Period must be one of: day, week, month.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
params.Enabled = false
|
||||
params.DefaultLimitMicros = 0
|
||||
params.Period = string(req.Period)
|
||||
if params.Period == "" {
|
||||
params.Period = string(codersdk.ChatUsageLimitPeriodMonth)
|
||||
}
|
||||
} else {
|
||||
if *req.SpendLimitMicros <= 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid chat usage limit spend limit.",
|
||||
Detail: "Spend limit must be greater than 0.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if !req.Period.Valid() {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid chat usage limit period.",
|
||||
Detail: "Period must be one of: day, week, month.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
params.Enabled = true
|
||||
params.DefaultLimitMicros = *req.SpendLimitMicros
|
||||
params.Period = string(req.Period)
|
||||
}
|
||||
|
||||
config, err := api.Database.UpsertChatUsageLimitConfig(ctx, params)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to update chat usage limit config.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response := codersdk.ChatUsageLimitConfig{
|
||||
Period: codersdk.ChatUsageLimitPeriod(config.Period),
|
||||
UpdatedAt: config.UpdatedAt,
|
||||
}
|
||||
if config.Enabled {
|
||||
response.SpendLimitMicros = ptr.Ref(config.DefaultLimitMicros)
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// @Summary Get my chat usage limit status
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
// getMyChatUsageLimitStatus returns the current usage-limit status for the
|
||||
// authenticated user. No additional RBAC check is required because the
|
||||
// endpoint always operates on the requesting user's own data via
|
||||
// httpmw.APIKey(r).UserID.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
func (api *API) getMyChatUsageLimitStatus(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
status, err := chatd.ResolveUsageLimitStatus(ctx, api.Database, httpmw.APIKey(r).UserID, time.Now())
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get chat usage limit status.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if status == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatUsageLimitStatus{IsLimited: false})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, status)
|
||||
}
|
||||
|
||||
// @Summary Upsert chat usage limit override
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) upsertChatUsageLimitOverride(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
userID, ok := parseChatUsageLimitUserID(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.UpsertChatUsageLimitOverrideRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
if req.SpendLimitMicros <= 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid chat usage limit override.",
|
||||
Detail: "Spend limit must be greater than 0.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := api.Database.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "User not found.",
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to look up chat usage limit user.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
_, err = api.Database.UpsertChatUsageLimitUserOverride(ctx, database.UpsertChatUsageLimitUserOverrideParams{
|
||||
UserID: userID,
|
||||
SpendLimitMicros: req.SpendLimitMicros,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to upsert chat usage limit override.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatUsageLimitOverride{
|
||||
UserID: user.ID,
|
||||
Username: user.Username,
|
||||
Name: user.Name,
|
||||
AvatarURL: user.AvatarURL,
|
||||
SpendLimitMicros: nullInt64Ptr(sql.NullInt64{Int64: req.SpendLimitMicros, Valid: true}),
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Delete chat usage limit override
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) deleteChatUsageLimitOverride(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
userID, ok := parseChatUsageLimitUserID(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := api.Database.GetUserByID(ctx, userID); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
writeChatUsageLimitUserNotFound(ctx, rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to look up chat usage limit user.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if _, err := api.Database.GetChatUsageLimitUserOverride(ctx, userID); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
writeChatUsageLimitOverrideNotFound(ctx, rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to look up chat usage limit override.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := api.Database.DeleteChatUsageLimitUserOverride(ctx, userID); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to delete chat usage limit override.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// @Summary Upsert chat usage limit group override
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) upsertChatUsageLimitGroupOverride(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
groupIDStr := chi.URLParam(r, "group")
|
||||
groupID, err := uuid.Parse(groupIDStr)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid group ID.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.UpdateChatUsageLimitGroupOverrideRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.SpendLimitMicros <= 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid chat usage limit group override.",
|
||||
Detail: "Spend limit (in microdollars) must be greater than 0.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
group, err := api.Database.GetGroupByID(ctx, groupID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Group not found.",
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to look up group details.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
_, err = api.Database.UpsertChatUsageLimitGroupOverride(ctx, database.UpsertChatUsageLimitGroupOverrideParams{
|
||||
GroupID: groupID,
|
||||
SpendLimitMicros: req.SpendLimitMicros,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to upsert group usage limit override.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
memberCount, err := api.Database.GetGroupMembersCountByGroupID(ctx, database.GetGroupMembersCountByGroupIDParams{
|
||||
GroupID: groupID,
|
||||
IncludeSystem: false,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
writeChatUsageLimitGroupNotFound(ctx, rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to fetch group member count.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatUsageLimitGroupOverride{
|
||||
GroupID: group.ID,
|
||||
GroupName: group.Name,
|
||||
GroupDisplayName: group.DisplayName,
|
||||
GroupAvatarURL: group.AvatarURL,
|
||||
MemberCount: memberCount,
|
||||
SpendLimitMicros: nullInt64Ptr(sql.NullInt64{Int64: req.SpendLimitMicros, Valid: true}),
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Delete chat usage limit group override
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) deleteChatUsageLimitGroupOverride(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
groupIDStr := chi.URLParam(r, "group")
|
||||
groupID, err := uuid.Parse(groupIDStr)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid group ID.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := api.Database.GetGroupByID(ctx, groupID); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
writeChatUsageLimitGroupNotFound(ctx, rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to look up group details.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if _, err := api.Database.GetChatUsageLimitGroupOverride(ctx, groupID); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
writeChatUsageLimitGroupOverrideNotFound(ctx, rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to look up group usage limit override.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := api.Database.DeleteChatUsageLimitGroupOverride(ctx, groupID); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to delete group usage limit override.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
@@ -1369,58 +893,64 @@ func (api *API) watchChatDesktop(rw http.ResponseWriter, r *http.Request) {
|
||||
logger.Debug(ctx, "desktop Bicopy finished")
|
||||
}
|
||||
|
||||
// patchChat updates a chat resource. Currently supports toggling the
|
||||
// archived state via the Archived field.
|
||||
func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) archiveChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
|
||||
var req codersdk.UpdateChatRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
if chat.Archived {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Chat is already archived.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Archived != nil {
|
||||
archived := *req.Archived
|
||||
if archived == chat.Archived {
|
||||
state := "archived"
|
||||
if !archived {
|
||||
state = "not archived"
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Chat is already %s.", state),
|
||||
})
|
||||
return
|
||||
}
|
||||
var err error
|
||||
// Use chatDaemon when available so it can notify
|
||||
// active subscribers. Fall back to direct DB for the
|
||||
// simple archive flag — no streaming state is involved.
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.ArchiveChat(ctx, chat.ID)
|
||||
} else {
|
||||
err = api.Database.ArchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to archive chat.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
// Use chatDaemon when available so it can notify active
|
||||
// subscribers. Fall back to direct DB for the simple
|
||||
// archive flag — no streaming state is involved.
|
||||
if archived {
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.ArchiveChat(ctx, chat)
|
||||
} else {
|
||||
err = api.Database.ArchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
} else {
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.UnarchiveChat(ctx, chat)
|
||||
} else {
|
||||
err = api.Database.UnarchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
action := "archive"
|
||||
if !archived {
|
||||
action = "unarchive"
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: fmt.Sprintf("Failed to %s chat.", action),
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (api *API) unarchiveChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
|
||||
if !chat.Archived {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Chat is not archived.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
// Use chatDaemon when available so it can notify
|
||||
// active subscribers. Fall back to direct DB for the
|
||||
// simple unarchive flag — no streaming state is involved.
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.UnarchiveChat(ctx, chat.ID)
|
||||
} else {
|
||||
err = api.Database.UnarchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to unarchive chat.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
@@ -1466,9 +996,6 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
},
|
||||
)
|
||||
if sendErr != nil {
|
||||
if maybeWriteLimitErr(ctx, rw, sendErr) {
|
||||
return
|
||||
}
|
||||
if xerrors.Is(sendErr, chatd.ErrMessageQueueFull) {
|
||||
httpapi.Write(ctx, rw, http.StatusTooManyRequests, codersdk.Response{
|
||||
Message: "Message queue is full.",
|
||||
@@ -1541,10 +1068,6 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) {
|
||||
Content: contentBlocks,
|
||||
})
|
||||
if editErr != nil {
|
||||
if maybeWriteLimitErr(ctx, rw, editErr) {
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case xerrors.Is(editErr, chatd.ErrEditedMessageNotFound):
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
@@ -1635,9 +1158,6 @@ func (api *API) promoteChatQueuedMessage(rw http.ResponseWriter, r *http.Request
|
||||
})
|
||||
|
||||
if txErr != nil {
|
||||
if maybeWriteLimitErr(ctx, rw, txErr) {
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to promote queued message.",
|
||||
Detail: txErr.Error(),
|
||||
@@ -2519,14 +2039,14 @@ func (api *API) getChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatSystemPrompt{
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatSystemPromptResponse{
|
||||
SystemPrompt: prompt,
|
||||
})
|
||||
}
|
||||
|
||||
func (api *API) putChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
var req codersdk.ChatSystemPrompt
|
||||
var req codersdk.UpdateChatSystemPromptRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
@@ -2554,49 +2074,6 @@ func (api *API) putChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
func (api *API) getChatDesktopEnabled(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
enabled, err := api.Database.GetChatDesktopEnabled(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching desktop setting.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatDesktopEnabledResponse{
|
||||
EnableDesktop: enabled,
|
||||
})
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) putChatDesktopEnabled(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.UpdateChatDesktopEnabledRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
if err := api.Database.UpsertChatDesktopEnabled(ctx, req.EnableDesktop); httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
} else if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error updating desktop setting.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
@@ -2619,7 +2096,7 @@ func (api *API) getUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request)
|
||||
customPrompt = ""
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPrompt{
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPromptResponse{
|
||||
CustomPrompt: customPrompt,
|
||||
})
|
||||
}
|
||||
@@ -2631,7 +2108,7 @@ func (api *API) putUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request)
|
||||
apiKey = httpmw.APIKey(r)
|
||||
)
|
||||
|
||||
var params codersdk.UserChatCustomPrompt
|
||||
var params codersdk.UpdateUserChatCustomPromptRequest
|
||||
if !httpapi.Read(ctx, rw, r, ¶ms) {
|
||||
return
|
||||
}
|
||||
@@ -2658,7 +2135,7 @@ func (api *API) putUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPrompt{
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPromptResponse{
|
||||
CustomPrompt: updatedConfig.Value,
|
||||
})
|
||||
}
|
||||
@@ -3878,49 +3355,6 @@ func chatModelConfigToUpdateParams(
|
||||
}
|
||||
}
|
||||
|
||||
func nullInt64Ptr(n sql.NullInt64) *int64 {
|
||||
if !n.Valid {
|
||||
return nil
|
||||
}
|
||||
return &n.Int64
|
||||
}
|
||||
|
||||
func writeChatUsageLimitUserNotFound(ctx context.Context, rw http.ResponseWriter) {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "User not found.",
|
||||
})
|
||||
}
|
||||
|
||||
func writeChatUsageLimitOverrideNotFound(ctx context.Context, rw http.ResponseWriter) {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Chat usage limit override not found.",
|
||||
})
|
||||
}
|
||||
|
||||
func writeChatUsageLimitGroupOverrideNotFound(ctx context.Context, rw http.ResponseWriter) {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Chat usage limit group override not found.",
|
||||
})
|
||||
}
|
||||
|
||||
func writeChatUsageLimitGroupNotFound(ctx context.Context, rw http.ResponseWriter) {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Group not found.",
|
||||
})
|
||||
}
|
||||
|
||||
func parseChatUsageLimitUserID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
|
||||
userID, err := uuid.Parse(chi.URLParam(r, "user"))
|
||||
if err != nil {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid chat usage limit user ID.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return uuid.Nil, false
|
||||
}
|
||||
return userID, true
|
||||
}
|
||||
|
||||
func parseChatProviderID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
|
||||
providerID, err := uuid.Parse(chi.URLParam(r, "providerConfig"))
|
||||
if err != nil {
|
||||
|
||||
+12
-502
@@ -2,7 +2,6 @@ package coderd_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -17,19 +16,15 @@ import (
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/websocket"
|
||||
@@ -60,93 +55,6 @@ func newChatClientWithDatabase(t testing.TB) (*codersdk.Client, database.Store)
|
||||
})
|
||||
}
|
||||
|
||||
func requireChatUsageLimitExceededError(
|
||||
t *testing.T,
|
||||
err error,
|
||||
wantSpentMicros int64,
|
||||
wantLimitMicros int64,
|
||||
wantResetsAt time.Time,
|
||||
) *codersdk.ChatUsageLimitExceededResponse {
|
||||
t.Helper()
|
||||
|
||||
sdkErr, ok := codersdk.AsError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
require.Equal(t, "Chat usage limit exceeded.", sdkErr.Message)
|
||||
|
||||
limitErr := codersdk.ChatUsageLimitExceededFrom(err)
|
||||
require.NotNil(t, limitErr)
|
||||
require.Equal(t, "Chat usage limit exceeded.", limitErr.Message)
|
||||
require.Equal(t, wantSpentMicros, limitErr.SpentMicros)
|
||||
require.Equal(t, wantLimitMicros, limitErr.LimitMicros)
|
||||
require.True(
|
||||
t,
|
||||
limitErr.ResetsAt.Equal(wantResetsAt),
|
||||
"expected resets_at %s, got %s",
|
||||
wantResetsAt.UTC().Format(time.RFC3339),
|
||||
limitErr.ResetsAt.UTC().Format(time.RFC3339),
|
||||
)
|
||||
|
||||
return limitErr
|
||||
}
|
||||
|
||||
func enableDailyChatUsageLimit(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
limitMicros int64,
|
||||
) time.Time {
|
||||
t.Helper()
|
||||
|
||||
_, err := db.UpsertChatUsageLimitConfig(
|
||||
dbauthz.AsSystemRestricted(ctx),
|
||||
database.UpsertChatUsageLimitConfigParams{
|
||||
Enabled: true,
|
||||
DefaultLimitMicros: limitMicros,
|
||||
Period: string(codersdk.ChatUsageLimitPeriodDay),
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, periodEnd := chatd.ComputeUsagePeriodBounds(time.Now(), codersdk.ChatUsageLimitPeriodDay)
|
||||
return periodEnd
|
||||
}
|
||||
|
||||
func insertAssistantCostMessage(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
modelConfigID uuid.UUID,
|
||||
totalCostMicros int64,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("assistant"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
|
||||
ChatID: chatID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
Content: assistantContent,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{},
|
||||
OutputTokens: sql.NullInt64{},
|
||||
TotalTokens: sql.NullInt64{},
|
||||
ReasoningTokens: sql.NullInt64{},
|
||||
CacheCreationTokens: sql.NullInt64{},
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
TotalCostMicros: sql.NullInt64{Int64: totalCostMicros, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPostChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -417,33 +325,6 @@ func TestPostChats(t *testing.T) {
|
||||
require.Equal(t, "Invalid input part.", sdkErr.Message)
|
||||
require.Equal(t, `content[0].type "image" is not supported.`, sdkErr.Detail)
|
||||
})
|
||||
|
||||
t.Run("UsageLimitExceeded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100)
|
||||
|
||||
existingChat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "existing-limit-chat",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
insertAssistantCostMessage(ctx, t, db, existingChat.ID, modelConfig.ID, 100)
|
||||
|
||||
_, err = client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "over limit",
|
||||
}},
|
||||
})
|
||||
requireChatUsageLimitExceededError(t, err, 100, 100, wantResetsAt)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListChats(t *testing.T) {
|
||||
@@ -1688,7 +1569,7 @@ func TestArchiveChat(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, chatsBeforeArchive, 2)
|
||||
|
||||
err = client.UpdateChat(ctx, chatToArchive.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
err = client.ArchiveChat(ctx, chatToArchive.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Default (no filter) returns only non-archived chats.
|
||||
@@ -1722,7 +1603,7 @@ func TestArchiveChat(t *testing.T) {
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
err := client.UpdateChat(ctx, uuid.New(), codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
err := client.ArchiveChat(ctx, uuid.New())
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
|
||||
@@ -1765,7 +1646,7 @@ func TestArchiveChat(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive the parent via the API.
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
err = client.ArchiveChat(ctx, parentChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// archived:false should exclude the entire archived family.
|
||||
@@ -1812,7 +1693,7 @@ func TestUnarchiveChat(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive the chat first.
|
||||
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
err = client.ArchiveChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's archived.
|
||||
@@ -1823,7 +1704,7 @@ func TestUnarchiveChat(t *testing.T) {
|
||||
require.Len(t, archivedChats, 1)
|
||||
require.True(t, archivedChats[0].Archived)
|
||||
// Unarchive the chat.
|
||||
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
|
||||
err = client.UnarchiveChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's no longer archived.
|
||||
@@ -1862,9 +1743,10 @@ func TestUnarchiveChat(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Trying to unarchive a non-archived chat should fail.
|
||||
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
|
||||
err = client.UnarchiveChat(ctx, chat.ID)
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1872,7 +1754,7 @@ func TestUnarchiveChat(t *testing.T) {
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
err := client.UpdateChat(ctx, uuid.New(), codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
|
||||
err := client.UnarchiveChat(ctx, uuid.New())
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
}
|
||||
@@ -2001,34 +1883,6 @@ func TestPostChatMessages(t *testing.T) {
|
||||
require.Equal(t, "content[0].text cannot be empty.", sdkErr.Detail)
|
||||
})
|
||||
|
||||
t.Run("UsageLimitExceeded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "initial message for usage-limit test",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100)
|
||||
insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100)
|
||||
|
||||
_, err = client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "over limit",
|
||||
}},
|
||||
})
|
||||
requireChatUsageLimitExceededError(t, err, 100, 100, wantResetsAt)
|
||||
})
|
||||
|
||||
t.Run("ChatNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -2789,46 +2643,6 @@ func TestPatchChatMessage(t *testing.T) {
|
||||
require.True(t, foundFileInChat, "chat should preserve file_id after edit")
|
||||
})
|
||||
|
||||
t.Run("UsageLimitExceeded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello before edit",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
var userMessageID int64
|
||||
for _, message := range messagesResult.Messages {
|
||||
if message.Role == codersdk.ChatMessageRoleUser {
|
||||
userMessageID = message.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotZero(t, userMessageID)
|
||||
|
||||
wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100)
|
||||
insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100)
|
||||
|
||||
_, err = client.EditChatMessage(ctx, chat.ID, userMessageID, codersdk.EditChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "edited over limit",
|
||||
}},
|
||||
})
|
||||
requireChatUsageLimitExceededError(t, err, 100, 100, wantResetsAt)
|
||||
})
|
||||
|
||||
t.Run("MessageNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -3538,81 +3352,6 @@ func TestPromoteChatQueuedMessage(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PromotesAlreadyQueuedMessageAfterLimitReached", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
enableDailyChatUsageLimit(ctx, t, db, 100)
|
||||
|
||||
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "promote queued usage limit",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
const queuedText = "queued message for promote route"
|
||||
queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText(queuedText),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
queuedMessage, err := db.InsertChatQueuedMessage(
|
||||
dbauthz.AsSystemRestricted(ctx),
|
||||
database.InsertChatQueuedMessageParams{
|
||||
ChatID: chat.ID,
|
||||
Content: queuedContent,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100)
|
||||
|
||||
_, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusWaiting,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
promoteRes, err := client.Request(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
fmt.Sprintf("/api/experimental/chats/%s/queue/%d/promote", chat.ID, queuedMessage.ID),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer promoteRes.Body.Close()
|
||||
require.Equal(t, http.StatusOK, promoteRes.StatusCode)
|
||||
|
||||
var promoted codersdk.ChatMessage
|
||||
err = json.NewDecoder(promoteRes.Body).Decode(&promoted)
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, promoted.ID)
|
||||
require.Equal(t, chat.ID, promoted.ChatID)
|
||||
require.Equal(t, codersdk.ChatMessageRoleUser, promoted.Role)
|
||||
|
||||
foundPromotedText := false
|
||||
for _, part := range promoted.Content {
|
||||
if part.Type == codersdk.ChatMessagePartTypeText && part.Text == queuedText {
|
||||
foundPromotedText = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundPromotedText)
|
||||
|
||||
queuedMessages, err := db.GetChatQueuedMessages(dbauthz.AsSystemRestricted(ctx), chat.ID)
|
||||
require.NoError(t, err)
|
||||
for _, queued := range queuedMessages {
|
||||
require.NotEqual(t, queuedMessage.ID, queued.ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InvalidQueuedMessageID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -3644,133 +3383,6 @@ func TestPromoteChatQueuedMessage(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatUsageLimitOverrideRoutes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("UpsertUserOverrideRequiresPositiveSpendLimit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, _ := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
_, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
|
||||
|
||||
res, err := client.Request(
|
||||
ctx,
|
||||
http.MethodPut,
|
||||
fmt.Sprintf("/api/experimental/chats/usage-limits/overrides/%s", member.ID),
|
||||
map[string]any{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
err = codersdk.ReadBodyAsError(res)
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid chat usage limit override.", sdkErr.Message)
|
||||
require.Equal(t, "Spend limit must be greater than 0.", sdkErr.Detail)
|
||||
})
|
||||
|
||||
t.Run("UpsertUserOverrideMissingUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_, err := client.UpsertChatUsageLimitOverride(ctx, uuid.New(), codersdk.UpsertChatUsageLimitOverrideRequest{
|
||||
SpendLimitMicros: 7_000_000,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusNotFound)
|
||||
require.Equal(t, "User not found.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("DeleteUserOverrideMissingUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
err := client.DeleteChatUsageLimitOverride(ctx, uuid.New())
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "User not found.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("DeleteUserOverrideMissingOverride", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
_, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
|
||||
|
||||
err := client.DeleteChatUsageLimitOverride(ctx, member.ID)
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Chat usage limit override not found.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("UpsertGroupOverrideIncludesMemberCount", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
_, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
|
||||
group := dbgen.Group(t, db, database.Group{OrganizationID: firstUser.OrganizationID})
|
||||
dbgen.GroupMember(t, db, database.GroupMemberTable{GroupID: group.ID, UserID: member.ID})
|
||||
dbgen.GroupMember(t, db, database.GroupMemberTable{GroupID: group.ID, UserID: database.PrebuildsSystemUserID})
|
||||
|
||||
override, err := client.UpsertChatUsageLimitGroupOverride(ctx, group.ID, codersdk.UpsertChatUsageLimitGroupOverrideRequest{
|
||||
SpendLimitMicros: 7_000_000,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, group.ID, override.GroupID)
|
||||
require.EqualValues(t, 1, override.MemberCount)
|
||||
require.NotNil(t, override.SpendLimitMicros)
|
||||
require.EqualValues(t, 7_000_000, *override.SpendLimitMicros)
|
||||
|
||||
config, err := client.GetChatUsageLimitConfig(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
var listed *codersdk.ChatUsageLimitGroupOverride
|
||||
for i := range config.GroupOverrides {
|
||||
if config.GroupOverrides[i].GroupID == group.ID {
|
||||
listed = &config.GroupOverrides[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, listed)
|
||||
require.EqualValues(t, 1, listed.MemberCount)
|
||||
})
|
||||
|
||||
t.Run("UpsertGroupOverrideMissingGroup", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_, err := client.UpsertChatUsageLimitGroupOverride(ctx, uuid.New(), codersdk.UpsertChatUsageLimitGroupOverrideRequest{
|
||||
SpendLimitMicros: 7_000_000,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusNotFound)
|
||||
require.Equal(t, "Group not found.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("DeleteGroupOverrideMissingOverride", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
group := dbgen.Group(t, db, database.Group{OrganizationID: firstUser.OrganizationID})
|
||||
|
||||
err := client.DeleteChatUsageLimitGroupOverride(ctx, group.ID)
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Chat usage limit group override not found.", sdkErr.Message)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostChatFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -4512,7 +4124,7 @@ func TestChatSystemPrompt(t *testing.T) {
|
||||
t.Run("AdminCanSet", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
|
||||
SystemPrompt: "You are a helpful coding assistant.",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -4526,7 +4138,7 @@ func TestChatSystemPrompt(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Unset by sending an empty string.
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
|
||||
SystemPrompt: "",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -4539,7 +4151,7 @@ func TestChatSystemPrompt(t *testing.T) {
|
||||
t.Run("NonAdminFails", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
err := memberClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
|
||||
err := memberClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
|
||||
SystemPrompt: "This should fail.",
|
||||
})
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
@@ -4560,7 +4172,7 @@ func TestChatSystemPrompt(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
tooLong := strings.Repeat("a", 131073)
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
|
||||
SystemPrompt: tooLong,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
@@ -4568,108 +4180,6 @@ func TestChatSystemPrompt(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatDesktopEnabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ReturnsFalseWhenUnset", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, adminClient)
|
||||
|
||||
resp, err := adminClient.GetChatDesktopEnabled(ctx)
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.EnableDesktop)
|
||||
})
|
||||
|
||||
t.Run("AdminCanSetTrue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, adminClient)
|
||||
|
||||
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
|
||||
EnableDesktop: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := adminClient.GetChatDesktopEnabled(ctx)
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.EnableDesktop)
|
||||
})
|
||||
|
||||
t.Run("AdminCanSetFalse", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, adminClient)
|
||||
|
||||
// Set true first, then set false.
|
||||
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
|
||||
EnableDesktop: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
|
||||
EnableDesktop: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := adminClient.GetChatDesktopEnabled(ctx)
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.EnableDesktop)
|
||||
})
|
||||
|
||||
t.Run("NonAdminCanRead", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient)
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
|
||||
|
||||
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
|
||||
EnableDesktop: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := memberClient.GetChatDesktopEnabled(ctx)
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.EnableDesktop)
|
||||
})
|
||||
|
||||
t.Run("NonAdminWriteFails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient)
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
|
||||
|
||||
err := memberClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
|
||||
EnableDesktop: true,
|
||||
})
|
||||
requireSDKError(t, err, http.StatusForbidden)
|
||||
})
|
||||
|
||||
t.Run("UnauthenticatedFails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, adminClient)
|
||||
|
||||
anonClient := codersdk.New(adminClient.URL)
|
||||
_, err := anonClient.GetChatDesktopEnabled(ctx)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error {
|
||||
t.Helper()
|
||||
|
||||
|
||||
+2
-16
@@ -1157,8 +1157,6 @@ func New(options *Options) *API {
|
||||
r.Route("/config", func(r chi.Router) {
|
||||
r.Get("/system-prompt", api.getChatSystemPrompt)
|
||||
r.Put("/system-prompt", api.putChatSystemPrompt)
|
||||
r.Get("/desktop-enabled", api.getChatDesktopEnabled)
|
||||
r.Put("/desktop-enabled", api.putChatDesktopEnabled)
|
||||
r.Get("/user-prompt", api.getUserChatCustomPrompt)
|
||||
r.Put("/user-prompt", api.putUserChatCustomPrompt)
|
||||
})
|
||||
@@ -1180,25 +1178,13 @@ func New(options *Options) *API {
|
||||
r.Delete("/", api.deleteChatModelConfig)
|
||||
})
|
||||
})
|
||||
r.Route("/usage-limits", func(r chi.Router) {
|
||||
r.Get("/", api.getChatUsageLimitConfig)
|
||||
r.Put("/", api.updateChatUsageLimitConfig)
|
||||
r.Get("/status", api.getMyChatUsageLimitStatus)
|
||||
r.Route("/overrides/{user}", func(r chi.Router) {
|
||||
r.Put("/", api.upsertChatUsageLimitOverride)
|
||||
r.Delete("/", api.deleteChatUsageLimitOverride)
|
||||
})
|
||||
r.Route("/group-overrides/{group}", func(r chi.Router) {
|
||||
r.Put("/", api.upsertChatUsageLimitGroupOverride)
|
||||
r.Delete("/", api.deleteChatUsageLimitGroupOverride)
|
||||
})
|
||||
})
|
||||
r.Route("/{chat}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractChatParam(options.Database))
|
||||
r.Get("/", api.getChat)
|
||||
r.Get("/git/watch", api.watchChatGit)
|
||||
r.Get("/desktop", api.watchChatDesktop)
|
||||
r.Patch("/", api.patchChat)
|
||||
r.Post("/archive", api.archiveChat)
|
||||
r.Post("/unarchive", api.unarchiveChat)
|
||||
r.Get("/messages", api.getChatMessages)
|
||||
r.Post("/messages", api.postChatMessages)
|
||||
r.Patch("/messages/{message}", api.patchChatMessage)
|
||||
|
||||
@@ -879,15 +879,6 @@ func createAnotherUserRetry(t testing.TB, client *codersdk.Client, organizationI
|
||||
m(&req)
|
||||
}
|
||||
|
||||
// Service accounts cannot have a password or email and must
|
||||
// use login_type=none. Enforce this after mutators so callers
|
||||
// only need to set ServiceAccount=true.
|
||||
if req.ServiceAccount {
|
||||
req.Password = ""
|
||||
req.Email = ""
|
||||
req.UserLoginType = codersdk.LoginTypeNone
|
||||
}
|
||||
|
||||
user, err := client.CreateUserWithOrgs(context.Background(), req)
|
||||
var apiError *codersdk.Error
|
||||
// If the user already exists by username or email conflict, try again up to "retries" times.
|
||||
|
||||
@@ -6,27 +6,22 @@ type CheckConstraint string
|
||||
|
||||
// CheckConstraint enums.
|
||||
const (
|
||||
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
|
||||
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
|
||||
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
|
||||
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
|
||||
CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config
|
||||
CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config
|
||||
CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config
|
||||
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
|
||||
CheckGroupsChatSpendLimitMicrosCheck CheckConstraint = "groups_chat_spend_limit_micros_check" // groups
|
||||
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
|
||||
CheckUsersChatSpendLimitMicrosCheck CheckConstraint = "users_chat_spend_limit_micros_check" // users
|
||||
CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users
|
||||
CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users
|
||||
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
|
||||
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
|
||||
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
|
||||
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
|
||||
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
|
||||
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
|
||||
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
|
||||
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
|
||||
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
|
||||
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
|
||||
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
|
||||
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
|
||||
CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users
|
||||
CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users
|
||||
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
|
||||
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
|
||||
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
|
||||
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
|
||||
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
|
||||
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
|
||||
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
)
|
||||
|
||||
@@ -195,14 +195,13 @@ func MinimalUserFromVisibleUser(user database.VisibleUser) codersdk.MinimalUser
|
||||
|
||||
func ReducedUser(user database.User) codersdk.ReducedUser {
|
||||
return codersdk.ReducedUser{
|
||||
MinimalUser: MinimalUser(user),
|
||||
Email: user.Email,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
LastSeenAt: user.LastSeenAt,
|
||||
Status: codersdk.UserStatus(user.Status),
|
||||
LoginType: codersdk.LoginType(user.LoginType),
|
||||
IsServiceAccount: user.IsServiceAccount,
|
||||
MinimalUser: MinimalUser(user),
|
||||
Email: user.Email,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
LastSeenAt: user.LastSeenAt,
|
||||
Status: codersdk.UserStatus(user.Status),
|
||||
LoginType: codersdk.LoginType(user.LoginType),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1264,7 +1264,7 @@ func (q *querier) canAssignRoles(ctx context.Context, orgID uuid.UUID, added, re
|
||||
// System roles are stored in the database but have a fixed, code-defined
|
||||
// meaning. Do not rewrite the name for them so the static "who can assign
|
||||
// what" mapping applies.
|
||||
if !rolestore.IsSystemRoleName(roleName.Name) {
|
||||
if !rbac.SystemRoleName(roleName.Name) {
|
||||
// To support a dynamic mapping of what roles can assign what, we need
|
||||
// to store this in the database. For now, just use a static role so
|
||||
// owners and org admins can assign roles.
|
||||
@@ -1726,13 +1726,6 @@ func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountCon
|
||||
return q.db.CountAuthorizedConnectionLogs(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.CountEnabledModelsWithoutPricing(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil {
|
||||
return nil, err
|
||||
@@ -1861,20 +1854,6 @@ func (q *querier) DeleteChatQueuedMessage(ctx context.Context, arg database.Dele
|
||||
return q.db.DeleteChatQueuedMessage(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatUsageLimitGroupOverride(ctx, groupID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatUsageLimitUserOverride(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceCryptoKey); err != nil {
|
||||
return database.CryptoKey{}, err
|
||||
@@ -2145,12 +2124,12 @@ func (q *querier) DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) erro
|
||||
return fetchAndExec(q.log, q.auth, policy.ActionShare, fetch, q.db.DeleteWorkspaceACLByID)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, params database.DeleteWorkspaceACLsByOrganizationParams) error {
|
||||
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
// This is a system-only function.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteWorkspaceACLsByOrganization(ctx, params)
|
||||
return q.db.DeleteWorkspaceACLsByOrganization(ctx, organizationID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWorkspaceAgentPortShare(ctx context.Context, arg database.DeleteWorkspaceAgentPortShareParams) error {
|
||||
@@ -2482,17 +2461,6 @@ func (q *querier) GetChatCostSummary(ctx context.Context, arg database.GetChatCo
|
||||
return q.db.GetChatCostSummary(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
// The desktop-enabled flag is a deployment-wide setting read by any
|
||||
// authenticated chat user and by chatd when deciding whether to expose
|
||||
// computer-use tooling. We only require that an explicit actor is present
|
||||
// in the context so unauthenticated calls fail closed.
|
||||
if _, ok := ActorFromContext(ctx); !ok {
|
||||
return false, ErrNoActor
|
||||
}
|
||||
return q.db.GetChatDesktopEnabled(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
@@ -2643,33 +2611,8 @@ func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
return q.db.GetChatSystemPrompt(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatUsageLimitConfig{}, err
|
||||
}
|
||||
return q.db.GetChatUsageLimitConfig(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.GetChatUsageLimitGroupOverrideRow{}, err
|
||||
}
|
||||
return q.db.GetChatUsageLimitGroupOverride(ctx, groupID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.GetChatUsageLimitUserOverrideRow{}, err
|
||||
}
|
||||
return q.db.GetChatUsageLimitUserOverride(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceChat.Type)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
}
|
||||
return q.db.GetAuthorizedChats(ctx, arg, prep)
|
||||
func (q *querier) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByOwnerID)(ctx, ownerID)
|
||||
}
|
||||
|
||||
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
@@ -3337,6 +3280,12 @@ func (q *querier) GetProvisionerJobTimingsByJobID(ctx context.Context, jobID uui
|
||||
return q.db.GetProvisionerJobTimingsByJobID(ctx, jobID)
|
||||
}
|
||||
|
||||
func (q *querier) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) {
|
||||
// TODO: Remove this once we have a proper rbac check for provisioner jobs.
|
||||
// Details in https://github.com/coder/coder/issues/16160
|
||||
return q.db.GetProvisionerJobsByIDs(ctx, ids)
|
||||
}
|
||||
|
||||
func (q *querier) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
|
||||
// TODO: Remove this once we have a proper rbac check for provisioner jobs.
|
||||
// Details in https://github.com/coder/coder/issues/16160
|
||||
@@ -3822,13 +3771,6 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID)
|
||||
return q.db.GetUserChatCustomPrompt(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.GetUserChatSpendInPeriod(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return 0, err
|
||||
@@ -3836,13 +3778,6 @@ func (q *querier) GetUserCount(ctx context.Context, includeSystem bool) (int64,
|
||||
return q.db.GetUserCount(ctx, includeSystem)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(userID.String())); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.GetUserGroupSpendLimit(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
|
||||
// Used by insights endpoints. Need to check both for auditors and for regular users with template acl perms.
|
||||
if err := q.authorizeContext(ctx, policy.ActionViewInsights, rbac.ResourceTemplate); err != nil {
|
||||
@@ -4512,13 +4447,6 @@ func (q *querier) InsertAIBridgeInterception(ctx context.Context, arg database.I
|
||||
return insert(q.log, q.auth, rbac.ResourceAibridgeInterception.WithOwner(arg.InitiatorID.String()), q.db.InsertAIBridgeInterception)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
|
||||
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
|
||||
return database.AIBridgeModelThought{}, err
|
||||
}
|
||||
return q.db.InsertAIBridgeModelThought(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
|
||||
// All aibridge_token_usages records belong to the initiator of their associated interception.
|
||||
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
|
||||
@@ -5208,20 +5136,6 @@ func (q *querier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context,
|
||||
return q.db.ListAIBridgeUserPromptsByInterceptionIDs(ctx, interceptionIDs)
|
||||
}
|
||||
|
||||
func (q *querier) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ListChatUsageLimitGroupOverrides(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ListChatUsageLimitOverrides(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListProvisionerKeysByOrganization)(ctx, organizationID)
|
||||
}
|
||||
@@ -5341,13 +5255,6 @@ func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveU
|
||||
return q.db.RemoveUserFromGroups(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(userID.String())); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.ResolveUserChatSpendLimit(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
@@ -6559,13 +6466,6 @@ func (q *querier) UpsertBoundaryUsageStats(ctx context.Context, arg database.Ups
|
||||
return q.db.UpsertBoundaryUsageStats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertChatDesktopEnabled(ctx, enableDesktop)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
// Authorize update on the parent chat.
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
@@ -6597,27 +6497,6 @@ func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) erro
|
||||
return q.db.UpsertChatSystemPrompt(ctx, value)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatUsageLimitConfig{}, err
|
||||
}
|
||||
return q.db.UpsertChatUsageLimitConfig(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.UpsertChatUsageLimitGroupOverrideRow{}, err
|
||||
}
|
||||
return q.db.UpsertChatUsageLimitGroupOverride(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.UpsertChatUsageLimitUserOverrideRow{}, err
|
||||
}
|
||||
return q.db.UpsertChatUsageLimitUserOverride(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
|
||||
return database.ConnectionLog{}, err
|
||||
@@ -6900,7 +6779,3 @@ func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database
|
||||
// database.Store interface, so dbauthz needs to implement it.
|
||||
return q.ListAIBridgeModels(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
return q.GetChats(ctx, arg)
|
||||
}
|
||||
|
||||
@@ -513,10 +513,6 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatCostSummary(gomock.Any(), arg).Return(row, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(row)
|
||||
}))
|
||||
s.Run("CountEnabledModelsWithoutPricing", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().CountEnabledModelsWithoutPricing(gomock.Any()).Return(int64(3), nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(int64(3))
|
||||
}))
|
||||
s.Run("GetChatDiffStatusByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
|
||||
@@ -618,17 +614,12 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
|
||||
}))
|
||||
s.Run("GetChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
params := database.GetChatsParams{}
|
||||
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params).Asserts()
|
||||
}))
|
||||
s.Run("GetAuthorizedChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
params := database.GetChatsParams{}
|
||||
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
|
||||
// No asserts here because it re-routes through GetChats which uses SQLFilter.
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
s.Run("GetChatsByOwnerID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
c1 := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
c2 := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
params := database.GetChatsByOwnerIDParams{OwnerID: c1.OwnerID}
|
||||
dbm.EXPECT().GetChatsByOwnerID(gomock.Any(), params).Return([]database.Chat{c1, c2}, nil).AnyTimes()
|
||||
check.Args(params).Asserts(c1, policy.ActionRead, c2, policy.ActionRead).Returns([]database.Chat{c1, c2})
|
||||
}))
|
||||
s.Run("GetChatQueuedMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
@@ -641,10 +632,6 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatSystemPrompt(gomock.Any()).Return("prompt", nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("GetChatDesktopEnabled", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatDesktopEnabled(gomock.Any()).Return(false, nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("GetEnabledChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
@@ -854,146 +841,6 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpsertChatSystemPrompt(gomock.Any(), "").Return(nil).AnyTimes()
|
||||
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("UpsertChatDesktopEnabled", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatDesktopEnabled(gomock.Any(), false).Return(nil).AnyTimes()
|
||||
check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetUserChatSpendInPeriod", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetUserChatSpendInPeriodParams{
|
||||
UserID: uuid.New(),
|
||||
StartTime: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
EndTime: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
spend := int64(123)
|
||||
dbm.EXPECT().GetUserChatSpendInPeriod(gomock.Any(), arg).Return(spend, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.UserID.String()), policy.ActionRead).Returns(spend)
|
||||
}))
|
||||
s.Run("GetUserGroupSpendLimit", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
userID := uuid.New()
|
||||
limit := int64(456)
|
||||
dbm.EXPECT().GetUserGroupSpendLimit(gomock.Any(), userID).Return(limit, nil).AnyTimes()
|
||||
check.Args(userID).Asserts(rbac.ResourceChat.WithOwner(userID.String()), policy.ActionRead).Returns(limit)
|
||||
}))
|
||||
s.Run("ResolveUserChatSpendLimit", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
userID := uuid.New()
|
||||
limit := int64(789)
|
||||
dbm.EXPECT().ResolveUserChatSpendLimit(gomock.Any(), userID).Return(limit, nil).AnyTimes()
|
||||
check.Args(userID).Asserts(rbac.ResourceChat.WithOwner(userID.String()), policy.ActionRead).Returns(limit)
|
||||
}))
|
||||
s.Run("GetChatUsageLimitConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
now := dbtime.Now()
|
||||
config := database.ChatUsageLimitConfig{
|
||||
ID: 1,
|
||||
Singleton: true,
|
||||
Enabled: true,
|
||||
DefaultLimitMicros: 1_000_000,
|
||||
Period: "monthly",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
dbm.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(config, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
|
||||
}))
|
||||
s.Run("GetChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
groupID := uuid.New()
|
||||
override := database.GetChatUsageLimitGroupOverrideRow{
|
||||
GroupID: groupID,
|
||||
SpendLimitMicros: sql.NullInt64{Int64: 2_000_000, Valid: true},
|
||||
}
|
||||
dbm.EXPECT().GetChatUsageLimitGroupOverride(gomock.Any(), groupID).Return(override, nil).AnyTimes()
|
||||
check.Args(groupID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(override)
|
||||
}))
|
||||
s.Run("GetChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
userID := uuid.New()
|
||||
override := database.GetChatUsageLimitUserOverrideRow{
|
||||
UserID: userID,
|
||||
SpendLimitMicros: sql.NullInt64{Int64: 3_000_000, Valid: true},
|
||||
}
|
||||
dbm.EXPECT().GetChatUsageLimitUserOverride(gomock.Any(), userID).Return(override, nil).AnyTimes()
|
||||
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(override)
|
||||
}))
|
||||
s.Run("ListChatUsageLimitGroupOverrides", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
overrides := []database.ListChatUsageLimitGroupOverridesRow{{
|
||||
GroupID: uuid.New(),
|
||||
GroupName: "group-name",
|
||||
GroupDisplayName: "Group Name",
|
||||
GroupAvatarUrl: "https://example.com/group.png",
|
||||
SpendLimitMicros: sql.NullInt64{Int64: 4_000_000, Valid: true},
|
||||
MemberCount: 5,
|
||||
}}
|
||||
dbm.EXPECT().ListChatUsageLimitGroupOverrides(gomock.Any()).Return(overrides, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(overrides)
|
||||
}))
|
||||
s.Run("ListChatUsageLimitOverrides", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
overrides := []database.ListChatUsageLimitOverridesRow{{
|
||||
UserID: uuid.New(),
|
||||
Username: "usage-limit-user",
|
||||
Name: "Usage Limit User",
|
||||
AvatarURL: "https://example.com/avatar.png",
|
||||
SpendLimitMicros: sql.NullInt64{Int64: 5_000_000, Valid: true},
|
||||
}}
|
||||
dbm.EXPECT().ListChatUsageLimitOverrides(gomock.Any()).Return(overrides, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(overrides)
|
||||
}))
|
||||
s.Run("UpsertChatUsageLimitConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
now := dbtime.Now()
|
||||
arg := database.UpsertChatUsageLimitConfigParams{
|
||||
Enabled: true,
|
||||
DefaultLimitMicros: 6_000_000,
|
||||
Period: "monthly",
|
||||
}
|
||||
config := database.ChatUsageLimitConfig{
|
||||
ID: 1,
|
||||
Singleton: true,
|
||||
Enabled: arg.Enabled,
|
||||
DefaultLimitMicros: arg.DefaultLimitMicros,
|
||||
Period: arg.Period,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
dbm.EXPECT().UpsertChatUsageLimitConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
|
||||
}))
|
||||
s.Run("UpsertChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.UpsertChatUsageLimitGroupOverrideParams{
|
||||
SpendLimitMicros: 7_000_000,
|
||||
GroupID: uuid.New(),
|
||||
}
|
||||
override := database.UpsertChatUsageLimitGroupOverrideRow{
|
||||
GroupID: arg.GroupID,
|
||||
Name: "group",
|
||||
DisplayName: "Group",
|
||||
AvatarURL: "",
|
||||
SpendLimitMicros: sql.NullInt64{Int64: arg.SpendLimitMicros, Valid: true},
|
||||
}
|
||||
dbm.EXPECT().UpsertChatUsageLimitGroupOverride(gomock.Any(), arg).Return(override, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(override)
|
||||
}))
|
||||
s.Run("UpsertChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.UpsertChatUsageLimitUserOverrideParams{
|
||||
SpendLimitMicros: 8_000_000,
|
||||
UserID: uuid.New(),
|
||||
}
|
||||
override := database.UpsertChatUsageLimitUserOverrideRow{
|
||||
UserID: arg.UserID,
|
||||
Username: "user",
|
||||
Name: "User",
|
||||
AvatarURL: "",
|
||||
SpendLimitMicros: sql.NullInt64{Int64: arg.SpendLimitMicros, Valid: true},
|
||||
}
|
||||
dbm.EXPECT().UpsertChatUsageLimitUserOverride(gomock.Any(), arg).Return(override, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(override)
|
||||
}))
|
||||
s.Run("DeleteChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
groupID := uuid.New()
|
||||
dbm.EXPECT().DeleteChatUsageLimitGroupOverride(gomock.Any(), groupID).Return(nil).AnyTimes()
|
||||
check.Args(groupID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("DeleteChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
userID := uuid.New()
|
||||
dbm.EXPECT().DeleteChatUsageLimitUserOverride(gomock.Any(), userID).Return(nil).AnyTimes()
|
||||
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestFile() {
|
||||
@@ -1493,7 +1340,7 @@ func (s *MethodTestSuite) TestOrganization() {
|
||||
org := testutil.Fake(s.T(), faker, database.Organization{})
|
||||
arg := database.UpdateOrganizationWorkspaceSharingSettingsParams{
|
||||
ID: org.ID,
|
||||
ShareableWorkspaceOwners: database.ShareableWorkspaceOwnersNone,
|
||||
WorkspaceSharingDisabled: true,
|
||||
}
|
||||
dbm.EXPECT().GetOrganizationByID(gomock.Any(), org.ID).Return(org, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateOrganizationWorkspaceSharingSettings(gomock.Any(), arg).Return(org, nil).AnyTimes()
|
||||
@@ -2412,12 +2259,9 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
check.Args(w.ID).Asserts(w, policy.ActionShare)
|
||||
}))
|
||||
s.Run("DeleteWorkspaceACLsByOrganization", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.DeleteWorkspaceACLsByOrganizationParams{
|
||||
OrganizationID: uuid.New(),
|
||||
ExcludeServiceAccounts: false,
|
||||
}
|
||||
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
|
||||
orgID := uuid.New()
|
||||
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), orgID).Return(nil).AnyTimes()
|
||||
check.Args(orgID).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetLatestWorkspaceBuildByWorkspaceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
@@ -5183,17 +5027,6 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
check.Args(params).Asserts(intc, policy.ActionCreate)
|
||||
}))
|
||||
|
||||
s.Run("InsertAIBridgeModelThought", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
intID := uuid.UUID{2}
|
||||
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
|
||||
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intID).Return(intc, nil).AnyTimes() // Validation.
|
||||
|
||||
params := database.InsertAIBridgeModelThoughtParams{InterceptionID: intc.ID}
|
||||
expected := testutil.Fake(s.T(), faker, database.AIBridgeModelThought{InterceptionID: intc.ID})
|
||||
db.EXPECT().InsertAIBridgeModelThought(gomock.Any(), params).Return(expected, nil).AnyTimes()
|
||||
check.Args(params).Asserts(intc, policy.ActionUpdate)
|
||||
}))
|
||||
|
||||
s.Run("InsertAIBridgeTokenUsage", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
intID := uuid.UUID{2}
|
||||
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
|
||||
|
||||
@@ -29,7 +29,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/rbac/regosql"
|
||||
"github.com/coder/coder/v2/coderd/rbac/rolestore"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
)
|
||||
|
||||
@@ -144,7 +143,7 @@ func (s *MethodTestSuite) Mocked(testCaseF func(dmb *dbmock.MockStore, faker *go
|
||||
UUID: pair.OrganizationID,
|
||||
Valid: pair.OrganizationID != uuid.Nil,
|
||||
},
|
||||
IsSystem: rolestore.IsSystemRoleName(pair.Name),
|
||||
IsSystem: rbac.SystemRoleName(pair.Name),
|
||||
ID: uuid.New(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -650,26 +650,34 @@ func Organization(t testing.TB, db database.Store, orig database.Organization) d
|
||||
})
|
||||
require.NoError(t, err, "insert organization")
|
||||
|
||||
// Populate the placeholder system roles (created by DB
|
||||
// trigger/migration) so org members have expected permissions.
|
||||
//nolint:gocritic // ReconcileSystemRole needs the system:update
|
||||
// Populate the placeholder organization-member system role (created by
|
||||
// DB trigger/migration) so org members have expected permissions.
|
||||
//nolint:gocritic // ReconcileOrgMemberRole needs the system:update
|
||||
// permission that `genCtx` does not have.
|
||||
sysCtx := dbauthz.AsSystemRestricted(genCtx)
|
||||
for roleName := range rolestore.SystemRoleNames {
|
||||
role := database.CustomRole{
|
||||
Name: roleName,
|
||||
OrganizationID: uuid.NullUUID{UUID: org.ID, Valid: true},
|
||||
}
|
||||
_, _, err = rolestore.ReconcileSystemRole(sysCtx, db, role, org)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// The trigger that creates the placeholder role didn't run (e.g.,
|
||||
// triggers were disabled in the test). Create the role manually.
|
||||
err = rolestore.CreateSystemRole(sysCtx, db, org, roleName)
|
||||
require.NoError(t, err, "create role "+roleName)
|
||||
_, _, err = rolestore.ReconcileSystemRole(sysCtx, db, role, org)
|
||||
}
|
||||
require.NoError(t, err, "reconcile role "+roleName)
|
||||
_, _, err = rolestore.ReconcileOrgMemberRole(sysCtx, db, database.CustomRole{
|
||||
Name: rbac.RoleOrgMember(),
|
||||
OrganizationID: uuid.NullUUID{
|
||||
UUID: org.ID,
|
||||
Valid: true,
|
||||
},
|
||||
}, org.WorkspaceSharingDisabled)
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// The trigger that creates the placeholder role didn't run (e.g.,
|
||||
// triggers were disabled in the test). Create the role manually.
|
||||
err = rolestore.CreateOrgMemberRole(sysCtx, db, org)
|
||||
require.NoError(t, err, "create organization-member role")
|
||||
|
||||
_, _, err = rolestore.ReconcileOrgMemberRole(sysCtx, db, database.CustomRole{
|
||||
Name: rbac.RoleOrgMember(),
|
||||
OrganizationID: uuid.NullUUID{
|
||||
UUID: org.ID,
|
||||
Valid: true,
|
||||
},
|
||||
}, org.WorkspaceSharingDisabled)
|
||||
}
|
||||
require.NoError(t, err, "reconcile organization-member role")
|
||||
|
||||
return org
|
||||
}
|
||||
|
||||
@@ -288,14 +288,6 @@ func (m queryMetricsStore) CountConnectionLogs(ctx context.Context, arg database
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountEnabledModelsWithoutPricing(ctx)
|
||||
m.queryLatencies.WithLabelValues("CountEnabledModelsWithoutPricing").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountEnabledModelsWithoutPricing").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountInProgressPrebuilds(ctx)
|
||||
@@ -416,22 +408,6 @@ func (m queryMetricsStore) DeleteChatQueuedMessage(ctx context.Context, arg data
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatUsageLimitGroupOverride(ctx, groupID)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatUsageLimitGroupOverride").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatUsageLimitUserOverride(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatUsageLimitUserOverride").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteCryptoKey(ctx, arg)
|
||||
@@ -696,11 +672,10 @@ func (m queryMetricsStore) DeleteWorkspaceACLByID(ctx context.Context, id uuid.U
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, arg database.DeleteWorkspaceACLsByOrganizationParams) error {
|
||||
func (m queryMetricsStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteWorkspaceACLsByOrganization(ctx, arg)
|
||||
r0 := m.s.DeleteWorkspaceACLsByOrganization(ctx, organizationID)
|
||||
m.queryLatencies.WithLabelValues("DeleteWorkspaceACLsByOrganization").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteWorkspaceACLsByOrganization").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -1048,14 +1023,6 @@ func (m queryMetricsStore) GetChatCostSummary(ctx context.Context, arg database.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDesktopEnabled(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatDesktopEnabled").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDesktopEnabled").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDiffStatusByChatID(ctx, chatID)
|
||||
@@ -1176,35 +1143,11 @@ func (m queryMetricsStore) GetChatSystemPrompt(ctx context.Context) (string, err
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||
func (m queryMetricsStore) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatUsageLimitConfig(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatUsageLimitConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatUsageLimitGroupOverride(ctx, groupID)
|
||||
m.queryLatencies.WithLabelValues("GetChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitGroupOverride").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatUsageLimitUserOverride(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("GetChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitUserOverride").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChats(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChats").Inc()
|
||||
r0, r1 := m.s.GetChatsByOwnerID(ctx, ownerID)
|
||||
m.queryLatencies.WithLabelValues("GetChatsByOwnerID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsByOwnerID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -1888,6 +1831,14 @@ func (m queryMetricsStore) GetProvisionerJobTimingsByJobID(ctx context.Context,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetProvisionerJobsByIDs(ctx, ids)
|
||||
m.queryLatencies.WithLabelValues("GetProvisionerJobsByIDs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetProvisionerJobsByIDs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetProvisionerJobsByIDsWithQueuePosition(ctx, arg)
|
||||
@@ -2328,14 +2279,6 @@ func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID u
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatSpendInPeriod(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetUserChatSpendInPeriod").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatSpendInPeriod").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserCount(ctx, includeSystem)
|
||||
@@ -2344,14 +2287,6 @@ func (m queryMetricsStore) GetUserCount(ctx context.Context, includeSystem bool)
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserGroupSpendLimit(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("GetUserGroupSpendLimit").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserGroupSpendLimit").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserLatencyInsights(ctx, arg)
|
||||
@@ -2960,14 +2895,6 @@ func (m queryMetricsStore) InsertAIBridgeInterception(ctx context.Context, arg d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertAIBridgeModelThought(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertAIBridgeModelThought").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertAIBridgeModelThought").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertAIBridgeTokenUsage(ctx, arg)
|
||||
@@ -3592,22 +3519,6 @@ func (m queryMetricsStore) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListChatUsageLimitGroupOverrides(ctx)
|
||||
m.queryLatencies.WithLabelValues("ListChatUsageLimitGroupOverrides").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatUsageLimitGroupOverrides").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListChatUsageLimitOverrides(ctx)
|
||||
m.queryLatencies.WithLabelValues("ListChatUsageLimitOverrides").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatUsageLimitOverrides").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListProvisionerKeysByOrganization(ctx, organizationID)
|
||||
@@ -3720,14 +3631,6 @@ func (m queryMetricsStore) RemoveUserFromGroups(ctx context.Context, arg databas
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ResolveUserChatSpendLimit(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("ResolveUserChatSpendLimit").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ResolveUserChatSpendLimit").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest)
|
||||
@@ -3980,7 +3883,6 @@ func (m queryMetricsStore) UpdateOrganizationWorkspaceSharingSettings(ctx contex
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateOrganizationWorkspaceSharingSettings(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateOrganizationWorkspaceSharingSettings").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateOrganizationWorkspaceSharingSettings").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -4560,14 +4462,6 @@ func (m queryMetricsStore) UpsertBoundaryUsageStats(ctx context.Context, arg dat
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatDesktopEnabled(ctx, enableDesktop)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatDesktopEnabled").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDesktopEnabled").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatDiffStatus(ctx, arg)
|
||||
@@ -4592,30 +4486,6 @@ func (m queryMetricsStore) UpsertChatSystemPrompt(ctx context.Context, value str
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatUsageLimitConfig(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatUsageLimitGroupOverride(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitGroupOverride").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatUsageLimitUserOverride(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitUserOverride").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertConnectionLog(ctx, arg)
|
||||
@@ -4911,11 +4781,3 @@ func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeModels").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAuthorizedChats(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("GetAuthorizedChats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAuthorizedChats").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -424,21 +424,6 @@ func (mr *MockStoreMockRecorder) CountConnectionLogs(ctx, arg any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountConnectionLogs", reflect.TypeOf((*MockStore)(nil).CountConnectionLogs), ctx, arg)
|
||||
}
|
||||
|
||||
// CountEnabledModelsWithoutPricing mocks base method.
|
||||
func (m *MockStore) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountEnabledModelsWithoutPricing", ctx)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountEnabledModelsWithoutPricing indicates an expected call of CountEnabledModelsWithoutPricing.
|
||||
func (mr *MockStoreMockRecorder) CountEnabledModelsWithoutPricing(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEnabledModelsWithoutPricing", reflect.TypeOf((*MockStore)(nil).CountEnabledModelsWithoutPricing), ctx)
|
||||
}
|
||||
|
||||
// CountInProgressPrebuilds mocks base method.
|
||||
func (m *MockStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -654,34 +639,6 @@ func (mr *MockStoreMockRecorder) DeleteChatQueuedMessage(ctx, arg any) *gomock.C
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatQueuedMessage", reflect.TypeOf((*MockStore)(nil).DeleteChatQueuedMessage), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteChatUsageLimitGroupOverride mocks base method.
|
||||
func (m *MockStore) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatUsageLimitGroupOverride", ctx, groupID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatUsageLimitGroupOverride indicates an expected call of DeleteChatUsageLimitGroupOverride.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatUsageLimitGroupOverride(ctx, groupID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).DeleteChatUsageLimitGroupOverride), ctx, groupID)
|
||||
}
|
||||
|
||||
// DeleteChatUsageLimitUserOverride mocks base method.
|
||||
func (m *MockStore) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatUsageLimitUserOverride", ctx, userID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatUsageLimitUserOverride indicates an expected call of DeleteChatUsageLimitUserOverride.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatUsageLimitUserOverride(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).DeleteChatUsageLimitUserOverride), ctx, userID)
|
||||
}
|
||||
|
||||
// DeleteCryptoKey mocks base method.
|
||||
func (m *MockStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1155,17 +1112,17 @@ func (mr *MockStoreMockRecorder) DeleteWorkspaceACLByID(ctx, id any) *gomock.Cal
|
||||
}
|
||||
|
||||
// DeleteWorkspaceACLsByOrganization mocks base method.
|
||||
func (m *MockStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, arg database.DeleteWorkspaceACLsByOrganizationParams) error {
|
||||
func (m *MockStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteWorkspaceACLsByOrganization", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "DeleteWorkspaceACLsByOrganization", ctx, organizationID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteWorkspaceACLsByOrganization indicates an expected call of DeleteWorkspaceACLsByOrganization.
|
||||
func (mr *MockStoreMockRecorder) DeleteWorkspaceACLsByOrganization(ctx, arg any) *gomock.Call {
|
||||
func (mr *MockStoreMockRecorder) DeleteWorkspaceACLsByOrganization(ctx, organizationID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLsByOrganization", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLsByOrganization), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLsByOrganization", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLsByOrganization), ctx, organizationID)
|
||||
}
|
||||
|
||||
// DeleteWorkspaceAgentPortShare mocks base method.
|
||||
@@ -1731,21 +1688,6 @@ func (mr *MockStoreMockRecorder) GetAuthorizedAuditLogsOffset(ctx, arg, prepared
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedAuditLogsOffset", reflect.TypeOf((*MockStore)(nil).GetAuthorizedAuditLogsOffset), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// GetAuthorizedChats mocks base method.
|
||||
func (m *MockStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAuthorizedChats", ctx, arg, prepared)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAuthorizedChats indicates an expected call of GetAuthorizedChats.
|
||||
func (mr *MockStoreMockRecorder) GetAuthorizedChats(ctx, arg, prepared any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedChats", reflect.TypeOf((*MockStore)(nil).GetAuthorizedChats), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// GetAuthorizedConnectionLogsOffset mocks base method.
|
||||
func (m *MockStore) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1911,21 +1853,6 @@ func (mr *MockStoreMockRecorder) GetChatCostSummary(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostSummary", reflect.TypeOf((*MockStore)(nil).GetChatCostSummary), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatDesktopEnabled mocks base method.
|
||||
func (m *MockStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatDesktopEnabled", ctx)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatDesktopEnabled indicates an expected call of GetChatDesktopEnabled.
|
||||
func (mr *MockStoreMockRecorder) GetChatDesktopEnabled(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDesktopEnabled", reflect.TypeOf((*MockStore)(nil).GetChatDesktopEnabled), ctx)
|
||||
}
|
||||
|
||||
// GetChatDiffStatusByChatID mocks base method.
|
||||
func (m *MockStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2151,64 +2078,19 @@ func (mr *MockStoreMockRecorder) GetChatSystemPrompt(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).GetChatSystemPrompt), ctx)
|
||||
}
|
||||
|
||||
// GetChatUsageLimitConfig mocks base method.
|
||||
func (m *MockStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||
// GetChatsByOwnerID mocks base method.
|
||||
func (m *MockStore) GetChatsByOwnerID(ctx context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatUsageLimitConfig", ctx)
|
||||
ret0, _ := ret[0].(database.ChatUsageLimitConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatUsageLimitConfig indicates an expected call of GetChatUsageLimitConfig.
|
||||
func (mr *MockStoreMockRecorder) GetChatUsageLimitConfig(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitConfig", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitConfig), ctx)
|
||||
}
|
||||
|
||||
// GetChatUsageLimitGroupOverride mocks base method.
|
||||
func (m *MockStore) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatUsageLimitGroupOverride", ctx, groupID)
|
||||
ret0, _ := ret[0].(database.GetChatUsageLimitGroupOverrideRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatUsageLimitGroupOverride indicates an expected call of GetChatUsageLimitGroupOverride.
|
||||
func (mr *MockStoreMockRecorder) GetChatUsageLimitGroupOverride(ctx, groupID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitGroupOverride), ctx, groupID)
|
||||
}
|
||||
|
||||
// GetChatUsageLimitUserOverride mocks base method.
|
||||
func (m *MockStore) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatUsageLimitUserOverride", ctx, userID)
|
||||
ret0, _ := ret[0].(database.GetChatUsageLimitUserOverrideRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatUsageLimitUserOverride indicates an expected call of GetChatUsageLimitUserOverride.
|
||||
func (mr *MockStoreMockRecorder) GetChatUsageLimitUserOverride(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitUserOverride), ctx, userID)
|
||||
}
|
||||
|
||||
// GetChats mocks base method.
|
||||
func (m *MockStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChats", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "GetChatsByOwnerID", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChats indicates an expected call of GetChats.
|
||||
func (mr *MockStoreMockRecorder) GetChats(ctx, arg any) *gomock.Call {
|
||||
// GetChatsByOwnerID indicates an expected call of GetChatsByOwnerID.
|
||||
func (mr *MockStoreMockRecorder) GetChatsByOwnerID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChats", reflect.TypeOf((*MockStore)(nil).GetChats), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByOwnerID", reflect.TypeOf((*MockStore)(nil).GetChatsByOwnerID), ctx, arg)
|
||||
}
|
||||
|
||||
// GetConnectionLogsOffset mocks base method.
|
||||
@@ -3486,6 +3368,21 @@ func (mr *MockStoreMockRecorder) GetProvisionerJobTimingsByJobID(ctx, jobID any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerJobTimingsByJobID", reflect.TypeOf((*MockStore)(nil).GetProvisionerJobTimingsByJobID), ctx, jobID)
|
||||
}
|
||||
|
||||
// GetProvisionerJobsByIDs mocks base method.
|
||||
func (m *MockStore) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProvisionerJobsByIDs", ctx, ids)
|
||||
ret0, _ := ret[0].([]database.ProvisionerJob)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetProvisionerJobsByIDs indicates an expected call of GetProvisionerJobsByIDs.
|
||||
func (mr *MockStoreMockRecorder) GetProvisionerJobsByIDs(ctx, ids any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerJobsByIDs", reflect.TypeOf((*MockStore)(nil).GetProvisionerJobsByIDs), ctx, ids)
|
||||
}
|
||||
|
||||
// GetProvisionerJobsByIDsWithQueuePosition mocks base method.
|
||||
func (m *MockStore) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4341,21 +4238,6 @@ func (mr *MockStoreMockRecorder) GetUserChatCustomPrompt(ctx, userID any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).GetUserChatCustomPrompt), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserChatSpendInPeriod mocks base method.
|
||||
func (m *MockStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserChatSpendInPeriod", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserChatSpendInPeriod indicates an expected call of GetUserChatSpendInPeriod.
|
||||
func (mr *MockStoreMockRecorder) GetUserChatSpendInPeriod(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatSpendInPeriod", reflect.TypeOf((*MockStore)(nil).GetUserChatSpendInPeriod), ctx, arg)
|
||||
}
|
||||
|
||||
// GetUserCount mocks base method.
|
||||
func (m *MockStore) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4371,21 +4253,6 @@ func (mr *MockStoreMockRecorder) GetUserCount(ctx, includeSystem any) *gomock.Ca
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserCount", reflect.TypeOf((*MockStore)(nil).GetUserCount), ctx, includeSystem)
|
||||
}
|
||||
|
||||
// GetUserGroupSpendLimit mocks base method.
|
||||
func (m *MockStore) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserGroupSpendLimit", ctx, userID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserGroupSpendLimit indicates an expected call of GetUserGroupSpendLimit.
|
||||
func (mr *MockStoreMockRecorder) GetUserGroupSpendLimit(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserGroupSpendLimit", reflect.TypeOf((*MockStore)(nil).GetUserGroupSpendLimit), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserLatencyInsights mocks base method.
|
||||
func (m *MockStore) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -5540,21 +5407,6 @@ func (mr *MockStoreMockRecorder) InsertAIBridgeInterception(ctx, arg any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeInterception", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeInterception), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertAIBridgeModelThought mocks base method.
|
||||
func (m *MockStore) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertAIBridgeModelThought", ctx, arg)
|
||||
ret0, _ := ret[0].(database.AIBridgeModelThought)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertAIBridgeModelThought indicates an expected call of InsertAIBridgeModelThought.
|
||||
func (mr *MockStoreMockRecorder) InsertAIBridgeModelThought(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeModelThought", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeModelThought), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertAIBridgeTokenUsage mocks base method.
|
||||
func (m *MockStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6740,36 +6592,6 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// ListChatUsageLimitGroupOverrides mocks base method.
|
||||
func (m *MockStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListChatUsageLimitGroupOverrides", ctx)
|
||||
ret0, _ := ret[0].([]database.ListChatUsageLimitGroupOverridesRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListChatUsageLimitGroupOverrides indicates an expected call of ListChatUsageLimitGroupOverrides.
|
||||
func (mr *MockStoreMockRecorder) ListChatUsageLimitGroupOverrides(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChatUsageLimitGroupOverrides", reflect.TypeOf((*MockStore)(nil).ListChatUsageLimitGroupOverrides), ctx)
|
||||
}
|
||||
|
||||
// ListChatUsageLimitOverrides mocks base method.
|
||||
func (m *MockStore) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListChatUsageLimitOverrides", ctx)
|
||||
ret0, _ := ret[0].([]database.ListChatUsageLimitOverridesRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListChatUsageLimitOverrides indicates an expected call of ListChatUsageLimitOverrides.
|
||||
func (mr *MockStoreMockRecorder) ListChatUsageLimitOverrides(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChatUsageLimitOverrides", reflect.TypeOf((*MockStore)(nil).ListChatUsageLimitOverrides), ctx)
|
||||
}
|
||||
|
||||
// ListProvisionerKeysByOrganization mocks base method.
|
||||
func (m *MockStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7008,21 +6830,6 @@ func (mr *MockStoreMockRecorder) RemoveUserFromGroups(ctx, arg any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromGroups), ctx, arg)
|
||||
}
|
||||
|
||||
// ResolveUserChatSpendLimit mocks base method.
|
||||
func (m *MockStore) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ResolveUserChatSpendLimit", ctx, userID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ResolveUserChatSpendLimit indicates an expected call of ResolveUserChatSpendLimit.
|
||||
func (mr *MockStoreMockRecorder) ResolveUserChatSpendLimit(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolveUserChatSpendLimit", reflect.TypeOf((*MockStore)(nil).ResolveUserChatSpendLimit), ctx, userID)
|
||||
}
|
||||
|
||||
// RevokeDBCryptKey mocks base method.
|
||||
func (m *MockStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8525,20 +8332,6 @@ func (mr *MockStoreMockRecorder) UpsertBoundaryUsageStats(ctx, arg any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertBoundaryUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertBoundaryUsageStats), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatDesktopEnabled mocks base method.
|
||||
func (m *MockStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatDesktopEnabled", ctx, enableDesktop)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertChatDesktopEnabled indicates an expected call of UpsertChatDesktopEnabled.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatDesktopEnabled(ctx, enableDesktop any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDesktopEnabled", reflect.TypeOf((*MockStore)(nil).UpsertChatDesktopEnabled), ctx, enableDesktop)
|
||||
}
|
||||
|
||||
// UpsertChatDiffStatus mocks base method.
|
||||
func (m *MockStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8583,51 +8376,6 @@ func (mr *MockStoreMockRecorder) UpsertChatSystemPrompt(ctx, value any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatSystemPrompt), ctx, value)
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitConfig mocks base method.
|
||||
func (m *MockStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatUsageLimitConfig", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatUsageLimitConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitConfig indicates an expected call of UpsertChatUsageLimitConfig.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitConfig(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitConfig", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitConfig), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitGroupOverride mocks base method.
|
||||
func (m *MockStore) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatUsageLimitGroupOverride", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UpsertChatUsageLimitGroupOverrideRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitGroupOverride indicates an expected call of UpsertChatUsageLimitGroupOverride.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitGroupOverride(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitGroupOverride), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitUserOverride mocks base method.
|
||||
func (m *MockStore) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatUsageLimitUserOverride", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UpsertChatUsageLimitUserOverrideRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitUserOverride indicates an expected call of UpsertChatUsageLimitUserOverride.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitUserOverride(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitUserOverride), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertConnectionLog mocks base method.
|
||||
func (m *MockStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+5
-73
@@ -512,12 +512,6 @@ CREATE TYPE resource_type AS ENUM (
|
||||
'ai_seat'
|
||||
);
|
||||
|
||||
CREATE TYPE shareable_workspace_owners AS ENUM (
|
||||
'none',
|
||||
'everyone',
|
||||
'service_accounts'
|
||||
);
|
||||
|
||||
CREATE TYPE startup_script_behavior AS ENUM (
|
||||
'blocking',
|
||||
'non-blocking'
|
||||
@@ -798,7 +792,7 @@ BEGIN
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION insert_organization_system_roles() RETURNS trigger
|
||||
CREATE FUNCTION insert_org_member_system_role() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
@@ -813,8 +807,7 @@ BEGIN
|
||||
is_system,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES
|
||||
(
|
||||
) VALUES (
|
||||
'organization-member',
|
||||
'',
|
||||
NEW.id,
|
||||
@@ -825,18 +818,6 @@ BEGIN
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
),
|
||||
(
|
||||
'organization-service-account',
|
||||
'',
|
||||
NEW.id,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
);
|
||||
RETURN NEW;
|
||||
END;
|
||||
@@ -1105,15 +1086,6 @@ COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interceptio
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID supplied by the client (optional and not universally supported).';
|
||||
|
||||
CREATE TABLE aibridge_model_thoughts (
|
||||
interception_id uuid NOT NULL,
|
||||
content text NOT NULL,
|
||||
metadata jsonb,
|
||||
created_at timestamp with time zone NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_model_thoughts IS 'Audit log of model thinking in intercepted requests in AI Bridge';
|
||||
|
||||
CREATE TABLE aibridge_token_usages (
|
||||
id uuid NOT NULL,
|
||||
interception_id uuid NOT NULL,
|
||||
@@ -1346,28 +1318,6 @@ CREATE SEQUENCE chat_queued_messages_id_seq
|
||||
|
||||
ALTER SEQUENCE chat_queued_messages_id_seq OWNED BY chat_queued_messages.id;
|
||||
|
||||
CREATE TABLE chat_usage_limit_config (
|
||||
id bigint NOT NULL,
|
||||
singleton boolean DEFAULT true NOT NULL,
|
||||
enabled boolean DEFAULT false NOT NULL,
|
||||
default_limit_micros bigint DEFAULT 0 NOT NULL,
|
||||
period text DEFAULT 'month'::text NOT NULL,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
CONSTRAINT chat_usage_limit_config_default_limit_micros_check CHECK ((default_limit_micros >= 0)),
|
||||
CONSTRAINT chat_usage_limit_config_period_check CHECK ((period = ANY (ARRAY['day'::text, 'week'::text, 'month'::text]))),
|
||||
CONSTRAINT chat_usage_limit_config_singleton_check CHECK (singleton)
|
||||
);
|
||||
|
||||
CREATE SEQUENCE chat_usage_limit_config_id_seq
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1;
|
||||
|
||||
ALTER SEQUENCE chat_usage_limit_config_id_seq OWNED BY chat_usage_limit_config.id;
|
||||
|
||||
CREATE TABLE chats (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
owner_id uuid NOT NULL,
|
||||
@@ -1524,9 +1474,7 @@ CREATE TABLE groups (
|
||||
avatar_url text DEFAULT ''::text NOT NULL,
|
||||
quota_allowance integer DEFAULT 0 NOT NULL,
|
||||
display_name text DEFAULT ''::text NOT NULL,
|
||||
source group_source DEFAULT 'user'::group_source NOT NULL,
|
||||
chat_spend_limit_micros bigint,
|
||||
CONSTRAINT groups_chat_spend_limit_micros_check CHECK (((chat_spend_limit_micros IS NULL) OR (chat_spend_limit_micros > 0)))
|
||||
source group_source DEFAULT 'user'::group_source NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN groups.display_name IS 'Display name is a custom, human-friendly group name that user can set. This is not required to be unique and can be the empty string.';
|
||||
@@ -1561,9 +1509,7 @@ CREATE TABLE users (
|
||||
one_time_passcode_expires_at timestamp with time zone,
|
||||
is_system boolean DEFAULT false NOT NULL,
|
||||
is_service_account boolean DEFAULT false NOT NULL,
|
||||
chat_spend_limit_micros bigint,
|
||||
CONSTRAINT one_time_passcode_set CHECK ((((hashed_one_time_passcode IS NULL) AND (one_time_passcode_expires_at IS NULL)) OR ((hashed_one_time_passcode IS NOT NULL) AND (one_time_passcode_expires_at IS NOT NULL)))),
|
||||
CONSTRAINT users_chat_spend_limit_micros_check CHECK (((chat_spend_limit_micros IS NULL) OR (chat_spend_limit_micros > 0))),
|
||||
CONSTRAINT users_email_not_empty CHECK (((is_service_account = true) = (email = ''::text))),
|
||||
CONSTRAINT users_service_account_login_type CHECK (((is_service_account = false) OR (login_type = 'none'::login_type))),
|
||||
CONSTRAINT users_username_min_length CHECK ((length(username) >= 1))
|
||||
@@ -1851,11 +1797,9 @@ CREATE TABLE organizations (
|
||||
display_name text NOT NULL,
|
||||
icon text DEFAULT ''::text NOT NULL,
|
||||
deleted boolean DEFAULT false NOT NULL,
|
||||
shareable_workspace_owners shareable_workspace_owners DEFAULT 'everyone'::shareable_workspace_owners NOT NULL
|
||||
workspace_sharing_disabled boolean DEFAULT false NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN organizations.shareable_workspace_owners IS 'Controls whose workspaces can be shared: none, everyone, or service_accounts.';
|
||||
|
||||
CREATE TABLE parameter_schemas (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
@@ -3212,8 +3156,6 @@ ALTER TABLE ONLY chat_messages ALTER COLUMN id SET DEFAULT nextval('chat_message
|
||||
|
||||
ALTER TABLE ONLY chat_queued_messages ALTER COLUMN id SET DEFAULT nextval('chat_queued_messages_id_seq'::regclass);
|
||||
|
||||
ALTER TABLE ONLY chat_usage_limit_config ALTER COLUMN id SET DEFAULT nextval('chat_usage_limit_config_id_seq'::regclass);
|
||||
|
||||
ALTER TABLE ONLY licenses ALTER COLUMN id SET DEFAULT nextval('licenses_id_seq'::regclass);
|
||||
|
||||
ALTER TABLE ONLY provisioner_job_logs ALTER COLUMN id SET DEFAULT nextval('provisioner_job_logs_id_seq'::regclass);
|
||||
@@ -3274,12 +3216,6 @@ ALTER TABLE ONLY chat_providers
|
||||
ALTER TABLE ONLY chat_queued_messages
|
||||
ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_usage_limit_config
|
||||
ADD CONSTRAINT chat_usage_limit_config_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_usage_limit_config
|
||||
ADD CONSTRAINT chat_usage_limit_config_singleton_key UNIQUE (singleton);
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_pkey PRIMARY KEY (id);
|
||||
|
||||
@@ -3592,8 +3528,6 @@ CREATE INDEX idx_aibridge_interceptions_thread_parent_id ON aibridge_interceptio
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_thread_root_id ON aibridge_interceptions USING btree (thread_root_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_model_thoughts_interception_id ON aibridge_model_thoughts USING btree (interception_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_token_usages_interception_id ON aibridge_token_usages USING btree (interception_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_token_usages_provider_response_id ON aibridge_token_usages USING btree (provider_response_id);
|
||||
@@ -3634,8 +3568,6 @@ CREATE INDEX idx_chat_messages_compressed_summary_boundary ON chat_messages USIN
|
||||
|
||||
CREATE INDEX idx_chat_messages_created_at ON chat_messages USING btree (created_at);
|
||||
|
||||
CREATE INDEX idx_chat_messages_owner_spend ON chat_messages USING btree (chat_id, created_at) WHERE (total_cost_micros IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_chat_model_configs_enabled ON chat_model_configs USING btree (enabled);
|
||||
|
||||
CREATE INDEX idx_chat_model_configs_provider ON chat_model_configs USING btree (provider);
|
||||
@@ -3884,7 +3816,7 @@ CREATE TRIGGER trigger_delete_oauth2_provider_app_token AFTER DELETE ON oauth2_p
|
||||
|
||||
CREATE TRIGGER trigger_insert_apikeys BEFORE INSERT ON api_keys FOR EACH ROW EXECUTE FUNCTION insert_apikey_fail_if_user_deleted();
|
||||
|
||||
CREATE TRIGGER trigger_insert_organization_system_roles AFTER INSERT ON organizations FOR EACH ROW EXECUTE FUNCTION insert_organization_system_roles();
|
||||
CREATE TRIGGER trigger_insert_org_member_system_role AFTER INSERT ON organizations FOR EACH ROW EXECUTE FUNCTION insert_org_member_system_role();
|
||||
|
||||
CREATE TRIGGER trigger_nullify_next_start_at_on_workspace_autostart_modificati AFTER UPDATE ON workspaces FOR EACH ROW EXECUTE FUNCTION nullify_next_start_at_on_workspace_autostart_modification();
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ func TestCustomQueriesSyncedRowScan(t *testing.T) {
|
||||
"GetTemplatesWithFilter": "GetAuthorizedTemplates",
|
||||
"GetWorkspaces": "GetAuthorizedWorkspaces",
|
||||
"GetUsers": "GetAuthorizedUsers",
|
||||
"GetChats": "GetAuthorizedChats",
|
||||
}
|
||||
|
||||
// Scan custom
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
DROP INDEX IF EXISTS idx_chat_messages_owner_spend;
|
||||
ALTER TABLE groups DROP COLUMN IF EXISTS chat_spend_limit_micros;
|
||||
ALTER TABLE users DROP COLUMN IF EXISTS chat_spend_limit_micros;
|
||||
DROP TABLE IF EXISTS chat_usage_limit_config;
|
||||
@@ -1,32 +0,0 @@
|
||||
-- 1. Singleton config table
|
||||
CREATE TABLE chat_usage_limit_config (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
-- Only one row allowed (enforced by CHECK).
|
||||
singleton BOOLEAN NOT NULL DEFAULT TRUE CHECK (singleton),
|
||||
UNIQUE (singleton),
|
||||
enabled BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
-- Limit per user per period, in micro-dollars (1 USD = 1,000,000).
|
||||
default_limit_micros BIGINT NOT NULL DEFAULT 0
|
||||
CHECK (default_limit_micros >= 0),
|
||||
-- Period length: 'day', 'week', or 'month'.
|
||||
period TEXT NOT NULL DEFAULT 'month'
|
||||
CHECK (period IN ('day', 'week', 'month')),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- Seed a single disabled row so reads never return empty.
|
||||
INSERT INTO chat_usage_limit_config (singleton) VALUES (TRUE);
|
||||
|
||||
-- 2. Per-user overrides (inline on users table).
|
||||
ALTER TABLE users ADD COLUMN chat_spend_limit_micros BIGINT DEFAULT NULL
|
||||
CHECK (chat_spend_limit_micros IS NULL OR chat_spend_limit_micros > 0);
|
||||
|
||||
-- 3. Per-group overrides (inline on groups table).
|
||||
ALTER TABLE groups ADD COLUMN chat_spend_limit_micros BIGINT DEFAULT NULL
|
||||
CHECK (chat_spend_limit_micros IS NULL OR chat_spend_limit_micros > 0);
|
||||
|
||||
-- Speed up per-user spend aggregation in the usage-limit hot path.
|
||||
CREATE INDEX idx_chat_messages_owner_spend
|
||||
ON chat_messages (chat_id, created_at)
|
||||
WHERE total_cost_micros IS NOT NULL;
|
||||
@@ -1,3 +0,0 @@
|
||||
DROP INDEX idx_aibridge_model_thoughts_interception_id;
|
||||
|
||||
DROP TABLE aibridge_model_thoughts;
|
||||
@@ -1,10 +0,0 @@
|
||||
CREATE TABLE aibridge_model_thoughts (
|
||||
interception_id UUID NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
metadata jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_model_thoughts IS 'Audit log of model thinking in intercepted requests in AI Bridge';
|
||||
|
||||
CREATE INDEX idx_aibridge_model_thoughts_interception_id ON aibridge_model_thoughts(interception_id);
|
||||
-52
@@ -1,52 +0,0 @@
|
||||
DELETE FROM custom_roles
|
||||
WHERE name = 'organization-service-account' AND is_system = true;
|
||||
|
||||
ALTER TABLE organizations
|
||||
ADD COLUMN workspace_sharing_disabled boolean NOT NULL DEFAULT false;
|
||||
|
||||
-- Migrate back: 'none' -> disabled, everything else -> enabled.
|
||||
UPDATE organizations
|
||||
SET workspace_sharing_disabled = true
|
||||
WHERE shareable_workspace_owners = 'none';
|
||||
|
||||
ALTER TABLE organizations DROP COLUMN shareable_workspace_owners;
|
||||
|
||||
DROP TYPE shareable_workspace_owners;
|
||||
|
||||
-- Restore the original single-role trigger from migration 408.
|
||||
DROP TRIGGER IF EXISTS trigger_insert_organization_system_roles ON organizations;
|
||||
DROP FUNCTION IF EXISTS insert_organization_system_roles;
|
||||
|
||||
CREATE OR REPLACE FUNCTION insert_org_member_system_role() RETURNS trigger AS $$
|
||||
BEGIN
|
||||
INSERT INTO custom_roles (
|
||||
name,
|
||||
display_name,
|
||||
organization_id,
|
||||
site_permissions,
|
||||
org_permissions,
|
||||
user_permissions,
|
||||
member_permissions,
|
||||
is_system,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
'organization-member',
|
||||
'',
|
||||
NEW.id,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TRIGGER trigger_insert_org_member_system_role
|
||||
AFTER INSERT ON organizations
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION insert_org_member_system_role();
|
||||
@@ -1,101 +0,0 @@
|
||||
CREATE TYPE shareable_workspace_owners AS ENUM ('none', 'everyone', 'service_accounts');
|
||||
|
||||
ALTER TABLE organizations
|
||||
ADD COLUMN shareable_workspace_owners shareable_workspace_owners NOT NULL DEFAULT 'everyone';
|
||||
|
||||
COMMENT ON COLUMN organizations.shareable_workspace_owners IS 'Controls whose workspaces can be shared: none, everyone, or service_accounts.';
|
||||
|
||||
-- Migrate existing data from the boolean column.
|
||||
UPDATE organizations
|
||||
SET shareable_workspace_owners = 'none'
|
||||
WHERE workspace_sharing_disabled = true;
|
||||
|
||||
ALTER TABLE organizations DROP COLUMN workspace_sharing_disabled;
|
||||
|
||||
-- Defensively rename any existing 'organization-service-account' roles
|
||||
-- so they don't collide with the new system role.
|
||||
UPDATE custom_roles
|
||||
SET name = name || '-' || id::text
|
||||
-- lower(name) is part of the existing unique index
|
||||
WHERE lower(name) = 'organization-service-account';
|
||||
|
||||
-- Create skeleton organization-service-account system roles for all
|
||||
-- existing organizations, mirroring what migration 408 did for
|
||||
-- organization-member.
|
||||
INSERT INTO custom_roles (
|
||||
name,
|
||||
display_name,
|
||||
organization_id,
|
||||
site_permissions,
|
||||
org_permissions,
|
||||
user_permissions,
|
||||
member_permissions,
|
||||
is_system,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
SELECT
|
||||
'organization-service-account',
|
||||
'',
|
||||
id,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
FROM
|
||||
organizations;
|
||||
|
||||
-- Replace the single-role trigger with one that creates both system
|
||||
-- roles when a new organization is inserted.
|
||||
DROP TRIGGER IF EXISTS trigger_insert_org_member_system_role ON organizations;
|
||||
DROP FUNCTION IF EXISTS insert_org_member_system_role;
|
||||
|
||||
CREATE OR REPLACE FUNCTION insert_organization_system_roles() RETURNS trigger AS $$
|
||||
BEGIN
|
||||
INSERT INTO custom_roles (
|
||||
name,
|
||||
display_name,
|
||||
organization_id,
|
||||
site_permissions,
|
||||
org_permissions,
|
||||
user_permissions,
|
||||
member_permissions,
|
||||
is_system,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES
|
||||
(
|
||||
'organization-member',
|
||||
'',
|
||||
NEW.id,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
),
|
||||
(
|
||||
'organization-service-account',
|
||||
'',
|
||||
NEW.id,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TRIGGER trigger_insert_organization_system_roles
|
||||
AFTER INSERT ON organizations
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION insert_organization_system_roles();
|
||||
Vendored
-28
@@ -1,28 +0,0 @@
|
||||
-- Fixture for migration 000443_three_options_for_allowed_workspace_sharing.
|
||||
-- Inserts a custom role named 'Organization-Service-Account' (mixed case)
|
||||
-- to ensure the migration's case-insensitive rename catches it.
|
||||
INSERT INTO custom_roles (
|
||||
name,
|
||||
display_name,
|
||||
organization_id,
|
||||
site_permissions,
|
||||
org_permissions,
|
||||
user_permissions,
|
||||
member_permissions,
|
||||
is_system,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
VALUES (
|
||||
'Organization-Service-Account',
|
||||
'User-created role',
|
||||
'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1',
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
false,
|
||||
NOW(),
|
||||
NOW()
|
||||
)
|
||||
ON CONFLICT DO NOTHING;
|
||||
@@ -1,5 +0,0 @@
|
||||
UPDATE users SET chat_spend_limit_micros = 5000000
|
||||
WHERE id = 'fc1511ef-4fcf-4a3b-98a1-8df64160e35a';
|
||||
|
||||
UPDATE groups SET chat_spend_limit_micros = 10000000
|
||||
WHERE id = 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1';
|
||||
-13
@@ -1,13 +0,0 @@
|
||||
INSERT INTO
|
||||
aibridge_model_thoughts (
|
||||
interception_id,
|
||||
content,
|
||||
metadata,
|
||||
created_at
|
||||
)
|
||||
VALUES (
|
||||
'be003e1e-b38f-43bf-847d-928074dd0aa8', -- from 000370_aibridge.up.sql
|
||||
'The user is asking about their workspaces. I should use the coder_list_workspaces tool to retrieve this information.',
|
||||
'{"source": "commentary"}',
|
||||
'2025-09-15 12:45:19.123456+00'
|
||||
);
|
||||
@@ -52,7 +52,6 @@ type customQuerier interface {
|
||||
auditLogQuerier
|
||||
connectionLogQuerier
|
||||
aibridgeQuerier
|
||||
chatQuerier
|
||||
}
|
||||
|
||||
type templateQuerier interface {
|
||||
@@ -452,7 +451,6 @@ func (q *sqlQuerier) GetAuthorizedUsers(ctx context.Context, arg GetUsersParams,
|
||||
&i.OneTimePasscodeExpiresAt,
|
||||
&i.IsSystem,
|
||||
&i.IsServiceAccount,
|
||||
&i.ChatSpendLimitMicros,
|
||||
&i.Count,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
@@ -739,68 +737,6 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
|
||||
return count, nil
|
||||
}
|
||||
|
||||
type chatQuerier interface {
|
||||
GetAuthorizedChats(ctx context.Context, arg GetChatsParams, prepared rbac.PreparedAuthorized) ([]Chat, error)
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams, prepared rbac.PreparedAuthorized) ([]Chat, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(ctx, rbac.ConfigChats())
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
|
||||
filtered, err := insertAuthorizedFilter(getChats, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
// The name comment is for metric tracking
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedChats :many\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.OwnerID,
|
||||
arg.Archived,
|
||||
arg.AfterID,
|
||||
arg.OffsetOpt,
|
||||
arg.LimitOpt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Chat
|
||||
for rows.Next() {
|
||||
var i Chat
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ParentChatID,
|
||||
&i.RootChatID,
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
type aibridgeQuerier interface {
|
||||
ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error)
|
||||
CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error)
|
||||
|
||||
+12
-94
@@ -3131,67 +3131,6 @@ func AllResourceTypeValues() []ResourceType {
|
||||
}
|
||||
}
|
||||
|
||||
type ShareableWorkspaceOwners string
|
||||
|
||||
const (
|
||||
ShareableWorkspaceOwnersNone ShareableWorkspaceOwners = "none"
|
||||
ShareableWorkspaceOwnersEveryone ShareableWorkspaceOwners = "everyone"
|
||||
ShareableWorkspaceOwnersServiceAccounts ShareableWorkspaceOwners = "service_accounts"
|
||||
)
|
||||
|
||||
func (e *ShareableWorkspaceOwners) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = ShareableWorkspaceOwners(s)
|
||||
case string:
|
||||
*e = ShareableWorkspaceOwners(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for ShareableWorkspaceOwners: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type NullShareableWorkspaceOwners struct {
|
||||
ShareableWorkspaceOwners ShareableWorkspaceOwners `json:"shareable_workspace_owners"`
|
||||
Valid bool `json:"valid"` // Valid is true if ShareableWorkspaceOwners is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (ns *NullShareableWorkspaceOwners) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
ns.ShareableWorkspaceOwners, ns.Valid = "", false
|
||||
return nil
|
||||
}
|
||||
ns.Valid = true
|
||||
return ns.ShareableWorkspaceOwners.Scan(value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (ns NullShareableWorkspaceOwners) Value() (driver.Value, error) {
|
||||
if !ns.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return string(ns.ShareableWorkspaceOwners), nil
|
||||
}
|
||||
|
||||
func (e ShareableWorkspaceOwners) Valid() bool {
|
||||
switch e {
|
||||
case ShareableWorkspaceOwnersNone,
|
||||
ShareableWorkspaceOwnersEveryone,
|
||||
ShareableWorkspaceOwnersServiceAccounts:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AllShareableWorkspaceOwnersValues() []ShareableWorkspaceOwners {
|
||||
return []ShareableWorkspaceOwners{
|
||||
ShareableWorkspaceOwnersNone,
|
||||
ShareableWorkspaceOwnersEveryone,
|
||||
ShareableWorkspaceOwnersServiceAccounts,
|
||||
}
|
||||
}
|
||||
|
||||
type StartupScriptBehavior string
|
||||
|
||||
const (
|
||||
@@ -4038,14 +3977,6 @@ type AIBridgeInterception struct {
|
||||
ClientSessionID sql.NullString `db:"client_session_id" json:"client_session_id"`
|
||||
}
|
||||
|
||||
// Audit log of model thinking in intercepted requests in AI Bridge
|
||||
type AIBridgeModelThought struct {
|
||||
InterceptionID uuid.UUID `db:"interception_id" json:"interception_id"`
|
||||
Content string `db:"content" json:"content"`
|
||||
Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
// Audit log of tokens used by intercepted requests in AI Bridge
|
||||
type AIBridgeTokenUsage struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
@@ -4265,16 +4196,6 @@ type ChatQueuedMessage struct {
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
type ChatUsageLimitConfig struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
Singleton bool `db:"singleton" json:"singleton"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
DefaultLimitMicros int64 `db:"default_limit_micros" json:"default_limit_micros"`
|
||||
Period string `db:"period" json:"period"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
type ConnectionLog struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
ConnectTime time.Time `db:"connect_time" json:"connect_time"`
|
||||
@@ -4387,8 +4308,7 @@ type Group struct {
|
||||
// Display name is a custom, human-friendly group name that user can set. This is not required to be unique and can be the empty string.
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
// Source indicates how the group was created. It can be created by a user manually, or through some system process like OIDC group sync.
|
||||
Source GroupSource `db:"source" json:"source"`
|
||||
ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"`
|
||||
Source GroupSource `db:"source" json:"source"`
|
||||
}
|
||||
|
||||
// Joins group members with user information, organization ID, group name. Includes both regular group members and organization members (as part of the "Everyone" group).
|
||||
@@ -4596,17 +4516,16 @@ type OAuth2ProviderAppToken struct {
|
||||
}
|
||||
|
||||
type Organization struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
IsDefault bool `db:"is_default" json:"is_default"`
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
Icon string `db:"icon" json:"icon"`
|
||||
Deleted bool `db:"deleted" json:"deleted"`
|
||||
// Controls whose workspaces can be shared: none, everyone, or service_accounts.
|
||||
ShareableWorkspaceOwners ShareableWorkspaceOwners `db:"shareable_workspace_owners" json:"shareable_workspace_owners"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
IsDefault bool `db:"is_default" json:"is_default"`
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
Icon string `db:"icon" json:"icon"`
|
||||
Deleted bool `db:"deleted" json:"deleted"`
|
||||
WorkspaceSharingDisabled bool `db:"workspace_sharing_disabled" json:"workspace_sharing_disabled"`
|
||||
}
|
||||
|
||||
type OrganizationMember struct {
|
||||
@@ -5159,8 +5078,7 @@ type User struct {
|
||||
// Determines if a user is a system user, and therefore cannot login or perform normal actions
|
||||
IsSystem bool `db:"is_system" json:"is_system"`
|
||||
// Determines if a user is an admin-managed account that cannot login
|
||||
IsServiceAccount bool `db:"is_service_account" json:"is_service_account"`
|
||||
ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"`
|
||||
IsServiceAccount bool `db:"is_service_account" json:"is_service_account"`
|
||||
}
|
||||
|
||||
type UserConfig struct {
|
||||
|
||||
@@ -77,9 +77,6 @@ type sqlcQuerier interface {
|
||||
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
|
||||
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
|
||||
CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error)
|
||||
// Counts enabled, non-deleted model configs that lack both input and
|
||||
// output pricing in their JSONB options.cost configuration.
|
||||
CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error)
|
||||
// CountInProgressPrebuilds returns the number of in-progress prebuilds, grouped by preset ID and transition.
|
||||
// Prebuild considered in-progress if it's in the "pending", "starting", "stopping", or "deleting" state.
|
||||
CountInProgressPrebuilds(ctx context.Context) ([]CountInProgressPrebuildsRow, error)
|
||||
@@ -102,8 +99,6 @@ type sqlcQuerier interface {
|
||||
DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteChatQueuedMessage(ctx context.Context, arg DeleteChatQueuedMessageParams) error
|
||||
DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error
|
||||
DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error
|
||||
DeleteCryptoKey(ctx context.Context, arg DeleteCryptoKeyParams) (CryptoKey, error)
|
||||
DeleteCustomRole(ctx context.Context, arg DeleteCustomRoleParams) error
|
||||
DeleteExpiredAPIKeys(ctx context.Context, arg DeleteExpiredAPIKeysParams) (int64, error)
|
||||
@@ -150,7 +145,7 @@ type sqlcQuerier interface {
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteWorkspaceACLsByOrganization(ctx context.Context, arg DeleteWorkspaceACLsByOrganizationParams) error
|
||||
DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error
|
||||
DeleteWorkspaceAgentPortShare(ctx context.Context, arg DeleteWorkspaceAgentPortShareParams) error
|
||||
DeleteWorkspaceAgentPortSharesByTemplate(ctx context.Context, templateID uuid.UUID) error
|
||||
DeleteWorkspaceSubAgentByID(ctx context.Context, id uuid.UUID) error
|
||||
@@ -234,7 +229,6 @@ type sqlcQuerier interface {
|
||||
// Aggregate cost summary for a single user within a date range.
|
||||
// Only counts assistant-role messages.
|
||||
GetChatCostSummary(ctx context.Context, arg GetChatCostSummaryParams) (GetChatCostSummaryRow, error)
|
||||
GetChatDesktopEnabled(ctx context.Context) (bool, error)
|
||||
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
|
||||
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
|
||||
GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error)
|
||||
@@ -250,10 +244,7 @@ type sqlcQuerier interface {
|
||||
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
|
||||
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
|
||||
GetChatSystemPrompt(ctx context.Context) (string, error)
|
||||
GetChatUsageLimitConfig(ctx context.Context) (ChatUsageLimitConfig, error)
|
||||
GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (GetChatUsageLimitGroupOverrideRow, error)
|
||||
GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (GetChatUsageLimitUserOverrideRow, error)
|
||||
GetChats(ctx context.Context, arg GetChatsParams) ([]Chat, error)
|
||||
GetChatsByOwnerID(ctx context.Context, arg GetChatsByOwnerIDParams) ([]Chat, error)
|
||||
GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error)
|
||||
GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error)
|
||||
GetCryptoKeys(ctx context.Context) ([]CryptoKey, error)
|
||||
@@ -387,6 +378,7 @@ type sqlcQuerier interface {
|
||||
// Blocks until the row is available for update.
|
||||
GetProvisionerJobByIDWithLock(ctx context.Context, id uuid.UUID) (ProvisionerJob, error)
|
||||
GetProvisionerJobTimingsByJobID(ctx context.Context, jobID uuid.UUID) ([]ProvisionerJobTiming, error)
|
||||
GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error)
|
||||
GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg GetProvisionerJobsByIDsWithQueuePositionParams) ([]GetProvisionerJobsByIDsWithQueuePositionRow, error)
|
||||
GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisioner(ctx context.Context, arg GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerParams) ([]GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow, error)
|
||||
GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]ProvisionerJob, error)
|
||||
@@ -505,11 +497,7 @@ type sqlcQuerier interface {
|
||||
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
|
||||
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
|
||||
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
|
||||
GetUserChatSpendInPeriod(ctx context.Context, arg GetUserChatSpendInPeriodParams) (int64, error)
|
||||
GetUserCount(ctx context.Context, includeSystem bool) (int64, error)
|
||||
// Returns the minimum (most restrictive) group limit for a user.
|
||||
// Returns -1 if the user has no group limits applied.
|
||||
GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error)
|
||||
// GetUserLatencyInsights returns the median and 95th percentile connection
|
||||
// latency that users have experienced. The result can be filtered on
|
||||
// template_ids, meaning only user data from workspaces based on those templates
|
||||
@@ -613,7 +601,6 @@ type sqlcQuerier interface {
|
||||
GetWorkspacesEligibleForTransition(ctx context.Context, now time.Time) ([]GetWorkspacesEligibleForTransitionRow, error)
|
||||
GetWorkspacesForWorkspaceMetrics(ctx context.Context) ([]GetWorkspacesForWorkspaceMetricsRow, error)
|
||||
InsertAIBridgeInterception(ctx context.Context, arg InsertAIBridgeInterceptionParams) (AIBridgeInterception, error)
|
||||
InsertAIBridgeModelThought(ctx context.Context, arg InsertAIBridgeModelThoughtParams) (AIBridgeModelThought, error)
|
||||
InsertAIBridgeTokenUsage(ctx context.Context, arg InsertAIBridgeTokenUsageParams) (AIBridgeTokenUsage, error)
|
||||
InsertAIBridgeToolUsage(ctx context.Context, arg InsertAIBridgeToolUsageParams) (AIBridgeToolUsage, error)
|
||||
InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIBridgeUserPromptParams) (AIBridgeUserPrompt, error)
|
||||
@@ -710,8 +697,6 @@ type sqlcQuerier interface {
|
||||
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
|
||||
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
|
||||
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
|
||||
ListChatUsageLimitGroupOverrides(ctx context.Context) ([]ListChatUsageLimitGroupOverridesRow, error)
|
||||
ListChatUsageLimitOverrides(ctx context.Context) ([]ListChatUsageLimitOverridesRow, error)
|
||||
ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
|
||||
@@ -732,12 +717,6 @@ type sqlcQuerier interface {
|
||||
ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error
|
||||
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
|
||||
RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error)
|
||||
// Resolves the effective spend limit for a user using the hierarchy:
|
||||
// 1. Individual user override (highest priority)
|
||||
// 2. Minimum group limit across all user's groups
|
||||
// 3. Global default from config
|
||||
// Returns -1 if limits are not enabled.
|
||||
ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error)
|
||||
RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error
|
||||
// Note that this selects from the CTE, not the original table. The CTE is named
|
||||
// the same as the original table to trick sqlc into reusing the existing struct
|
||||
@@ -866,13 +845,9 @@ type sqlcQuerier interface {
|
||||
// cumulative values for unique counts (accurate period totals). Request counts
|
||||
// are always deltas, accumulated in DB. Returns true if insert, false if update.
|
||||
UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBoundaryUsageStatsParams) (bool, error)
|
||||
UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error
|
||||
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
|
||||
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
|
||||
UpsertChatSystemPrompt(ctx context.Context, value string) error
|
||||
UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error)
|
||||
UpsertChatUsageLimitGroupOverride(ctx context.Context, arg UpsertChatUsageLimitGroupOverrideParams) (UpsertChatUsageLimitGroupOverrideRow, error)
|
||||
UpsertChatUsageLimitUserOverride(ctx context.Context, arg UpsertChatUsageLimitUserOverrideParams) (UpsertChatUsageLimitUserOverrideRow, error)
|
||||
UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error)
|
||||
// The default proxy is implied and not actually stored in the database.
|
||||
// So we need to store it's configuration here for display purposes.
|
||||
|
||||
+58
-391
@@ -1235,230 +1235,6 @@ func TestGetAuthorizedWorkspacesAndAgentsByOwnerID(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAuthorizedChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
if testing.Short() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
sqlDB := testSQLDB(t)
|
||||
err := migrations.Up(sqlDB)
|
||||
require.NoError(t, err)
|
||||
db := database.New(sqlDB)
|
||||
authorizer := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||||
|
||||
// Create users with different roles.
|
||||
owner := dbgen.User(t, db, database.User{
|
||||
RBACRoles: []string{rbac.RoleOwner().String()},
|
||||
})
|
||||
member := dbgen.User(t, db, database.User{})
|
||||
secondMember := dbgen.User(t, db, database.User{})
|
||||
|
||||
// Create FK dependencies: a chat provider and model config.
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "test-model",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 80,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create 3 chats owned by owner.
|
||||
for i := range 3 {
|
||||
_, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: fmt.Sprintf("owner chat %d", i+1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create 2 chats owned by member.
|
||||
for i := range 2 {
|
||||
_, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: member.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: fmt.Sprintf("member chat %d", i+1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("sqlQuerier", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
// Member should only see their own 2 chats.
|
||||
memberSubject, _, err := httpmw.UserRBACSubject(ctx, db, member.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
||||
require.NoError(t, err)
|
||||
preparedMember, err := authorizer.Prepare(ctx, memberSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
||||
require.NoError(t, err)
|
||||
memberRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedMember)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, memberRows, 2)
|
||||
for _, row := range memberRows {
|
||||
require.Equal(t, member.ID, row.OwnerID, "member should only see own chats")
|
||||
}
|
||||
|
||||
// Owner should see at least the 5 pre-created chats (site-wide
|
||||
// access). Parallel subtests may add more, so use GreaterOrEqual.
|
||||
ownerSubject, _, err := httpmw.UserRBACSubject(ctx, db, owner.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
||||
require.NoError(t, err)
|
||||
preparedOwner, err := authorizer.Prepare(ctx, ownerSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
||||
require.NoError(t, err)
|
||||
ownerRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedOwner)
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, len(ownerRows), 5)
|
||||
|
||||
// secondMember has no chats and should see 0.
|
||||
secondSubject, _, err := httpmw.UserRBACSubject(ctx, db, secondMember.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
||||
require.NoError(t, err)
|
||||
preparedSecond, err := authorizer.Prepare(ctx, secondSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
||||
require.NoError(t, err)
|
||||
secondRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedSecond)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secondRows, 0)
|
||||
|
||||
// Org admin should NOT see other users' chats — chats are
|
||||
// not org-scoped resources.
|
||||
orgs, err := db.GetOrganizations(ctx, database.GetOrganizationsParams{})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, orgs)
|
||||
orgAdmin := dbgen.User(t, db, database.User{})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: orgAdmin.ID,
|
||||
OrganizationID: orgs[0].ID,
|
||||
Roles: []string{rbac.RoleOrgAdmin()},
|
||||
})
|
||||
orgAdminSubject, _, err := httpmw.UserRBACSubject(ctx, db, orgAdmin.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
||||
require.NoError(t, err)
|
||||
preparedOrgAdmin, err := authorizer.Prepare(ctx, orgAdminSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
||||
require.NoError(t, err)
|
||||
orgAdminRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedOrgAdmin)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, orgAdminRows, 0, "org admin with no chats should see 0 chats")
|
||||
|
||||
// OwnerID filter: member queries their own chats.
|
||||
memberFilterSelf, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
|
||||
OwnerID: member.ID,
|
||||
}, preparedMember)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, memberFilterSelf, 2)
|
||||
|
||||
// OwnerID filter: member queries owner's chats → sees 0.
|
||||
memberFilterOwner, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
|
||||
OwnerID: owner.ID,
|
||||
}, preparedMember)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, memberFilterOwner, 0)
|
||||
})
|
||||
|
||||
t.Run("dbauthz", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
authzdb := dbauthz.New(db, authorizer, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
||||
|
||||
// As member: should see only own 2 chats.
|
||||
memberSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, member.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
||||
require.NoError(t, err)
|
||||
memberCtx := dbauthz.As(ctx, memberSubject)
|
||||
memberRows, err := authzdb.GetChats(memberCtx, database.GetChatsParams{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, memberRows, 2)
|
||||
for _, row := range memberRows {
|
||||
require.Equal(t, member.ID, row.OwnerID, "member should only see own chats")
|
||||
}
|
||||
|
||||
// As owner: should see at least the 5 pre-created chats.
|
||||
ownerSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, owner.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
||||
require.NoError(t, err)
|
||||
ownerCtx := dbauthz.As(ctx, ownerSubject)
|
||||
ownerRows, err := authzdb.GetChats(ownerCtx, database.GetChatsParams{})
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, len(ownerRows), 5)
|
||||
|
||||
// As secondMember: should see 0 chats.
|
||||
secondSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, secondMember.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
||||
require.NoError(t, err)
|
||||
secondCtx := dbauthz.As(ctx, secondSubject)
|
||||
secondRows, err := authzdb.GetChats(secondCtx, database.GetChatsParams{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secondRows, 0)
|
||||
})
|
||||
|
||||
t.Run("pagination", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
// Use a dedicated user for pagination to avoid interference
|
||||
// with the other parallel subtests.
|
||||
paginationUser := dbgen.User(t, db, database.User{})
|
||||
for i := range 7 {
|
||||
_, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: paginationUser.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: fmt.Sprintf("pagination chat %d", i+1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
pagUserSubject, _, err := httpmw.UserRBACSubject(ctx, db, paginationUser.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
||||
require.NoError(t, err)
|
||||
preparedMember, err := authorizer.Prepare(ctx, pagUserSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fetch first page with limit=2.
|
||||
page1, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
|
||||
LimitOpt: 2,
|
||||
}, preparedMember)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, page1, 2)
|
||||
for _, row := range page1 {
|
||||
require.Equal(t, paginationUser.ID, row.OwnerID, "paginated results must belong to pagination user")
|
||||
}
|
||||
|
||||
// Fetch remaining pages and collect all chat IDs.
|
||||
allIDs := make(map[uuid.UUID]struct{})
|
||||
for _, row := range page1 {
|
||||
allIDs[row.ID] = struct{}{}
|
||||
}
|
||||
offset := int32(2)
|
||||
for {
|
||||
page, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
|
||||
LimitOpt: 2,
|
||||
OffsetOpt: offset,
|
||||
}, preparedMember)
|
||||
require.NoError(t, err)
|
||||
for _, row := range page {
|
||||
require.Equal(t, paginationUser.ID, row.OwnerID, "paginated results must belong to pagination user")
|
||||
allIDs[row.ID] = struct{}{}
|
||||
}
|
||||
if len(page) < 2 {
|
||||
break
|
||||
}
|
||||
offset += int32(len(page)) //nolint:gosec // Test code, pagination values are small.
|
||||
}
|
||||
|
||||
// All 7 member chats should be accounted for with no leakage.
|
||||
require.Len(t, allIDs, 7, "pagination should return all member chats exactly once")
|
||||
})
|
||||
}
|
||||
|
||||
func TestInsertWorkspaceAgentLogs(t *testing.T) {
|
||||
t.Parallel()
|
||||
if testing.Short() {
|
||||
@@ -2655,42 +2431,6 @@ func TestDeleteCustomRoleDoesNotDeleteSystemRole(t *testing.T) {
|
||||
require.True(t, roles[0].IsSystem)
|
||||
}
|
||||
|
||||
func TestGetAuthorizationUserRolesImpliedOrgRole(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
regularUser := dbgen.User(t, db, database.User{})
|
||||
saUser := dbgen.User(t, db, database.User{IsServiceAccount: true})
|
||||
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org.ID,
|
||||
UserID: regularUser.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org.ID,
|
||||
UserID: saUser.ID,
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
wantMember := rbac.RoleOrgMember() + ":" + org.ID.String()
|
||||
wantSA := rbac.RoleOrgServiceAccount() + ":" + org.ID.String()
|
||||
|
||||
// Regular users get the implied organization-member role.
|
||||
regularRoles, err := db.GetAuthorizationUserRoles(ctx, regularUser.ID)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, regularRoles.Roles, wantMember)
|
||||
require.NotContains(t, regularRoles.Roles, wantSA)
|
||||
|
||||
// Service accounts get the implied organization-service-account role.
|
||||
saRoles, err := db.GetAuthorizationUserRoles(ctx, saUser.ID)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, saRoles.Roles, wantSA)
|
||||
require.NotContains(t, saRoles.Roles, wantMember)
|
||||
}
|
||||
|
||||
func TestUpdateOrganizationWorkspaceSharingSettings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -2701,155 +2441,82 @@ func TestUpdateOrganizationWorkspaceSharingSettings(t *testing.T) {
|
||||
|
||||
updated, err := db.UpdateOrganizationWorkspaceSharingSettings(ctx, database.UpdateOrganizationWorkspaceSharingSettingsParams{
|
||||
ID: org.ID,
|
||||
ShareableWorkspaceOwners: database.ShareableWorkspaceOwnersNone,
|
||||
WorkspaceSharingDisabled: true,
|
||||
UpdatedAt: dbtime.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ShareableWorkspaceOwnersNone, updated.ShareableWorkspaceOwners)
|
||||
require.True(t, updated.WorkspaceSharingDisabled)
|
||||
|
||||
got, err := db.GetOrganizationByID(ctx, org.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ShareableWorkspaceOwnersNone, got.ShareableWorkspaceOwners)
|
||||
require.True(t, got.WorkspaceSharingDisabled)
|
||||
}
|
||||
|
||||
func TestDeleteWorkspaceACLsByOrganization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("DeletesAll", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
org1 := dbgen.Organization(t, db, database.Organization{})
|
||||
org2 := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
org1 := dbgen.Organization(t, db, database.Organization{})
|
||||
org2 := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
owner1 := dbgen.User(t, db, database.User{})
|
||||
owner2 := dbgen.User(t, db, database.User{})
|
||||
sharedUser := dbgen.User(t, db, database.User{})
|
||||
sharedGroup := dbgen.Group(t, db, database.Group{
|
||||
OrganizationID: org1.ID,
|
||||
})
|
||||
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org1.ID,
|
||||
UserID: owner1.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org2.ID,
|
||||
UserID: owner2.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org1.ID,
|
||||
UserID: sharedUser.ID,
|
||||
})
|
||||
|
||||
ws1 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: owner1.ID,
|
||||
OrganizationID: org1.ID,
|
||||
UserACL: database.WorkspaceACL{
|
||||
sharedUser.ID.String(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
GroupACL: database.WorkspaceACL{
|
||||
sharedGroup.ID.String(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
}).Do().Workspace
|
||||
|
||||
ws2 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: owner2.ID,
|
||||
OrganizationID: org2.ID,
|
||||
UserACL: database.WorkspaceACL{
|
||||
uuid.NewString(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
}).Do().Workspace
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
err := db.DeleteWorkspaceACLsByOrganization(ctx, database.DeleteWorkspaceACLsByOrganizationParams{
|
||||
OrganizationID: org1.ID,
|
||||
ExcludeServiceAccounts: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got1, err := db.GetWorkspaceByID(ctx, ws1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, got1.UserACL)
|
||||
require.Empty(t, got1.GroupACL)
|
||||
|
||||
got2, err := db.GetWorkspaceByID(ctx, ws2.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, got2.UserACL)
|
||||
owner1 := dbgen.User(t, db, database.User{})
|
||||
owner2 := dbgen.User(t, db, database.User{})
|
||||
sharedUser := dbgen.User(t, db, database.User{})
|
||||
sharedGroup := dbgen.Group(t, db, database.Group{
|
||||
OrganizationID: org1.ID,
|
||||
})
|
||||
|
||||
t.Run("ExcludesServiceAccounts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org1.ID,
|
||||
UserID: owner1.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org2.ID,
|
||||
UserID: owner2.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org1.ID,
|
||||
UserID: sharedUser.ID,
|
||||
})
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
regularUser := dbgen.User(t, db, database.User{})
|
||||
saUser := dbgen.User(t, db, database.User{IsServiceAccount: true})
|
||||
sharedUser := dbgen.User(t, db, database.User{})
|
||||
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org.ID,
|
||||
UserID: regularUser.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org.ID,
|
||||
UserID: saUser.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org.ID,
|
||||
UserID: sharedUser.ID,
|
||||
})
|
||||
|
||||
regularWS := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: regularUser.ID,
|
||||
OrganizationID: org.ID,
|
||||
UserACL: database.WorkspaceACL{
|
||||
sharedUser.ID.String(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
}).Do().Workspace
|
||||
|
||||
saWS := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: saUser.ID,
|
||||
OrganizationID: org.ID,
|
||||
UserACL: database.WorkspaceACL{
|
||||
sharedUser.ID.String(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
}).Do().Workspace
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
err := db.DeleteWorkspaceACLsByOrganization(ctx, database.DeleteWorkspaceACLsByOrganizationParams{
|
||||
OrganizationID: org.ID,
|
||||
ExcludeServiceAccounts: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Regular user workspace ACLs should be cleared.
|
||||
gotRegular, err := db.GetWorkspaceByID(ctx, regularWS.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, gotRegular.UserACL)
|
||||
|
||||
// Service account workspace ACLs should be preserved.
|
||||
gotSA, err := db.GetWorkspaceByID(ctx, saWS.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.WorkspaceACL{
|
||||
ws1 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: owner1.ID,
|
||||
OrganizationID: org1.ID,
|
||||
UserACL: database.WorkspaceACL{
|
||||
sharedUser.ID.String(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
}, gotSA.UserACL)
|
||||
})
|
||||
},
|
||||
GroupACL: database.WorkspaceACL{
|
||||
sharedGroup.ID.String(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
}).Do().Workspace
|
||||
|
||||
ws2 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: owner2.ID,
|
||||
OrganizationID: org2.ID,
|
||||
UserACL: database.WorkspaceACL{
|
||||
uuid.NewString(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
}).Do().Workspace
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
err := db.DeleteWorkspaceACLsByOrganization(ctx, org1.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
got1, err := db.GetWorkspaceByID(ctx, ws1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, got1.UserACL)
|
||||
require.Empty(t, got1.GroupACL)
|
||||
|
||||
got2, err := db.GetWorkspaceByID(ctx, ws2.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, got2.UserACL)
|
||||
}
|
||||
|
||||
func TestAuthorizedAuditLogs(t *testing.T) {
|
||||
|
||||
+102
-539
File diff suppressed because it is too large
Load Diff
@@ -53,14 +53,6 @@ INSERT INTO aibridge_tool_usages (
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: InsertAIBridgeModelThought :one
|
||||
INSERT INTO aibridge_model_thoughts (
|
||||
interception_id, content, metadata, created_at
|
||||
) VALUES (
|
||||
@interception_id, @content, COALESCE(@metadata::jsonb, '{}'::jsonb), @created_at
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetAIBridgeInterceptionByID :one
|
||||
SELECT
|
||||
*
|
||||
@@ -370,11 +362,6 @@ WITH
|
||||
WHERE started_at < @before_time::timestamp with time zone
|
||||
),
|
||||
-- CTEs are executed in order.
|
||||
model_thoughts AS (
|
||||
DELETE FROM aibridge_model_thoughts
|
||||
WHERE interception_id IN (SELECT id FROM to_delete)
|
||||
RETURNING 1
|
||||
),
|
||||
tool_usages AS (
|
||||
DELETE FROM aibridge_tool_usages
|
||||
WHERE interception_id IN (SELECT id FROM to_delete)
|
||||
@@ -397,7 +384,6 @@ WITH
|
||||
)
|
||||
-- Cumulative count.
|
||||
SELECT (
|
||||
(SELECT COUNT(*) FROM model_thoughts) +
|
||||
(SELECT COUNT(*) FROM tool_usages) +
|
||||
(SELECT COUNT(*) FROM token_usages) +
|
||||
(SELECT COUNT(*) FROM user_prompts) +
|
||||
|
||||
@@ -113,16 +113,13 @@ ORDER BY
|
||||
created_at ASC,
|
||||
id ASC;
|
||||
|
||||
-- name: GetChats :many
|
||||
-- name: GetChatsByOwnerID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
CASE
|
||||
WHEN @owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN chats.owner_id = @owner_id
|
||||
ELSE true
|
||||
END
|
||||
owner_id = @owner_id::uuid
|
||||
AND CASE
|
||||
WHEN sqlc.narg('archived') :: boolean IS NULL THEN true
|
||||
ELSE chats.archived = sqlc.narg('archived') :: boolean
|
||||
@@ -146,8 +143,6 @@ WHERE
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedChats
|
||||
-- @authorize_filter
|
||||
ORDER BY
|
||||
-- Deterministic and consistent ordering of all rows, even if they share
|
||||
-- a timestamp. This is to ensure consistent pagination.
|
||||
@@ -705,128 +700,3 @@ LIMIT
|
||||
sqlc.arg('page_limit')::int
|
||||
OFFSET
|
||||
sqlc.arg('page_offset')::int;
|
||||
|
||||
-- name: GetChatUsageLimitConfig :one
|
||||
SELECT * FROM chat_usage_limit_config WHERE singleton = TRUE LIMIT 1;
|
||||
|
||||
-- name: UpsertChatUsageLimitConfig :one
|
||||
INSERT INTO chat_usage_limit_config (singleton, enabled, default_limit_micros, period, updated_at)
|
||||
VALUES (TRUE, @enabled::boolean, @default_limit_micros::bigint, @period::text, NOW())
|
||||
ON CONFLICT (singleton) DO UPDATE SET
|
||||
enabled = EXCLUDED.enabled,
|
||||
default_limit_micros = EXCLUDED.default_limit_micros,
|
||||
period = EXCLUDED.period,
|
||||
updated_at = NOW()
|
||||
RETURNING *;
|
||||
|
||||
-- name: ListChatUsageLimitOverrides :many
|
||||
SELECT u.id AS user_id, u.username, u.name, u.avatar_url,
|
||||
u.chat_spend_limit_micros AS spend_limit_micros
|
||||
FROM users u
|
||||
WHERE u.chat_spend_limit_micros IS NOT NULL
|
||||
ORDER BY u.username ASC;
|
||||
|
||||
-- name: UpsertChatUsageLimitUserOverride :one
|
||||
UPDATE users
|
||||
SET chat_spend_limit_micros = @spend_limit_micros::bigint
|
||||
WHERE id = @user_id::uuid
|
||||
RETURNING id AS user_id, username, name, avatar_url, chat_spend_limit_micros AS spend_limit_micros;
|
||||
|
||||
-- name: DeleteChatUsageLimitUserOverride :exec
|
||||
UPDATE users SET chat_spend_limit_micros = NULL WHERE id = @user_id::uuid;
|
||||
|
||||
-- name: GetChatUsageLimitUserOverride :one
|
||||
SELECT id AS user_id, chat_spend_limit_micros AS spend_limit_micros
|
||||
FROM users
|
||||
WHERE id = @user_id::uuid AND chat_spend_limit_micros IS NOT NULL;
|
||||
|
||||
-- name: GetUserChatSpendInPeriod :one
|
||||
SELECT COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_spend_micros
|
||||
FROM chat_messages cm
|
||||
JOIN chats c ON c.id = cm.chat_id
|
||||
WHERE c.owner_id = @user_id::uuid
|
||||
AND cm.created_at >= @start_time::timestamptz
|
||||
AND cm.created_at < @end_time::timestamptz
|
||||
AND cm.total_cost_micros IS NOT NULL;
|
||||
|
||||
-- name: CountEnabledModelsWithoutPricing :one
|
||||
-- Counts enabled, non-deleted model configs that lack both input and
|
||||
-- output pricing in their JSONB options.cost configuration.
|
||||
SELECT COUNT(*)::bigint AS count
|
||||
FROM chat_model_configs
|
||||
WHERE enabled = TRUE
|
||||
AND deleted = FALSE
|
||||
AND (
|
||||
options->'cost' IS NULL
|
||||
OR options->'cost' = 'null'::jsonb
|
||||
OR (
|
||||
(options->'cost'->>'input_price_per_million_tokens' IS NULL)
|
||||
AND (options->'cost'->>'output_price_per_million_tokens' IS NULL)
|
||||
)
|
||||
);
|
||||
|
||||
-- name: ListChatUsageLimitGroupOverrides :many
|
||||
SELECT
|
||||
g.id AS group_id,
|
||||
g.name AS group_name,
|
||||
g.display_name AS group_display_name,
|
||||
g.avatar_url AS group_avatar_url,
|
||||
g.chat_spend_limit_micros AS spend_limit_micros,
|
||||
(SELECT COUNT(*)
|
||||
FROM group_members_expanded gme
|
||||
WHERE gme.group_id = g.id
|
||||
AND gme.user_is_system = FALSE) AS member_count
|
||||
FROM groups g
|
||||
WHERE g.chat_spend_limit_micros IS NOT NULL
|
||||
ORDER BY g.name ASC;
|
||||
|
||||
-- name: UpsertChatUsageLimitGroupOverride :one
|
||||
UPDATE groups
|
||||
SET chat_spend_limit_micros = @spend_limit_micros::bigint
|
||||
WHERE id = @group_id::uuid
|
||||
RETURNING id AS group_id, name, display_name, avatar_url, chat_spend_limit_micros AS spend_limit_micros;
|
||||
|
||||
-- name: DeleteChatUsageLimitGroupOverride :exec
|
||||
UPDATE groups SET chat_spend_limit_micros = NULL WHERE id = @group_id::uuid;
|
||||
|
||||
-- name: GetChatUsageLimitGroupOverride :one
|
||||
SELECT id AS group_id, chat_spend_limit_micros AS spend_limit_micros
|
||||
FROM groups
|
||||
WHERE id = @group_id::uuid AND chat_spend_limit_micros IS NOT NULL;
|
||||
|
||||
-- name: GetUserGroupSpendLimit :one
|
||||
-- Returns the minimum (most restrictive) group limit for a user.
|
||||
-- Returns -1 if the user has no group limits applied.
|
||||
SELECT COALESCE(MIN(g.chat_spend_limit_micros), -1)::bigint AS limit_micros
|
||||
FROM groups g
|
||||
JOIN group_members_expanded gme ON gme.group_id = g.id
|
||||
WHERE gme.user_id = @user_id::uuid
|
||||
AND g.chat_spend_limit_micros IS NOT NULL;
|
||||
|
||||
-- name: ResolveUserChatSpendLimit :one
|
||||
-- Resolves the effective spend limit for a user using the hierarchy:
|
||||
-- 1. Individual user override (highest priority)
|
||||
-- 2. Minimum group limit across all user's groups
|
||||
-- 3. Global default from config
|
||||
-- Returns -1 if limits are not enabled.
|
||||
SELECT CASE
|
||||
-- If limits are disabled, return -1.
|
||||
WHEN NOT cfg.enabled THEN -1
|
||||
-- Individual override takes priority.
|
||||
WHEN u.chat_spend_limit_micros IS NOT NULL THEN u.chat_spend_limit_micros
|
||||
-- Group limit (minimum across all user's groups) is next.
|
||||
WHEN gl.limit_micros IS NOT NULL THEN gl.limit_micros
|
||||
-- Fall back to global default.
|
||||
ELSE cfg.default_limit_micros
|
||||
END::bigint AS effective_limit_micros
|
||||
FROM chat_usage_limit_config cfg
|
||||
CROSS JOIN users u
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT MIN(g.chat_spend_limit_micros) AS limit_micros
|
||||
FROM groups g
|
||||
JOIN group_members_expanded gme ON gme.group_id = g.id
|
||||
WHERE gme.user_id = @user_id::uuid
|
||||
AND g.chat_spend_limit_micros IS NOT NULL
|
||||
) gl ON TRUE
|
||||
WHERE u.id = @user_id::uuid
|
||||
LIMIT 1;
|
||||
|
||||
@@ -147,7 +147,7 @@ WHERE
|
||||
UPDATE
|
||||
organizations
|
||||
SET
|
||||
shareable_workspace_owners = @shareable_workspace_owners,
|
||||
workspace_sharing_disabled = @workspace_sharing_disabled,
|
||||
updated_at = @updated_at
|
||||
WHERE
|
||||
id = @id
|
||||
|
||||
@@ -67,6 +67,14 @@ WHERE
|
||||
id = $1
|
||||
FOR UPDATE;
|
||||
|
||||
-- name: GetProvisionerJobsByIDs :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
provisioner_jobs
|
||||
WHERE
|
||||
id = ANY(@ids :: uuid [ ]);
|
||||
|
||||
-- name: GetProvisionerJobsByIDsWithQueuePosition :many
|
||||
WITH filtered_provisioner_jobs AS (
|
||||
-- Step 1: Filter provisioner_jobs
|
||||
|
||||
@@ -140,23 +140,3 @@ SELECT
|
||||
-- name: UpsertChatSystemPrompt :exec
|
||||
INSERT INTO site_configs (key, value) VALUES ('agents_chat_system_prompt', $1)
|
||||
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_system_prompt';
|
||||
|
||||
-- name: GetChatDesktopEnabled :one
|
||||
SELECT
|
||||
COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_desktop_enabled'), false) :: boolean AS enable_desktop;
|
||||
|
||||
-- name: UpsertChatDesktopEnabled :exec
|
||||
INSERT INTO site_configs (key, value)
|
||||
VALUES (
|
||||
'agents_desktop_enabled',
|
||||
CASE
|
||||
WHEN sqlc.arg(enable_desktop)::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET value = CASE
|
||||
WHEN sqlc.arg(enable_desktop)::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
WHERE site_configs.key = 'agents_desktop_enabled';
|
||||
|
||||
@@ -391,21 +391,9 @@ SELECT
|
||||
array_agg(org_roles || ':' || organization_members.organization_id::text)
|
||||
FROM
|
||||
organization_members,
|
||||
-- All org members get an implied role for their orgs. Most members
|
||||
-- get organization-member, but service accounts will get
|
||||
-- organization-service-account instead. They're largely the same,
|
||||
-- but having them be distinct means we can allow configuring
|
||||
-- service-accounts to have slightly broader permissions–such as
|
||||
-- for workspace sharing.
|
||||
-- All org_members get the organization-member role for their orgs
|
||||
unnest(
|
||||
array_append(
|
||||
roles,
|
||||
CASE WHEN users.is_service_account THEN
|
||||
'organization-service-account'
|
||||
ELSE
|
||||
'organization-member'
|
||||
END
|
||||
)
|
||||
array_append(roles, 'organization-member')
|
||||
) AS org_roles
|
||||
WHERE
|
||||
user_id = users.id
|
||||
|
||||
@@ -955,13 +955,7 @@ SET
|
||||
group_acl = '{}'::jsonb,
|
||||
user_acl = '{}'::jsonb
|
||||
WHERE
|
||||
organization_id = @organization_id
|
||||
AND (
|
||||
NOT @exclude_service_accounts::boolean
|
||||
OR owner_id NOT IN (
|
||||
SELECT id FROM users WHERE is_service_account = true
|
||||
)
|
||||
);
|
||||
organization_id = @organization_id;
|
||||
|
||||
-- name: GetRegularWorkspaceCreateMetrics :many
|
||||
-- Count regular workspaces: only those whose first successful 'start' build
|
||||
|
||||
@@ -235,7 +235,6 @@ sql:
|
||||
aibridge_tool_usage: AIBridgeToolUsage
|
||||
aibridge_token_usage: AIBridgeTokenUsage
|
||||
aibridge_user_prompt: AIBridgeUserPrompt
|
||||
aibridge_model_thought: AIBridgeModelThought
|
||||
rules:
|
||||
- name: do-not-use-public-schema-in-queries
|
||||
message: "do not use public schema in queries"
|
||||
|
||||
+15
-15
@@ -20,28 +20,28 @@ import (
|
||||
// AuditOAuthConvertState is never stored in the database. It is stored in a cookie
|
||||
// clientside as a JWT. This type is provided for audit logging purposes.
|
||||
type AuditOAuthConvertState struct {
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
// The time at which the state string expires, a merge request times out if the user does not perform it quick enough.
|
||||
ExpiresAt time.Time `json:"expires_at" db:"expires_at"`
|
||||
FromLoginType LoginType `json:"from_login_type" db:"from_login_type"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
FromLoginType LoginType `db:"from_login_type" json:"from_login_type"`
|
||||
// The login type the user is converting to. Should be github or oidc.
|
||||
ToLoginType LoginType `json:"to_login_type" db:"to_login_type"`
|
||||
UserID uuid.UUID `json:"user_id" db:"user_id"`
|
||||
ToLoginType LoginType `db:"to_login_type" json:"to_login_type"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
}
|
||||
|
||||
type HealthSettings struct {
|
||||
ID uuid.UUID `json:"id" db:"id"`
|
||||
DismissedHealthchecks []string `json:"dismissed_healthchecks" db:"dismissed_healthchecks"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
DismissedHealthchecks []string `db:"dismissed_healthchecks" json:"dismissed_healthchecks"`
|
||||
}
|
||||
|
||||
type NotificationsSettings struct {
|
||||
ID uuid.UUID `json:"id" db:"id"`
|
||||
NotifierPaused bool `json:"notifier_paused" db:"notifier_paused"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
NotifierPaused bool `db:"notifier_paused" json:"notifier_paused"`
|
||||
}
|
||||
|
||||
type PrebuildsSettings struct {
|
||||
ID uuid.UUID `json:"id" db:"id"`
|
||||
ReconciliationPaused bool `json:"reconciliation_paused" db:"reconciliation_paused"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
ReconciliationPaused bool `db:"reconciliation_paused" json:"reconciliation_paused"`
|
||||
}
|
||||
|
||||
type Actions []policy.Action
|
||||
@@ -237,9 +237,9 @@ func (a CustomRolePermission) String() string {
|
||||
|
||||
// NameOrganizationPair is used as a lookup tuple for custom role rows.
|
||||
type NameOrganizationPair struct {
|
||||
Name string `json:"name" db:"name"`
|
||||
Name string `db:"name" json:"name"`
|
||||
// OrganizationID if unset will assume a null column value
|
||||
OrganizationID uuid.UUID `json:"organization_id" db:"organization_id"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
}
|
||||
|
||||
func (*NameOrganizationPair) Scan(_ interface{}) error {
|
||||
@@ -264,8 +264,8 @@ func (a NameOrganizationPair) Value() (driver.Value, error) {
|
||||
|
||||
// AgentIDNamePair is used as a result tuple for workspace and agent rows.
|
||||
type AgentIDNamePair struct {
|
||||
ID uuid.UUID `json:"id" db:"id"`
|
||||
Name string `json:"name" db:"name"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
}
|
||||
|
||||
func (p *AgentIDNamePair) Scan(src interface{}) error {
|
||||
|
||||
@@ -22,8 +22,6 @@ const (
|
||||
UniqueChatProvidersPkey UniqueConstraint = "chat_providers_pkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id);
|
||||
UniqueChatProvidersProviderKey UniqueConstraint = "chat_providers_provider_key" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider);
|
||||
UniqueChatQueuedMessagesPkey UniqueConstraint = "chat_queued_messages_pkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
|
||||
UniqueChatUsageLimitConfigPkey UniqueConstraint = "chat_usage_limit_config_pkey" // ALTER TABLE ONLY chat_usage_limit_config ADD CONSTRAINT chat_usage_limit_config_pkey PRIMARY KEY (id);
|
||||
UniqueChatUsageLimitConfigSingletonKey UniqueConstraint = "chat_usage_limit_config_singleton_key" // ALTER TABLE ONLY chat_usage_limit_config ADD CONSTRAINT chat_usage_limit_config_singleton_key UNIQUE (singleton);
|
||||
UniqueChatsPkey UniqueConstraint = "chats_pkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_pkey PRIMARY KEY (id);
|
||||
UniqueConnectionLogsPkey UniqueConstraint = "connection_logs_pkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_pkey PRIMARY KEY (id);
|
||||
UniqueCryptoKeysPkey UniqueConstraint = "crypto_keys_pkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_pkey PRIMARY KEY (feature, sequence);
|
||||
|
||||
@@ -48,8 +48,8 @@ type Store interface {
|
||||
UpsertChatDiffStatusReference(
|
||||
ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams,
|
||||
) (database.ChatDiffStatus, error)
|
||||
GetChats(
|
||||
ctx context.Context, arg database.GetChatsParams,
|
||||
GetChatsByOwnerID(
|
||||
ctx context.Context, arg database.GetChatsByOwnerIDParams,
|
||||
) ([]database.Chat, error)
|
||||
}
|
||||
|
||||
@@ -250,7 +250,7 @@ func (w *Worker) MarkStale(
|
||||
return
|
||||
}
|
||||
|
||||
chats, err := w.store.GetChats(ctx, database.GetChatsParams{
|
||||
chats, err := w.store.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
|
||||
OwnerID: ownerID,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -469,8 +469,8 @@ func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
|
||||
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
||||
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
|
||||
require.Equal(t, ownerID, arg.OwnerID)
|
||||
return []database.Chat{
|
||||
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
|
||||
@@ -478,12 +478,13 @@ func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) {
|
||||
{ID: chatOther, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
|
||||
}, nil
|
||||
})
|
||||
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
||||
mu.Lock()
|
||||
upsertRefCalls = append(upsertRefCalls, arg)
|
||||
mu.Unlock()
|
||||
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
||||
}).Times(2)
|
||||
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
||||
mu.Lock()
|
||||
upsertRefCalls = append(upsertRefCalls, arg)
|
||||
mu.Unlock()
|
||||
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
||||
}).Times(2)
|
||||
|
||||
pub := func(_ context.Context, chatID uuid.UUID) error {
|
||||
mu.Lock()
|
||||
@@ -526,7 +527,7 @@ func TestWorker_MarkStale_NoMatchingChats(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
|
||||
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
|
||||
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
|
||||
Return([]database.Chat{
|
||||
{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
|
||||
{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
|
||||
@@ -554,7 +555,7 @@ func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
|
||||
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
|
||||
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
|
||||
Return([]database.Chat{
|
||||
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
|
||||
{ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
|
||||
@@ -589,7 +590,7 @@ func TestWorker_MarkStale_GetChatsFails(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
|
||||
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
|
||||
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
|
||||
Return(nil, fmt.Errorf("db error"))
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
|
||||
@@ -688,15 +688,6 @@ func ConfigWithoutACL() regosql.ConvertConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigChats is the configuration for converting rego to SQL when
|
||||
// the target table is "chats", which has no organization_id or ACL
|
||||
// columns.
|
||||
func ConfigChats() regosql.ConvertConfig {
|
||||
return regosql.ConvertConfig{
|
||||
VariableConverter: regosql.ChatConverter(),
|
||||
}
|
||||
}
|
||||
|
||||
func ConfigWorkspaces() regosql.ConvertConfig {
|
||||
return regosql.ConvertConfig{
|
||||
VariableConverter: regosql.WorkspaceConverter(),
|
||||
|
||||
@@ -1404,8 +1404,8 @@ func testAuthorize(t *testing.T, name string, subject Subject, sets ...[]authTes
|
||||
// RoleByName won't resolve it here. Assume the default behavior: workspace
|
||||
// sharing enabled.
|
||||
func orgMemberRole(orgID uuid.UUID) Role {
|
||||
settings := OrgSettings{ShareableWorkspaceOwners: ShareableWorkspaceOwnersEveryone}
|
||||
perms := OrgMemberPermissions(settings)
|
||||
workspaceSharingDisabled := false
|
||||
orgPerms, memberPerms := OrgMemberPermissions(workspaceSharingDisabled)
|
||||
return Role{
|
||||
Identifier: ScopedRoleOrgMember(orgID),
|
||||
DisplayName: "",
|
||||
@@ -1413,8 +1413,8 @@ func orgMemberRole(orgID uuid.UUID) Role {
|
||||
User: []Permission{},
|
||||
ByOrgID: map[string]OrgPermissions{
|
||||
orgID.String(): {
|
||||
Org: perms.Org,
|
||||
Member: perms.Member,
|
||||
Org: orgPerms,
|
||||
Member: memberPerms,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -38,8 +38,8 @@ type Object struct {
|
||||
// Type is "workspace", "project", "app", etc
|
||||
Type string `json:"type"`
|
||||
|
||||
ACLUserList map[string][]policy.Action `json:"acl_user_list"`
|
||||
ACLGroupList map[string][]policy.Action `json:"acl_group_list"`
|
||||
ACLUserList map[string][]policy.Action ` json:"acl_user_list"`
|
||||
ACLGroupList map[string][]policy.Action ` json:"acl_group_list"`
|
||||
}
|
||||
|
||||
// String is not perfect, but decent enough for human display
|
||||
|
||||
@@ -282,22 +282,6 @@ neq(input.object.owner, "");
|
||||
p("'10d03e62-7703-4df5-a358-4f76577d4e2f' = id :: text") + " AND " + p("id :: text != ''") + " AND " + p("'' = ''"),
|
||||
),
|
||||
},
|
||||
{
|
||||
Name: "ChatOwnerMe",
|
||||
Queries: []string{
|
||||
`"me" = input.object.owner; input.object.owner != ""; input.object.org_owner = ""`,
|
||||
},
|
||||
ExpectedSQL: p(p("'me' = owner_id :: text") + " AND " + p("owner_id :: text != ''") + " AND " + p("'' = ''")),
|
||||
VariableConverter: regosql.ChatConverter(),
|
||||
},
|
||||
{
|
||||
Name: "ChatOrgScopedNeverMatches",
|
||||
Queries: []string{
|
||||
`input.object.org_owner = "org-id"`,
|
||||
},
|
||||
ExpectedSQL: p("'' = 'org-id'"),
|
||||
VariableConverter: regosql.ChatConverter(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -126,30 +126,6 @@ func NoACLConverter() *sqltypes.VariableConverter {
|
||||
return matcher
|
||||
}
|
||||
|
||||
// ChatConverter should be used for the chats table, which has no
|
||||
// organization_id, group_acl, or user_acl columns.
|
||||
func ChatConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
resourceIDMatcher(),
|
||||
// The chats table has no organization_id column. Map org_owner
|
||||
// to a literal empty string so that:
|
||||
// - User-level ownership checks (org_owner = '') activate correctly.
|
||||
// - Org-scoped permissions never match (org_owner will never equal
|
||||
// a real org UUID), which is intentional since chats are not
|
||||
// org-scoped resources.
|
||||
// Note: custom org roles that include "chat" permissions will
|
||||
// silently have no effect because of this mapping.
|
||||
sqltypes.StringVarMatcher("''", []string{"input", "object", "org_owner"}),
|
||||
userOwnerMatcher(),
|
||||
)
|
||||
matcher.RegisterMatcher(
|
||||
sqltypes.AlwaysFalse(groupACLMatcher(matcher)),
|
||||
sqltypes.AlwaysFalse(userACLMatcher(matcher)),
|
||||
)
|
||||
|
||||
return matcher
|
||||
}
|
||||
|
||||
func DefaultVariableConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
resourceIDMatcher(),
|
||||
|
||||
+79
-150
@@ -29,7 +29,6 @@ const (
|
||||
|
||||
orgAdmin string = "organization-admin"
|
||||
orgMember string = "organization-member"
|
||||
orgServiceAccount string = "organization-service-account"
|
||||
orgAuditor string = "organization-auditor"
|
||||
orgUserAdmin string = "organization-user-admin"
|
||||
orgTemplateAdmin string = "organization-template-admin"
|
||||
@@ -151,10 +150,6 @@ func RoleOrgMember() string {
|
||||
return orgMember
|
||||
}
|
||||
|
||||
func RoleOrgServiceAccount() string {
|
||||
return orgServiceAccount
|
||||
}
|
||||
|
||||
func RoleOrgAuditor() string {
|
||||
return orgAuditor
|
||||
}
|
||||
@@ -234,16 +229,31 @@ func allPermsExcept(excepts ...Objecter) []Permission {
|
||||
// https://github.com/coder/coder/issues/1194
|
||||
var builtInRoles map[string]func(orgID uuid.UUID) Role
|
||||
|
||||
// systemRoles are roles that have migrated from builtInRoles to
|
||||
// database storage. This migration is partial - permissions are still
|
||||
// generated at runtime and reconciled to the database, rather than
|
||||
// the database being the source of truth.
|
||||
var systemRoles = map[string]struct{}{
|
||||
RoleOrgMember(): {},
|
||||
}
|
||||
|
||||
func SystemRoleName(name string) bool {
|
||||
_, ok := systemRoles[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
type RoleOptions struct {
|
||||
NoOwnerWorkspaceExec bool
|
||||
NoWorkspaceSharing bool
|
||||
}
|
||||
|
||||
// ReservedRoleName exists because the database should only allow unique role
|
||||
// names, but some roles are built in. So these names are reserved
|
||||
// names, but some roles are built in or generated at runtime. So these names
|
||||
// are reserved
|
||||
func ReservedRoleName(name string) bool {
|
||||
_, ok := builtInRoles[name]
|
||||
return ok
|
||||
_, isBuiltIn := builtInRoles[name]
|
||||
_, isSystem := systemRoles[name]
|
||||
return isBuiltIn || isSystem
|
||||
}
|
||||
|
||||
// ReloadBuiltinRoles loads the static roles into the builtInRoles map.
|
||||
@@ -928,32 +938,21 @@ func PermissionsEqual(a, b []Permission) bool {
|
||||
return len(setA) == len(setB)
|
||||
}
|
||||
|
||||
// OrgSettings carries organization-level settings that affect system
|
||||
// role permissions. It lives in the rbac package to avoid a cyclic
|
||||
// dependency with the database package. Callers in rolestore map
|
||||
// database.Organization fields onto this struct.
|
||||
type OrgSettings struct {
|
||||
ShareableWorkspaceOwners ShareableWorkspaceOwners
|
||||
}
|
||||
type ShareableWorkspaceOwners string
|
||||
|
||||
const (
|
||||
ShareableWorkspaceOwnersNone ShareableWorkspaceOwners = "none"
|
||||
ShareableWorkspaceOwnersEveryone ShareableWorkspaceOwners = "everyone"
|
||||
ShareableWorkspaceOwnersServiceAccounts ShareableWorkspaceOwners = "service_accounts"
|
||||
)
|
||||
|
||||
// OrgRolePermissions holds the two permission sets that make up a
|
||||
// system role: org-wide permissions and member-scoped permissions.
|
||||
type OrgRolePermissions struct {
|
||||
Org []Permission
|
||||
Member []Permission
|
||||
}
|
||||
|
||||
// OrgMemberPermissions returns the permissions for the organization-member
|
||||
// system role, which can vary based on the organization's workspace sharing
|
||||
// settings.
|
||||
func OrgMemberPermissions(org OrgSettings) OrgRolePermissions {
|
||||
// system role. The results are then stored in the database and can vary per
|
||||
// organization based on the workspace_sharing_disabled setting.
|
||||
// This is the source of truth for org-member permissions, used by:
|
||||
// - the startup reconciliation routine, to keep permissions current with
|
||||
// RBAC resources
|
||||
// - the organization workspace sharing setting endpoint, when updating
|
||||
// the setting
|
||||
// - the org creation endpoint, when populating the organization-member
|
||||
// system role created by the DB trigger
|
||||
//
|
||||
//nolint:revive // workspaceSharingDisabled is an org setting
|
||||
func OrgMemberPermissions(workspaceSharingDisabled bool) (
|
||||
orgPerms, memberPerms []Permission,
|
||||
) {
|
||||
// Organization-level permissions that all org members get.
|
||||
orgPermMap := map[string][]policy.Action{
|
||||
// All users can see provisioner daemons for workspace creation.
|
||||
@@ -964,25 +963,58 @@ func OrgMemberPermissions(org OrgSettings) OrgRolePermissions {
|
||||
ResourceAssignOrgRole.Type: {policy.ActionRead},
|
||||
}
|
||||
|
||||
// In all modes of workspace sharing but `none`, members need to
|
||||
// see other org members (including service accounts) to either
|
||||
// share with them or get access to their shared workspaces,
|
||||
// resolved through GET /users/{user}/workspace/{workspace}
|
||||
if org.ShareableWorkspaceOwners != ShareableWorkspaceOwnersNone {
|
||||
// When workspace sharing is enabled, members need to see other org members
|
||||
// and groups to share workspaces with them.
|
||||
if !workspaceSharingDisabled {
|
||||
orgPermMap[ResourceOrganizationMember.Type] = []policy.Action{policy.ActionRead}
|
||||
}
|
||||
|
||||
// When workspace sharing is open to members, they also need to
|
||||
// see org groups to share with them.
|
||||
if org.ShareableWorkspaceOwners == ShareableWorkspaceOwnersEveryone {
|
||||
orgPermMap[ResourceGroup.Type] = []policy.Action{policy.ActionRead}
|
||||
}
|
||||
|
||||
orgPerms := Permissions(orgPermMap)
|
||||
orgPerms = Permissions(orgPermMap)
|
||||
|
||||
if org.ShareableWorkspaceOwners == ShareableWorkspaceOwnersNone {
|
||||
// Member-scoped permissions (resources owned by the member).
|
||||
// Uses allPermsExcept to automatically include permissions for new resources.
|
||||
memberPerms = append(
|
||||
allPermsExcept(
|
||||
ResourceWorkspaceDormant,
|
||||
ResourcePrebuiltWorkspace,
|
||||
ResourceUser,
|
||||
ResourceOrganizationMember,
|
||||
),
|
||||
Permissions(map[string][]policy.Action{
|
||||
// Reduced permission set on dormant workspaces. No build,
|
||||
// ssh, or exec.
|
||||
ResourceWorkspaceDormant.Type: {
|
||||
policy.ActionRead,
|
||||
policy.ActionDelete,
|
||||
policy.ActionCreate,
|
||||
policy.ActionUpdate,
|
||||
policy.ActionWorkspaceStop,
|
||||
policy.ActionCreateAgent,
|
||||
policy.ActionDeleteAgent,
|
||||
policy.ActionUpdateAgent,
|
||||
},
|
||||
// Can read their own organization member record.
|
||||
ResourceOrganizationMember.Type: {
|
||||
policy.ActionRead,
|
||||
},
|
||||
// Users can create provisioner daemons scoped to themselves.
|
||||
//
|
||||
// TODO(geokat): copied from the original built-in role
|
||||
// verbatim, but seems to be a no-op (not excepted above;
|
||||
// plus no owner is set for the ProvisionerDaemon RBAC
|
||||
// object).
|
||||
ResourceProvisionerDaemon.Type: {
|
||||
policy.ActionRead,
|
||||
policy.ActionCreate,
|
||||
policy.ActionUpdate,
|
||||
},
|
||||
})...,
|
||||
)
|
||||
|
||||
if workspaceSharingDisabled {
|
||||
// Org-level negation blocks sharing on ANY workspace in the
|
||||
// org. This overrides any positive permission from other
|
||||
// org. This overrides any positive permission from other
|
||||
// roles, including org-admin.
|
||||
orgPerms = append(orgPerms, Permission{
|
||||
Negate: true,
|
||||
@@ -991,108 +1023,5 @@ func OrgMemberPermissions(org OrgSettings) OrgRolePermissions {
|
||||
})
|
||||
}
|
||||
|
||||
// Uses allPermsExcept to automatically include permissions for new resources.
|
||||
memberPerms := append(
|
||||
allPermsExcept(
|
||||
ResourceWorkspaceDormant,
|
||||
ResourcePrebuiltWorkspace,
|
||||
ResourceUser,
|
||||
ResourceOrganizationMember,
|
||||
),
|
||||
Permissions(map[string][]policy.Action{
|
||||
// Reduced permission set on dormant workspaces. No build,
|
||||
// ssh, or exec.
|
||||
ResourceWorkspaceDormant.Type: {
|
||||
policy.ActionRead,
|
||||
policy.ActionDelete,
|
||||
policy.ActionCreate,
|
||||
policy.ActionUpdate,
|
||||
policy.ActionWorkspaceStop,
|
||||
policy.ActionCreateAgent,
|
||||
policy.ActionDeleteAgent,
|
||||
policy.ActionUpdateAgent,
|
||||
},
|
||||
// Can read their own organization member record.
|
||||
ResourceOrganizationMember.Type: {
|
||||
policy.ActionRead,
|
||||
},
|
||||
})...,
|
||||
)
|
||||
|
||||
if org.ShareableWorkspaceOwners != ShareableWorkspaceOwnersEveryone {
|
||||
memberPerms = append(memberPerms, Permission{
|
||||
Negate: true,
|
||||
ResourceType: ResourceWorkspace.Type,
|
||||
Action: policy.ActionShare,
|
||||
})
|
||||
}
|
||||
|
||||
return OrgRolePermissions{Org: orgPerms, Member: memberPerms}
|
||||
}
|
||||
|
||||
// OrgServiceAccountPermissions returns the permissions for the
|
||||
// organization-service-account system role, which can vary based on
|
||||
// the organization's workspace sharing settings.
|
||||
func OrgServiceAccountPermissions(org OrgSettings) OrgRolePermissions {
|
||||
// Organization-level permissions that all org service accounts get.
|
||||
orgPermMap := map[string][]policy.Action{
|
||||
// All users can see provisioner daemons for workspace creation.
|
||||
ResourceProvisionerDaemon.Type: {policy.ActionRead},
|
||||
// All org members can read the organization.
|
||||
ResourceOrganization.Type: {policy.ActionRead},
|
||||
// Can read available roles.
|
||||
ResourceAssignOrgRole.Type: {policy.ActionRead},
|
||||
}
|
||||
|
||||
// When workspace sharing is enabled, service accounts need to see
|
||||
// other org members and groups to share workspaces with them.
|
||||
if org.ShareableWorkspaceOwners != ShareableWorkspaceOwnersNone {
|
||||
orgPermMap[ResourceOrganizationMember.Type] = []policy.Action{policy.ActionRead}
|
||||
orgPermMap[ResourceGroup.Type] = []policy.Action{policy.ActionRead}
|
||||
}
|
||||
|
||||
orgPerms := Permissions(orgPermMap)
|
||||
|
||||
if org.ShareableWorkspaceOwners == ShareableWorkspaceOwnersNone {
|
||||
// Org-level negation blocks sharing on ANY workspace in the
|
||||
// org. If a service account has any other roles assigned,
|
||||
// this negation will override any positive perms in them, too.
|
||||
orgPerms = append(orgPerms, Permission{
|
||||
Negate: true,
|
||||
ResourceType: ResourceWorkspace.Type,
|
||||
Action: policy.ActionShare,
|
||||
})
|
||||
}
|
||||
|
||||
// service account-scoped permissions (resources owned by the
|
||||
// service account). Uses allPermsExcept to automatically include
|
||||
// permissions for new resources.
|
||||
memberPerms := append(
|
||||
allPermsExcept(
|
||||
ResourceWorkspaceDormant,
|
||||
ResourcePrebuiltWorkspace,
|
||||
ResourceUser,
|
||||
ResourceOrganizationMember,
|
||||
),
|
||||
Permissions(map[string][]policy.Action{
|
||||
// Reduced permission set on dormant workspaces. No build,
|
||||
// ssh, or exec.
|
||||
ResourceWorkspaceDormant.Type: {
|
||||
policy.ActionRead,
|
||||
policy.ActionDelete,
|
||||
policy.ActionCreate,
|
||||
policy.ActionUpdate,
|
||||
policy.ActionWorkspaceStop,
|
||||
policy.ActionCreateAgent,
|
||||
policy.ActionDeleteAgent,
|
||||
policy.ActionUpdateAgent,
|
||||
},
|
||||
// Can read their own organization member record.
|
||||
ResourceOrganizationMember.Type: {
|
||||
policy.ActionRead,
|
||||
},
|
||||
})...,
|
||||
)
|
||||
|
||||
return OrgRolePermissions{Org: orgPerms, Member: memberPerms}
|
||||
return orgPerms, memberPerms
|
||||
}
|
||||
|
||||
+40
-54
@@ -51,68 +51,54 @@ func TestBuiltInRoles(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// permissionGranted checks whether a permission list contains a
|
||||
// matching entry for the target, accounting for wildcard actions.
|
||||
// It does not evaluate negations that may override a positive grant.
|
||||
func permissionGranted(perms []rbac.Permission, target rbac.Permission) bool {
|
||||
return slices.ContainsFunc(perms, func(p rbac.Permission) bool {
|
||||
return p.Negate == target.Negate &&
|
||||
p.ResourceType == target.ResourceType &&
|
||||
(p.Action == target.Action || p.Action == policy.WildcardSymbol)
|
||||
})
|
||||
}
|
||||
|
||||
func TestOrgSharingPermissions(t *testing.T) {
|
||||
func TestSystemRolesAreReservedRoleNames(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
permsFunc func(rbac.OrgSettings) rbac.OrgRolePermissions
|
||||
mode rbac.ShareableWorkspaceOwners
|
||||
orgReadMembers bool
|
||||
orgReadGroups bool
|
||||
orgNegateShare bool
|
||||
memberNegateShare bool
|
||||
}{
|
||||
{"Member/Everyone", rbac.OrgMemberPermissions, rbac.ShareableWorkspaceOwnersEveryone, true, true, false, false},
|
||||
{"Member/None", rbac.OrgMemberPermissions, rbac.ShareableWorkspaceOwnersNone, false, false, true, true},
|
||||
{"Member/ServiceAccounts", rbac.OrgMemberPermissions, rbac.ShareableWorkspaceOwnersServiceAccounts, true, false, false, true},
|
||||
{"ServiceAccount/Everyone", rbac.OrgServiceAccountPermissions, rbac.ShareableWorkspaceOwnersEveryone, true, true, false, false},
|
||||
{"ServiceAccount/None", rbac.OrgServiceAccountPermissions, rbac.ShareableWorkspaceOwnersNone, false, false, true, false},
|
||||
{"ServiceAccount/ServiceAccounts", rbac.OrgServiceAccountPermissions, rbac.ShareableWorkspaceOwnersServiceAccounts, true, true, false, false},
|
||||
}
|
||||
require.True(t, rbac.ReservedRoleName(rbac.RoleOrgMember()))
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
func TestOrgMemberPermissions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
perms := tt.permsFunc(rbac.OrgSettings{
|
||||
ShareableWorkspaceOwners: tt.mode,
|
||||
})
|
||||
t.Run("WorkspaceSharingEnabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, tt.orgReadMembers, permissionGranted(perms.Org, rbac.Permission{
|
||||
ResourceType: rbac.ResourceOrganizationMember.Type,
|
||||
Action: policy.ActionRead,
|
||||
}), "org read members")
|
||||
orgPerms, _ := rbac.OrgMemberPermissions(false)
|
||||
|
||||
assert.Equal(t, tt.orgReadGroups, permissionGranted(perms.Org, rbac.Permission{
|
||||
ResourceType: rbac.ResourceGroup.Type,
|
||||
Action: policy.ActionRead,
|
||||
}), "org read groups")
|
||||
require.True(t, slices.Contains(orgPerms, rbac.Permission{
|
||||
ResourceType: rbac.ResourceOrganizationMember.Type,
|
||||
Action: policy.ActionRead,
|
||||
}))
|
||||
require.True(t, slices.Contains(orgPerms, rbac.Permission{
|
||||
ResourceType: rbac.ResourceGroup.Type,
|
||||
Action: policy.ActionRead,
|
||||
}))
|
||||
require.False(t, slices.Contains(orgPerms, rbac.Permission{
|
||||
Negate: true,
|
||||
ResourceType: rbac.ResourceWorkspace.Type,
|
||||
Action: policy.ActionShare,
|
||||
}))
|
||||
})
|
||||
|
||||
assert.Equal(t, tt.orgNegateShare, permissionGranted(perms.Org, rbac.Permission{
|
||||
Negate: true,
|
||||
ResourceType: rbac.ResourceWorkspace.Type,
|
||||
Action: policy.ActionShare,
|
||||
}), "org negate share")
|
||||
t.Run("WorkspaceSharingDisabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, tt.memberNegateShare, permissionGranted(perms.Member, rbac.Permission{
|
||||
Negate: true,
|
||||
ResourceType: rbac.ResourceWorkspace.Type,
|
||||
Action: policy.ActionShare,
|
||||
}), "member negate share")
|
||||
})
|
||||
}
|
||||
orgPerms, _ := rbac.OrgMemberPermissions(true)
|
||||
|
||||
require.False(t, slices.Contains(orgPerms, rbac.Permission{
|
||||
ResourceType: rbac.ResourceOrganizationMember.Type,
|
||||
Action: policy.ActionRead,
|
||||
}))
|
||||
require.False(t, slices.Contains(orgPerms, rbac.Permission{
|
||||
ResourceType: rbac.ResourceGroup.Type,
|
||||
Action: policy.ActionRead,
|
||||
}))
|
||||
require.True(t, slices.Contains(orgPerms, rbac.Permission{
|
||||
Negate: true,
|
||||
ResourceType: rbac.ResourceWorkspace.Type,
|
||||
Action: policy.ActionShare,
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest
|
||||
|
||||
@@ -2,7 +2,6 @@ package rolestore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"maps"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -162,28 +161,13 @@ func ConvertDBRole(dbRole database.CustomRole) (rbac.Role, error) {
|
||||
return role, nil
|
||||
}
|
||||
|
||||
// System roles are defined in code but stored in the database,
|
||||
// allowing their permissions to be adjusted per-organization at
|
||||
// runtime based on org settings (e.g. workspace sharing).
|
||||
var systemRoles = map[string]permissionsFunc{
|
||||
rbac.RoleOrgMember(): rbac.OrgMemberPermissions,
|
||||
rbac.RoleOrgServiceAccount(): rbac.OrgServiceAccountPermissions,
|
||||
}
|
||||
|
||||
// permissionsFunc produces the desired permissions for a system role
|
||||
// given organization settings.
|
||||
type permissionsFunc func(rbac.OrgSettings) rbac.OrgRolePermissions
|
||||
|
||||
func IsSystemRoleName(name string) bool {
|
||||
_, ok := systemRoles[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
var SystemRoleNames = maps.Keys(systemRoles)
|
||||
|
||||
// ReconcileSystemRoles ensures that every organization's system roles
|
||||
// in the DB are up-to-date with the current RBAC definitions and
|
||||
// organization settings.
|
||||
// ReconcileSystemRoles ensures that every organization's org-member
|
||||
// system role in the DB is up-to-date with permissions reflecting
|
||||
// current RBAC resources and the organization's
|
||||
// workspace_sharing_disabled setting. Uses PostgreSQL advisory lock
|
||||
// (LockIDReconcileSystemRoles) to safely handle multi-instance
|
||||
// deployments. Uses set-based comparison to avoid unnecessary
|
||||
// database writes when permissions haven't changed.
|
||||
func ReconcileSystemRoles(ctx context.Context, log slog.Logger, db database.Store) error {
|
||||
return db.InTx(func(tx database.Store) error {
|
||||
// Acquire advisory lock to prevent concurrent updates from
|
||||
@@ -209,45 +193,36 @@ func ReconcileSystemRoles(ctx context.Context, log slog.Logger, db database.Stor
|
||||
return xerrors.Errorf("fetch custom roles: %w", err)
|
||||
}
|
||||
|
||||
// Index system roles by (org ID, role name) for quick lookup.
|
||||
type orgRoleKey struct {
|
||||
OrgID uuid.UUID
|
||||
RoleName string
|
||||
}
|
||||
roleIndex := make(map[orgRoleKey]database.CustomRole)
|
||||
// Find org-member roles and index by organization ID for quick lookup.
|
||||
rolesByOrg := make(map[uuid.UUID]database.CustomRole)
|
||||
for _, role := range customRoles {
|
||||
if role.IsSystem && IsSystemRoleName(role.Name) && role.OrganizationID.Valid {
|
||||
roleIndex[orgRoleKey{role.OrganizationID.UUID, role.Name}] = role
|
||||
if role.IsSystem && role.Name == rbac.RoleOrgMember() && role.OrganizationID.Valid {
|
||||
rolesByOrg[role.OrganizationID.UUID] = role
|
||||
}
|
||||
}
|
||||
|
||||
for _, org := range orgs {
|
||||
for roleName := range systemRoles {
|
||||
role, exists := roleIndex[orgRoleKey{org.ID, roleName}]
|
||||
if !exists {
|
||||
// Something is very wrong: the role should have been
|
||||
// created by the db trigger or migration. Log loudly and
|
||||
// try creating it as a last-ditch effort before giving up.
|
||||
log.Critical(ctx, "missing system role; trying to re-create",
|
||||
slog.F("organization_id", org.ID),
|
||||
slog.F("role_name", roleName))
|
||||
role, exists := rolesByOrg[org.ID]
|
||||
if !exists {
|
||||
// Something is very wrong: the role should have been created by the
|
||||
// database trigger or migration. Log loudly and try creating it as
|
||||
// a last-ditch effort before giving up.
|
||||
log.Critical(ctx, "missing organization-member system role; trying to re-create",
|
||||
slog.F("organization_id", org.ID))
|
||||
|
||||
err := CreateSystemRole(ctx, tx, org, roleName)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create missing %s system role for organization %s: %w",
|
||||
roleName, org.ID, err)
|
||||
}
|
||||
|
||||
// Nothing more to do; the new role's permissions are
|
||||
// up-to-date.
|
||||
continue
|
||||
if err := CreateOrgMemberRole(ctx, tx, org); err != nil {
|
||||
return xerrors.Errorf("create missing organization-member role for organization %s: %w",
|
||||
org.ID, err)
|
||||
}
|
||||
|
||||
_, _, err := ReconcileSystemRole(ctx, tx, role, org)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("reconcile %s system role for organization %s: %w",
|
||||
roleName, org.ID, err)
|
||||
}
|
||||
// Nothing more to do; the new role's permissions are up-to-date.
|
||||
continue
|
||||
}
|
||||
|
||||
_, _, err := ReconcileOrgMemberRole(ctx, tx, role, org.WorkspaceSharingDisabled)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("reconcile organization-member role for organization %s: %w",
|
||||
org.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,30 +230,28 @@ func ReconcileSystemRoles(ctx context.Context, log slog.Logger, db database.Stor
|
||||
}, nil)
|
||||
}
|
||||
|
||||
// ReconcileSystemRole compares the given role's permissions against
|
||||
// the desired permissions produced by the permissions function based
|
||||
// on the organization's settings. If they differ, the DB row is
|
||||
// updated. Uses set-based comparison so permission ordering doesn't
|
||||
// matter. Returns the correct role and a boolean indicating whether
|
||||
// the reconciliation was necessary.
|
||||
//
|
||||
// IMPORTANT: Callers must hold database.LockIDReconcileSystemRoles
|
||||
// for the duration of the enclosing transaction.
|
||||
func ReconcileSystemRole(
|
||||
// ReconcileOrgMemberRole ensures passed-in org-member role's perms
|
||||
// are correct (current) and stored in the DB. Uses set-based
|
||||
// comparison to avoid unnecessary database writes when permissions
|
||||
// haven't changed. Returns the correct role and a boolean indicating
|
||||
// whether the reconciliation was necessary.
|
||||
// NOTE: Callers must acquire `database.LockIDReconcileSystemRoles` at
|
||||
// the start of the transaction and hold it for the transaction’s
|
||||
// duration. This prevents concurrent org-member reconciliation from
|
||||
// racing and producing inconsistent writes.
|
||||
func ReconcileOrgMemberRole(
|
||||
ctx context.Context,
|
||||
tx database.Store,
|
||||
in database.CustomRole,
|
||||
org database.Organization,
|
||||
) (database.CustomRole, bool, error) {
|
||||
permsFunc, ok := systemRoles[in.Name]
|
||||
if !ok {
|
||||
panic("dev error: no permissions function exists for role " + in.Name)
|
||||
}
|
||||
|
||||
workspaceSharingDisabled bool,
|
||||
) (
|
||||
database.CustomRole, bool, error,
|
||||
) {
|
||||
// All fields except OrgPermissions and MemberPermissions will be the same.
|
||||
out := in
|
||||
|
||||
// Paranoia check: we don't use these in custom roles yet.
|
||||
// TODO(geokat): Have these as check constraints in DB for now?
|
||||
out.SitePermissions = database.CustomRolePermissions{}
|
||||
out.UserPermissions = database.CustomRolePermissions{}
|
||||
out.DisplayName = ""
|
||||
@@ -286,14 +259,15 @@ func ReconcileSystemRole(
|
||||
inOrgPerms := ConvertDBPermissions(in.OrgPermissions)
|
||||
inMemberPerms := ConvertDBPermissions(in.MemberPermissions)
|
||||
|
||||
outPerms := permsFunc(orgSettings(org))
|
||||
outOrgPerms, outMemberPerms := rbac.OrgMemberPermissions(workspaceSharingDisabled)
|
||||
|
||||
match := rbac.PermissionsEqual(inOrgPerms, outPerms.Org) &&
|
||||
rbac.PermissionsEqual(inMemberPerms, outPerms.Member)
|
||||
// Compare using set-based comparison (order doesn't matter).
|
||||
match := rbac.PermissionsEqual(inOrgPerms, outOrgPerms) &&
|
||||
rbac.PermissionsEqual(inMemberPerms, outMemberPerms)
|
||||
|
||||
if !match {
|
||||
out.OrgPermissions = ConvertPermissionsToDB(outPerms.Org)
|
||||
out.MemberPermissions = ConvertPermissionsToDB(outPerms.Member)
|
||||
out.OrgPermissions = ConvertPermissionsToDB(outOrgPerms)
|
||||
out.MemberPermissions = ConvertPermissionsToDB(outMemberPerms)
|
||||
|
||||
_, err := tx.UpdateCustomRole(ctx, database.UpdateCustomRoleParams{
|
||||
Name: out.Name,
|
||||
@@ -305,50 +279,30 @@ func ReconcileSystemRole(
|
||||
MemberPermissions: out.MemberPermissions,
|
||||
})
|
||||
if err != nil {
|
||||
return out, !match, xerrors.Errorf("update %s system role for organization %s: %w",
|
||||
in.Name, in.OrganizationID.UUID, err)
|
||||
return out, !match, xerrors.Errorf("update organization-member custom role for organization %s: %w",
|
||||
in.OrganizationID.UUID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return out, !match, nil
|
||||
}
|
||||
|
||||
// orgSettings maps database.Organization fields to the
|
||||
// rbac.OrgSettings struct, bridging the database and rbac packages
|
||||
// without introducing a circular dependency.
|
||||
func orgSettings(org database.Organization) rbac.OrgSettings {
|
||||
return rbac.OrgSettings{
|
||||
ShareableWorkspaceOwners: rbac.ShareableWorkspaceOwners(org.ShareableWorkspaceOwners),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSystemRole inserts a new system role into the database with
|
||||
// permissions produced by permsFunc based on the organization's current
|
||||
// settings.
|
||||
func CreateSystemRole(
|
||||
ctx context.Context,
|
||||
tx database.Store,
|
||||
org database.Organization,
|
||||
roleName string,
|
||||
) error {
|
||||
permsFunc, ok := systemRoles[roleName]
|
||||
if !ok {
|
||||
panic("dev error: no permissions function exists for role " + roleName)
|
||||
}
|
||||
perms := permsFunc(orgSettings(org))
|
||||
// CreateOrgMemberRole creates an org-member system role for an organization.
|
||||
func CreateOrgMemberRole(ctx context.Context, tx database.Store, org database.Organization) error {
|
||||
orgPerms, memberPerms := rbac.OrgMemberPermissions(org.WorkspaceSharingDisabled)
|
||||
|
||||
_, err := tx.InsertCustomRole(ctx, database.InsertCustomRoleParams{
|
||||
Name: roleName,
|
||||
Name: rbac.RoleOrgMember(),
|
||||
DisplayName: "",
|
||||
OrganizationID: uuid.NullUUID{UUID: org.ID, Valid: true},
|
||||
SitePermissions: database.CustomRolePermissions{},
|
||||
OrgPermissions: ConvertPermissionsToDB(perms.Org),
|
||||
OrgPermissions: ConvertPermissionsToDB(orgPerms),
|
||||
UserPermissions: database.CustomRolePermissions{},
|
||||
MemberPermissions: ConvertPermissionsToDB(perms.Member),
|
||||
MemberPermissions: ConvertPermissionsToDB(memberPerms),
|
||||
IsSystem: true,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert %s role: %w", roleName, err)
|
||||
return xerrors.Errorf("insert org-member role: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -42,84 +42,68 @@ func TestExpandCustomRoleRoles(t *testing.T) {
|
||||
require.Len(t, roles, 1, "role found")
|
||||
}
|
||||
|
||||
func TestReconcileSystemRole(t *testing.T) {
|
||||
func TestReconcileOrgMemberRole(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
roleName string
|
||||
permsFunc func(rbac.OrgSettings) rbac.OrgRolePermissions
|
||||
}{
|
||||
{"OrgMember", rbac.RoleOrgMember(), rbac.OrgMemberPermissions},
|
||||
{"ServiceAccount", rbac.RoleOrgServiceAccount(), rbac.OrgServiceAccountPermissions},
|
||||
}
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
existing, err := database.ExpectOne(db.CustomRoles(ctx, database.CustomRolesParams{
|
||||
LookupRoles: []database.NameOrganizationPair{
|
||||
{
|
||||
Name: tt.roleName,
|
||||
OrganizationID: org.ID,
|
||||
},
|
||||
},
|
||||
IncludeSystemRoles: true,
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
existing, err := database.ExpectOne(db.CustomRoles(ctx, database.CustomRolesParams{
|
||||
LookupRoles: []database.NameOrganizationPair{
|
||||
{
|
||||
Name: rbac.RoleOrgMember(),
|
||||
OrganizationID: org.ID,
|
||||
},
|
||||
},
|
||||
IncludeSystemRoles: true,
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Zero out permissions to simulate stale state.
|
||||
_, err = db.UpdateCustomRole(ctx, database.UpdateCustomRoleParams{
|
||||
Name: existing.Name,
|
||||
OrganizationID: uuid.NullUUID{
|
||||
UUID: org.ID,
|
||||
Valid: true,
|
||||
},
|
||||
DisplayName: "",
|
||||
SitePermissions: database.CustomRolePermissions{},
|
||||
UserPermissions: database.CustomRolePermissions{},
|
||||
OrgPermissions: database.CustomRolePermissions{},
|
||||
MemberPermissions: database.CustomRolePermissions{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.UpdateCustomRole(ctx, database.UpdateCustomRoleParams{
|
||||
Name: existing.Name,
|
||||
OrganizationID: uuid.NullUUID{
|
||||
UUID: org.ID,
|
||||
Valid: true,
|
||||
},
|
||||
DisplayName: "",
|
||||
SitePermissions: database.CustomRolePermissions{},
|
||||
UserPermissions: database.CustomRolePermissions{},
|
||||
OrgPermissions: database.CustomRolePermissions{},
|
||||
MemberPermissions: database.CustomRolePermissions{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
stale := existing
|
||||
stale.OrgPermissions = database.CustomRolePermissions{}
|
||||
stale.MemberPermissions = database.CustomRolePermissions{}
|
||||
stale := existing
|
||||
stale.OrgPermissions = database.CustomRolePermissions{}
|
||||
stale.MemberPermissions = database.CustomRolePermissions{}
|
||||
|
||||
reconciled, didUpdate, err := rolestore.ReconcileSystemRole(ctx, db, stale, org)
|
||||
require.NoError(t, err)
|
||||
require.True(t, didUpdate, "expected reconciliation to update stale permissions")
|
||||
reconciled, didUpdate, err := rolestore.ReconcileOrgMemberRole(ctx, db, stale, org.WorkspaceSharingDisabled)
|
||||
require.NoError(t, err)
|
||||
require.True(t, didUpdate, "expected reconciliation to update stale permissions")
|
||||
|
||||
dbstored, err := database.ExpectOne(db.CustomRoles(ctx, database.CustomRolesParams{
|
||||
LookupRoles: []database.NameOrganizationPair{
|
||||
{
|
||||
Name: tt.roleName,
|
||||
OrganizationID: org.ID,
|
||||
},
|
||||
},
|
||||
IncludeSystemRoles: true,
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
got, err := database.ExpectOne(db.CustomRoles(ctx, database.CustomRolesParams{
|
||||
LookupRoles: []database.NameOrganizationPair{
|
||||
{
|
||||
Name: rbac.RoleOrgMember(),
|
||||
OrganizationID: org.ID,
|
||||
},
|
||||
},
|
||||
IncludeSystemRoles: true,
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
|
||||
want := tt.permsFunc(rbac.OrgSettings{
|
||||
ShareableWorkspaceOwners: rbac.ShareableWorkspaceOwners(org.ShareableWorkspaceOwners),
|
||||
})
|
||||
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(dbstored.OrgPermissions), want.Org))
|
||||
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(dbstored.MemberPermissions), want.Member))
|
||||
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(reconciled.OrgPermissions), want.Org))
|
||||
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(reconciled.MemberPermissions), want.Member))
|
||||
wantOrg, wantMember := rbac.OrgMemberPermissions(org.WorkspaceSharingDisabled)
|
||||
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(got.OrgPermissions), wantOrg))
|
||||
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(got.MemberPermissions), wantMember))
|
||||
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(reconciled.OrgPermissions), wantOrg))
|
||||
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(reconciled.MemberPermissions), wantMember))
|
||||
|
||||
_, didUpdate, err = rolestore.ReconcileSystemRole(ctx, db, reconciled, org)
|
||||
require.NoError(t, err)
|
||||
require.False(t, didUpdate, "expected no-op reconciliation when permissions are already current")
|
||||
})
|
||||
}
|
||||
_, didUpdate, err = rolestore.ReconcileOrgMemberRole(ctx, db, reconciled, org.WorkspaceSharingDisabled)
|
||||
require.NoError(t, err)
|
||||
require.False(t, didUpdate, "expected no-op reconciliation when permissions are already current")
|
||||
}
|
||||
|
||||
func TestReconcileSystemRoles(t *testing.T) {
|
||||
@@ -134,7 +118,7 @@ func TestReconcileSystemRoles(t *testing.T) {
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
_, err := sqlDB.ExecContext(ctx, "UPDATE organizations SET shareable_workspace_owners = 'none' WHERE id = $1", org2.ID)
|
||||
_, err := sqlDB.ExecContext(ctx, "UPDATE organizations SET workspace_sharing_disabled = true WHERE id = $1", org2.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate a missing system role by bypassing the application's
|
||||
@@ -179,9 +163,9 @@ func TestReconcileSystemRoles(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, got.IsSystem)
|
||||
|
||||
want := rbac.OrgMemberPermissions(rbac.OrgSettings{ShareableWorkspaceOwners: rbac.ShareableWorkspaceOwners(org.ShareableWorkspaceOwners)})
|
||||
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(got.OrgPermissions), want.Org))
|
||||
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(got.MemberPermissions), want.Member))
|
||||
wantOrg, wantMember := rbac.OrgMemberPermissions(org.WorkspaceSharingDisabled)
|
||||
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(got.OrgPermissions), wantOrg))
|
||||
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(got.MemberPermissions), wantMember))
|
||||
}
|
||||
|
||||
assertOrgMemberRole(t, org1.ID)
|
||||
|
||||
@@ -471,8 +471,8 @@ func Tasks(ctx context.Context, db database.Store, query string, actorID uuid.UU
|
||||
//
|
||||
// Supported query parameters:
|
||||
// - archived: boolean (default: false, excludes archived chats unless explicitly set)
|
||||
func Chats(query string) (database.GetChatsParams, []codersdk.ValidationError) {
|
||||
filter := database.GetChatsParams{
|
||||
func Chats(query string) (database.GetChatsByOwnerIDParams, []codersdk.ValidationError) {
|
||||
filter := database.GetChatsByOwnerIDParams{
|
||||
// Default to hiding archived chats.
|
||||
Archived: sql.NullBool{Bool: false, Valid: true},
|
||||
}
|
||||
|
||||
@@ -1222,27 +1222,27 @@ func TestSearchChats(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Name string
|
||||
Query string
|
||||
Expected database.GetChatsParams
|
||||
Expected database.GetChatsByOwnerIDParams
|
||||
ExpectedErrorContains string
|
||||
}{
|
||||
{
|
||||
Name: "Empty",
|
||||
Query: "",
|
||||
Expected: database.GetChatsParams{
|
||||
Expected: database.GetChatsByOwnerIDParams{
|
||||
Archived: sql.NullBool{Bool: false, Valid: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "ArchivedTrue",
|
||||
Query: "archived:true",
|
||||
Expected: database.GetChatsParams{
|
||||
Expected: database.GetChatsByOwnerIDParams{
|
||||
Archived: sql.NullBool{Bool: true, Valid: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "ArchivedFalse",
|
||||
Query: "archived:false",
|
||||
Expected: database.GetChatsParams{
|
||||
Expected: database.GetChatsByOwnerIDParams{
|
||||
Archived: sql.NullBool{Bool: false, Valid: true},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/webpush"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -55,9 +54,6 @@ func (api *API) postUserWebpushSubscription(rw http.ResponseWriter, r *http.Requ
|
||||
})
|
||||
return
|
||||
}
|
||||
if invalidator, ok := api.WebpushDispatcher.(webpush.SubscriptionCacheInvalidator); ok {
|
||||
invalidator.InvalidateUser(user.ID)
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@@ -115,9 +111,6 @@ func (api *API) deleteUserWebpushSubscription(rw http.ResponseWriter, r *http.Re
|
||||
})
|
||||
return
|
||||
}
|
||||
if invalidator, ok := api.WebpushDispatcher.(webpush.SubscriptionCacheInvalidator); ok {
|
||||
invalidator.InvalidateUser(user.ID)
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
+7
-190
@@ -9,23 +9,18 @@ import (
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/SherClockHolmes/webpush-go"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/util/singleflight"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const defaultSubscriptionCacheTTL = 3 * time.Minute
|
||||
|
||||
// Dispatcher is an interface that can be used to dispatch
|
||||
// web push notifications to clients such as browsers.
|
||||
type Dispatcher interface {
|
||||
@@ -38,36 +33,6 @@ type Dispatcher interface {
|
||||
PublicKey() string
|
||||
}
|
||||
|
||||
// SubscriptionCacheInvalidator is an optional interface that lets local
|
||||
// subscription mutation handlers invalidate cached subscriptions.
|
||||
type SubscriptionCacheInvalidator interface {
|
||||
InvalidateUser(userID uuid.UUID)
|
||||
}
|
||||
|
||||
type options struct {
|
||||
clock quartz.Clock
|
||||
subscriptionCacheTTL time.Duration
|
||||
}
|
||||
|
||||
// Option configures optional behavior for a Webpusher.
|
||||
type Option func(*options)
|
||||
|
||||
// WithClock sets the clock used by the subscription cache. Defaults to a real
|
||||
// clock when not provided.
|
||||
func WithClock(clock quartz.Clock) Option {
|
||||
return func(o *options) {
|
||||
o.clock = clock
|
||||
}
|
||||
}
|
||||
|
||||
// WithSubscriptionCacheTTL sets the in-memory subscription cache TTL. Defaults
|
||||
// to three minutes when not provided or when given a non-positive duration.
|
||||
func WithSubscriptionCacheTTL(ttl time.Duration) Option {
|
||||
return func(o *options) {
|
||||
o.subscriptionCacheTTL = ttl
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new Dispatcher to dispatch web push notifications.
|
||||
//
|
||||
// This is *not* integrated into the enqueue system unfortunately.
|
||||
@@ -76,21 +41,7 @@ func WithSubscriptionCacheTTL(ttl time.Duration) Option {
|
||||
// for updates inside of a workspace, which we want to be immediate.
|
||||
//
|
||||
// See: https://github.com/coder/internal/issues/528
|
||||
func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub string, opts ...Option) (Dispatcher, error) {
|
||||
cfg := options{
|
||||
clock: quartz.NewReal(),
|
||||
subscriptionCacheTTL: defaultSubscriptionCacheTTL,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(&cfg)
|
||||
}
|
||||
if cfg.clock == nil {
|
||||
cfg.clock = quartz.NewReal()
|
||||
}
|
||||
if cfg.subscriptionCacheTTL <= 0 {
|
||||
cfg.subscriptionCacheTTL = defaultSubscriptionCacheTTL
|
||||
}
|
||||
|
||||
func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub string) (Dispatcher, error) {
|
||||
keys, err := db.GetWebpushVAPIDKeys(ctx)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
@@ -112,23 +63,14 @@ func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub stri
|
||||
}
|
||||
|
||||
return &Webpusher{
|
||||
vapidSub: vapidSub,
|
||||
store: db,
|
||||
log: log,
|
||||
VAPIDPublicKey: keys.VapidPublicKey,
|
||||
VAPIDPrivateKey: keys.VapidPrivateKey,
|
||||
clock: cfg.clock,
|
||||
subscriptionCacheTTL: cfg.subscriptionCacheTTL,
|
||||
subscriptionCache: make(map[uuid.UUID]cachedSubscriptions),
|
||||
subscriptionGenerations: make(map[uuid.UUID]uint64),
|
||||
vapidSub: vapidSub,
|
||||
store: db,
|
||||
log: log,
|
||||
VAPIDPublicKey: keys.VapidPublicKey,
|
||||
VAPIDPrivateKey: keys.VapidPrivateKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type cachedSubscriptions struct {
|
||||
subscriptions []database.WebpushSubscription
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type Webpusher struct {
|
||||
store database.Store
|
||||
log *slog.Logger
|
||||
@@ -141,18 +83,10 @@ type Webpusher struct {
|
||||
// the message payload.
|
||||
VAPIDPublicKey string
|
||||
VAPIDPrivateKey string
|
||||
|
||||
clock quartz.Clock
|
||||
|
||||
cacheMu sync.RWMutex
|
||||
subscriptionCache map[uuid.UUID]cachedSubscriptions
|
||||
subscriptionGenerations map[uuid.UUID]uint64
|
||||
subscriptionCacheTTL time.Duration
|
||||
subscriptionFetches singleflight.Group[string, []database.WebpushSubscription]
|
||||
}
|
||||
|
||||
func (n *Webpusher) Dispatch(ctx context.Context, userID uuid.UUID, msg codersdk.WebpushMessage) error {
|
||||
subscriptions, err := n.subscriptionsForUser(ctx, userID)
|
||||
subscriptions, err := n.store.GetWebpushSubscriptionsByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get web push subscriptions by user ID: %w", err)
|
||||
}
|
||||
@@ -208,129 +142,12 @@ func (n *Webpusher) Dispatch(ctx context.Context, userID uuid.UUID, msg codersdk
|
||||
err = n.store.DeleteWebpushSubscriptions(dbauthz.AsNotifier(ctx), cleanupSubscriptions)
|
||||
if err != nil {
|
||||
n.log.Error(ctx, "failed to delete stale push subscriptions", slog.Error(err))
|
||||
} else {
|
||||
n.pruneSubscriptions(userID, cleanupSubscriptions)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *Webpusher) subscriptionsForUser(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
|
||||
if subscriptions, ok := n.cachedSubscriptions(userID); ok {
|
||||
return subscriptions, nil
|
||||
}
|
||||
|
||||
subscriptions, err, _ := n.subscriptionFetches.Do(userID.String(), func() ([]database.WebpushSubscription, error) {
|
||||
if cached, ok := n.cachedSubscriptions(userID); ok {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
generation := n.subscriptionGeneration(userID)
|
||||
fetched, err := n.store.GetWebpushSubscriptionsByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n.storeSubscriptions(userID, generation, fetched)
|
||||
return slices.Clone(fetched), nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return slices.Clone(subscriptions), nil
|
||||
}
|
||||
|
||||
func (n *Webpusher) cachedSubscriptions(userID uuid.UUID) ([]database.WebpushSubscription, bool) {
|
||||
n.cacheMu.RLock()
|
||||
entry, ok := n.subscriptionCache[userID]
|
||||
n.cacheMu.RUnlock()
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if n.clock.Now().Before(entry.expiresAt) {
|
||||
return slices.Clone(entry.subscriptions), true
|
||||
}
|
||||
|
||||
n.cacheMu.Lock()
|
||||
if current, ok := n.subscriptionCache[userID]; ok && !n.clock.Now().Before(current.expiresAt) {
|
||||
delete(n.subscriptionCache, userID)
|
||||
}
|
||||
n.cacheMu.Unlock()
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (n *Webpusher) subscriptionGeneration(userID uuid.UUID) uint64 {
|
||||
n.cacheMu.RLock()
|
||||
generation := n.subscriptionGenerations[userID]
|
||||
n.cacheMu.RUnlock()
|
||||
return generation
|
||||
}
|
||||
|
||||
func (n *Webpusher) storeSubscriptions(userID uuid.UUID, generation uint64, subscriptions []database.WebpushSubscription) {
|
||||
n.cacheMu.Lock()
|
||||
defer n.cacheMu.Unlock()
|
||||
|
||||
if n.subscriptionGenerations[userID] != generation {
|
||||
return
|
||||
}
|
||||
|
||||
n.subscriptionCache[userID] = cachedSubscriptions{
|
||||
subscriptions: slices.Clone(subscriptions),
|
||||
expiresAt: n.clock.Now().Add(n.subscriptionCacheTTL),
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Webpusher) pruneSubscriptions(userID uuid.UUID, staleIDs []uuid.UUID) {
|
||||
if len(staleIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
stale := make(map[uuid.UUID]struct{}, len(staleIDs))
|
||||
for _, id := range staleIDs {
|
||||
stale[id] = struct{}{}
|
||||
}
|
||||
|
||||
n.cacheMu.Lock()
|
||||
defer n.cacheMu.Unlock()
|
||||
|
||||
entry, ok := n.subscriptionCache[userID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if !n.clock.Now().Before(entry.expiresAt) {
|
||||
delete(n.subscriptionCache, userID)
|
||||
return
|
||||
}
|
||||
|
||||
filtered := make([]database.WebpushSubscription, 0, len(entry.subscriptions))
|
||||
for _, subscription := range entry.subscriptions {
|
||||
if _, shouldDelete := stale[subscription.ID]; shouldDelete {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, subscription)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
delete(n.subscriptionCache, userID)
|
||||
return
|
||||
}
|
||||
|
||||
entry.subscriptions = filtered
|
||||
n.subscriptionCache[userID] = entry
|
||||
}
|
||||
|
||||
// InvalidateUser clears the cached subscriptions for a user and advances
|
||||
// its invalidation generation. Local subscribe and unsubscribe handlers call
|
||||
// this after mutating subscriptions in the same process.
|
||||
func (n *Webpusher) InvalidateUser(userID uuid.UUID) {
|
||||
n.cacheMu.Lock()
|
||||
delete(n.subscriptionCache, userID)
|
||||
n.subscriptionGenerations[userID]++
|
||||
n.cacheMu.Unlock()
|
||||
n.subscriptionFetches.Forget(userID.String())
|
||||
}
|
||||
|
||||
func (n *Webpusher) webpushSend(ctx context.Context, msg []byte, endpoint string, keys webpush.Keys) (int, []byte, error) {
|
||||
// Copy the message to avoid modifying the original.
|
||||
cpy := slices.Clone(msg)
|
||||
|
||||
@@ -6,9 +6,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -23,7 +21,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/webpush"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -31,20 +28,6 @@ const (
|
||||
validEndpointP256dhKey = "BNNL5ZaTfK81qhXOx23+wewhigUeFb632jN6LvRWCFH1ubQr77FE/9qV1FuojuRmHP42zmf34rXgW80OvUVDgTk="
|
||||
)
|
||||
|
||||
type countingWebpushStore struct {
|
||||
database.Store
|
||||
getSubscriptionsCalls atomic.Int32
|
||||
}
|
||||
|
||||
func (s *countingWebpushStore) GetWebpushSubscriptionsByUserID(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
|
||||
s.getSubscriptionsCalls.Add(1)
|
||||
return s.Store.GetWebpushSubscriptionsByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *countingWebpushStore) getCallCount() int32 {
|
||||
return s.getSubscriptionsCalls.Load()
|
||||
}
|
||||
|
||||
func TestPush(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -233,131 +216,6 @@ func TestPush(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, subscriptions, "No subscriptions should be returned")
|
||||
})
|
||||
|
||||
t.Run("CachesSubscriptionsWithinTTL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
rawStore, _ := dbtestutil.NewDB(t)
|
||||
store := &countingWebpushStore{Store: rawStore}
|
||||
var delivered atomic.Int32
|
||||
manager, _, serverURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
|
||||
delivered.Add(1)
|
||||
assertWebpushPayload(t, r)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
|
||||
|
||||
user := dbgen.User(t, rawStore, database.User{})
|
||||
_, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
|
||||
CreatedAt: dbtime.Now(),
|
||||
UserID: user.ID,
|
||||
Endpoint: serverURL,
|
||||
EndpointAuthKey: validEndpointAuthKey,
|
||||
EndpointP256dhKey: validEndpointP256dhKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := randomWebpushMessage(t)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int32(1), store.getCallCount(), "subscriptions should be read once within the TTL")
|
||||
require.Equal(t, int32(2), delivered.Load(), "both dispatches should send a notification")
|
||||
})
|
||||
|
||||
t.Run("RefreshesSubscriptionsAfterTTLExpires", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
rawStore, _ := dbtestutil.NewDB(t)
|
||||
store := &countingWebpushStore{Store: rawStore}
|
||||
var delivered atomic.Int32
|
||||
manager, _, serverURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
|
||||
delivered.Add(1)
|
||||
assertWebpushPayload(t, r)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
|
||||
|
||||
user := dbgen.User(t, rawStore, database.User{})
|
||||
_, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
|
||||
CreatedAt: dbtime.Now(),
|
||||
UserID: user.ID,
|
||||
Endpoint: serverURL,
|
||||
EndpointAuthKey: validEndpointAuthKey,
|
||||
EndpointP256dhKey: validEndpointP256dhKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := randomWebpushMessage(t)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
clock.Advance(time.Minute)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int32(2), store.getCallCount(), "dispatch should refresh subscriptions after the TTL expires")
|
||||
require.Equal(t, int32(2), delivered.Load(), "both dispatches should send a notification")
|
||||
})
|
||||
|
||||
t.Run("PrunesStaleSubscriptionsFromCache", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
rawStore, _ := dbtestutil.NewDB(t)
|
||||
store := &countingWebpushStore{Store: rawStore}
|
||||
var okCalls atomic.Int32
|
||||
var goneCalls atomic.Int32
|
||||
manager, _, okServerURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
|
||||
okCalls.Add(1)
|
||||
assertWebpushPayload(t, r)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
|
||||
|
||||
goneServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
goneCalls.Add(1)
|
||||
assertWebpushPayload(t, r)
|
||||
w.WriteHeader(http.StatusGone)
|
||||
}))
|
||||
defer goneServer.Close()
|
||||
|
||||
user := dbgen.User(t, rawStore, database.User{})
|
||||
okSubscription, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
|
||||
CreatedAt: dbtime.Now(),
|
||||
UserID: user.ID,
|
||||
Endpoint: okServerURL,
|
||||
EndpointAuthKey: validEndpointAuthKey,
|
||||
EndpointP256dhKey: validEndpointP256dhKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
|
||||
CreatedAt: dbtime.Now(),
|
||||
UserID: user.ID,
|
||||
Endpoint: goneServer.URL,
|
||||
EndpointAuthKey: validEndpointAuthKey,
|
||||
EndpointP256dhKey: validEndpointP256dhKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := randomWebpushMessage(t)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int32(1), store.getCallCount(), "stale subscription cleanup should not force a second DB read within the TTL")
|
||||
require.Equal(t, int32(2), okCalls.Load(), "the healthy endpoint should receive both dispatches")
|
||||
require.Equal(t, int32(1), goneCalls.Load(), "the stale endpoint should be pruned from the cache after the first dispatch")
|
||||
|
||||
subscriptions, err := rawStore.GetWebpushSubscriptionsByUserID(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, subscriptions, 1, "only the healthy subscription should remain")
|
||||
require.Equal(t, okSubscription.ID, subscriptions[0].ID)
|
||||
})
|
||||
}
|
||||
|
||||
func randomWebpushMessage(t testing.TB) codersdk.WebpushMessage {
|
||||
@@ -386,21 +244,16 @@ func assertWebpushPayload(t testing.TB, r *http.Request) {
|
||||
assert.Error(t, json.NewDecoder(r.Body).Decode(io.Discard))
|
||||
}
|
||||
|
||||
// setupPushTest creates a common test setup for webpush notification tests.
|
||||
// setupPushTest creates a common test setup for webpush notification tests
|
||||
func setupPushTest(ctx context.Context, t *testing.T, handlerFunc func(w http.ResponseWriter, r *http.Request)) (webpush.Dispatcher, database.Store, string) {
|
||||
t.Helper()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
return setupPushTestWithOptions(ctx, t, db, handlerFunc)
|
||||
}
|
||||
|
||||
func setupPushTestWithOptions(ctx context.Context, t *testing.T, db database.Store, handlerFunc func(w http.ResponseWriter, r *http.Request), opts ...webpush.Option) (webpush.Dispatcher, database.Store, string) {
|
||||
t.Helper()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(handlerFunc))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
manager, err := webpush.New(ctx, &logger, db, "http://example.com", opts...)
|
||||
manager, err := webpush.New(ctx, &logger, db, "http://example.com")
|
||||
require.NoError(t, err, "Failed to create webpush manager")
|
||||
|
||||
return manager, db, server.URL
|
||||
|
||||
+7
-18
@@ -35,42 +35,31 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
_, anotherMember := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
var handlerCalls atomic.Int32
|
||||
handlerCalled := make(chan bool, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
handlerCalls.Add(1)
|
||||
handlerCalled <- true
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Seed the dispatcher cache with an empty subscription set. Creating the
|
||||
// subscription should invalidate that entry so the next dispatch sees the new
|
||||
// subscription immediately.
|
||||
err := memberClient.PostTestWebpushMessage(ctx)
|
||||
require.NoError(t, err, "test webpush message without a subscription")
|
||||
require.Zero(t, handlerCalls.Load(), "a user without subscriptions should not receive a push")
|
||||
|
||||
err = memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
|
||||
err := memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
|
||||
Endpoint: server.URL,
|
||||
AuthKey: validEndpointAuthKey,
|
||||
P256DHKey: validEndpointP256dhKey,
|
||||
})
|
||||
require.NoError(t, err, "create webpush subscription")
|
||||
require.Equal(t, int32(1), handlerCalls.Load(), "subscription validation should hit the endpoint once")
|
||||
require.True(t, <-handlerCalled, "handler should have been called")
|
||||
|
||||
err = memberClient.PostTestWebpushMessage(ctx)
|
||||
require.NoError(t, err, "test webpush message after subscribing")
|
||||
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate empty cache entries after subscribing")
|
||||
require.NoError(t, err, "test webpush message")
|
||||
require.True(t, <-handlerCalled, "handler should have been called again")
|
||||
|
||||
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
|
||||
Endpoint: server.URL,
|
||||
})
|
||||
require.NoError(t, err, "delete webpush subscription")
|
||||
|
||||
err = memberClient.PostTestWebpushMessage(ctx)
|
||||
require.NoError(t, err, "test webpush message after unsubscribing")
|
||||
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate cached subscriptions after unsubscribing")
|
||||
|
||||
// Deleting the subscription for a non-existent endpoint should return a 404.
|
||||
// Deleting the subscription for a non-existent endpoint should return a 404
|
||||
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
|
||||
Endpoint: server.URL,
|
||||
})
|
||||
|
||||
+126
-49
@@ -989,48 +989,108 @@ type workspaceBuildsData struct {
|
||||
|
||||
func (api *API) workspaceBuildsData(ctx context.Context, workspaceBuilds []database.WorkspaceBuild) (workspaceBuildsData, error) {
|
||||
jobIDs := make([]uuid.UUID, 0, len(workspaceBuilds))
|
||||
for _, build := range workspaceBuilds {
|
||||
jobIDs = append(jobIDs, build.JobID)
|
||||
}
|
||||
jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
||||
IDs: jobIDs,
|
||||
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
||||
})
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return workspaceBuildsData{}, xerrors.Errorf("get provisioner jobs: %w", err)
|
||||
}
|
||||
pendingJobIDs := []uuid.UUID{}
|
||||
for _, job := range jobs {
|
||||
if job.ProvisionerJob.JobStatus == database.ProvisionerJobStatusPending {
|
||||
pendingJobIDs = append(pendingJobIDs, job.ProvisionerJob.ID)
|
||||
}
|
||||
}
|
||||
|
||||
pendingJobProvisioners, err := api.Database.GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx, pendingJobIDs)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return workspaceBuildsData{}, xerrors.Errorf("get provisioner daemons: %w", err)
|
||||
}
|
||||
|
||||
templateVersionIDs := make([]uuid.UUID, 0, len(workspaceBuilds))
|
||||
for _, build := range workspaceBuilds {
|
||||
jobIDs = append(jobIDs, build.JobID)
|
||||
templateVersionIDs = append(templateVersionIDs, build.TemplateVersionID)
|
||||
}
|
||||
|
||||
// nolint:gocritic // Getting template versions by ID is a system function.
|
||||
templateVersions, err := api.Database.GetTemplateVersionsByIDs(dbauthz.AsSystemRestricted(ctx), templateVersionIDs)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return workspaceBuildsData{}, xerrors.Errorf("get template versions: %w", err)
|
||||
// Phase A: Fetch jobs, template versions, and resources in parallel.
|
||||
// These three queries depend only on the build list and can run
|
||||
// concurrently.
|
||||
var (
|
||||
jobs []database.ProvisionerJob
|
||||
templateVersions []database.TemplateVersion
|
||||
resources []database.WorkspaceResource
|
||||
)
|
||||
var eg errgroup.Group
|
||||
eg.Go(func() error {
|
||||
var err error
|
||||
jobs, err = api.Database.GetProvisionerJobsByIDs(ctx, jobIDs)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return xerrors.Errorf("get provisioner jobs: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
var err error
|
||||
// nolint:gocritic // Getting template versions by ID is a system function.
|
||||
templateVersions, err = api.Database.GetTemplateVersionsByIDs(dbauthz.AsSystemRestricted(ctx), templateVersionIDs)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return xerrors.Errorf("get template versions: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
var err error
|
||||
// nolint:gocritic // Getting workspace resources by job ID is a system function.
|
||||
resources, err = api.Database.GetWorkspaceResourcesByJobIDs(dbauthz.AsSystemRestricted(ctx), jobIDs)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return xerrors.Errorf("get workspace resources by job: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err := eg.Wait(); err != nil {
|
||||
return workspaceBuildsData{}, err
|
||||
}
|
||||
|
||||
// nolint:gocritic // Getting workspace resources by job ID is a system function.
|
||||
resources, err := api.Database.GetWorkspaceResourcesByJobIDs(dbauthz.AsSystemRestricted(ctx), jobIDs)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return workspaceBuildsData{}, xerrors.Errorf("get workspace resources by job: %w", err)
|
||||
// Phase B: Get queue position and eligible daemons for pending
|
||||
// jobs only. The queue position query is expensive (cross-join
|
||||
// with all pending jobs and active daemons), so we only run it
|
||||
// for the small number of actually-pending jobs.
|
||||
var pendingJobIDs []uuid.UUID
|
||||
for _, job := range jobs {
|
||||
if job.JobStatus == database.ProvisionerJobStatusPending {
|
||||
pendingJobIDs = append(pendingJobIDs, job.ID)
|
||||
}
|
||||
}
|
||||
|
||||
var queuePositionRows []database.GetProvisionerJobsByIDsWithQueuePositionRow
|
||||
if len(pendingJobIDs) > 0 {
|
||||
var err error
|
||||
queuePositionRows, err = api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
||||
IDs: pendingJobIDs,
|
||||
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
||||
})
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return workspaceBuildsData{}, xerrors.Errorf("get provisioner jobs queue position: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var pendingJobProvisioners []database.GetEligibleProvisionerDaemonsByProvisionerJobIDsRow
|
||||
if len(pendingJobIDs) > 0 {
|
||||
var err error
|
||||
pendingJobProvisioners, err = api.Database.GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx, pendingJobIDs)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return workspaceBuildsData{}, xerrors.Errorf("get provisioner daemons: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Merge job rows with queue position information so downstream
|
||||
// consumers see the same type they expect.
|
||||
queuePositionByID := make(map[uuid.UUID]database.GetProvisionerJobsByIDsWithQueuePositionRow, len(queuePositionRows))
|
||||
for _, qpRow := range queuePositionRows {
|
||||
queuePositionByID[qpRow.ID] = qpRow
|
||||
}
|
||||
jobsWithQueuePosition := make([]database.GetProvisionerJobsByIDsWithQueuePositionRow, 0, len(jobs))
|
||||
for _, job := range jobs {
|
||||
row := database.GetProvisionerJobsByIDsWithQueuePositionRow{
|
||||
ID: job.ID,
|
||||
CreatedAt: job.CreatedAt,
|
||||
ProvisionerJob: job,
|
||||
QueuePosition: 0,
|
||||
QueueSize: 0,
|
||||
}
|
||||
if qpRow, ok := queuePositionByID[job.ID]; ok {
|
||||
row.QueuePosition = qpRow.QueuePosition
|
||||
row.QueueSize = qpRow.QueueSize
|
||||
}
|
||||
jobsWithQueuePosition = append(jobsWithQueuePosition, row)
|
||||
}
|
||||
|
||||
if len(resources) == 0 {
|
||||
return workspaceBuildsData{
|
||||
jobs: jobs,
|
||||
jobs: jobsWithQueuePosition,
|
||||
templateVersions: templateVersions,
|
||||
provisionerDaemons: pendingJobProvisioners,
|
||||
}, nil
|
||||
@@ -1041,21 +1101,38 @@ func (api *API) workspaceBuildsData(ctx context.Context, workspaceBuilds []datab
|
||||
resourceIDs = append(resourceIDs, resource.ID)
|
||||
}
|
||||
|
||||
// nolint:gocritic // Getting workspace resource metadata by resource ID is a system function.
|
||||
metadata, err := api.Database.GetWorkspaceResourceMetadataByResourceIDs(dbauthz.AsSystemRestricted(ctx), resourceIDs)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return workspaceBuildsData{}, xerrors.Errorf("fetching resource metadata: %w", err)
|
||||
// Phase C: Fetch metadata and agents in parallel (both depend
|
||||
// on resourceIDs which we just computed).
|
||||
var (
|
||||
metadata []database.WorkspaceResourceMetadatum
|
||||
agents []database.WorkspaceAgent
|
||||
)
|
||||
var eg2 errgroup.Group
|
||||
eg2.Go(func() error {
|
||||
var err error
|
||||
// nolint:gocritic // Getting workspace resource metadata by resource ID is a system function.
|
||||
metadata, err = api.Database.GetWorkspaceResourceMetadataByResourceIDs(dbauthz.AsSystemRestricted(ctx), resourceIDs)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return xerrors.Errorf("fetching resource metadata: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg2.Go(func() error {
|
||||
var err error
|
||||
// nolint:gocritic // Getting workspace agents by resource IDs is a system function.
|
||||
agents, err = api.Database.GetWorkspaceAgentsByResourceIDs(dbauthz.AsSystemRestricted(ctx), resourceIDs)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return xerrors.Errorf("get workspace agents: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err := eg2.Wait(); err != nil {
|
||||
return workspaceBuildsData{}, err
|
||||
}
|
||||
|
||||
// nolint:gocritic // Getting workspace agents by resource IDs is a system function.
|
||||
agents, err := api.Database.GetWorkspaceAgentsByResourceIDs(dbauthz.AsSystemRestricted(ctx), resourceIDs)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return workspaceBuildsData{}, xerrors.Errorf("get workspace agents: %w", err)
|
||||
}
|
||||
|
||||
if len(resources) == 0 {
|
||||
if len(agents) == 0 {
|
||||
return workspaceBuildsData{
|
||||
jobs: jobs,
|
||||
jobs: jobsWithQueuePosition,
|
||||
templateVersions: templateVersions,
|
||||
resources: resources,
|
||||
metadata: metadata,
|
||||
@@ -1074,23 +1151,23 @@ func (api *API) workspaceBuildsData(ctx context.Context, workspaceBuilds []datab
|
||||
logSources []database.WorkspaceAgentLogSource
|
||||
)
|
||||
|
||||
var eg errgroup.Group
|
||||
eg.Go(func() (err error) {
|
||||
var eg3 errgroup.Group
|
||||
eg3.Go(func() (err error) {
|
||||
// nolint:gocritic // Getting workspace apps by agent IDs is a system function.
|
||||
apps, err = api.Database.GetWorkspaceAppsByAgentIDs(dbauthz.AsSystemRestricted(ctx), agentIDs)
|
||||
return err
|
||||
})
|
||||
eg.Go(func() (err error) {
|
||||
eg3.Go(func() (err error) {
|
||||
// nolint:gocritic // Getting workspace scripts by agent IDs is a system function.
|
||||
scripts, err = api.Database.GetWorkspaceAgentScriptsByAgentIDs(dbauthz.AsSystemRestricted(ctx), agentIDs)
|
||||
return err
|
||||
})
|
||||
eg.Go(func() error {
|
||||
eg3.Go(func() (err error) {
|
||||
// nolint:gocritic // Getting workspace agent log sources by agent IDs is a system function.
|
||||
logSources, err = api.Database.GetWorkspaceAgentLogSourcesByAgentIDs(dbauthz.AsSystemRestricted(ctx), agentIDs)
|
||||
return err
|
||||
})
|
||||
err = eg.Wait()
|
||||
err := eg3.Wait()
|
||||
if err != nil {
|
||||
return workspaceBuildsData{}, err
|
||||
}
|
||||
@@ -1107,7 +1184,7 @@ func (api *API) workspaceBuildsData(ctx context.Context, workspaceBuilds []datab
|
||||
}
|
||||
|
||||
return workspaceBuildsData{
|
||||
jobs: jobs,
|
||||
jobs: jobsWithQueuePosition,
|
||||
templateVersions: templateVersions,
|
||||
resources: resources,
|
||||
metadata: metadata,
|
||||
|
||||
@@ -2487,11 +2487,7 @@ func (api *API) patchWorkspaceACL(rw http.ResponseWriter, r *http.Request) {
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
} else {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
}
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2570,7 +2566,7 @@ func (api *API) allowWorkspaceSharing(ctx context.Context, rw http.ResponseWrite
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return false
|
||||
}
|
||||
if org.ShareableWorkspaceOwners == database.ShareableWorkspaceOwnersNone {
|
||||
if org.WorkspaceSharingDisabled {
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Workspace sharing is disabled for this organization.",
|
||||
})
|
||||
|
||||
@@ -517,14 +517,14 @@ const (
|
||||
)
|
||||
|
||||
type AgentMetric struct {
|
||||
Name string `json:"name" validate:"required"`
|
||||
Type AgentMetricType `json:"type" enums:"counter,gauge" validate:"required"`
|
||||
Value float64 `json:"value" validate:"required"`
|
||||
Name string `json:"name" validate:"required"`
|
||||
Type AgentMetricType `json:"type" validate:"required" enums:"counter,gauge"`
|
||||
Value float64 `json:"value" validate:"required"`
|
||||
Labels []AgentMetricLabel `json:"labels,omitempty"`
|
||||
}
|
||||
|
||||
type AgentMetricLabel struct {
|
||||
Name string `json:"name" validate:"required"`
|
||||
Name string `json:"name" validate:"required"`
|
||||
Value string `json:"value" validate:"required"`
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
type AWSInstanceIdentityToken struct {
|
||||
Signature string `json:"signature" validate:"required"`
|
||||
Document string `json:"document" validate:"required"`
|
||||
Document string `json:"document" validate:"required"`
|
||||
}
|
||||
|
||||
// AWSSessionTokenExchanger exchanges AWS instance metadata for a Coder session token.
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
type AzureInstanceIdentityToken struct {
|
||||
Signature string `json:"signature" validate:"required"`
|
||||
Encoding string `json:"encoding" validate:"required"`
|
||||
Encoding string `json:"encoding" validate:"required"`
|
||||
}
|
||||
|
||||
// AzureSessionTokenExchanger exchanges Azure attested metadata for a Coder session token.
|
||||
|
||||
+13
-13
@@ -12,42 +12,42 @@ import (
|
||||
)
|
||||
|
||||
type AIBridgeInterception struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
APIKeyID *string `json:"api_key_id"`
|
||||
Initiator MinimalUser `json:"initiator"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Client *string `json:"client"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
StartedAt time.Time `json:"started_at" format:"date-time"`
|
||||
EndedAt *time.Time `json:"ended_at" format:"date-time"`
|
||||
StartedAt time.Time `json:"started_at" format:"date-time"`
|
||||
EndedAt *time.Time `json:"ended_at" format:"date-time"`
|
||||
TokenUsages []AIBridgeTokenUsage `json:"token_usages"`
|
||||
UserPrompts []AIBridgeUserPrompt `json:"user_prompts"`
|
||||
ToolUsages []AIBridgeToolUsage `json:"tool_usages"`
|
||||
}
|
||||
|
||||
type AIBridgeTokenUsage struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
InterceptionID uuid.UUID `json:"interception_id" format:"uuid"`
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
InterceptionID uuid.UUID `json:"interception_id" format:"uuid"`
|
||||
ProviderResponseID string `json:"provider_response_id"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
}
|
||||
|
||||
type AIBridgeUserPrompt struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
InterceptionID uuid.UUID `json:"interception_id" format:"uuid"`
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
InterceptionID uuid.UUID `json:"interception_id" format:"uuid"`
|
||||
ProviderResponseID string `json:"provider_response_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
}
|
||||
|
||||
type AIBridgeToolUsage struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
InterceptionID uuid.UUID `json:"interception_id" format:"uuid"`
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
InterceptionID uuid.UUID `json:"interception_id" format:"uuid"`
|
||||
ProviderResponseID string `json:"provider_response_id"`
|
||||
ServerURL string `json:"server_url"`
|
||||
Tool string `json:"tool"`
|
||||
@@ -55,7 +55,7 @@ type AIBridgeToolUsage struct {
|
||||
Injected bool `json:"injected"`
|
||||
InvocationError string `json:"invocation_error"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
}
|
||||
|
||||
type AIBridgeListInterceptionsResponse struct {
|
||||
@@ -73,7 +73,7 @@ type AIBridgeListInterceptionsFilter struct {
|
||||
// Initiator is a user ID, username, or "me".
|
||||
Initiator string `json:"initiator,omitempty"`
|
||||
StartedBefore time.Time `json:"started_before,omitempty" format:"date-time"`
|
||||
StartedAfter time.Time `json:"started_after,omitempty" format:"date-time"`
|
||||
StartedAfter time.Time `json:"started_after,omitempty" format:"date-time"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Client string `json:"client,omitempty"`
|
||||
|
||||
+32
-32
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
// CreateTaskRequest represents the request to create a new task.
|
||||
type CreateTaskRequest struct {
|
||||
TemplateVersionID uuid.UUID `json:"template_version_id" format:"uuid"`
|
||||
TemplateVersionID uuid.UUID `json:"template_version_id" format:"uuid"`
|
||||
TemplateVersionPresetID uuid.UUID `json:"template_version_preset_id,omitempty" format:"uuid"`
|
||||
Input string `json:"input"`
|
||||
Name string `json:"name,omitempty"`
|
||||
@@ -98,39 +98,39 @@ const (
|
||||
|
||||
// Task represents a task.
|
||||
type Task struct {
|
||||
ID uuid.UUID `json:"id" table:"id" format:"uuid"`
|
||||
OrganizationID uuid.UUID `json:"organization_id" table:"organization id" format:"uuid"`
|
||||
OwnerID uuid.UUID `json:"owner_id" table:"owner id" format:"uuid"`
|
||||
OwnerName string `json:"owner_name" table:"owner name"`
|
||||
OwnerAvatarURL string `json:"owner_avatar_url,omitempty" table:"owner avatar url"`
|
||||
Name string `json:"name" table:"name,default_sort"`
|
||||
DisplayName string `json:"display_name" table:"display_name"`
|
||||
TemplateID uuid.UUID `json:"template_id" table:"template id" format:"uuid"`
|
||||
TemplateVersionID uuid.UUID `json:"template_version_id" table:"template version id" format:"uuid"`
|
||||
TemplateName string `json:"template_name" table:"template name"`
|
||||
TemplateDisplayName string `json:"template_display_name" table:"template display name"`
|
||||
TemplateIcon string `json:"template_icon" table:"template icon"`
|
||||
WorkspaceID uuid.NullUUID `json:"workspace_id" table:"workspace id" format:"uuid"`
|
||||
WorkspaceName string `json:"workspace_name" table:"workspace name"`
|
||||
WorkspaceStatus WorkspaceStatus `json:"workspace_status,omitempty" table:"workspace status" enums:"pending,starting,running,stopping,stopped,failed,canceling,canceled,deleting,deleted"`
|
||||
ID uuid.UUID `json:"id" format:"uuid" table:"id"`
|
||||
OrganizationID uuid.UUID `json:"organization_id" format:"uuid" table:"organization id"`
|
||||
OwnerID uuid.UUID `json:"owner_id" format:"uuid" table:"owner id"`
|
||||
OwnerName string `json:"owner_name" table:"owner name"`
|
||||
OwnerAvatarURL string `json:"owner_avatar_url,omitempty" table:"owner avatar url"`
|
||||
Name string `json:"name" table:"name,default_sort"`
|
||||
DisplayName string `json:"display_name" table:"display_name"`
|
||||
TemplateID uuid.UUID `json:"template_id" format:"uuid" table:"template id"`
|
||||
TemplateVersionID uuid.UUID `json:"template_version_id" format:"uuid" table:"template version id"`
|
||||
TemplateName string `json:"template_name" table:"template name"`
|
||||
TemplateDisplayName string `json:"template_display_name" table:"template display name"`
|
||||
TemplateIcon string `json:"template_icon" table:"template icon"`
|
||||
WorkspaceID uuid.NullUUID `json:"workspace_id" format:"uuid" table:"workspace id"`
|
||||
WorkspaceName string `json:"workspace_name" table:"workspace name"`
|
||||
WorkspaceStatus WorkspaceStatus `json:"workspace_status,omitempty" enums:"pending,starting,running,stopping,stopped,failed,canceling,canceled,deleting,deleted" table:"workspace status"`
|
||||
WorkspaceBuildNumber int32 `json:"workspace_build_number,omitempty" table:"workspace build number"`
|
||||
WorkspaceAgentID uuid.NullUUID `json:"workspace_agent_id" table:"workspace agent id" format:"uuid"`
|
||||
WorkspaceAgentLifecycle *WorkspaceAgentLifecycle `json:"workspace_agent_lifecycle" table:"workspace agent lifecycle"`
|
||||
WorkspaceAgentHealth *WorkspaceAgentHealth `json:"workspace_agent_health" table:"workspace agent health"`
|
||||
WorkspaceAppID uuid.NullUUID `json:"workspace_app_id" table:"workspace app id" format:"uuid"`
|
||||
InitialPrompt string `json:"initial_prompt" table:"initial prompt"`
|
||||
Status TaskStatus `json:"status" table:"status" enums:"pending,initializing,active,paused,unknown,error"`
|
||||
CurrentState *TaskStateEntry `json:"current_state" table:"cs,recursive_inline,empty_nil"`
|
||||
CreatedAt time.Time `json:"created_at" table:"created at" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at" table:"updated at" format:"date-time"`
|
||||
WorkspaceAgentID uuid.NullUUID `json:"workspace_agent_id" format:"uuid" table:"workspace agent id"`
|
||||
WorkspaceAgentLifecycle *WorkspaceAgentLifecycle `json:"workspace_agent_lifecycle" table:"workspace agent lifecycle"`
|
||||
WorkspaceAgentHealth *WorkspaceAgentHealth `json:"workspace_agent_health" table:"workspace agent health"`
|
||||
WorkspaceAppID uuid.NullUUID `json:"workspace_app_id" format:"uuid" table:"workspace app id"`
|
||||
InitialPrompt string `json:"initial_prompt" table:"initial prompt"`
|
||||
Status TaskStatus `json:"status" enums:"pending,initializing,active,paused,unknown,error" table:"status"`
|
||||
CurrentState *TaskStateEntry `json:"current_state" table:"cs,recursive_inline,empty_nil"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time" table:"created at"`
|
||||
UpdatedAt time.Time `json:"updated_at" format:"date-time" table:"updated at"`
|
||||
}
|
||||
|
||||
// TaskStateEntry represents a single entry in the task's state history.
|
||||
type TaskStateEntry struct {
|
||||
Timestamp time.Time `json:"timestamp" table:"-" format:"date-time"`
|
||||
State TaskState `json:"state" table:"state" enum:"working,idle,completed,failed"`
|
||||
Message string `json:"message" table:"message"`
|
||||
URI string `json:"uri" table:"-"`
|
||||
Timestamp time.Time `json:"timestamp" format:"date-time" table:"-"`
|
||||
State TaskState `json:"state" enum:"working,idle,completed,failed" table:"state"`
|
||||
Message string `json:"message" table:"message"`
|
||||
URI string `json:"uri" table:"-"`
|
||||
}
|
||||
|
||||
// TasksFilter filters the list of tasks.
|
||||
@@ -387,10 +387,10 @@ const (
|
||||
|
||||
// TaskLogEntry represents a single log entry for a task.
|
||||
type TaskLogEntry struct {
|
||||
ID int `json:"id" table:"id"`
|
||||
ID int `json:"id" table:"id"`
|
||||
Content string `json:"content" table:"content"`
|
||||
Type TaskLogType `json:"type" table:"type" enum:"input,output"`
|
||||
Time time.Time `json:"time" table:"time,default_sort" format:"date-time"`
|
||||
Type TaskLogType `json:"type" enum:"input,output" table:"type"`
|
||||
Time time.Time `json:"time" format:"date-time" table:"time,default_sort"`
|
||||
}
|
||||
|
||||
// TaskLogsResponse contains task logs and metadata. When snapshot is false,
|
||||
|
||||
+10
-10
@@ -12,17 +12,17 @@ import (
|
||||
|
||||
// APIKey: do not ever return the HashedSecret
|
||||
type APIKey struct {
|
||||
ID string `json:"id" validate:"required"`
|
||||
UserID uuid.UUID `json:"user_id" format:"uuid" validate:"required"`
|
||||
LastUsed time.Time `json:"last_used" format:"date-time" validate:"required"`
|
||||
ExpiresAt time.Time `json:"expires_at" format:"date-time" validate:"required"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time" validate:"required"`
|
||||
UpdatedAt time.Time `json:"updated_at" format:"date-time" validate:"required"`
|
||||
LoginType LoginType `json:"login_type" enums:"password,github,oidc,token" validate:"required"`
|
||||
Scope APIKeyScope `json:"scope" enums:"all,application_connect"` // Deprecated: use Scopes instead.
|
||||
ID string `json:"id" validate:"required"`
|
||||
UserID uuid.UUID `json:"user_id" validate:"required" format:"uuid"`
|
||||
LastUsed time.Time `json:"last_used" validate:"required" format:"date-time"`
|
||||
ExpiresAt time.Time `json:"expires_at" validate:"required" format:"date-time"`
|
||||
CreatedAt time.Time `json:"created_at" validate:"required" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at" validate:"required" format:"date-time"`
|
||||
LoginType LoginType `json:"login_type" validate:"required" enums:"password,github,oidc,token"`
|
||||
Scope APIKeyScope `json:"scope" enums:"all,application_connect"` // Deprecated: use Scopes instead.
|
||||
Scopes []APIKeyScope `json:"scopes"`
|
||||
TokenName string `json:"token_name" validate:"required"`
|
||||
LifetimeSeconds int64 `json:"lifetime_seconds" validate:"required"`
|
||||
TokenName string `json:"token_name" validate:"required"`
|
||||
LifetimeSeconds int64 `json:"lifetime_seconds" validate:"required"`
|
||||
AllowList []APIAllowListTarget `json:"allow_list"`
|
||||
}
|
||||
|
||||
|
||||
+11
-11
@@ -178,13 +178,13 @@ type AuditDiffField struct {
|
||||
}
|
||||
|
||||
type AuditLog struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
RequestID uuid.UUID `json:"request_id" format:"uuid"`
|
||||
Time time.Time `json:"time" format:"date-time"`
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
RequestID uuid.UUID `json:"request_id" format:"uuid"`
|
||||
Time time.Time `json:"time" format:"date-time"`
|
||||
IP netip.Addr `json:"ip"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
ResourceType ResourceType `json:"resource_type"`
|
||||
ResourceID uuid.UUID `json:"resource_id" format:"uuid"`
|
||||
ResourceID uuid.UUID `json:"resource_id" format:"uuid"`
|
||||
// ResourceTarget is the name of the resource.
|
||||
ResourceTarget string `json:"resource_target"`
|
||||
ResourceIcon string `json:"resource_icon"`
|
||||
@@ -215,14 +215,14 @@ type AuditLogResponse struct {
|
||||
}
|
||||
|
||||
type CreateTestAuditLogRequest struct {
|
||||
Action AuditAction `json:"action,omitempty" enums:"create,write,delete,start,stop"`
|
||||
ResourceType ResourceType `json:"resource_type,omitempty" enums:"template,template_version,user,workspace,workspace_build,git_ssh_key,auditable_group"`
|
||||
ResourceID uuid.UUID `json:"resource_id,omitempty" format:"uuid"`
|
||||
Action AuditAction `json:"action,omitempty" enums:"create,write,delete,start,stop"`
|
||||
ResourceType ResourceType `json:"resource_type,omitempty" enums:"template,template_version,user,workspace,workspace_build,git_ssh_key,auditable_group"`
|
||||
ResourceID uuid.UUID `json:"resource_id,omitempty" format:"uuid"`
|
||||
AdditionalFields json.RawMessage `json:"additional_fields,omitempty"`
|
||||
Time time.Time `json:"time,omitempty" format:"date-time"`
|
||||
BuildReason BuildReason `json:"build_reason,omitempty" enums:"autostart,autostop,initiator"`
|
||||
OrganizationID uuid.UUID `json:"organization_id,omitempty" format:"uuid"`
|
||||
RequestID uuid.UUID `json:"request_id,omitempty" format:"uuid"`
|
||||
Time time.Time `json:"time,omitempty" format:"date-time"`
|
||||
BuildReason BuildReason `json:"build_reason,omitempty" enums:"autostart,autostop,initiator"`
|
||||
OrganizationID uuid.UUID `json:"organization_id,omitempty" format:"uuid"`
|
||||
RequestID uuid.UUID `json:"request_id,omitempty" format:"uuid"`
|
||||
}
|
||||
|
||||
// AuditLogs retrieves audit logs from the given page.
|
||||
|
||||
+182
-544
File diff suppressed because it is too large
Load Diff
@@ -1,15 +1,8 @@
|
||||
package codersdk_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/shopspring/decimal"
|
||||
@@ -62,81 +55,6 @@ func TestChatModelProviderOptions_UnmarshalJSON_ParsesPlainProviderPayloads(t *t
|
||||
)
|
||||
}
|
||||
|
||||
func TestChatUsageLimitExceededFrom(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ExtractsTyped409", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
want := codersdk.ChatUsageLimitExceededResponse{
|
||||
Response: codersdk.Response{Message: "Chat usage limit exceeded."},
|
||||
SpentMicros: 123,
|
||||
LimitMicros: 456,
|
||||
ResetsAt: time.Date(2026, time.March, 16, 12, 0, 0, 0, time.UTC),
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodPost, r.Method)
|
||||
require.Equal(t, "/api/experimental/chats", r.URL.Path)
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusConflict)
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(want))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
serverURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := codersdk.New(serverURL)
|
||||
_, err = client.CreateChat(context.Background(), codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello",
|
||||
}},
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
sdkErr, ok := codersdk.AsError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
require.Equal(t, want.Message, sdkErr.Message)
|
||||
|
||||
limitErr := codersdk.ChatUsageLimitExceededFrom(err)
|
||||
require.NotNil(t, limitErr)
|
||||
require.Equal(t, want, *limitErr)
|
||||
})
|
||||
|
||||
t.Run("ReturnsNilForNonLimitErrors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Nil(t, codersdk.ChatUsageLimitExceededFrom(codersdk.NewError(http.StatusConflict, codersdk.Response{Message: "plain conflict"})))
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.Response{Message: "Invalid request."}))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
serverURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := codersdk.New(serverURL)
|
||||
_, err = client.CreateChat(context.Background(), codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello",
|
||||
}},
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
sdkErr, ok := codersdk.AsError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
require.Nil(t, codersdk.ChatUsageLimitExceededFrom(err))
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatMessagePart_StripInternal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -193,85 +111,6 @@ func TestChatMessagePart_StripInternal(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestChatMessagePartVariantTags validates the `variants` struct tags
|
||||
// on ChatMessagePart fields. Every field must either declare variant
|
||||
// membership or be explicitly excluded, and every known part type
|
||||
// must appear in at least one tag.
|
||||
//
|
||||
// If this test fails, edit the variants struct tags on ChatMessagePart
|
||||
// in codersdk/chats.go.
|
||||
func TestChatMessagePartVariantTags(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const editHint = "edit the variants struct tags on ChatMessagePart in codersdk/chats.go"
|
||||
|
||||
// Fields intentionally excluded from all generated variants.
|
||||
// If you add a new field to ChatMessagePart, either add a
|
||||
// variants tag or add it here with a comment explaining why.
|
||||
excludedFields := map[string]string{
|
||||
"type": "discriminant, added automatically by codegen",
|
||||
"signature": "added in #22290, never populated by any code path",
|
||||
"result_delta": "added in #22290, never populated by any code path",
|
||||
"provider_metadata": "internal only, stripped by db2sdk before API responses",
|
||||
}
|
||||
|
||||
knownTypes := make(map[codersdk.ChatMessagePartType]bool)
|
||||
for _, pt := range codersdk.AllChatMessagePartTypes() {
|
||||
knownTypes[pt] = true
|
||||
}
|
||||
|
||||
// Parse all variants tags from the struct and validate them.
|
||||
typ := reflect.TypeOf(codersdk.ChatMessagePart{})
|
||||
coveredTypes := make(map[codersdk.ChatMessagePartType]bool)
|
||||
hasRequired := make(map[codersdk.ChatMessagePartType]bool)
|
||||
|
||||
for i := range typ.NumField() {
|
||||
f := typ.Field(i)
|
||||
jsonTag := f.Tag.Get("json")
|
||||
if jsonTag == "" || jsonTag == "-" {
|
||||
continue
|
||||
}
|
||||
jsonName, _, _ := strings.Cut(jsonTag, ",")
|
||||
|
||||
varTag := f.Tag.Get("variants")
|
||||
if varTag == "" {
|
||||
assert.Contains(t, excludedFields, jsonName,
|
||||
"field %s (json:%q) has no variants tag and is not in excludedFields; %s",
|
||||
f.Name, jsonName, editHint)
|
||||
continue
|
||||
}
|
||||
|
||||
assert.NotEqual(t, "type", jsonName,
|
||||
"the discriminant field must not have a variants tag; %s", editHint)
|
||||
|
||||
for _, entry := range strings.Split(varTag, ",") {
|
||||
isOptional := strings.HasSuffix(entry, "?")
|
||||
typeLit := codersdk.ChatMessagePartType(strings.TrimSuffix(entry, "?"))
|
||||
|
||||
assert.True(t, knownTypes[typeLit],
|
||||
"field %s variants tag references unknown type %q; %s",
|
||||
f.Name, typeLit, editHint)
|
||||
|
||||
coveredTypes[typeLit] = true
|
||||
if !isOptional {
|
||||
hasRequired[typeLit] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Every known type must appear in at least one variants tag.
|
||||
for pt := range knownTypes {
|
||||
assert.True(t, coveredTypes[pt],
|
||||
"ChatMessagePartType %q is not referenced by any variants tag; %s", pt, editHint)
|
||||
}
|
||||
|
||||
// Every variant must have at least one required field.
|
||||
for pt := range coveredTypes {
|
||||
assert.True(t, hasRequired[pt],
|
||||
"variant %q has no required fields (all have ? suffix); %s", pt, editHint)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelCostConfig_LegacyNumericJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
+1
-1
@@ -587,7 +587,7 @@ type Response struct {
|
||||
|
||||
// ValidationError represents a scoped error to a user input.
|
||||
type ValidationError struct {
|
||||
Field string `json:"field" validate:"required"`
|
||||
Field string `json:"field" validate:"required"`
|
||||
Detail string `json:"detail" validate:"required"`
|
||||
}
|
||||
|
||||
|
||||
@@ -12,12 +12,12 @@ import (
|
||||
)
|
||||
|
||||
type ConnectionLog struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
ConnectTime time.Time `json:"connect_time" format:"date-time"`
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
ConnectTime time.Time `json:"connect_time" format:"date-time"`
|
||||
Organization MinimalOrganization `json:"organization"`
|
||||
WorkspaceOwnerID uuid.UUID `json:"workspace_owner_id" format:"uuid"`
|
||||
WorkspaceOwnerID uuid.UUID `json:"workspace_owner_id" format:"uuid"`
|
||||
WorkspaceOwnerUsername string `json:"workspace_owner_username"`
|
||||
WorkspaceID uuid.UUID `json:"workspace_id" format:"uuid"`
|
||||
WorkspaceID uuid.UUID `json:"workspace_id" format:"uuid"`
|
||||
WorkspaceName string `json:"workspace_name"`
|
||||
AgentName string `json:"agent_name"`
|
||||
IP *netip.Addr `json:"ip,omitempty"`
|
||||
|
||||
+185
-185
@@ -386,8 +386,8 @@ type Feature struct {
|
||||
|
||||
type UsagePeriod struct {
|
||||
IssuedAt time.Time `json:"issued_at" format:"date-time"`
|
||||
Start time.Time `json:"start" format:"date-time"`
|
||||
End time.Time `json:"end" format:"date-time"`
|
||||
Start time.Time `json:"start" format:"date-time"`
|
||||
End time.Time `json:"end" format:"date-time"`
|
||||
}
|
||||
|
||||
// Compare compares two features and returns an integer representing
|
||||
@@ -499,7 +499,7 @@ type Entitlements struct {
|
||||
HasLicense bool `json:"has_license"`
|
||||
Trial bool `json:"trial"`
|
||||
RequireTelemetry bool `json:"require_telemetry"`
|
||||
RefreshedAt time.Time `json:"refreshed_at" format:"date-time"`
|
||||
RefreshedAt time.Time `json:"refreshed_at" format:"date-time"`
|
||||
}
|
||||
|
||||
// AddFeature will add the feature to the entitlements iff it expands
|
||||
@@ -579,71 +579,71 @@ type DeploymentValues struct {
|
||||
DocsURL serpent.URL `json:"docs_url,omitempty"`
|
||||
RedirectToAccessURL serpent.Bool `json:"redirect_to_access_url,omitempty"`
|
||||
// HTTPAddress is a string because it may be set to zero to disable.
|
||||
HTTPAddress serpent.String `json:"http_address,omitempty" typescript:",notnull"`
|
||||
HTTPAddress serpent.String `json:"http_address,omitempty" typescript:",notnull"`
|
||||
AutobuildPollInterval serpent.Duration `json:"autobuild_poll_interval,omitempty"`
|
||||
JobReaperDetectorInterval serpent.Duration `json:"job_hang_detector_interval,omitempty"`
|
||||
DERP DERP `json:"derp,omitempty" typescript:",notnull"`
|
||||
Prometheus PrometheusConfig `json:"prometheus,omitempty" typescript:",notnull"`
|
||||
Pprof PprofConfig `json:"pprof,omitempty" typescript:",notnull"`
|
||||
ProxyTrustedHeaders serpent.StringArray `json:"proxy_trusted_headers,omitempty" typescript:",notnull"`
|
||||
ProxyTrustedOrigins serpent.StringArray `json:"proxy_trusted_origins,omitempty" typescript:",notnull"`
|
||||
CacheDir serpent.String `json:"cache_directory,omitempty" typescript:",notnull"`
|
||||
EphemeralDeployment serpent.Bool `json:"ephemeral_deployment,omitempty" typescript:",notnull"`
|
||||
PostgresURL serpent.String `json:"pg_connection_url,omitempty" typescript:",notnull"`
|
||||
PostgresAuth string `json:"pg_auth,omitempty" typescript:",notnull"`
|
||||
PostgresConnMaxOpen serpent.Int64 `json:"pg_conn_max_open,omitempty" typescript:",notnull"`
|
||||
PostgresConnMaxIdle serpent.String `json:"pg_conn_max_idle,omitempty" typescript:",notnull"`
|
||||
OAuth2 OAuth2Config `json:"oauth2,omitempty" typescript:",notnull"`
|
||||
OIDC OIDCConfig `json:"oidc,omitempty" typescript:",notnull"`
|
||||
Telemetry TelemetryConfig `json:"telemetry,omitempty" typescript:",notnull"`
|
||||
TLS TLSConfig `json:"tls,omitempty" typescript:",notnull"`
|
||||
Trace TraceConfig `json:"trace,omitempty" typescript:",notnull"`
|
||||
HTTPCookies HTTPCookieConfig `json:"http_cookies,omitempty" typescript:",notnull"`
|
||||
StrictTransportSecurity serpent.Int64 `json:"strict_transport_security,omitempty" typescript:",notnull"`
|
||||
StrictTransportSecurityOptions serpent.StringArray `json:"strict_transport_security_options,omitempty" typescript:",notnull"`
|
||||
SSHKeygenAlgorithm serpent.String `json:"ssh_keygen_algorithm,omitempty" typescript:",notnull"`
|
||||
MetricsCacheRefreshInterval serpent.Duration `json:"metrics_cache_refresh_interval,omitempty" typescript:",notnull"`
|
||||
AgentStatRefreshInterval serpent.Duration `json:"agent_stat_refresh_interval,omitempty" typescript:",notnull"`
|
||||
AgentFallbackTroubleshootingURL serpent.URL `json:"agent_fallback_troubleshooting_url,omitempty" typescript:",notnull"`
|
||||
BrowserOnly serpent.Bool `json:"browser_only,omitempty" typescript:",notnull"`
|
||||
SCIMAPIKey serpent.String `json:"scim_api_key,omitempty" typescript:",notnull"`
|
||||
ExternalTokenEncryptionKeys serpent.StringArray `json:"external_token_encryption_keys,omitempty" typescript:",notnull"`
|
||||
Provisioner ProvisionerConfig `json:"provisioner,omitempty" typescript:",notnull"`
|
||||
RateLimit RateLimitConfig `json:"rate_limit,omitempty" typescript:",notnull"`
|
||||
Experiments serpent.StringArray `json:"experiments,omitempty" typescript:",notnull"`
|
||||
UpdateCheck serpent.Bool `json:"update_check,omitempty" typescript:",notnull"`
|
||||
Swagger SwaggerConfig `json:"swagger,omitempty" typescript:",notnull"`
|
||||
Logging LoggingConfig `json:"logging,omitempty" typescript:",notnull"`
|
||||
Dangerous DangerousConfig `json:"dangerous,omitempty" typescript:",notnull"`
|
||||
DisablePathApps serpent.Bool `json:"disable_path_apps,omitempty" typescript:",notnull"`
|
||||
Sessions SessionLifetime `json:"session_lifetime,omitempty" typescript:",notnull"`
|
||||
DisablePasswordAuth serpent.Bool `json:"disable_password_auth,omitempty" typescript:",notnull"`
|
||||
Support SupportConfig `json:"support,omitempty" typescript:",notnull"`
|
||||
EnableAuthzRecording serpent.Bool `json:"enable_authz_recording,omitempty" typescript:",notnull"`
|
||||
ExternalAuthConfigs serpent.Struct[[]ExternalAuthConfig] `json:"external_auth,omitempty" typescript:",notnull"`
|
||||
DERP DERP `json:"derp,omitempty" typescript:",notnull"`
|
||||
Prometheus PrometheusConfig `json:"prometheus,omitempty" typescript:",notnull"`
|
||||
Pprof PprofConfig `json:"pprof,omitempty" typescript:",notnull"`
|
||||
ProxyTrustedHeaders serpent.StringArray `json:"proxy_trusted_headers,omitempty" typescript:",notnull"`
|
||||
ProxyTrustedOrigins serpent.StringArray `json:"proxy_trusted_origins,omitempty" typescript:",notnull"`
|
||||
CacheDir serpent.String `json:"cache_directory,omitempty" typescript:",notnull"`
|
||||
EphemeralDeployment serpent.Bool `json:"ephemeral_deployment,omitempty" typescript:",notnull"`
|
||||
PostgresURL serpent.String `json:"pg_connection_url,omitempty" typescript:",notnull"`
|
||||
PostgresAuth string `json:"pg_auth,omitempty" typescript:",notnull"`
|
||||
PostgresConnMaxOpen serpent.Int64 `json:"pg_conn_max_open,omitempty" typescript:",notnull"`
|
||||
PostgresConnMaxIdle serpent.String `json:"pg_conn_max_idle,omitempty" typescript:",notnull"`
|
||||
OAuth2 OAuth2Config `json:"oauth2,omitempty" typescript:",notnull"`
|
||||
OIDC OIDCConfig `json:"oidc,omitempty" typescript:",notnull"`
|
||||
Telemetry TelemetryConfig `json:"telemetry,omitempty" typescript:",notnull"`
|
||||
TLS TLSConfig `json:"tls,omitempty" typescript:",notnull"`
|
||||
Trace TraceConfig `json:"trace,omitempty" typescript:",notnull"`
|
||||
HTTPCookies HTTPCookieConfig `json:"http_cookies,omitempty" typescript:",notnull"`
|
||||
StrictTransportSecurity serpent.Int64 `json:"strict_transport_security,omitempty" typescript:",notnull"`
|
||||
StrictTransportSecurityOptions serpent.StringArray `json:"strict_transport_security_options,omitempty" typescript:",notnull"`
|
||||
SSHKeygenAlgorithm serpent.String `json:"ssh_keygen_algorithm,omitempty" typescript:",notnull"`
|
||||
MetricsCacheRefreshInterval serpent.Duration `json:"metrics_cache_refresh_interval,omitempty" typescript:",notnull"`
|
||||
AgentStatRefreshInterval serpent.Duration `json:"agent_stat_refresh_interval,omitempty" typescript:",notnull"`
|
||||
AgentFallbackTroubleshootingURL serpent.URL `json:"agent_fallback_troubleshooting_url,omitempty" typescript:",notnull"`
|
||||
BrowserOnly serpent.Bool `json:"browser_only,omitempty" typescript:",notnull"`
|
||||
SCIMAPIKey serpent.String `json:"scim_api_key,omitempty" typescript:",notnull"`
|
||||
ExternalTokenEncryptionKeys serpent.StringArray `json:"external_token_encryption_keys,omitempty" typescript:",notnull"`
|
||||
Provisioner ProvisionerConfig `json:"provisioner,omitempty" typescript:",notnull"`
|
||||
RateLimit RateLimitConfig `json:"rate_limit,omitempty" typescript:",notnull"`
|
||||
Experiments serpent.StringArray `json:"experiments,omitempty" typescript:",notnull"`
|
||||
UpdateCheck serpent.Bool `json:"update_check,omitempty" typescript:",notnull"`
|
||||
Swagger SwaggerConfig `json:"swagger,omitempty" typescript:",notnull"`
|
||||
Logging LoggingConfig `json:"logging,omitempty" typescript:",notnull"`
|
||||
Dangerous DangerousConfig `json:"dangerous,omitempty" typescript:",notnull"`
|
||||
DisablePathApps serpent.Bool `json:"disable_path_apps,omitempty" typescript:",notnull"`
|
||||
Sessions SessionLifetime `json:"session_lifetime,omitempty" typescript:",notnull"`
|
||||
DisablePasswordAuth serpent.Bool `json:"disable_password_auth,omitempty" typescript:",notnull"`
|
||||
Support SupportConfig `json:"support,omitempty" typescript:",notnull"`
|
||||
EnableAuthzRecording serpent.Bool `json:"enable_authz_recording,omitempty" typescript:",notnull"`
|
||||
ExternalAuthConfigs serpent.Struct[[]ExternalAuthConfig] `json:"external_auth,omitempty" typescript:",notnull"`
|
||||
ExternalAuthGithubDefaultProviderEnable serpent.Bool `json:"external_auth_github_default_provider_enable,omitempty" typescript:",notnull"`
|
||||
SSHConfig SSHConfig `json:"config_ssh,omitempty" typescript:",notnull"`
|
||||
WgtunnelHost serpent.String `json:"wgtunnel_host,omitempty" typescript:",notnull"`
|
||||
DisableOwnerWorkspaceExec serpent.Bool `json:"disable_owner_workspace_exec,omitempty" typescript:",notnull"`
|
||||
DisableWorkspaceSharing serpent.Bool `json:"disable_workspace_sharing,omitempty" typescript:",notnull"`
|
||||
ProxyHealthStatusInterval serpent.Duration `json:"proxy_health_status_interval,omitempty" typescript:",notnull"`
|
||||
EnableTerraformDebugMode serpent.Bool `json:"enable_terraform_debug_mode,omitempty" typescript:",notnull"`
|
||||
UserQuietHoursSchedule UserQuietHoursScheduleConfig `json:"user_quiet_hours_schedule,omitempty" typescript:",notnull"`
|
||||
WebTerminalRenderer serpent.String `json:"web_terminal_renderer,omitempty" typescript:",notnull"`
|
||||
AllowWorkspaceRenames serpent.Bool `json:"allow_workspace_renames,omitempty" typescript:",notnull"`
|
||||
Healthcheck HealthcheckConfig `json:"healthcheck,omitempty" typescript:",notnull"`
|
||||
Retention RetentionConfig `json:"retention,omitempty" typescript:",notnull"`
|
||||
CLIUpgradeMessage serpent.String `json:"cli_upgrade_message,omitempty" typescript:",notnull"`
|
||||
TermsOfServiceURL serpent.String `json:"terms_of_service_url,omitempty" typescript:",notnull"`
|
||||
Notifications NotificationsConfig `json:"notifications,omitempty" typescript:",notnull"`
|
||||
AdditionalCSPPolicy serpent.StringArray `json:"additional_csp_policy,omitempty" typescript:",notnull"`
|
||||
WorkspaceHostnameSuffix serpent.String `json:"workspace_hostname_suffix,omitempty" typescript:",notnull"`
|
||||
Prebuilds PrebuildsConfig `json:"workspace_prebuilds,omitempty" typescript:",notnull"`
|
||||
HideAITasks serpent.Bool `json:"hide_ai_tasks,omitempty" typescript:",notnull"`
|
||||
SSHConfig SSHConfig `json:"config_ssh,omitempty" typescript:",notnull"`
|
||||
WgtunnelHost serpent.String `json:"wgtunnel_host,omitempty" typescript:",notnull"`
|
||||
DisableOwnerWorkspaceExec serpent.Bool `json:"disable_owner_workspace_exec,omitempty" typescript:",notnull"`
|
||||
DisableWorkspaceSharing serpent.Bool `json:"disable_workspace_sharing,omitempty" typescript:",notnull"`
|
||||
ProxyHealthStatusInterval serpent.Duration `json:"proxy_health_status_interval,omitempty" typescript:",notnull"`
|
||||
EnableTerraformDebugMode serpent.Bool `json:"enable_terraform_debug_mode,omitempty" typescript:",notnull"`
|
||||
UserQuietHoursSchedule UserQuietHoursScheduleConfig `json:"user_quiet_hours_schedule,omitempty" typescript:",notnull"`
|
||||
WebTerminalRenderer serpent.String `json:"web_terminal_renderer,omitempty" typescript:",notnull"`
|
||||
AllowWorkspaceRenames serpent.Bool `json:"allow_workspace_renames,omitempty" typescript:",notnull"`
|
||||
Healthcheck HealthcheckConfig `json:"healthcheck,omitempty" typescript:",notnull"`
|
||||
Retention RetentionConfig `json:"retention,omitempty" typescript:",notnull"`
|
||||
CLIUpgradeMessage serpent.String `json:"cli_upgrade_message,omitempty" typescript:",notnull"`
|
||||
TermsOfServiceURL serpent.String `json:"terms_of_service_url,omitempty" typescript:",notnull"`
|
||||
Notifications NotificationsConfig `json:"notifications,omitempty" typescript:",notnull"`
|
||||
AdditionalCSPPolicy serpent.StringArray `json:"additional_csp_policy,omitempty" typescript:",notnull"`
|
||||
WorkspaceHostnameSuffix serpent.String `json:"workspace_hostname_suffix,omitempty" typescript:",notnull"`
|
||||
Prebuilds PrebuildsConfig `json:"workspace_prebuilds,omitempty" typescript:",notnull"`
|
||||
HideAITasks serpent.Bool `json:"hide_ai_tasks,omitempty" typescript:",notnull"`
|
||||
AI AIConfig `json:"ai,omitempty"`
|
||||
StatsCollection StatsCollectionConfig `json:"stats_collection,omitempty" typescript:",notnull"`
|
||||
StatsCollection StatsCollectionConfig `json:"stats_collection,omitempty" typescript:",notnull"`
|
||||
|
||||
Config serpent.YAMLConfigPath `json:"config,omitempty" typescript:",notnull"`
|
||||
Config serpent.YAMLConfigPath `json:"config,omitempty" typescript:",notnull"`
|
||||
WriteConfig serpent.Bool `json:"write_config,omitempty" typescript:",notnull"`
|
||||
|
||||
// Deprecated: Use HTTPAddress or TLS.Address instead.
|
||||
@@ -726,19 +726,19 @@ type DERP struct {
|
||||
}
|
||||
|
||||
type DERPServerConfig struct {
|
||||
Enable serpent.Bool `json:"enable" typescript:",notnull"`
|
||||
RegionID serpent.Int64 `json:"region_id" typescript:",notnull"`
|
||||
RegionCode serpent.String `json:"region_code" typescript:",notnull"`
|
||||
RegionName serpent.String `json:"region_name" typescript:",notnull"`
|
||||
Enable serpent.Bool `json:"enable" typescript:",notnull"`
|
||||
RegionID serpent.Int64 `json:"region_id" typescript:",notnull"`
|
||||
RegionCode serpent.String `json:"region_code" typescript:",notnull"`
|
||||
RegionName serpent.String `json:"region_name" typescript:",notnull"`
|
||||
STUNAddresses serpent.StringArray `json:"stun_addresses" typescript:",notnull"`
|
||||
RelayURL serpent.URL `json:"relay_url" typescript:",notnull"`
|
||||
RelayURL serpent.URL `json:"relay_url" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type DERPConfig struct {
|
||||
BlockDirect serpent.Bool `json:"block_direct" typescript:",notnull"`
|
||||
BlockDirect serpent.Bool `json:"block_direct" typescript:",notnull"`
|
||||
ForceWebSockets serpent.Bool `json:"force_websockets" typescript:",notnull"`
|
||||
URL serpent.String `json:"url" typescript:",notnull"`
|
||||
Path serpent.String `json:"path" typescript:",notnull"`
|
||||
URL serpent.String `json:"url" typescript:",notnull"`
|
||||
Path serpent.String `json:"path" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type UsageStatsConfig struct {
|
||||
@@ -750,15 +750,15 @@ type StatsCollectionConfig struct {
|
||||
}
|
||||
|
||||
type PrometheusConfig struct {
|
||||
Enable serpent.Bool `json:"enable" typescript:",notnull"`
|
||||
Address serpent.HostPort `json:"address" typescript:",notnull"`
|
||||
CollectAgentStats serpent.Bool `json:"collect_agent_stats" typescript:",notnull"`
|
||||
CollectDBMetrics serpent.Bool `json:"collect_db_metrics" typescript:",notnull"`
|
||||
Enable serpent.Bool `json:"enable" typescript:",notnull"`
|
||||
Address serpent.HostPort `json:"address" typescript:",notnull"`
|
||||
CollectAgentStats serpent.Bool `json:"collect_agent_stats" typescript:",notnull"`
|
||||
CollectDBMetrics serpent.Bool `json:"collect_db_metrics" typescript:",notnull"`
|
||||
AggregateAgentStatsBy serpent.StringArray `json:"aggregate_agent_stats_by" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type PprofConfig struct {
|
||||
Enable serpent.Bool `json:"enable" typescript:",notnull"`
|
||||
Enable serpent.Bool `json:"enable" typescript:",notnull"`
|
||||
Address serpent.HostPort `json:"address" typescript:",notnull"`
|
||||
}
|
||||
|
||||
@@ -767,32 +767,32 @@ type OAuth2Config struct {
|
||||
}
|
||||
|
||||
type OAuth2GithubConfig struct {
|
||||
ClientID serpent.String `json:"client_id" typescript:",notnull"`
|
||||
ClientSecret serpent.String `json:"client_secret" typescript:",notnull"`
|
||||
DeviceFlow serpent.Bool `json:"device_flow" typescript:",notnull"`
|
||||
ClientID serpent.String `json:"client_id" typescript:",notnull"`
|
||||
ClientSecret serpent.String `json:"client_secret" typescript:",notnull"`
|
||||
DeviceFlow serpent.Bool `json:"device_flow" typescript:",notnull"`
|
||||
DefaultProviderEnable serpent.Bool `json:"default_provider_enable" typescript:",notnull"`
|
||||
AllowedOrgs serpent.StringArray `json:"allowed_orgs" typescript:",notnull"`
|
||||
AllowedTeams serpent.StringArray `json:"allowed_teams" typescript:",notnull"`
|
||||
AllowSignups serpent.Bool `json:"allow_signups" typescript:",notnull"`
|
||||
AllowEveryone serpent.Bool `json:"allow_everyone" typescript:",notnull"`
|
||||
EnterpriseBaseURL serpent.String `json:"enterprise_base_url" typescript:",notnull"`
|
||||
AllowedOrgs serpent.StringArray `json:"allowed_orgs" typescript:",notnull"`
|
||||
AllowedTeams serpent.StringArray `json:"allowed_teams" typescript:",notnull"`
|
||||
AllowSignups serpent.Bool `json:"allow_signups" typescript:",notnull"`
|
||||
AllowEveryone serpent.Bool `json:"allow_everyone" typescript:",notnull"`
|
||||
EnterpriseBaseURL serpent.String `json:"enterprise_base_url" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type OIDCConfig struct {
|
||||
AllowSignups serpent.Bool `json:"allow_signups" typescript:",notnull"`
|
||||
ClientID serpent.String `json:"client_id" typescript:",notnull"`
|
||||
ClientID serpent.String `json:"client_id" typescript:",notnull"`
|
||||
ClientSecret serpent.String `json:"client_secret" typescript:",notnull"`
|
||||
// ClientKeyFile & ClientCertFile are used in place of ClientSecret for PKI auth.
|
||||
ClientKeyFile serpent.String `json:"client_key_file" typescript:",notnull"`
|
||||
ClientCertFile serpent.String `json:"client_cert_file" typescript:",notnull"`
|
||||
EmailDomain serpent.StringArray `json:"email_domain" typescript:",notnull"`
|
||||
IssuerURL serpent.String `json:"issuer_url" typescript:",notnull"`
|
||||
Scopes serpent.StringArray `json:"scopes" typescript:",notnull"`
|
||||
ClientKeyFile serpent.String `json:"client_key_file" typescript:",notnull"`
|
||||
ClientCertFile serpent.String `json:"client_cert_file" typescript:",notnull"`
|
||||
EmailDomain serpent.StringArray `json:"email_domain" typescript:",notnull"`
|
||||
IssuerURL serpent.String `json:"issuer_url" typescript:",notnull"`
|
||||
Scopes serpent.StringArray `json:"scopes" typescript:",notnull"`
|
||||
IgnoreEmailVerified serpent.Bool `json:"ignore_email_verified" typescript:",notnull"`
|
||||
UsernameField serpent.String `json:"username_field" typescript:",notnull"`
|
||||
NameField serpent.String `json:"name_field" typescript:",notnull"`
|
||||
EmailField serpent.String `json:"email_field" typescript:",notnull"`
|
||||
AuthURLParams serpent.Struct[map[string]string] `json:"auth_url_params" typescript:",notnull"`
|
||||
UsernameField serpent.String `json:"username_field" typescript:",notnull"`
|
||||
NameField serpent.String `json:"name_field" typescript:",notnull"`
|
||||
EmailField serpent.String `json:"email_field" typescript:",notnull"`
|
||||
AuthURLParams serpent.Struct[map[string]string] `json:"auth_url_params" typescript:",notnull"`
|
||||
// IgnoreUserInfo & UserInfoFromAccessToken are mutually exclusive. Only 1
|
||||
// can be set to true. Ideally this would be an enum with 3 states, ['none',
|
||||
// 'userinfo', 'access_token']. However, for backward compatibility,
|
||||
@@ -804,21 +804,21 @@ type OIDCConfig struct {
|
||||
// endpoint. This assumes the access token is a valid JWT with a set of claims to
|
||||
// be merged with the id_token.
|
||||
UserInfoFromAccessToken serpent.Bool `json:"source_user_info_from_access_token" typescript:",notnull"`
|
||||
OrganizationField serpent.String `json:"organization_field" typescript:",notnull"`
|
||||
OrganizationMapping serpent.Struct[map[string][]uuid.UUID] `json:"organization_mapping" typescript:",notnull"`
|
||||
OrganizationAssignDefault serpent.Bool `json:"organization_assign_default" typescript:",notnull"`
|
||||
GroupAutoCreate serpent.Bool `json:"group_auto_create" typescript:",notnull"`
|
||||
GroupRegexFilter serpent.Regexp `json:"group_regex_filter" typescript:",notnull"`
|
||||
GroupAllowList serpent.StringArray `json:"group_allow_list" typescript:",notnull"`
|
||||
GroupField serpent.String `json:"groups_field" typescript:",notnull"`
|
||||
GroupMapping serpent.Struct[map[string]string] `json:"group_mapping" typescript:",notnull"`
|
||||
UserRoleField serpent.String `json:"user_role_field" typescript:",notnull"`
|
||||
UserRoleMapping serpent.Struct[map[string][]string] `json:"user_role_mapping" typescript:",notnull"`
|
||||
UserRolesDefault serpent.StringArray `json:"user_roles_default" typescript:",notnull"`
|
||||
SignInText serpent.String `json:"sign_in_text" typescript:",notnull"`
|
||||
IconURL serpent.URL `json:"icon_url" typescript:",notnull"`
|
||||
SignupsDisabledText serpent.String `json:"signups_disabled_text" typescript:",notnull"`
|
||||
SkipIssuerChecks serpent.Bool `json:"skip_issuer_checks" typescript:",notnull"`
|
||||
OrganizationField serpent.String `json:"organization_field" typescript:",notnull"`
|
||||
OrganizationMapping serpent.Struct[map[string][]uuid.UUID] `json:"organization_mapping" typescript:",notnull"`
|
||||
OrganizationAssignDefault serpent.Bool `json:"organization_assign_default" typescript:",notnull"`
|
||||
GroupAutoCreate serpent.Bool `json:"group_auto_create" typescript:",notnull"`
|
||||
GroupRegexFilter serpent.Regexp `json:"group_regex_filter" typescript:",notnull"`
|
||||
GroupAllowList serpent.StringArray `json:"group_allow_list" typescript:",notnull"`
|
||||
GroupField serpent.String `json:"groups_field" typescript:",notnull"`
|
||||
GroupMapping serpent.Struct[map[string]string] `json:"group_mapping" typescript:",notnull"`
|
||||
UserRoleField serpent.String `json:"user_role_field" typescript:",notnull"`
|
||||
UserRoleMapping serpent.Struct[map[string][]string] `json:"user_role_mapping" typescript:",notnull"`
|
||||
UserRolesDefault serpent.StringArray `json:"user_roles_default" typescript:",notnull"`
|
||||
SignInText serpent.String `json:"sign_in_text" typescript:",notnull"`
|
||||
IconURL serpent.URL `json:"icon_url" typescript:",notnull"`
|
||||
SignupsDisabledText serpent.String `json:"signups_disabled_text" typescript:",notnull"`
|
||||
SkipIssuerChecks serpent.Bool `json:"skip_issuer_checks" typescript:",notnull"`
|
||||
|
||||
// RedirectURL is optional, defaulting to 'ACCESS_URL'. Only useful in niche
|
||||
// situations where the OIDC callback domain is different from the ACCESS_URL
|
||||
@@ -828,38 +828,38 @@ type OIDCConfig struct {
|
||||
|
||||
type TelemetryConfig struct {
|
||||
Enable serpent.Bool `json:"enable" typescript:",notnull"`
|
||||
Trace serpent.Bool `json:"trace" typescript:",notnull"`
|
||||
URL serpent.URL `json:"url" typescript:",notnull"`
|
||||
Trace serpent.Bool `json:"trace" typescript:",notnull"`
|
||||
URL serpent.URL `json:"url" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type TLSConfig struct {
|
||||
Enable serpent.Bool `json:"enable" typescript:",notnull"`
|
||||
Address serpent.HostPort `json:"address" typescript:",notnull"`
|
||||
RedirectHTTP serpent.Bool `json:"redirect_http" typescript:",notnull"`
|
||||
CertFiles serpent.StringArray `json:"cert_file" typescript:",notnull"`
|
||||
ClientAuth serpent.String `json:"client_auth" typescript:",notnull"`
|
||||
ClientCAFile serpent.String `json:"client_ca_file" typescript:",notnull"`
|
||||
KeyFiles serpent.StringArray `json:"key_file" typescript:",notnull"`
|
||||
MinVersion serpent.String `json:"min_version" typescript:",notnull"`
|
||||
ClientCertFile serpent.String `json:"client_cert_file" typescript:",notnull"`
|
||||
ClientKeyFile serpent.String `json:"client_key_file" typescript:",notnull"`
|
||||
SupportedCiphers serpent.StringArray `json:"supported_ciphers" typescript:",notnull"`
|
||||
Enable serpent.Bool `json:"enable" typescript:",notnull"`
|
||||
Address serpent.HostPort `json:"address" typescript:",notnull"`
|
||||
RedirectHTTP serpent.Bool `json:"redirect_http" typescript:",notnull"`
|
||||
CertFiles serpent.StringArray `json:"cert_file" typescript:",notnull"`
|
||||
ClientAuth serpent.String `json:"client_auth" typescript:",notnull"`
|
||||
ClientCAFile serpent.String `json:"client_ca_file" typescript:",notnull"`
|
||||
KeyFiles serpent.StringArray `json:"key_file" typescript:",notnull"`
|
||||
MinVersion serpent.String `json:"min_version" typescript:",notnull"`
|
||||
ClientCertFile serpent.String `json:"client_cert_file" typescript:",notnull"`
|
||||
ClientKeyFile serpent.String `json:"client_key_file" typescript:",notnull"`
|
||||
SupportedCiphers serpent.StringArray `json:"supported_ciphers" typescript:",notnull"`
|
||||
AllowInsecureCiphers serpent.Bool `json:"allow_insecure_ciphers" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type TraceConfig struct {
|
||||
Enable serpent.Bool `json:"enable" typescript:",notnull"`
|
||||
Enable serpent.Bool `json:"enable" typescript:",notnull"`
|
||||
HoneycombAPIKey serpent.String `json:"honeycomb_api_key" typescript:",notnull"`
|
||||
CaptureLogs serpent.Bool `json:"capture_logs" typescript:",notnull"`
|
||||
DataDog serpent.Bool `json:"data_dog" typescript:",notnull"`
|
||||
CaptureLogs serpent.Bool `json:"capture_logs" typescript:",notnull"`
|
||||
DataDog serpent.Bool `json:"data_dog" typescript:",notnull"`
|
||||
}
|
||||
|
||||
const cookieHostPrefix = "__Host-"
|
||||
|
||||
type HTTPCookieConfig struct {
|
||||
Secure serpent.Bool `json:"secure_auth_cookie,omitempty" typescript:",notnull"`
|
||||
SameSite string `json:"same_site,omitempty" typescript:",notnull"`
|
||||
EnableHostPrefix bool `json:"host_prefix,omitempty" typescript:",notnull"`
|
||||
SameSite string `json:"same_site,omitempty" typescript:",notnull"`
|
||||
EnableHostPrefix bool `json:"host_prefix,omitempty" typescript:",notnull"`
|
||||
}
|
||||
|
||||
// cookiesToPrefix is the set of cookies that should be prefixed with the host prefix if EnableHostPrefix is true.
|
||||
@@ -953,23 +953,23 @@ func (cfg HTTPCookieConfig) HTTPSameSite() http.SameSite {
|
||||
|
||||
type ExternalAuthConfig struct {
|
||||
// Type is the type of external auth config.
|
||||
Type string `json:"type" yaml:"type"`
|
||||
Type string `json:"type" yaml:"type"`
|
||||
ClientID string `json:"client_id" yaml:"client_id"`
|
||||
ClientSecret string `json:"-" yaml:"client_secret"`
|
||||
ClientSecret string `json:"-" yaml:"client_secret"`
|
||||
// ID is a unique identifier for the auth config.
|
||||
// It defaults to `type` when not provided.
|
||||
ID string `json:"id" yaml:"id"`
|
||||
AuthURL string `json:"auth_url" yaml:"auth_url"`
|
||||
TokenURL string `json:"token_url" yaml:"token_url"`
|
||||
ValidateURL string `json:"validate_url" yaml:"validate_url"`
|
||||
RevokeURL string `json:"revoke_url" yaml:"revoke_url"`
|
||||
AppInstallURL string `json:"app_install_url" yaml:"app_install_url"`
|
||||
ID string `json:"id" yaml:"id"`
|
||||
AuthURL string `json:"auth_url" yaml:"auth_url"`
|
||||
TokenURL string `json:"token_url" yaml:"token_url"`
|
||||
ValidateURL string `json:"validate_url" yaml:"validate_url"`
|
||||
RevokeURL string `json:"revoke_url" yaml:"revoke_url"`
|
||||
AppInstallURL string `json:"app_install_url" yaml:"app_install_url"`
|
||||
AppInstallationsURL string `json:"app_installations_url" yaml:"app_installations_url"`
|
||||
NoRefresh bool `json:"no_refresh" yaml:"no_refresh"`
|
||||
Scopes []string `json:"scopes" yaml:"scopes"`
|
||||
ExtraTokenKeys []string `json:"-" yaml:"extra_token_keys"`
|
||||
DeviceFlow bool `json:"device_flow" yaml:"device_flow"`
|
||||
DeviceCodeURL string `json:"device_code_url" yaml:"device_code_url"`
|
||||
NoRefresh bool `json:"no_refresh" yaml:"no_refresh"`
|
||||
Scopes []string `json:"scopes" yaml:"scopes"`
|
||||
ExtraTokenKeys []string `json:"-" yaml:"extra_token_keys"`
|
||||
DeviceFlow bool `json:"device_flow" yaml:"device_flow"`
|
||||
DeviceCodeURL string `json:"device_code_url" yaml:"device_code_url"`
|
||||
// Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.
|
||||
MCPURL string `json:"mcp_url" yaml:"mcp_url"`
|
||||
// Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.
|
||||
@@ -998,17 +998,17 @@ type ExternalAuthConfig struct {
|
||||
|
||||
type ProvisionerConfig struct {
|
||||
// Daemons is the number of built-in terraform provisioners.
|
||||
Daemons serpent.Int64 `json:"daemons" typescript:",notnull"`
|
||||
DaemonTypes serpent.StringArray `json:"daemon_types" typescript:",notnull"`
|
||||
DaemonPollInterval serpent.Duration `json:"daemon_poll_interval" typescript:",notnull"`
|
||||
DaemonPollJitter serpent.Duration `json:"daemon_poll_jitter" typescript:",notnull"`
|
||||
Daemons serpent.Int64 `json:"daemons" typescript:",notnull"`
|
||||
DaemonTypes serpent.StringArray `json:"daemon_types" typescript:",notnull"`
|
||||
DaemonPollInterval serpent.Duration `json:"daemon_poll_interval" typescript:",notnull"`
|
||||
DaemonPollJitter serpent.Duration `json:"daemon_poll_jitter" typescript:",notnull"`
|
||||
ForceCancelInterval serpent.Duration `json:"force_cancel_interval" typescript:",notnull"`
|
||||
DaemonPSK serpent.String `json:"daemon_psk" typescript:",notnull"`
|
||||
DaemonPSK serpent.String `json:"daemon_psk" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type RateLimitConfig struct {
|
||||
DisableAll serpent.Bool `json:"disable_all" typescript:",notnull"`
|
||||
API serpent.Int64 `json:"api" typescript:",notnull"`
|
||||
API serpent.Int64 `json:"api" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type SwaggerConfig struct {
|
||||
@@ -1016,20 +1016,20 @@ type SwaggerConfig struct {
|
||||
}
|
||||
|
||||
type LoggingConfig struct {
|
||||
Filter serpent.StringArray `json:"log_filter" typescript:",notnull"`
|
||||
Human serpent.String `json:"human" typescript:",notnull"`
|
||||
JSON serpent.String `json:"json" typescript:",notnull"`
|
||||
Filter serpent.StringArray `json:"log_filter" typescript:",notnull"`
|
||||
Human serpent.String `json:"human" typescript:",notnull"`
|
||||
JSON serpent.String `json:"json" typescript:",notnull"`
|
||||
Stackdriver serpent.String `json:"stackdriver" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type DangerousConfig struct {
|
||||
AllowPathAppSharing serpent.Bool `json:"allow_path_app_sharing" typescript:",notnull"`
|
||||
AllowPathAppSharing serpent.Bool `json:"allow_path_app_sharing" typescript:",notnull"`
|
||||
AllowPathAppSiteOwnerAccess serpent.Bool `json:"allow_path_app_site_owner_access" typescript:",notnull"`
|
||||
AllowAllCors serpent.Bool `json:"allow_all_cors" typescript:",notnull"`
|
||||
AllowAllCors serpent.Bool `json:"allow_all_cors" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type UserQuietHoursScheduleConfig struct {
|
||||
DefaultSchedule serpent.String `json:"default_schedule" typescript:",notnull"`
|
||||
DefaultSchedule serpent.String `json:"default_schedule" typescript:",notnull"`
|
||||
AllowUserCustom serpent.Bool `json:"allow_user_custom" typescript:",notnull"`
|
||||
// TODO: add WindowDuration and the ability to postpone max_deadline by this
|
||||
// amount
|
||||
@@ -1038,7 +1038,7 @@ type UserQuietHoursScheduleConfig struct {
|
||||
|
||||
// HealthcheckConfig contains configuration for healthchecks.
|
||||
type HealthcheckConfig struct {
|
||||
Refresh serpent.Duration `json:"refresh" typescript:",notnull"`
|
||||
Refresh serpent.Duration `json:"refresh" typescript:",notnull"`
|
||||
ThresholdDatabase serpent.Duration `json:"threshold_database" typescript:",notnull"`
|
||||
}
|
||||
|
||||
@@ -4001,54 +4001,54 @@ Write out the current server config as YAML to stdout.`,
|
||||
}
|
||||
|
||||
type AIBridgeConfig struct {
|
||||
Enabled serpent.Bool `json:"enabled" typescript:",notnull"`
|
||||
OpenAI AIBridgeOpenAIConfig `json:"openai" typescript:",notnull"`
|
||||
Enabled serpent.Bool `json:"enabled" typescript:",notnull"`
|
||||
OpenAI AIBridgeOpenAIConfig `json:"openai" typescript:",notnull"`
|
||||
Anthropic AIBridgeAnthropicConfig `json:"anthropic" typescript:",notnull"`
|
||||
Bedrock AIBridgeBedrockConfig `json:"bedrock" typescript:",notnull"`
|
||||
Bedrock AIBridgeBedrockConfig `json:"bedrock" typescript:",notnull"`
|
||||
// Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.
|
||||
InjectCoderMCPTools serpent.Bool `json:"inject_coder_mcp_tools" typescript:",notnull"`
|
||||
Retention serpent.Duration `json:"retention" typescript:",notnull"`
|
||||
MaxConcurrency serpent.Int64 `json:"max_concurrency" typescript:",notnull"`
|
||||
RateLimit serpent.Int64 `json:"rate_limit" typescript:",notnull"`
|
||||
StructuredLogging serpent.Bool `json:"structured_logging" typescript:",notnull"`
|
||||
SendActorHeaders serpent.Bool `json:"send_actor_headers" typescript:",notnull"`
|
||||
Retention serpent.Duration `json:"retention" typescript:",notnull"`
|
||||
MaxConcurrency serpent.Int64 `json:"max_concurrency" typescript:",notnull"`
|
||||
RateLimit serpent.Int64 `json:"rate_limit" typescript:",notnull"`
|
||||
StructuredLogging serpent.Bool `json:"structured_logging" typescript:",notnull"`
|
||||
SendActorHeaders serpent.Bool `json:"send_actor_headers" typescript:",notnull"`
|
||||
// Circuit breaker protects against cascading failures from upstream AI
|
||||
// provider rate limits (429, 503, 529 overloaded).
|
||||
CircuitBreakerEnabled serpent.Bool `json:"circuit_breaker_enabled" typescript:",notnull"`
|
||||
CircuitBreakerEnabled serpent.Bool `json:"circuit_breaker_enabled" typescript:",notnull"`
|
||||
CircuitBreakerFailureThreshold serpent.Int64 `json:"circuit_breaker_failure_threshold" typescript:",notnull"`
|
||||
CircuitBreakerInterval serpent.Duration `json:"circuit_breaker_interval" typescript:",notnull"`
|
||||
CircuitBreakerTimeout serpent.Duration `json:"circuit_breaker_timeout" typescript:",notnull"`
|
||||
CircuitBreakerMaxRequests serpent.Int64 `json:"circuit_breaker_max_requests" typescript:",notnull"`
|
||||
CircuitBreakerInterval serpent.Duration `json:"circuit_breaker_interval" typescript:",notnull"`
|
||||
CircuitBreakerTimeout serpent.Duration `json:"circuit_breaker_timeout" typescript:",notnull"`
|
||||
CircuitBreakerMaxRequests serpent.Int64 `json:"circuit_breaker_max_requests" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type AIBridgeOpenAIConfig struct {
|
||||
BaseURL serpent.String `json:"base_url" typescript:",notnull"`
|
||||
Key serpent.String `json:"key" typescript:",notnull"`
|
||||
Key serpent.String `json:"key" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type AIBridgeAnthropicConfig struct {
|
||||
BaseURL serpent.String `json:"base_url" typescript:",notnull"`
|
||||
Key serpent.String `json:"key" typescript:",notnull"`
|
||||
Key serpent.String `json:"key" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type AIBridgeBedrockConfig struct {
|
||||
BaseURL serpent.String `json:"base_url" typescript:",notnull"`
|
||||
Region serpent.String `json:"region" typescript:",notnull"`
|
||||
AccessKey serpent.String `json:"access_key" typescript:",notnull"`
|
||||
BaseURL serpent.String `json:"base_url" typescript:",notnull"`
|
||||
Region serpent.String `json:"region" typescript:",notnull"`
|
||||
AccessKey serpent.String `json:"access_key" typescript:",notnull"`
|
||||
AccessKeySecret serpent.String `json:"access_key_secret" typescript:",notnull"`
|
||||
Model serpent.String `json:"model" typescript:",notnull"`
|
||||
SmallFastModel serpent.String `json:"small_fast_model" typescript:",notnull"`
|
||||
Model serpent.String `json:"model" typescript:",notnull"`
|
||||
SmallFastModel serpent.String `json:"small_fast_model" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type AIBridgeProxyConfig struct {
|
||||
Enabled serpent.Bool `json:"enabled" typescript:",notnull"`
|
||||
ListenAddr serpent.String `json:"listen_addr" typescript:",notnull"`
|
||||
TLSCertFile serpent.String `json:"tls_cert_file" typescript:",notnull"`
|
||||
TLSKeyFile serpent.String `json:"tls_key_file" typescript:",notnull"`
|
||||
MITMCertFile serpent.String `json:"cert_file" typescript:",notnull"`
|
||||
MITMKeyFile serpent.String `json:"key_file" typescript:",notnull"`
|
||||
DomainAllowlist serpent.StringArray `json:"domain_allowlist" typescript:",notnull"`
|
||||
UpstreamProxy serpent.String `json:"upstream_proxy" typescript:",notnull"`
|
||||
Enabled serpent.Bool `json:"enabled" typescript:",notnull"`
|
||||
ListenAddr serpent.String `json:"listen_addr" typescript:",notnull"`
|
||||
TLSCertFile serpent.String `json:"tls_cert_file" typescript:",notnull"`
|
||||
TLSKeyFile serpent.String `json:"tls_key_file" typescript:",notnull"`
|
||||
MITMCertFile serpent.String `json:"cert_file" typescript:",notnull"`
|
||||
MITMKeyFile serpent.String `json:"key_file" typescript:",notnull"`
|
||||
DomainAllowlist serpent.StringArray `json:"domain_allowlist" typescript:",notnull"`
|
||||
UpstreamProxy serpent.String `json:"upstream_proxy" typescript:",notnull"`
|
||||
UpstreamProxyCA serpent.String `json:"upstream_proxy_ca" typescript:",notnull"`
|
||||
}
|
||||
|
||||
@@ -4062,11 +4062,11 @@ type SupportConfig struct {
|
||||
}
|
||||
|
||||
type LinkConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
Target string `json:"target" yaml:"target"`
|
||||
Icon string `json:"icon" enums:"bug,chat,docs,star" yaml:"icon"`
|
||||
Name string `json:"name" yaml:"name"`
|
||||
Target string `json:"target" yaml:"target"`
|
||||
Icon string `json:"icon" yaml:"icon" enums:"bug,chat,docs,star"`
|
||||
|
||||
Location string `json:"location,omitempty" enums:"navbar,dropdown" yaml:"location,omitempty"`
|
||||
Location string `json:"location,omitempty" yaml:"location,omitempty" enums:"navbar,dropdown"`
|
||||
}
|
||||
|
||||
// Validate checks cross-field constraints for deployment values.
|
||||
@@ -4574,7 +4574,7 @@ type CryptoKey struct {
|
||||
Secret string `json:"secret"`
|
||||
DeletesAt time.Time `json:"deletes_at" format:"date-time"`
|
||||
Sequence int32 `json:"sequence"`
|
||||
StartsAt time.Time `json:"starts_at" format:"date-time"`
|
||||
StartsAt time.Time `json:"starts_at" format:"date-time"`
|
||||
}
|
||||
|
||||
func (c CryptoKey) CanSign(now time.Time) bool {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user