Compare commits

..

3 Commits

Author SHA1 Message Date
Jon Ayers 801b467b75 make gen 2026-03-17 18:56:29 +00:00
Jon Ayers 192b25e30a fix(dbauthz): implement GetProvisionerJobsByIDs instead of panic stub 2026-03-17 00:11:01 +00:00
Jon Ayers 8c7111fe4a perf(coderd): split queue position query and parallelize workspaceBuildsData
Add GetProvisionerJobsByIDs query that fetches jobs without the expensive
queue position cross-join. Restructure workspaceBuildsData() to:

- Phase A: Run jobs, template versions, and resources queries in parallel
- Phase B: Only call GetProvisionerJobsByIDsWithQueuePosition for pending
  jobs (typically 0-5 instead of thousands), then merge results
- Phase C: Run metadata and agents queries in parallel

This dramatically reduces query time for 'coder list' with thousands of
workspaces since the expensive queue position query now processes only
pending jobs instead of all jobs.
2026-03-17 00:07:36 +00:00
243 changed files with 5328 additions and 14858 deletions
-10
View File
@@ -98,15 +98,6 @@ linters-settings:
# stdlib port.
checks: ["all", "-SA1019"]
tagalign:
align: true
sort: true
strict: true
order:
- json
- db
- table
goimports:
local-prefixes: coder.com,cdr.dev,go.coder.com,github.com/cdr,github.com/coder
@@ -271,7 +262,6 @@ linters:
# - wastedassign
- staticcheck
- tagalign
# In Go, it's possible for a package to test it's internal functionality
# without testing any exported functions. This is enabled to promote
# decomposing a package before testing it's internals. A function caller
+11 -6
View File
@@ -136,10 +136,18 @@ endif
# the search path so that these exclusions match.
FIND_EXCLUSIONS= \
-not \( \( -path '*/.git/*' -o -path './build/*' -o -path './vendor/*' -o -path './.coderv2/*' -o -path '*/node_modules/*' -o -path '*/out/*' -o -path './coderd/apidoc/*' -o -path '*/.next/*' -o -path '*/.terraform/*' -o -path './_gen/*' \) -prune \)
# Source files used for make targets, evaluated on use.
GO_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.go' -not -name '*_test.go')
# Same as GO_SRC_FILES but excluding certain files that have problematic
# Makefile dependencies (e.g. pnpm).
MOST_GO_SRC_FILES := $(shell \
find . \
$(FIND_EXCLUSIONS) \
-type f \
-name '*.go' \
-not -name '*_test.go' \
-not -wholename './agent/agentcontainers/dcspec/dcspec_gen.go' \
)
# All the shell files in the repo, excluding ignored files.
SHELL_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.sh')
@@ -506,10 +514,7 @@ install: build/coder_$(VERSION)_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
cp "$<" "$$output_file"
.PHONY: install
# Only wildcard the go files in the develop directory to avoid rebuilds
# when project files are changd. Technically changes to some imports may
# not be detected, but it's unlikely to cause any issues.
build/.bin/develop: go.mod go.sum $(wildcard scripts/develop/*.go)
build/.bin/develop: go.mod go.sum $(GO_SRC_FILES)
CGO_ENABLED=0 go build -o $@ ./scripts/develop
BOLD := $(shell tput bold 2>/dev/null)
+1 -1
View File
@@ -389,7 +389,7 @@ func (a *agent) init() {
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
desktop := agentdesktop.NewPortableDesktop(
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(),
a.logger.Named("desktop"), a.execer, a.scriptDataDir,
)
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
a.reconnectingPTYServer = reconnectingpty.NewServer(
+169 -24
View File
@@ -2,9 +2,13 @@ package agentdesktop
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
@@ -20,6 +24,28 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
portableDesktopVersion = "v0.0.4"
downloadRetries = 3
downloadRetryDelay = time.Second
)
// platformBinaries maps GOARCH to download URL and expected SHA-256
// digest for each supported platform.
var platformBinaries = map[string]struct {
URL string
SHA256 string
}{
"amd64": {
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-x64",
SHA256: "a04e05e6c7d6f2e6b3acbf1729a7b21271276300b4fee321f4ffee6136538317",
},
"arm64": {
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-arm64",
SHA256: "b8cb9142dc32d46a608f25229cbe8168ff2a3aadc54253c74ff54cd347e16ca6",
},
}
// portableDesktopOutput is the JSON output from
// `portabledesktop up --json`.
type portableDesktopOutput struct {
@@ -52,31 +78,43 @@ type screenshotOutput struct {
// portableDesktop implements Desktop by shelling out to the
// portabledesktop CLI via agentexec.Execer.
type portableDesktop struct {
logger slog.Logger
execer agentexec.Execer
scriptBinDir string // coder script bin directory
logger slog.Logger
execer agentexec.Execer
dataDir string // agent's ScriptDataDir, used for binary caching
mu sync.Mutex
session *desktopSession // nil until started
binPath string // resolved path to binary, cached
closed bool
// httpClient is used for downloading the binary. If nil,
// http.DefaultClient is used.
httpClient *http.Client
}
// NewPortableDesktop creates a Desktop backed by the portabledesktop
// CLI binary, using execer to spawn child processes. scriptBinDir is
// the coder script bin directory checked for the binary.
// CLI binary, using execer to spawn child processes. dataDir is used
// to cache the downloaded binary.
func NewPortableDesktop(
logger slog.Logger,
execer agentexec.Execer,
scriptBinDir string,
dataDir string,
) Desktop {
return &portableDesktop{
logger: logger,
execer: execer,
scriptBinDir: scriptBinDir,
logger: logger,
execer: execer,
dataDir: dataDir,
}
}
// httpDo returns the HTTP client to use for downloads.
func (p *portableDesktop) httpDo() *http.Client {
if p.httpClient != nil {
return p.httpClient
}
return http.DefaultClient
}
// Start launches the desktop session (idempotent).
func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
p.mu.Lock()
@@ -361,8 +399,8 @@ func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, e
return string(out), nil
}
// ensureBinary resolves the portabledesktop binary from PATH or the
// coder script bin directory. It must be called while p.mu is held.
// ensureBinary resolves or downloads the portabledesktop binary. It
// must be called while p.mu is held.
func (p *portableDesktop) ensureBinary(ctx context.Context) error {
if p.binPath != "" {
return nil
@@ -377,23 +415,130 @@ func (p *portableDesktop) ensureBinary(ctx context.Context) error {
return nil
}
// 2. Check the coder script bin directory.
scriptBinPath := filepath.Join(p.scriptBinDir, "portabledesktop")
if info, err := os.Stat(scriptBinPath); err == nil && !info.IsDir() {
// On Windows, permission bits don't indicate executability,
// so accept any regular file.
if runtime.GOOS == "windows" || info.Mode()&0o111 != 0 {
p.logger.Info(ctx, "found portabledesktop in script bin directory",
slog.F("path", scriptBinPath),
// 2. Platform checks.
if runtime.GOOS != "linux" {
return xerrors.New("portabledesktop is only supported on Linux")
}
bin, ok := platformBinaries[runtime.GOARCH]
if !ok {
return xerrors.Errorf("unsupported architecture for portabledesktop: %s", runtime.GOARCH)
}
// 3. Check cache.
cacheDir := filepath.Join(p.dataDir, "portabledesktop", bin.SHA256)
cachedPath := filepath.Join(cacheDir, "portabledesktop")
if info, err := os.Stat(cachedPath); err == nil && !info.IsDir() {
// Verify it is executable.
if info.Mode()&0o100 != 0 {
p.logger.Info(ctx, "using cached portabledesktop binary",
slog.F("path", cachedPath),
)
p.binPath = scriptBinPath
p.binPath = cachedPath
return nil
}
p.logger.Warn(ctx, "portabledesktop found in script bin directory but not executable",
slog.F("path", scriptBinPath),
slog.F("mode", info.Mode().String()),
}
// 4. Download with retry.
p.logger.Info(ctx, "downloading portabledesktop binary",
slog.F("url", bin.URL),
slog.F("version", portableDesktopVersion),
slog.F("arch", runtime.GOARCH),
)
var lastErr error
for attempt := range downloadRetries {
if err := downloadBinary(ctx, p.httpDo(), bin.URL, bin.SHA256, cachedPath); err != nil {
lastErr = err
p.logger.Warn(ctx, "download attempt failed",
slog.F("attempt", attempt+1),
slog.F("max_attempts", downloadRetries),
slog.Error(err),
)
if attempt < downloadRetries-1 {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(downloadRetryDelay):
}
}
continue
}
p.binPath = cachedPath
p.logger.Info(ctx, "downloaded portabledesktop binary",
slog.F("path", cachedPath),
)
return nil
}
return xerrors.Errorf("download portabledesktop after %d attempts: %w", downloadRetries, lastErr)
}
// downloadBinary fetches a binary from url, verifies its SHA-256
// digest matches expectedSHA256, and atomically writes it to destPath.
func downloadBinary(ctx context.Context, client *http.Client, url, expectedSHA256, destPath string) error {
if err := os.MkdirAll(filepath.Dir(destPath), 0o700); err != nil {
return xerrors.Errorf("create cache directory: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return xerrors.Errorf("create HTTP request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return xerrors.Errorf("HTTP GET %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return xerrors.Errorf("HTTP GET %s: status %d", url, resp.StatusCode)
}
// Write to a temp file in the same directory so the final rename
// is atomic on the same filesystem.
tmpFile, err := os.CreateTemp(filepath.Dir(destPath), "portabledesktop-download-*")
if err != nil {
return xerrors.Errorf("create temp file: %w", err)
}
tmpPath := tmpFile.Name()
// Clean up the temp file on any error path.
success := false
defer func() {
if !success {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
}
}()
// Stream the response body while computing SHA-256.
hasher := sha256.New()
if _, err := io.Copy(tmpFile, io.TeeReader(resp.Body, hasher)); err != nil {
return xerrors.Errorf("download body: %w", err)
}
if err := tmpFile.Close(); err != nil {
return xerrors.Errorf("close temp file: %w", err)
}
// Verify digest.
actualSHA256 := hex.EncodeToString(hasher.Sum(nil))
if actualSHA256 != expectedSHA256 {
return xerrors.Errorf(
"SHA-256 mismatch: expected %s, got %s",
expectedSHA256, actualSHA256,
)
}
return xerrors.New("portabledesktop binary not found in PATH or script bin directory")
if err := os.Chmod(tmpPath, 0o700); err != nil {
return xerrors.Errorf("chmod: %w", err)
}
if err := os.Rename(tmpPath, destPath); err != nil {
return xerrors.Errorf("rename to final path: %w", err)
}
success = true
return nil
}
@@ -2,6 +2,11 @@ package agentdesktop
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"path/filepath"
@@ -72,6 +77,7 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
// The "up" script prints the JSON line then sleeps until
// the context is canceled (simulating a long-running process).
@@ -82,13 +88,13 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
}
ctx := t.Context()
ctx := context.Background()
cfg, err := pd.Start(ctx)
require.NoError(t, err)
@@ -105,6 +111,7 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
@@ -113,13 +120,13 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
}
ctx := t.Context()
ctx := context.Background()
cfg1, err := pd.Start(ctx)
require.NoError(t, err)
@@ -147,6 +154,7 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
@@ -155,13 +163,13 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
}
ctx := t.Context()
ctx := context.Background()
result, err := pd.Screenshot(ctx, ScreenshotOptions{})
require.NoError(t, err)
@@ -172,6 +180,7 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
@@ -180,13 +189,13 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
}
ctx := t.Context()
ctx := context.Background()
_, err := pd.Screenshot(ctx, ScreenshotOptions{
TargetWidth: 800,
TargetHeight: 600,
@@ -278,13 +287,13 @@ func TestPortableDesktop_MouseMethods(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
err := tt.invoke(t.Context(), pd)
err := tt.invoke(context.Background(), pd)
require.NoError(t, err)
cmds := rec.allCommands()
@@ -363,13 +372,13 @@ func TestPortableDesktop_KeyboardMethods(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
err := tt.invoke(t.Context(), pd)
err := tt.invoke(context.Background(), pd)
require.NoError(t, err)
cmds := rec.allCommands()
@@ -395,13 +404,13 @@ func TestPortableDesktop_CursorPosition(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
x, y, err := pd.CursorPosition(t.Context())
x, y, err := pd.CursorPosition(context.Background())
require.NoError(t, err)
assert.Equal(t, 100, x)
assert.Equal(t, 200, y)
@@ -419,13 +428,13 @@ func TestPortableDesktop_Close(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
ctx := t.Context()
ctx := context.Background()
_, err := pd.Start(ctx)
require.NoError(t, err)
@@ -448,6 +457,81 @@ func TestPortableDesktop_Close(t *testing.T) {
assert.Contains(t, err.Error(), "desktop is closed")
}
// --- downloadBinary tests ---
func TestDownloadBinary_Success(t *testing.T) {
t.Parallel()
binaryContent := []byte("#!/bin/sh\necho portable\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
err := downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
require.NoError(t, err)
// Verify the file exists and has correct content.
got, err := os.ReadFile(destPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
// Verify executable permissions.
info, err := os.Stat(destPath)
require.NoError(t, err)
assert.NotZero(t, info.Mode()&0o700, "binary should be executable")
}
func TestDownloadBinary_ChecksumMismatch(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("real binary content"))
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
wrongSHA := "0000000000000000000000000000000000000000000000000000000000000000"
err := downloadBinary(context.Background(), srv.Client(), srv.URL, wrongSHA, destPath)
require.Error(t, err)
assert.Contains(t, err.Error(), "SHA-256 mismatch")
// The destination file should not exist (temp file cleaned up).
_, statErr := os.Stat(destPath)
assert.True(t, os.IsNotExist(statErr), "dest file should not exist after checksum failure")
// No leftover temp files in the directory.
entries, err := os.ReadDir(destDir)
require.NoError(t, err)
assert.Empty(t, entries, "no leftover temp files should remain")
}
func TestDownloadBinary_HTTPError(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
err := downloadBinary(context.Background(), srv.Client(), srv.URL, "irrelevant", destPath)
require.Error(t, err)
assert.Contains(t, err.Error(), "status 404")
}
// --- ensureBinary tests ---
func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
@@ -457,89 +541,173 @@ func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
// immediately without doing any work.
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: t.TempDir(),
binPath: "/already/set",
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: t.TempDir(),
binPath: "/already/set",
}
err := pd.ensureBinary(t.Context())
err := pd.ensureBinary(context.Background())
require.NoError(t, err)
assert.Equal(t, "/already/set", pd.binPath)
}
func TestEnsureBinary_UsesScriptBinDir(t *testing.T) {
func TestEnsureBinary_UsesCachedBinary(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
}
scriptBinDir := t.TempDir()
binPath := filepath.Join(scriptBinDir, "portabledesktop")
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
require.NoError(t, os.Chmod(binPath, 0o755))
bin, ok := platformBinaries[runtime.GOARCH]
if !ok {
t.Skipf("no platformBinary entry for %s", runtime.GOARCH)
}
dataDir := t.TempDir()
cacheDir := filepath.Join(dataDir, "portabledesktop", bin.SHA256)
require.NoError(t, os.MkdirAll(cacheDir, 0o700))
cachedPath := filepath.Join(cacheDir, "portabledesktop")
require.NoError(t, os.WriteFile(cachedPath, []byte("#!/bin/sh\n"), 0o600))
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: scriptBinDir,
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: dataDir,
}
// Clear PATH so LookPath won't find a real binary.
t.Setenv("PATH", "")
err := pd.ensureBinary(t.Context())
err := pd.ensureBinary(context.Background())
require.NoError(t, err)
assert.Equal(t, binPath, pd.binPath)
assert.Equal(t, cachedPath, pd.binPath)
}
func TestEnsureBinary_ScriptBinDirNotExecutable(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Windows does not support Unix permission bits")
}
func TestEnsureBinary_Downloads(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
// environment and we override the package-level platformBinaries.
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
}
scriptBinDir := t.TempDir()
binPath := filepath.Join(scriptBinDir, "portabledesktop")
// Write without execute permission.
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
_ = binPath
binaryContent := []byte("#!/bin/sh\necho downloaded\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
// Save and restore platformBinaries for this test.
origBinaries := platformBinaries
platformBinaries = map[string]struct {
URL string
SHA256 string
}{
runtime.GOARCH: {
URL: srv.URL + "/portabledesktop",
SHA256: expectedSHA,
},
}
t.Cleanup(func() { platformBinaries = origBinaries })
dataDir := t.TempDir()
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: scriptBinDir,
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: dataDir,
httpClient: srv.Client(),
}
// Clear PATH so LookPath won't find a real binary.
// Ensure PATH doesn't contain a real portabledesktop binary.
t.Setenv("PATH", "")
err := pd.ensureBinary(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
err := pd.ensureBinary(context.Background())
require.NoError(t, err)
expectedPath := filepath.Join(dataDir, "portabledesktop", expectedSHA, "portabledesktop")
assert.Equal(t, expectedPath, pd.binPath)
// Verify the downloaded file has correct content.
got, err := os.ReadFile(expectedPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
}
func TestEnsureBinary_NotFound(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
func TestEnsureBinary_RetriesOnFailure(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: t.TempDir(), // empty directory
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
}
// Clear PATH so LookPath won't find a real binary.
t.Setenv("PATH", "")
binaryContent := []byte("#!/bin/sh\necho retried\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
err := pd.ensureBinary(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
var mu sync.Mutex
attempt := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
mu.Lock()
current := attempt
attempt++
mu.Unlock()
// Fail the first 2 attempts, succeed on the third.
if current < 2 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
// Test downloadBinary directly to avoid time.Sleep in
// ensureBinary's retry loop. We call it 3 times to simulate
// what ensureBinary would do.
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
var lastErr error
for i := range 3 {
lastErr = downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
if lastErr == nil {
break
}
if i < 2 {
// In the real code, ensureBinary sleeps here.
// We skip the sleep in tests.
continue
}
}
require.NoError(t, lastErr, "download should succeed on the third attempt")
got, err := os.ReadFile(destPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
mu.Lock()
assert.Equal(t, 3, attempt, "server should have been hit 3 times")
mu.Unlock()
}
// Ensure that portableDesktop satisfies the Desktop interface at
// compile time. This uses the unexported type so it lives in the
// internal test package.
var _ Desktop = (*portableDesktop)(nil)
// Silence the linter about unused imports — agentexec.DefaultExecer
// is used in TestEnsureBinary_UsesCachedBinPath and others, and
// fmt.Sscanf is used indirectly via the implementation.
var (
_ = agentexec.DefaultExecer
_ = fmt.Sprintf
)
+8 -8
View File
@@ -140,16 +140,16 @@ func (c *Client) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateAppS
// SyncStatusResponse contains the status information for a unit.
type SyncStatusResponse struct {
UnitName unit.ID `json:"unit_name" table:"unit,default_sort"`
Status unit.Status `json:"status" table:"status"`
IsReady bool `json:"is_ready" table:"ready"`
Dependencies []DependencyInfo `json:"dependencies" table:"dependencies"`
UnitName unit.ID `table:"unit,default_sort" json:"unit_name"`
Status unit.Status `table:"status" json:"status"`
IsReady bool `table:"ready" json:"is_ready"`
Dependencies []DependencyInfo `table:"dependencies" json:"dependencies"`
}
// DependencyInfo contains information about a unit dependency.
type DependencyInfo struct {
DependsOn unit.ID `json:"depends_on" table:"depends on,default_sort"`
RequiredStatus unit.Status `json:"required_status" table:"required status"`
CurrentStatus unit.Status `json:"current_status" table:"current status"`
IsSatisfied bool `json:"is_satisfied" table:"satisfied"`
DependsOn unit.ID `table:"depends on,default_sort" json:"depends_on"`
RequiredStatus unit.Status `table:"required status" json:"required_status"`
CurrentStatus unit.Status `table:"current status" json:"current_status"`
IsSatisfied bool `table:"satisfied" json:"is_satisfied"`
}
+2 -2
View File
@@ -17,8 +17,8 @@ import (
func NewLicenseFormatter() *cliui.OutputFormatter {
type tableLicense struct {
ID int32 `table:"id,default_sort"`
UUID uuid.UUID `table:"uuid" format:"uuid"`
UploadedAt time.Time `table:"uploaded at" format:"date-time"`
UUID uuid.UUID `table:"uuid" format:"uuid"`
UploadedAt time.Time `table:"uploaded at" format:"date-time"`
// Features is the formatted string for the license claims.
// Used for the table view.
Features string `table:"features"`
+2 -2
View File
@@ -214,7 +214,7 @@ func (r *RootCmd) createOrganizationRole(orgContext *OrganizationContext) *serpe
} else {
updated, err = client.CreateOrganizationRole(ctx, customRole)
if err != nil {
return xerrors.Errorf("create role: %w", err)
return xerrors.Errorf("patch role: %w", err)
}
}
@@ -524,7 +524,7 @@ type roleTableRow struct {
Name string `table:"name,default_sort"`
DisplayName string `table:"display name"`
OrganizationID string `table:"organization id"`
SitePermissions string `table:"site permissions"`
SitePermissions string ` table:"site permissions"`
// map[<org_id>] -> Permissions
OrganizationPermissions string `table:"organization permissions"`
UserPermissions string `table:"user permissions"`
+2 -2
View File
@@ -32,9 +32,9 @@ func (r *RootCmd) provisionerJobs() *serpent.Command {
func (r *RootCmd) provisionerJobsList() *serpent.Command {
type provisionerJobRow struct {
codersdk.ProvisionerJob ` table:"provisioner_job,recursive_inline,nosort"`
codersdk.ProvisionerJob `table:"provisioner_job,recursive_inline,nosort"`
OrganizationName string `json:"organization_name" table:"organization"`
Queue string `json:"-" table:"queue"`
Queue string `json:"-" table:"queue"`
}
var (
+1 -1
View File
@@ -31,7 +31,7 @@ func (r *RootCmd) Provisioners() *serpent.Command {
func (r *RootCmd) provisionerList() *serpent.Command {
type provisionerDaemonRow struct {
codersdk.ProvisionerDaemon ` table:"provisioner_daemon,recursive_inline"`
codersdk.ProvisionerDaemon `table:"provisioner_daemon,recursive_inline"`
OrganizationName string `json:"organization_name" table:"organization"`
}
var (
+3 -3
View File
@@ -350,11 +350,11 @@ func displaySchedule(ws codersdk.Workspace, out io.Writer) error {
// scheduleListRow is a row in the schedule list.
// this is required for proper JSON output.
type scheduleListRow struct {
WorkspaceName string `json:"workspace" table:"workspace,default_sort"`
StartsAt string `json:"starts_at" table:"starts at"`
WorkspaceName string `json:"workspace" table:"workspace,default_sort"`
StartsAt string `json:"starts_at" table:"starts at"`
StartsNext string `json:"starts_next" table:"starts next"`
StopsAfter string `json:"stops_after" table:"stops after"`
StopsNext string `json:"stops_next" table:"stops next"`
StopsNext string `json:"stops_next" table:"stops next"`
}
func scheduleListRowFromWorkspace(now time.Time, workspace codersdk.Workspace) scheduleListRow {
+4 -4
View File
@@ -284,9 +284,9 @@ func (*RootCmd) statDisk(fs afero.Fs) *serpent.Command {
}
type statsRow struct {
HostCPU *clistat.Result `json:"host_cpu" table:"host cpu,default_sort"`
HostMemory *clistat.Result `json:"host_memory" table:"host memory"`
Disk *clistat.Result `json:"home_disk" table:"home disk"`
ContainerCPU *clistat.Result `json:"container_cpu" table:"container cpu"`
HostCPU *clistat.Result `json:"host_cpu" table:"host cpu,default_sort"`
HostMemory *clistat.Result `json:"host_memory" table:"host memory"`
Disk *clistat.Result `json:"home_disk" table:"home disk"`
ContainerCPU *clistat.Result `json:"container_cpu" table:"container cpu"`
ContainerMemory *clistat.Result `json:"container_memory" table:"container memory"`
}
+1 -1
View File
@@ -155,7 +155,7 @@ func taskWatchIsEnded(task codersdk.Task) bool {
}
type taskStatusRow struct {
codersdk.Task ` table:"r,recursive_inline"`
codersdk.Task `table:"r,recursive_inline"`
ChangedAgo string `json:"-" table:"state changed"`
Healthy bool `json:"-" table:"healthy"`
}
+1 -1
View File
@@ -138,7 +138,7 @@ func (r *RootCmd) templateVersionsList() *serpent.Command {
type templateVersionRow struct {
// For json format:
TemplateVersion codersdk.TemplateVersion ` table:"-"`
TemplateVersion codersdk.TemplateVersion `table:"-"`
ActiveJSON bool `json:"active" table:"-"`
// For table format:
-4
View File
@@ -24,10 +24,6 @@ OPTIONS:
-p, --password string
Specifies a password for the new user.
--service-account bool
Create a user account intended to be used by a service or as an
intermediary rather than by a human.
-u, --username string
Specifies a username for the new user.
+7 -7
View File
@@ -161,14 +161,14 @@ type tokenListRow struct {
codersdk.APIKey `table:"-"`
// For table format:
ID string `json:"-" table:"id,default_sort"`
ID string `json:"-" table:"id,default_sort"`
TokenName string `json:"token_name" table:"name"`
Scopes string `json:"-" table:"scopes"`
Allow string `json:"-" table:"allow list"`
LastUsed time.Time `json:"-" table:"last used"`
ExpiresAt time.Time `json:"-" table:"expires at"`
CreatedAt time.Time `json:"-" table:"created at"`
Owner string `json:"-" table:"owner"`
Scopes string `json:"-" table:"scopes"`
Allow string `json:"-" table:"allow list"`
LastUsed time.Time `json:"-" table:"last used"`
ExpiresAt time.Time `json:"-" table:"expires at"`
CreatedAt time.Time `json:"-" table:"created at"`
Owner string `json:"-" table:"owner"`
}
func tokenListRowFromToken(token codersdk.APIKeyWithOwner) tokenListRow {
+12 -37
View File
@@ -17,14 +17,13 @@ import (
func (r *RootCmd) userCreate() *serpent.Command {
var (
email string
username string
name string
password string
disableLogin bool
loginType string
serviceAccount bool
orgContext = NewOrganizationContext()
email string
username string
name string
password string
disableLogin bool
loginType string
orgContext = NewOrganizationContext()
)
cmd := &serpent.Command{
Use: "create",
@@ -33,23 +32,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
serpent.RequireNArgs(0),
),
Handler: func(inv *serpent.Invocation) error {
if serviceAccount {
switch {
case loginType != "":
return xerrors.New("You cannot use --login-type with --service-account")
case password != "":
return xerrors.New("You cannot use --password with --service-account")
case email != "":
return xerrors.New("You cannot use --email with --service-account")
case disableLogin:
return xerrors.New("You cannot use --disable-login with --service-account")
}
}
if disableLogin && loginType != "" {
return xerrors.New("You cannot specify both --disable-login and --login-type")
}
client, err := r.InitClient(inv)
if err != nil {
return err
@@ -77,7 +59,7 @@ func (r *RootCmd) userCreate() *serpent.Command {
return err
}
}
if email == "" && !serviceAccount {
if email == "" {
email, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "Email:",
Validate: func(s string) error {
@@ -105,7 +87,10 @@ func (r *RootCmd) userCreate() *serpent.Command {
}
}
userLoginType := codersdk.LoginTypePassword
if disableLogin || serviceAccount {
if disableLogin && loginType != "" {
return xerrors.New("You cannot specify both --disable-login and --login-type")
}
if disableLogin {
userLoginType = codersdk.LoginTypeNone
} else if loginType != "" {
userLoginType = codersdk.LoginType(loginType)
@@ -126,7 +111,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
Password: password,
OrganizationIDs: []uuid.UUID{organization.ID},
UserLoginType: userLoginType,
ServiceAccount: serviceAccount,
})
if err != nil {
return err
@@ -143,10 +127,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
case codersdk.LoginTypeOIDC:
authenticationMethod = `Login is authenticated through the configured OIDC provider.`
}
if serviceAccount {
email = "n/a"
authenticationMethod = "Service accounts must authenticate with a token and cannot log in."
}
_, _ = fmt.Fprintln(inv.Stderr, `A new user has been created!
Share the instructions below to get them started.
@@ -214,11 +194,6 @@ Create a workspace `+pretty.Sprint(cliui.DefaultStyles.Code, "coder create")+`!
)),
Value: serpent.StringOf(&loginType),
},
{
Flag: "service-account",
Description: "Create a user account intended to be used by a service or as an intermediary rather than by a human.",
Value: serpent.BoolOf(&serviceAccount),
},
}
orgContext.AttachOptions(cmd)
-53
View File
@@ -8,7 +8,6 @@ import (
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
@@ -125,56 +124,4 @@ func TestUserCreate(t *testing.T) {
assert.Equal(t, args[5], created.Username)
assert.Empty(t, created.Name)
})
tests := []struct {
name string
args []string
err string
}{
{
name: "ServiceAccount",
args: []string{"--service-account", "-u", "dean"},
},
{
name: "ServiceAccountLoginType",
args: []string{"--service-account", "-u", "dean", "--login-type", "none"},
err: "You cannot use --login-type with --service-account",
},
{
name: "ServiceAccountDisableLogin",
args: []string{"--service-account", "-u", "dean", "--disable-login"},
err: "You cannot use --disable-login with --service-account",
},
{
name: "ServiceAccountEmail",
args: []string{"--service-account", "-u", "dean", "--email", "dean@coder.com"},
err: "You cannot use --email with --service-account",
},
{
name: "ServiceAccountPassword",
args: []string{"--service-account", "-u", "dean", "--password", "1n5ecureP4ssw0rd!"},
err: "You cannot use --password with --service-account",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
coderdtest.CreateFirstUser(t, client)
inv, root := clitest.New(t, append([]string{"users", "create"}, tt.args...)...)
clitest.SetupConfig(t, client, root)
err := inv.Run()
if tt.err == "" {
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitShort)
created, err := client.User(ctx, "dean")
require.NoError(t, err)
assert.Equal(t, codersdk.LoginTypeNone, created.LoginType)
} else {
require.Error(t, err)
require.ErrorContains(t, err, tt.err)
}
})
}
}
+6 -6
View File
@@ -11,13 +11,13 @@ import (
)
type whoamiRow struct {
URL string `json:"url" table:"URL,default_sort"`
Username string `json:"username" table:"Username"`
UserID string `json:"user_id" table:"ID"`
OrganizationIDs string `json:"-" table:"Orgs"`
URL string `json:"url" table:"URL,default_sort"`
Username string `json:"username" table:"Username"`
UserID string `json:"user_id" table:"ID"`
OrganizationIDs string `json:"-" table:"Orgs"`
OrganizationIDsJSON []string `json:"organization_ids" table:"-"`
Roles string `json:"-" table:"Roles"`
RolesJSON map[string][]string `json:"roles" table:"-"`
Roles string `json:"-" table:"Roles"`
RolesJSON map[string][]string `json:"roles" table:"-"`
}
func (r whoamiRow) String() string {
+2 -52
View File
@@ -4819,7 +4819,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/codersdk.UpdateWorkspaceSharingSettingsRequest"
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
}
}
],
@@ -4827,7 +4827,7 @@ const docTemplate = `{
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
"$ref": "#/definitions/codersdk.UpdateWorkspaceSharingSettingsRequest"
}
}
}
@@ -18310,9 +18310,6 @@ const docTemplate = `{
"type": "string",
"format": "uuid"
},
"is_service_account": {
"type": "boolean"
},
"last_seen_at": {
"type": "string",
"format": "date-time"
@@ -18723,19 +18720,6 @@ const docTemplate = `{
}
}
},
"codersdk.ShareableWorkspaceOwners": {
"type": "string",
"enum": [
"none",
"everyone",
"service_accounts"
],
"x-enum-varnames": [
"ShareableWorkspaceOwnersNone",
"ShareableWorkspaceOwnersEveryone",
"ShareableWorkspaceOwnersServiceAccounts"
]
},
"codersdk.SharedWorkspaceActor": {
"type": "object",
"properties": {
@@ -19634,9 +19618,6 @@ const docTemplate = `{
"type": "string",
"format": "uuid"
},
"is_service_account": {
"type": "boolean"
},
"last_seen_at": {
"type": "string",
"format": "date-time"
@@ -20347,21 +20328,7 @@ const docTemplate = `{
"codersdk.UpdateWorkspaceSharingSettingsRequest": {
"type": "object",
"properties": {
"shareable_workspace_owners": {
"description": "ShareableWorkspaceOwners controls whose workspaces can be shared\nwithin the organization.",
"enum": [
"none",
"everyone",
"service_accounts"
],
"allOf": [
{
"$ref": "#/definitions/codersdk.ShareableWorkspaceOwners"
}
]
},
"sharing_disabled": {
"description": "SharingDisabled is deprecated and left for backward compatibility\npurposes.\nDeprecated: use ` + "`" + `ShareableWorkspaceOwners` + "`" + ` instead",
"type": "boolean"
}
}
@@ -20483,9 +20450,6 @@ const docTemplate = `{
"type": "string",
"format": "uuid"
},
"is_service_account": {
"type": "boolean"
},
"last_seen_at": {
"type": "string",
"format": "date-time"
@@ -22205,21 +22169,7 @@ const docTemplate = `{
"codersdk.WorkspaceSharingSettings": {
"type": "object",
"properties": {
"shareable_workspace_owners": {
"description": "ShareableWorkspaceOwners controls whose workspaces can be shared\nwithin the organization.",
"enum": [
"none",
"everyone",
"service_accounts"
],
"allOf": [
{
"$ref": "#/definitions/codersdk.ShareableWorkspaceOwners"
}
]
},
"sharing_disabled": {
"description": "SharingDisabled is deprecated and left for backward compatibility\npurposes.\nDeprecated: use ` + "`" + `ShareableWorkspaceOwners` + "`" + ` instead",
"type": "boolean"
},
"sharing_globally_disabled": {
+2 -40
View File
@@ -4262,7 +4262,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/codersdk.UpdateWorkspaceSharingSettingsRequest"
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
}
}
],
@@ -4270,7 +4270,7 @@
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
"$ref": "#/definitions/codersdk.UpdateWorkspaceSharingSettingsRequest"
}
}
}
@@ -16708,9 +16708,6 @@
"type": "string",
"format": "uuid"
},
"is_service_account": {
"type": "boolean"
},
"last_seen_at": {
"type": "string",
"format": "date-time"
@@ -17112,15 +17109,6 @@
}
}
},
"codersdk.ShareableWorkspaceOwners": {
"type": "string",
"enum": ["none", "everyone", "service_accounts"],
"x-enum-varnames": [
"ShareableWorkspaceOwnersNone",
"ShareableWorkspaceOwnersEveryone",
"ShareableWorkspaceOwnersServiceAccounts"
]
},
"codersdk.SharedWorkspaceActor": {
"type": "object",
"properties": {
@@ -17982,9 +17970,6 @@
"type": "string",
"format": "uuid"
},
"is_service_account": {
"type": "boolean"
},
"last_seen_at": {
"type": "string",
"format": "date-time"
@@ -18660,17 +18645,7 @@
"codersdk.UpdateWorkspaceSharingSettingsRequest": {
"type": "object",
"properties": {
"shareable_workspace_owners": {
"description": "ShareableWorkspaceOwners controls whose workspaces can be shared\nwithin the organization.",
"enum": ["none", "everyone", "service_accounts"],
"allOf": [
{
"$ref": "#/definitions/codersdk.ShareableWorkspaceOwners"
}
]
},
"sharing_disabled": {
"description": "SharingDisabled is deprecated and left for backward compatibility\npurposes.\nDeprecated: use `ShareableWorkspaceOwners` instead",
"type": "boolean"
}
}
@@ -18774,9 +18749,6 @@
"type": "string",
"format": "uuid"
},
"is_service_account": {
"type": "boolean"
},
"last_seen_at": {
"type": "string",
"format": "date-time"
@@ -20412,17 +20384,7 @@
"codersdk.WorkspaceSharingSettings": {
"type": "object",
"properties": {
"shareable_workspace_owners": {
"description": "ShareableWorkspaceOwners controls whose workspaces can be shared\nwithin the organization.",
"enum": ["none", "everyone", "service_accounts"],
"allOf": [
{
"$ref": "#/definitions/codersdk.ShareableWorkspaceOwners"
}
]
},
"sharing_disabled": {
"description": "SharingDisabled is deprecated and left for backward compatibility\npurposes.\nDeprecated: use `ShareableWorkspaceOwners` instead",
"type": "boolean"
},
"sharing_globally_disabled": {
+272 -484
View File
File diff suppressed because it is too large Load Diff
-139
View File
@@ -2,20 +2,13 @@ package chatd
import (
"context"
"sync"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
)
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
@@ -91,135 +84,3 @@ func TestRefreshChatWorkspaceSnapshot_ReturnsReloadError(t *testing.T) {
require.ErrorContains(t, err, loadErr.Error())
require.Equal(t, chat, refreshed)
}
func TestResolveInstructionsReusesTurnLocalWorkspaceAgent(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
workspaceAgent := database.WorkspaceAgent{
ID: uuid.New(),
OperatingSystem: "linux",
Directory: "/home/coder/project",
ExpandedDirectory: "/home/coder/project",
}
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{workspaceAgent}, nil).Times(1)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).Return(
workspacesdk.LSResponse{},
codersdk.NewTestError(404, "POST", "/api/v0/list-directory"),
).Times(1)
conn.EXPECT().ReadFile(
gomock.Any(),
"/home/coder/project/AGENTS.md",
int64(0),
int64(maxInstructionFileBytes+1),
).Return(
nil,
"",
codersdk.NewTestError(404, "GET", "/api/v0/read-file"),
).Times(1)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := &Server{
db: db,
logger: logger,
instructionCache: make(map[uuid.UUID]cachedInstruction),
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
return conn, func() {}, nil
},
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
instruction := server.resolveInstructions(
ctx,
chat,
workspaceCtx.getWorkspaceAgent,
workspaceCtx.getWorkspaceConn,
)
require.Contains(t, instruction, "Operating System: linux")
require.Contains(t, instruction, "Working Directory: /home/coder/project")
}
func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
initialAgent := database.WorkspaceAgent{ID: uuid.New()}
refreshedAgent := database.WorkspaceAgent{ID: uuid.New()}
gomock.InOrder(
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{initialAgent}, nil),
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{refreshedAgent}, nil),
)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
var dialed []uuid.UUID
server := &Server{db: db}
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialed = append(dialed, agentID)
if agentID == initialAgent.ID {
return nil, nil, xerrors.New("dial failed")
}
return conn, func() {}, nil
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
require.NoError(t, err)
require.Same(t, conn, gotConn)
require.Equal(t, []uuid.UUID{initialAgent.ID, refreshedAgent.ID}, dialed)
}
+5 -580
View File
@@ -374,72 +374,6 @@ func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
require.Len(t, messages, 1)
}
func TestSendMessageQueuesWhenWaitingWithQueuedBacklog(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "queue-when-waiting-with-backlog",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("older queued"),
})
require.NoError(t, err)
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
ChatID: chat.ID,
Content: queuedContent,
})
require.NoError(t, err)
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusWaiting,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
})
require.NoError(t, err)
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("newer queued")},
})
require.NoError(t, err)
require.True(t, result.Queued)
require.NotNil(t, result.QueuedMessage)
require.Equal(t, database.ChatStatusWaiting, result.Chat.Status)
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, queued, 2)
olderSDK := db2sdk.ChatQueuedMessage(queued[0])
require.Len(t, olderSDK.Content, 1)
require.Equal(t, "older queued", olderSDK.Content[0].Text)
newerSDK := db2sdk.ChatQueuedMessage(queued[1])
require.Len(t, newerSDK.Content, 1)
require.Equal(t, "newer queued", newerSDK.Content[0].Text)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, messages, 1)
}
func TestSendMessageInterruptBehaviorQueuesAndInterruptsWhenBusy(t *testing.T) {
t.Parallel()
@@ -591,457 +525,6 @@ func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
require.False(t, chatFromDB.WorkerID.Valid)
}
func TestCreateChatRejectsWhenUsageLimitReached(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
Enabled: true,
DefaultLimitMicros: 100,
Period: string(codersdk.ChatUsageLimitPeriodDay),
})
require.NoError(t, err)
existingChat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: user.ID,
Title: "existing-limit-chat",
LastModelConfigID: model.ID,
})
require.NoError(t, err)
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("assistant"),
})
require.NoError(t, err)
_, err = db.InsertChatMessage(ctx, database.InsertChatMessageParams{
ChatID: existingChat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleAssistant,
ContentVersion: chatprompt.CurrentContentVersion,
Content: assistantContent,
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{},
OutputTokens: sql.NullInt64{},
TotalTokens: sql.NullInt64{},
ReasoningTokens: sql.NullInt64{},
CacheCreationTokens: sql.NullInt64{},
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
Compressed: sql.NullBool{},
TotalCostMicros: sql.NullInt64{Int64: 100, Valid: true},
})
require.NoError(t, err)
beforeChats, err := db.GetChats(ctx, database.GetChatsParams{
OwnerID: user.ID,
AfterID: uuid.Nil,
OffsetOpt: 0,
LimitOpt: 100,
})
require.NoError(t, err)
require.Len(t, beforeChats, 1)
_, err = replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "over-limit",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.Error(t, err)
var limitErr *chatd.UsageLimitExceededError
require.ErrorAs(t, err, &limitErr)
require.Equal(t, int64(100), limitErr.LimitMicros)
require.Equal(t, int64(100), limitErr.ConsumedMicros)
afterChats, err := db.GetChats(ctx, database.GetChatsParams{
OwnerID: user.ID,
AfterID: uuid.Nil,
OffsetOpt: 0,
LimitOpt: 100,
})
require.NoError(t, err)
require.Len(t, afterChats, len(beforeChats))
}
func TestPromoteQueuedAllowsAlreadyQueuedMessageWhenUsageLimitReached(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
Enabled: true,
DefaultLimitMicros: 100,
Period: string(codersdk.ChatUsageLimitPeriodDay),
})
require.NoError(t, err)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "queued-limit-reached",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
queuedResult, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
})
require.NoError(t, err)
require.True(t, queuedResult.Queued)
require.NotNil(t, queuedResult.QueuedMessage)
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("assistant"),
})
require.NoError(t, err)
_, err = db.InsertChatMessage(ctx, database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleAssistant,
ContentVersion: chatprompt.CurrentContentVersion,
Content: assistantContent,
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{},
OutputTokens: sql.NullInt64{},
TotalTokens: sql.NullInt64{},
ReasoningTokens: sql.NullInt64{},
CacheCreationTokens: sql.NullInt64{},
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
Compressed: sql.NullBool{},
TotalCostMicros: sql.NullInt64{Int64: 100, Valid: true},
})
require.NoError(t, err)
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusWaiting,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
})
require.NoError(t, err)
result, err := replica.PromoteQueued(ctx, chatd.PromoteQueuedOptions{
ChatID: chat.ID,
QueuedMessageID: queuedResult.QueuedMessage.ID,
CreatedBy: user.ID,
})
require.NoError(t, err)
require.Equal(t, database.ChatMessageRoleUser, result.PromotedMessage.Role)
chat, err = db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusPending, chat.Status)
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
require.Empty(t, queued)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, messages, 3)
require.Equal(t, database.ChatMessageRoleUser, messages[2].Role)
}
func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
Enabled: true,
DefaultLimitMicros: 100,
Period: string(codersdk.ChatUsageLimitPeriodDay),
})
require.NoError(t, err)
streamStarted := make(chan struct{})
interrupted := make(chan struct{})
allowFinish := make(chan struct{})
var requestCount atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
if requestCount.Add(1) == 1 {
chunks := make(chan chattest.OpenAIChunk, 1)
go func() {
defer close(chunks)
chunks <- chattest.OpenAITextChunks("partial")[0]
select {
case <-streamStarted:
default:
close(streamStarted)
}
<-req.Context().Done()
select {
case <-interrupted:
default:
close(interrupted)
}
<-allowFinish
}()
return chattest.OpenAIResponse{StreamingChunks: chunks}
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("done")...,
)
})
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: 10 * time.Millisecond,
InFlightChatStaleAfter: testutil.WaitSuperLong,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
user, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "interrupt-autopromote-limit",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
require.Eventually(t, func() bool {
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil {
return false
}
return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid
}, testutil.WaitMedium, testutil.IntervalFast)
require.Eventually(t, func() bool {
select {
case <-streamStarted:
return true
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
})
require.NoError(t, err)
require.True(t, queuedResult.Queued)
require.NotNil(t, queuedResult.QueuedMessage)
// Send "later queued" immediately after "queued" while the first
// message is still in chat_queued_messages. The existing backlog
// (len(existingQueued) > 0) guarantees this is queued regardless
// of chat status, avoiding a race where the auto-promoted "queued"
// message finishes processing before we can send this.
laterQueuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("later queued")},
})
require.NoError(t, err)
require.True(t, laterQueuedResult.Queued)
require.NotNil(t, laterQueuedResult.QueuedMessage)
require.Eventually(t, func() bool {
select {
case <-interrupted:
return true
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
spendChat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: user.ID,
WorkspaceID: uuid.NullUUID{},
ParentChatID: uuid.NullUUID{},
RootChatID: uuid.NullUUID{},
LastModelConfigID: model.ID,
Title: "other-spend",
Mode: database.NullChatMode{},
})
require.NoError(t, err)
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("spent elsewhere"),
})
require.NoError(t, err)
_, err = db.InsertChatMessage(ctx, database.InsertChatMessageParams{
ChatID: spendChat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleAssistant,
ContentVersion: chatprompt.CurrentContentVersion,
Content: assistantContent,
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{},
OutputTokens: sql.NullInt64{},
TotalTokens: sql.NullInt64{},
ReasoningTokens: sql.NullInt64{},
CacheCreationTokens: sql.NullInt64{},
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
Compressed: sql.NullBool{},
TotalCostMicros: sql.NullInt64{Int64: 100, Valid: true},
})
require.NoError(t, err)
close(allowFinish)
require.Eventually(t, func() bool {
queued, dbErr := db.GetChatQueuedMessages(ctx, chat.ID)
if dbErr != nil || len(queued) != 0 {
return false
}
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil || fromDB.Status != database.ChatStatusWaiting {
return false
}
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
if dbErr != nil {
return false
}
userTexts := make([]string, 0, 3)
for _, message := range messages {
if message.Role != database.ChatMessageRoleUser {
continue
}
sdkMessage := db2sdk.ChatMessage(message)
if len(sdkMessage.Content) != 1 {
continue
}
userTexts = append(userTexts, sdkMessage.Content[0].Text)
}
if len(userTexts) != 3 {
return false
}
return userTexts[0] == "hello" && userTexts[1] == "queued" && userTexts[2] == "later queued"
}, testutil.WaitLong, testutil.IntervalFast)
}
func TestEditMessageRejectsWhenUsageLimitReached(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
Enabled: true,
DefaultLimitMicros: 100,
Period: string(codersdk.ChatUsageLimitPeriodDay),
})
require.NoError(t, err)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "edit-limit-reached",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("original")},
})
require.NoError(t, err)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, messages, 1)
editedMessageID := messages[0].ID
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("assistant"),
})
require.NoError(t, err)
_, err = db.InsertChatMessage(ctx, database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleAssistant,
ContentVersion: chatprompt.CurrentContentVersion,
Content: assistantContent,
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{},
OutputTokens: sql.NullInt64{},
TotalTokens: sql.NullInt64{},
ReasoningTokens: sql.NullInt64{},
CacheCreationTokens: sql.NullInt64{},
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
Compressed: sql.NullBool{},
TotalCostMicros: sql.NullInt64{Int64: 100, Valid: true},
})
require.NoError(t, err)
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
ChatID: chat.ID,
EditedMessageID: editedMessageID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
})
require.Error(t, err)
var limitErr *chatd.UsageLimitExceededError
require.ErrorAs(t, err, &limitErr)
require.Equal(t, int64(100), limitErr.LimitMicros)
require.Equal(t, int64(100), limitErr.ConsumedMicros)
messages, err = db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, messages, 2)
originalMessage := db2sdk.ChatMessage(messages[0])
require.Len(t, originalMessage.Content, 1)
require.Equal(t, "original", originalMessage.Content[0].Text)
}
func TestEditMessageRejectsMissingMessage(t *testing.T) {
t.Parallel()
@@ -2239,10 +1722,12 @@ func TestSuccessfulChatSendsWebPushWithSummary(t *testing.T) {
const assistantText = "I have completed the task successfully and all tests are passing now."
const summaryText = "Completed task and verified all tests pass."
var nonStreamingRequests atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
nonStreamingRequests.Add(1)
// Non-streaming calls are used for title
// generation and push summary generation.
// Return the summary text for both — the title
// result is irrelevant to this test.
return chattest.OpenAINonStreamingResponse(summaryText)
}
return chattest.OpenAIStreamingResponse(
@@ -2289,63 +1774,6 @@ func TestSuccessfulChatSendsWebPushWithSummary(t *testing.T) {
"push body should be the LLM-generated summary")
require.NotEqual(t, "Agent has finished running.", msg.Body,
"push body should not use the default fallback text")
require.Equal(t, int32(1), nonStreamingRequests.Load(),
"expected exactly one non-streaming request for push summary generation")
}
func TestSuccessfulChatSendsWebPushFallbackWithoutSummaryForEmptyAssistantText(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
var nonStreamingRequests atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
nonStreamingRequests.Add(1)
return chattest.OpenAINonStreamingResponse("unexpected summary request")
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks(" ")...,
)
})
mockPush := &mockWebpushDispatcher{}
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: 10 * time.Millisecond,
InFlightChatStaleAfter: testutil.WaitSuperLong,
WebpushDispatcher: mockPush,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
user, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
_, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "empty-summary-push-test",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("do the thing")},
})
require.NoError(t, err)
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return mockPush.dispatchCount.Load() >= 1
}, testutil.IntervalFast)
msg := mockPush.getLastMessage()
require.Equal(t, "Agent has finished running.", msg.Body,
"push body should fall back when the final assistant text is empty")
require.Equal(t, int32(0), nonStreamingRequests.Load(),
"push summary should not be requested when final assistant text has no usable text")
}
func TestComputerUseSubagentToolsAndModel(t *testing.T) {
@@ -2517,9 +1945,6 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) {
})
require.NoError(t, err)
err = db.UpsertChatDesktopEnabled(ctx, true)
require.NoError(t, err)
// Build workspace + agent records so getWorkspaceConn can
// resolve the agent for the computer use child.
org := dbgen.Organization(t, db, database.Organization{})
@@ -2681,7 +2106,7 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) {
// 6. Verify the child chat has Mode = computer_use in
// the DB.
allChats, err := db.GetChats(ctx, database.GetChatsParams{
allChats, err := db.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
OwnerID: user.ID,
})
require.NoError(t, err)
@@ -82,7 +82,7 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
options := &codersdk.ChatModelProviderOptions{
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
Reasoning: &codersdk.ChatModelReasoningOptions{
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
Enabled: boolPtr(true),
},
Provider: &codersdk.ChatModelOpenRouterProvider{
@@ -92,7 +92,7 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
}
defaults := &codersdk.ChatModelProviderOptions{
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
Reasoning: &codersdk.ChatModelReasoningOptions{
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
Enabled: boolPtr(false),
Exclude: boolPtr(true),
MaxTokens: int64Ptr(123),
+3 -3
View File
@@ -78,9 +78,9 @@ type ProcessToolOptions struct {
// ExecuteArgs are the parameters accepted by the execute tool.
type ExecuteArgs struct {
Command string `json:"command" description:"The shell command to execute."`
Timeout *string `json:"timeout,omitempty" description:"Timeout duration (e.g. '30s', '5m'). Default is 10s. Only applies to foreground commands."`
WorkDir *string `json:"workdir,omitempty" description:"Working directory for the command."`
Command string `json:"command" description:"The shell command to execute."`
Timeout *string `json:"timeout,omitempty" description:"Timeout duration (e.g. '30s', '5m'). Default is 10s. Only applies to foreground commands."`
WorkDir *string `json:"workdir,omitempty" description:"Working directory for the command."`
RunInBackground *bool `json:"run_in_background,omitempty" description:"Run this command in the background without blocking. Use for long-running processes like dev servers, file watchers, or builds that run longer than 5 seconds. Do NOT use shell & to background processes — it will not work correctly. Always use this parameter instead."`
}
-2
View File
@@ -62,7 +62,6 @@ func (p *Server) maybeGenerateChatTitle(
messages []database.ChatMessage,
fallbackModel fantasy.LanguageModel,
keys chatprovider.ProviderAPIKeys,
generatedTitle *generatedChatTitle,
logger slog.Logger,
) {
input, ok := titleInput(chat, messages)
@@ -112,7 +111,6 @@ func (p *Server) maybeGenerateChatTitle(
return
}
chat.Title = title
generatedTitle.Store(title)
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange, nil)
return
}
+3 -10
View File
@@ -84,14 +84,6 @@ func (p *Server) isAnthropicConfigured(ctx context.Context) bool {
return false
}
func (p *Server) isDesktopEnabled(ctx context.Context) bool {
enabled, err := p.db.GetChatDesktopEnabled(ctx)
if err != nil {
return false
}
return enabled
}
func (p *Server) subagentTools(ctx context.Context, currentChat func() database.Chat) []fantasy.AgentTool {
tools := []fantasy.AgentTool{
fantasy.NewAgentTool(
@@ -261,8 +253,9 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
}
// Only include the computer use tool when an Anthropic
// provider is configured and desktop is enabled.
if p.isAnthropicConfigured(ctx) && p.isDesktopEnabled(ctx) {
// provider is configured, since it requires an Anthropic
// model.
if p.isAnthropicConfigured(ctx) {
tools = append(tools, fantasy.NewAgentTool(
"spawn_computer_use_agent",
"Spawn a dedicated computer use agent that can see the desktop "+
+3 -37
View File
@@ -15,7 +15,6 @@ import (
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/chatd/chattool"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/pubsub"
@@ -145,20 +144,14 @@ func findToolByName(tools []fantasy.AgentTool, name string) fantasy.AgentTool {
return nil
}
func chatdTestContext(t *testing.T) context.Context {
t.Helper()
return dbauthz.AsChatd(testutil.Context(t, testutil.WaitLong))
}
func TestSpawnComputerUseAgent_NoAnthropicProvider(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
// No Anthropic key in ProviderAPIKeys.
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
ctx := chatdTestContext(t)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedInternalChatDeps(ctx, t, db)
// Create a root parent chat.
@@ -183,13 +176,12 @@ func TestSpawnComputerUseAgent_NotAvailableForChildChats(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
// Provide an Anthropic key so the provider check passes.
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
Anthropic: "test-anthropic-key",
})
ctx := chatdTestContext(t)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedInternalChatDeps(ctx, t, db)
// Create a root parent chat.
@@ -240,42 +232,16 @@ func TestSpawnComputerUseAgent_NotAvailableForChildChats(t *testing.T) {
assert.Contains(t, resp.Content, "delegated chats cannot create child subagents")
}
func TestSpawnComputerUseAgent_DesktopDisabled(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
Anthropic: "test-anthropic-key",
})
ctx := chatdTestContext(t)
user, model := seedInternalChatDeps(ctx, t, db)
parent, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
Title: "parent-desktop-disabled",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
parentChat, err := db.GetChatByID(ctx, parent.ID)
require.NoError(t, err)
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
tool := findToolByName(tools, "spawn_computer_use_agent")
assert.Nil(t, tool, "spawn_computer_use_agent tool must be omitted when desktop is disabled")
}
func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
// Provide an Anthropic key so the tool can proceed.
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
Anthropic: "test-anthropic-key",
})
ctx := chatdTestContext(t)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedInternalChatDeps(ctx, t, db)
// The parent uses an OpenAI model.
-128
View File
@@ -1,128 +0,0 @@
package chatd
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/codersdk"
)
// ComputeUsagePeriodBounds returns the UTC-aligned start and end bounds for the
// active usage-limit period containing now.
func ComputeUsagePeriodBounds(now time.Time, period codersdk.ChatUsageLimitPeriod) (start, end time.Time) {
utcNow := now.UTC()
switch period {
case codersdk.ChatUsageLimitPeriodDay:
start = time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
end = start.AddDate(0, 0, 1)
case codersdk.ChatUsageLimitPeriodWeek:
// Walk backward to Monday of the current ISO week.
// ISO 8601 weeks always start on Monday, so this never
// crosses an ISO-week boundary.
start = time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
for start.Weekday() != time.Monday {
start = start.AddDate(0, 0, -1)
}
end = start.AddDate(0, 0, 7)
case codersdk.ChatUsageLimitPeriodMonth:
start = time.Date(utcNow.Year(), utcNow.Month(), 1, 0, 0, 0, 0, time.UTC)
end = start.AddDate(0, 1, 0)
default:
panic(fmt.Sprintf("unknown chat usage limit period: %q", period))
}
return start, end
}
// ResolveUsageLimitStatus resolves the current usage-limit status for userID.
//
// Note: There is a potential race condition where two concurrent messages
// from the same user can both pass the limit check if processed in
// parallel, allowing brief overage. This is acceptable because:
// - Cost is only known after the LLM API returns.
// - Overage is bounded by message cost × concurrency.
// - Fail-open is the deliberate design choice for this feature.
//
// Architecture note: today this path enforces one period globally
// (day/week/month) from config.
// To support simultaneous periods, add nullable
// daily/weekly/monthly_limit_micros columns on override tables, where NULL
// means no limit for that period.
// Then scan spend once over the widest active window with conditional SUMs
// for each period and compare each spend/limit pair Go-side, blocking on
// whichever period is tightest.
func ResolveUsageLimitStatus(ctx context.Context, db database.Store, userID uuid.UUID, now time.Time) (*codersdk.ChatUsageLimitStatus, error) {
//nolint:gocritic // AsChatd provides narrowly-scoped daemon access for
// deployment config reads and cross-user chat spend aggregation.
authCtx := dbauthz.AsChatd(ctx)
config, err := db.GetChatUsageLimitConfig(authCtx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
}
return nil, err
}
if !config.Enabled {
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
}
period, ok := mapDBPeriodToSDK(config.Period)
if !ok {
return nil, xerrors.Errorf("invalid chat usage limit period %q", config.Period)
}
// Resolve effective limit in a single query:
// individual override > group limit > global default.
effectiveLimit, err := db.ResolveUserChatSpendLimit(authCtx, userID)
if err != nil {
return nil, err
}
// -1 means limits are disabled (shouldn't happen since we checked above,
// but handle gracefully).
if effectiveLimit < 0 {
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
}
start, end := ComputeUsagePeriodBounds(now, period)
spendTotal, err := db.GetUserChatSpendInPeriod(authCtx, database.GetUserChatSpendInPeriodParams{
UserID: userID,
StartTime: start,
EndTime: end,
})
if err != nil {
return nil, err
}
return &codersdk.ChatUsageLimitStatus{
IsLimited: true,
Period: period,
SpendLimitMicros: &effectiveLimit,
CurrentSpend: spendTotal,
PeriodStart: start,
PeriodEnd: end,
}, nil
}
func mapDBPeriodToSDK(dbPeriod string) (codersdk.ChatUsageLimitPeriod, bool) {
switch dbPeriod {
case string(codersdk.ChatUsageLimitPeriodDay):
return codersdk.ChatUsageLimitPeriodDay, true
case string(codersdk.ChatUsageLimitPeriodWeek):
return codersdk.ChatUsageLimitPeriodWeek, true
case string(codersdk.ChatUsageLimitPeriodMonth):
return codersdk.ChatUsageLimitPeriodMonth, true
default:
return "", false
}
}
-132
View File
@@ -1,132 +0,0 @@
package chatd //nolint:testpackage // Keeps chatd unit tests in the package.
import (
"testing"
"time"
"github.com/coder/coder/v2/codersdk"
)
func TestComputeUsagePeriodBounds(t *testing.T) {
t.Parallel()
newYork, err := time.LoadLocation("America/New_York")
if err != nil {
t.Fatalf("load America/New_York: %v", err)
}
tests := []struct {
name string
now time.Time
period codersdk.ChatUsageLimitPeriod
wantStart time.Time
wantEnd time.Time
}{
{
name: "day/mid_day",
now: time.Date(2025, time.June, 15, 14, 30, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodDay,
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "day/midnight_exactly",
now: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodDay,
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "day/end_of_day",
now: time.Date(2025, time.June, 15, 23, 59, 59, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodDay,
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "week/wednesday",
now: time.Date(2025, time.June, 11, 10, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodWeek,
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "week/monday",
now: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodWeek,
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "week/sunday",
now: time.Date(2025, time.June, 15, 23, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodWeek,
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "week/year_boundary",
now: time.Date(2024, time.December, 31, 12, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodWeek,
wantStart: time.Date(2024, time.December, 30, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.January, 6, 0, 0, 0, 0, time.UTC),
},
{
name: "month/mid_month",
now: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "month/first_day",
now: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "month/last_day",
now: time.Date(2025, time.June, 30, 23, 59, 59, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "month/february",
now: time.Date(2025, time.February, 15, 12, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2025, time.February, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.March, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "month/leap_year_february",
now: time.Date(2024, time.February, 29, 12, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2024, time.February, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2024, time.March, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "day/non_utc_timezone",
now: time.Date(2025, time.June, 15, 22, 0, 0, 0, newYork),
period: codersdk.ChatUsageLimitPeriodDay,
wantStart: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 17, 0, 0, 0, 0, time.UTC),
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
start, end := ComputeUsagePeriodBounds(tc.now, tc.period)
if !start.Equal(tc.wantStart) {
t.Errorf("start: got %v, want %v", start, tc.wantStart)
}
if !end.Equal(tc.wantEnd) {
t.Errorf("end: got %v, want %v", end, tc.wantEnd)
}
})
}
}
+60 -626
View File
@@ -82,30 +82,6 @@ type chatDiffReference struct {
RepositoryRef *chatRepositoryRef
}
func writeChatUsageLimitExceeded(
ctx context.Context,
rw http.ResponseWriter,
limitErr *chatd.UsageLimitExceededError,
) {
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.ChatUsageLimitExceededResponse{
Response: codersdk.Response{
Message: "Chat usage limit exceeded.",
},
SpentMicros: limitErr.ConsumedMicros,
LimitMicros: limitErr.LimitMicros,
ResetsAt: limitErr.PeriodEnd,
})
}
func maybeWriteLimitErr(ctx context.Context, rw http.ResponseWriter, err error) bool {
var limitErr *chatd.UsageLimitExceededError
if errors.As(err, &limitErr) {
writeChatUsageLimitExceeded(ctx, rw, limitErr)
return true
}
return false
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
@@ -189,7 +165,7 @@ func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
return
}
params := database.GetChatsParams{
params := database.GetChatsByOwnerIDParams{
OwnerID: apiKey.UserID,
Archived: searchParams.Archived,
AfterID: paginationParams.AfterID,
@@ -199,7 +175,7 @@ func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
LimitOpt: int32(paginationParams.Limit),
}
chats, err := api.Database.GetChats(ctx, params)
chats, err := api.Database.GetChatsByOwnerID(ctx, params)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to list chats.",
@@ -292,9 +268,6 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
InitialUserContent: contentBlocks,
})
if err != nil {
if maybeWriteLimitErr(ctx, rw, err) {
return
}
if database.IsForeignKeyViolation(
err,
database.ForeignKeyChatsLastModelConfigID,
@@ -458,12 +431,7 @@ func (api *API) chatCostSummary(rw http.ResponseWriter, r *http.Request) {
chatBreakdowns = append(chatBreakdowns, convertChatCostChatBreakdown(chat))
}
usageStatus, err := chatd.ResolveUsageLimitStatus(ctx, api.Database, targetUser.ID, time.Now())
if err != nil {
api.Logger.Warn(ctx, "failed to resolve usage limit status", slog.Error(err))
}
response := codersdk.ChatCostSummary{
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatCostSummary{
StartDate: startDate,
EndDate: endDate,
TotalCostMicros: summary.TotalCostMicros,
@@ -475,12 +443,7 @@ func (api *API) chatCostSummary(rw http.ResponseWriter, r *http.Request) {
TotalCacheCreationTokens: summary.TotalCacheCreationTokens,
ByModel: modelBreakdowns,
ByChat: chatBreakdowns,
}
if usageStatus != nil {
response.UsageLimit = usageStatus
}
httpapi.Write(ctx, rw, http.StatusOK, response)
})
}
func (api *API) chatCostUsers(rw http.ResponseWriter, r *http.Request) {
@@ -584,445 +547,6 @@ func (api *API) chatCostUsers(rw http.ResponseWriter, r *http.Request) {
})
}
// @Summary Get chat usage limit config
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // HTTP handler writes to ResponseWriter.
func (api *API) getChatUsageLimitConfig(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
httpapi.Forbidden(rw)
return
}
config, configErr := api.Database.GetChatUsageLimitConfig(ctx)
if configErr != nil && !errors.Is(configErr, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get chat usage limit config.",
Detail: configErr.Error(),
})
return
}
overrideRows, err := api.Database.ListChatUsageLimitOverrides(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to list chat usage limit overrides.",
Detail: err.Error(),
})
return
}
groupOverrides, err := api.Database.ListChatUsageLimitGroupOverrides(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to list group usage limit overrides.",
Detail: err.Error(),
})
return
}
unpricedModelCount, err := api.Database.CountEnabledModelsWithoutPricing(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to count unpriced chat models.",
Detail: err.Error(),
})
return
}
response := codersdk.ChatUsageLimitConfigResponse{
ChatUsageLimitConfig: codersdk.ChatUsageLimitConfig{},
UnpricedModelCount: unpricedModelCount,
Overrides: make([]codersdk.ChatUsageLimitOverride, 0, len(overrideRows)),
GroupOverrides: make([]codersdk.ChatUsageLimitGroupOverride, 0, len(groupOverrides)),
}
if configErr == nil {
response.Period = codersdk.ChatUsageLimitPeriod(config.Period)
response.UpdatedAt = config.UpdatedAt
if config.Enabled {
response.SpendLimitMicros = ptr.Ref(config.DefaultLimitMicros)
}
}
for _, row := range overrideRows {
response.Overrides = append(response.Overrides, codersdk.ChatUsageLimitOverride{
UserID: row.UserID,
Username: row.Username,
Name: row.Name,
AvatarURL: row.AvatarURL,
SpendLimitMicros: nullInt64Ptr(row.SpendLimitMicros),
})
}
for _, glo := range groupOverrides {
response.GroupOverrides = append(response.GroupOverrides, codersdk.ChatUsageLimitGroupOverride{
GroupID: glo.GroupID,
GroupName: glo.GroupName,
GroupDisplayName: glo.GroupDisplayName,
GroupAvatarURL: glo.GroupAvatarUrl,
MemberCount: glo.MemberCount,
SpendLimitMicros: nullInt64Ptr(glo.SpendLimitMicros),
})
}
httpapi.Write(ctx, rw, http.StatusOK, response)
}
// @Summary Update chat usage limit config
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) updateChatUsageLimitConfig(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
httpapi.Forbidden(rw)
return
}
var req codersdk.ChatUsageLimitConfig
if !httpapi.Read(ctx, rw, r, &req) {
return
}
params := database.UpsertChatUsageLimitConfigParams{
Enabled: false,
DefaultLimitMicros: 0,
Period: "",
}
if req.SpendLimitMicros == nil {
if req.Period != "" && !req.Period.Valid() {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid chat usage limit period.",
Detail: "Period must be one of: day, week, month.",
})
return
}
params.Enabled = false
params.DefaultLimitMicros = 0
params.Period = string(req.Period)
if params.Period == "" {
params.Period = string(codersdk.ChatUsageLimitPeriodMonth)
}
} else {
if *req.SpendLimitMicros <= 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid chat usage limit spend limit.",
Detail: "Spend limit must be greater than 0.",
})
return
}
if !req.Period.Valid() {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid chat usage limit period.",
Detail: "Period must be one of: day, week, month.",
})
return
}
params.Enabled = true
params.DefaultLimitMicros = *req.SpendLimitMicros
params.Period = string(req.Period)
}
config, err := api.Database.UpsertChatUsageLimitConfig(ctx, params)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to update chat usage limit config.",
Detail: err.Error(),
})
return
}
response := codersdk.ChatUsageLimitConfig{
Period: codersdk.ChatUsageLimitPeriod(config.Period),
UpdatedAt: config.UpdatedAt,
}
if config.Enabled {
response.SpendLimitMicros = ptr.Ref(config.DefaultLimitMicros)
}
httpapi.Write(ctx, rw, http.StatusOK, response)
}
// @Summary Get my chat usage limit status
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
// getMyChatUsageLimitStatus returns the current usage-limit status for the
// authenticated user. No additional RBAC check is required because the
// endpoint always operates on the requesting user's own data via
// httpmw.APIKey(r).UserID.
//
//nolint:revive // HTTP handler writes to ResponseWriter.
func (api *API) getMyChatUsageLimitStatus(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
status, err := chatd.ResolveUsageLimitStatus(ctx, api.Database, httpmw.APIKey(r).UserID, time.Now())
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get chat usage limit status.",
Detail: err.Error(),
})
return
}
if status == nil {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatUsageLimitStatus{IsLimited: false})
return
}
httpapi.Write(ctx, rw, http.StatusOK, status)
}
// @Summary Upsert chat usage limit override
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) upsertChatUsageLimitOverride(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
httpapi.Forbidden(rw)
return
}
userID, ok := parseChatUsageLimitUserID(rw, r)
if !ok {
return
}
var req codersdk.UpsertChatUsageLimitOverrideRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
if req.SpendLimitMicros <= 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid chat usage limit override.",
Detail: "Spend limit must be greater than 0.",
})
return
}
user, err := api.Database.GetUserByID(ctx, userID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: "User not found.",
})
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to look up chat usage limit user.",
Detail: err.Error(),
})
return
}
_, err = api.Database.UpsertChatUsageLimitUserOverride(ctx, database.UpsertChatUsageLimitUserOverrideParams{
UserID: userID,
SpendLimitMicros: req.SpendLimitMicros,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to upsert chat usage limit override.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatUsageLimitOverride{
UserID: user.ID,
Username: user.Username,
Name: user.Name,
AvatarURL: user.AvatarURL,
SpendLimitMicros: nullInt64Ptr(sql.NullInt64{Int64: req.SpendLimitMicros, Valid: true}),
})
}
// @Summary Delete chat usage limit override
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) deleteChatUsageLimitOverride(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
httpapi.Forbidden(rw)
return
}
userID, ok := parseChatUsageLimitUserID(rw, r)
if !ok {
return
}
if _, err := api.Database.GetUserByID(ctx, userID); err != nil {
if errors.Is(err, sql.ErrNoRows) {
writeChatUsageLimitUserNotFound(ctx, rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to look up chat usage limit user.",
Detail: err.Error(),
})
return
}
if _, err := api.Database.GetChatUsageLimitUserOverride(ctx, userID); err != nil {
if errors.Is(err, sql.ErrNoRows) {
writeChatUsageLimitOverrideNotFound(ctx, rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to look up chat usage limit override.",
Detail: err.Error(),
})
return
}
if err := api.Database.DeleteChatUsageLimitUserOverride(ctx, userID); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to delete chat usage limit override.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
// @Summary Upsert chat usage limit group override
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) upsertChatUsageLimitGroupOverride(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
httpapi.Forbidden(rw)
return
}
groupIDStr := chi.URLParam(r, "group")
groupID, err := uuid.Parse(groupIDStr)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid group ID.",
Detail: err.Error(),
})
return
}
var req codersdk.UpdateChatUsageLimitGroupOverrideRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
if req.SpendLimitMicros <= 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid chat usage limit group override.",
Detail: "Spend limit (in microdollars) must be greater than 0.",
})
return
}
group, err := api.Database.GetGroupByID(ctx, groupID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: "Group not found.",
})
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to look up group details.",
Detail: err.Error(),
})
return
}
_, err = api.Database.UpsertChatUsageLimitGroupOverride(ctx, database.UpsertChatUsageLimitGroupOverrideParams{
GroupID: groupID,
SpendLimitMicros: req.SpendLimitMicros,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to upsert group usage limit override.",
Detail: err.Error(),
})
return
}
memberCount, err := api.Database.GetGroupMembersCountByGroupID(ctx, database.GetGroupMembersCountByGroupIDParams{
GroupID: groupID,
IncludeSystem: false,
})
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
writeChatUsageLimitGroupNotFound(ctx, rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to fetch group member count.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatUsageLimitGroupOverride{
GroupID: group.ID,
GroupName: group.Name,
GroupDisplayName: group.DisplayName,
GroupAvatarURL: group.AvatarURL,
MemberCount: memberCount,
SpendLimitMicros: nullInt64Ptr(sql.NullInt64{Int64: req.SpendLimitMicros, Valid: true}),
})
}
// @Summary Delete chat usage limit group override
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) deleteChatUsageLimitGroupOverride(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
httpapi.Forbidden(rw)
return
}
groupIDStr := chi.URLParam(r, "group")
groupID, err := uuid.Parse(groupIDStr)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid group ID.",
Detail: err.Error(),
})
return
}
if _, err := api.Database.GetGroupByID(ctx, groupID); err != nil {
if errors.Is(err, sql.ErrNoRows) {
writeChatUsageLimitGroupNotFound(ctx, rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to look up group details.",
Detail: err.Error(),
})
return
}
if _, err := api.Database.GetChatUsageLimitGroupOverride(ctx, groupID); err != nil {
if errors.Is(err, sql.ErrNoRows) {
writeChatUsageLimitGroupOverrideNotFound(ctx, rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to look up group usage limit override.",
Detail: err.Error(),
})
return
}
if err := api.Database.DeleteChatUsageLimitGroupOverride(ctx, groupID); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to delete group usage limit override.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // HTTP handler writes to ResponseWriter.
@@ -1369,58 +893,64 @@ func (api *API) watchChatDesktop(rw http.ResponseWriter, r *http.Request) {
logger.Debug(ctx, "desktop Bicopy finished")
}
// patchChat updates a chat resource. Currently supports toggling the
// archived state via the Archived field.
func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) archiveChat(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
chat := httpmw.ChatParam(r)
var req codersdk.UpdateChatRequest
if !httpapi.Read(ctx, rw, r, &req) {
if chat.Archived {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Chat is already archived.",
})
return
}
if req.Archived != nil {
archived := *req.Archived
if archived == chat.Archived {
state := "archived"
if !archived {
state = "not archived"
}
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Chat is already %s.", state),
})
return
}
var err error
// Use chatDaemon when available so it can notify
// active subscribers. Fall back to direct DB for the
// simple archive flag — no streaming state is involved.
if api.chatDaemon != nil {
err = api.chatDaemon.ArchiveChat(ctx, chat.ID)
} else {
err = api.Database.ArchiveChatByID(ctx, chat.ID)
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to archive chat.",
Detail: err.Error(),
})
return
}
var err error
// Use chatDaemon when available so it can notify active
// subscribers. Fall back to direct DB for the simple
// archive flag — no streaming state is involved.
if archived {
if api.chatDaemon != nil {
err = api.chatDaemon.ArchiveChat(ctx, chat)
} else {
err = api.Database.ArchiveChatByID(ctx, chat.ID)
}
} else {
if api.chatDaemon != nil {
err = api.chatDaemon.UnarchiveChat(ctx, chat)
} else {
err = api.Database.UnarchiveChatByID(ctx, chat.ID)
}
}
if err != nil {
action := "archive"
if !archived {
action = "unarchive"
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: fmt.Sprintf("Failed to %s chat.", action),
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
func (api *API) unarchiveChat(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
chat := httpmw.ChatParam(r)
if !chat.Archived {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Chat is not archived.",
})
return
}
var err error
// Use chatDaemon when available so it can notify
// active subscribers. Fall back to direct DB for the
// simple unarchive flag — no streaming state is involved.
if api.chatDaemon != nil {
err = api.chatDaemon.UnarchiveChat(ctx, chat.ID)
} else {
err = api.Database.UnarchiveChatByID(ctx, chat.ID)
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to unarchive chat.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
@@ -1466,9 +996,6 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
},
)
if sendErr != nil {
if maybeWriteLimitErr(ctx, rw, sendErr) {
return
}
if xerrors.Is(sendErr, chatd.ErrMessageQueueFull) {
httpapi.Write(ctx, rw, http.StatusTooManyRequests, codersdk.Response{
Message: "Message queue is full.",
@@ -1541,10 +1068,6 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) {
Content: contentBlocks,
})
if editErr != nil {
if maybeWriteLimitErr(ctx, rw, editErr) {
return
}
switch {
case xerrors.Is(editErr, chatd.ErrEditedMessageNotFound):
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
@@ -1635,9 +1158,6 @@ func (api *API) promoteChatQueuedMessage(rw http.ResponseWriter, r *http.Request
})
if txErr != nil {
if maybeWriteLimitErr(ctx, rw, txErr) {
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to promote queued message.",
Detail: txErr.Error(),
@@ -2519,14 +2039,14 @@ func (api *API) getChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatSystemPrompt{
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatSystemPromptResponse{
SystemPrompt: prompt,
})
}
func (api *API) putChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var req codersdk.ChatSystemPrompt
var req codersdk.UpdateChatSystemPromptRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
@@ -2554,49 +2074,6 @@ func (api *API) putChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusNoContent)
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
func (api *API) getChatDesktopEnabled(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
enabled, err := api.Database.GetChatDesktopEnabled(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching desktop setting.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatDesktopEnabledResponse{
EnableDesktop: enabled,
})
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) putChatDesktopEnabled(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
httpapi.Forbidden(rw)
return
}
var req codersdk.UpdateChatDesktopEnabledRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
if err := api.Database.UpsertChatDesktopEnabled(ctx, req.EnableDesktop); httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
} else if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error updating desktop setting.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
@@ -2619,7 +2096,7 @@ func (api *API) getUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request)
customPrompt = ""
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPrompt{
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPromptResponse{
CustomPrompt: customPrompt,
})
}
@@ -2631,7 +2108,7 @@ func (api *API) putUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request)
apiKey = httpmw.APIKey(r)
)
var params codersdk.UserChatCustomPrompt
var params codersdk.UpdateUserChatCustomPromptRequest
if !httpapi.Read(ctx, rw, r, &params) {
return
}
@@ -2658,7 +2135,7 @@ func (api *API) putUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request)
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPrompt{
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPromptResponse{
CustomPrompt: updatedConfig.Value,
})
}
@@ -3878,49 +3355,6 @@ func chatModelConfigToUpdateParams(
}
}
func nullInt64Ptr(n sql.NullInt64) *int64 {
if !n.Valid {
return nil
}
return &n.Int64
}
func writeChatUsageLimitUserNotFound(ctx context.Context, rw http.ResponseWriter) {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "User not found.",
})
}
func writeChatUsageLimitOverrideNotFound(ctx context.Context, rw http.ResponseWriter) {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Chat usage limit override not found.",
})
}
func writeChatUsageLimitGroupOverrideNotFound(ctx context.Context, rw http.ResponseWriter) {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Chat usage limit group override not found.",
})
}
func writeChatUsageLimitGroupNotFound(ctx context.Context, rw http.ResponseWriter) {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Group not found.",
})
}
func parseChatUsageLimitUserID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
userID, err := uuid.Parse(chi.URLParam(r, "user"))
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid chat usage limit user ID.",
Detail: err.Error(),
})
return uuid.Nil, false
}
return userID, true
}
func parseChatProviderID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
providerID, err := uuid.Parse(chi.URLParam(r, "providerConfig"))
if err != nil {
+12 -502
View File
@@ -2,7 +2,6 @@ package coderd_test
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
@@ -17,19 +16,15 @@ import (
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/externalauth"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/websocket"
@@ -60,93 +55,6 @@ func newChatClientWithDatabase(t testing.TB) (*codersdk.Client, database.Store)
})
}
func requireChatUsageLimitExceededError(
t *testing.T,
err error,
wantSpentMicros int64,
wantLimitMicros int64,
wantResetsAt time.Time,
) *codersdk.ChatUsageLimitExceededResponse {
t.Helper()
sdkErr, ok := codersdk.AsError(err)
require.True(t, ok)
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
require.Equal(t, "Chat usage limit exceeded.", sdkErr.Message)
limitErr := codersdk.ChatUsageLimitExceededFrom(err)
require.NotNil(t, limitErr)
require.Equal(t, "Chat usage limit exceeded.", limitErr.Message)
require.Equal(t, wantSpentMicros, limitErr.SpentMicros)
require.Equal(t, wantLimitMicros, limitErr.LimitMicros)
require.True(
t,
limitErr.ResetsAt.Equal(wantResetsAt),
"expected resets_at %s, got %s",
wantResetsAt.UTC().Format(time.RFC3339),
limitErr.ResetsAt.UTC().Format(time.RFC3339),
)
return limitErr
}
func enableDailyChatUsageLimit(
ctx context.Context,
t *testing.T,
db database.Store,
limitMicros int64,
) time.Time {
t.Helper()
_, err := db.UpsertChatUsageLimitConfig(
dbauthz.AsSystemRestricted(ctx),
database.UpsertChatUsageLimitConfigParams{
Enabled: true,
DefaultLimitMicros: limitMicros,
Period: string(codersdk.ChatUsageLimitPeriodDay),
},
)
require.NoError(t, err)
_, periodEnd := chatd.ComputeUsagePeriodBounds(time.Now(), codersdk.ChatUsageLimitPeriodDay)
return periodEnd
}
func insertAssistantCostMessage(
ctx context.Context,
t *testing.T,
db database.Store,
chatID uuid.UUID,
modelConfigID uuid.UUID,
totalCostMicros int64,
) {
t.Helper()
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("assistant"),
})
require.NoError(t, err)
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
ChatID: chatID,
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
Role: database.ChatMessageRoleAssistant,
ContentVersion: chatprompt.CurrentContentVersion,
Content: assistantContent,
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{},
OutputTokens: sql.NullInt64{},
TotalTokens: sql.NullInt64{},
ReasoningTokens: sql.NullInt64{},
CacheCreationTokens: sql.NullInt64{},
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
Compressed: sql.NullBool{},
TotalCostMicros: sql.NullInt64{Int64: totalCostMicros, Valid: true},
})
require.NoError(t, err)
}
func TestPostChats(t *testing.T) {
t.Parallel()
@@ -417,33 +325,6 @@ func TestPostChats(t *testing.T) {
require.Equal(t, "Invalid input part.", sdkErr.Message)
require.Equal(t, `content[0].type "image" is not supported.`, sdkErr.Detail)
})
t.Run("UsageLimitExceeded", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client)
modelConfig := createChatModelConfig(t, client)
wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100)
existingChat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "existing-limit-chat",
})
require.NoError(t, err)
insertAssistantCostMessage(ctx, t, db, existingChat.ID, modelConfig.ID, 100)
_, err = client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "over limit",
}},
})
requireChatUsageLimitExceededError(t, err, 100, 100, wantResetsAt)
})
}
func TestListChats(t *testing.T) {
@@ -1688,7 +1569,7 @@ func TestArchiveChat(t *testing.T) {
require.NoError(t, err)
require.Len(t, chatsBeforeArchive, 2)
err = client.UpdateChat(ctx, chatToArchive.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
err = client.ArchiveChat(ctx, chatToArchive.ID)
require.NoError(t, err)
// Default (no filter) returns only non-archived chats.
@@ -1722,7 +1603,7 @@ func TestArchiveChat(t *testing.T) {
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client)
err := client.UpdateChat(ctx, uuid.New(), codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
err := client.ArchiveChat(ctx, uuid.New())
requireSDKError(t, err, http.StatusNotFound)
})
@@ -1765,7 +1646,7 @@ func TestArchiveChat(t *testing.T) {
require.NoError(t, err)
// Archive the parent via the API.
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
err = client.ArchiveChat(ctx, parentChat.ID)
require.NoError(t, err)
// archived:false should exclude the entire archived family.
@@ -1812,7 +1693,7 @@ func TestUnarchiveChat(t *testing.T) {
require.NoError(t, err)
// Archive the chat first.
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
err = client.ArchiveChat(ctx, chat.ID)
require.NoError(t, err)
// Verify it's archived.
@@ -1823,7 +1704,7 @@ func TestUnarchiveChat(t *testing.T) {
require.Len(t, archivedChats, 1)
require.True(t, archivedChats[0].Archived)
// Unarchive the chat.
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
err = client.UnarchiveChat(ctx, chat.ID)
require.NoError(t, err)
// Verify it's no longer archived.
@@ -1862,9 +1743,10 @@ func TestUnarchiveChat(t *testing.T) {
require.NoError(t, err)
// Trying to unarchive a non-archived chat should fail.
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
err = client.UnarchiveChat(ctx, chat.ID)
requireSDKError(t, err, http.StatusBadRequest)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
@@ -1872,7 +1754,7 @@ func TestUnarchiveChat(t *testing.T) {
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client)
err := client.UpdateChat(ctx, uuid.New(), codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
err := client.UnarchiveChat(ctx, uuid.New())
requireSDKError(t, err, http.StatusNotFound)
})
}
@@ -2001,34 +1883,6 @@ func TestPostChatMessages(t *testing.T) {
require.Equal(t, "content[0].text cannot be empty.", sdkErr.Detail)
})
t.Run("UsageLimitExceeded", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
_ = coderdtest.CreateFirstUser(t, client)
modelConfig := createChatModelConfig(t, client)
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "initial message for usage-limit test",
}},
})
require.NoError(t, err)
wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100)
insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100)
_, err = client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "over limit",
}},
})
requireChatUsageLimitExceededError(t, err, 100, 100, wantResetsAt)
})
t.Run("ChatNotFound", func(t *testing.T) {
t.Parallel()
@@ -2789,46 +2643,6 @@ func TestPatchChatMessage(t *testing.T) {
require.True(t, foundFileInChat, "chat should preserve file_id after edit")
})
t.Run("UsageLimitExceeded", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
_ = coderdtest.CreateFirstUser(t, client)
modelConfig := createChatModelConfig(t, client)
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "hello before edit",
}},
})
require.NoError(t, err)
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
var userMessageID int64
for _, message := range messagesResult.Messages {
if message.Role == codersdk.ChatMessageRoleUser {
userMessageID = message.ID
break
}
}
require.NotZero(t, userMessageID)
wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100)
insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100)
_, err = client.EditChatMessage(ctx, chat.ID, userMessageID, codersdk.EditChatMessageRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "edited over limit",
}},
})
requireChatUsageLimitExceededError(t, err, 100, 100, wantResetsAt)
})
t.Run("MessageNotFound", func(t *testing.T) {
t.Parallel()
@@ -3538,81 +3352,6 @@ func TestPromoteChatQueuedMessage(t *testing.T) {
}
})
t.Run("PromotesAlreadyQueuedMessageAfterLimitReached", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client)
modelConfig := createChatModelConfig(t, client)
enableDailyChatUsageLimit(ctx, t, db, 100)
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "promote queued usage limit",
})
require.NoError(t, err)
const queuedText = "queued message for promote route"
queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{
codersdk.ChatMessageText(queuedText),
})
require.NoError(t, err)
queuedMessage, err := db.InsertChatQueuedMessage(
dbauthz.AsSystemRestricted(ctx),
database.InsertChatQueuedMessageParams{
ChatID: chat.ID,
Content: queuedContent,
},
)
require.NoError(t, err)
insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100)
_, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusWaiting,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
})
require.NoError(t, err)
promoteRes, err := client.Request(
ctx,
http.MethodPost,
fmt.Sprintf("/api/experimental/chats/%s/queue/%d/promote", chat.ID, queuedMessage.ID),
nil,
)
require.NoError(t, err)
defer promoteRes.Body.Close()
require.Equal(t, http.StatusOK, promoteRes.StatusCode)
var promoted codersdk.ChatMessage
err = json.NewDecoder(promoteRes.Body).Decode(&promoted)
require.NoError(t, err)
require.NotZero(t, promoted.ID)
require.Equal(t, chat.ID, promoted.ChatID)
require.Equal(t, codersdk.ChatMessageRoleUser, promoted.Role)
foundPromotedText := false
for _, part := range promoted.Content {
if part.Type == codersdk.ChatMessagePartTypeText && part.Text == queuedText {
foundPromotedText = true
break
}
}
require.True(t, foundPromotedText)
queuedMessages, err := db.GetChatQueuedMessages(dbauthz.AsSystemRestricted(ctx), chat.ID)
require.NoError(t, err)
for _, queued := range queuedMessages {
require.NotEqual(t, queuedMessage.ID, queued.ID)
}
})
t.Run("InvalidQueuedMessageID", func(t *testing.T) {
t.Parallel()
@@ -3644,133 +3383,6 @@ func TestPromoteChatQueuedMessage(t *testing.T) {
})
}
func TestChatUsageLimitOverrideRoutes(t *testing.T) {
t.Parallel()
t.Run("UpsertUserOverrideRequiresPositiveSpendLimit", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, _ := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, client)
_, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
res, err := client.Request(
ctx,
http.MethodPut,
fmt.Sprintf("/api/experimental/chats/usage-limits/overrides/%s", member.ID),
map[string]any{},
)
require.NoError(t, err)
defer res.Body.Close()
err = codersdk.ReadBodyAsError(res)
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Invalid chat usage limit override.", sdkErr.Message)
require.Equal(t, "Spend limit must be greater than 0.", sdkErr.Detail)
})
t.Run("UpsertUserOverrideMissingUser", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client)
_, err := client.UpsertChatUsageLimitOverride(ctx, uuid.New(), codersdk.UpsertChatUsageLimitOverrideRequest{
SpendLimitMicros: 7_000_000,
})
sdkErr := requireSDKError(t, err, http.StatusNotFound)
require.Equal(t, "User not found.", sdkErr.Message)
})
t.Run("DeleteUserOverrideMissingUser", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client)
err := client.DeleteChatUsageLimitOverride(ctx, uuid.New())
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "User not found.", sdkErr.Message)
})
t.Run("DeleteUserOverrideMissingOverride", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, client)
_, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
err := client.DeleteChatUsageLimitOverride(ctx, member.ID)
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Chat usage limit override not found.", sdkErr.Message)
})
t.Run("UpsertGroupOverrideIncludesMemberCount", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, client)
_, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
group := dbgen.Group(t, db, database.Group{OrganizationID: firstUser.OrganizationID})
dbgen.GroupMember(t, db, database.GroupMemberTable{GroupID: group.ID, UserID: member.ID})
dbgen.GroupMember(t, db, database.GroupMemberTable{GroupID: group.ID, UserID: database.PrebuildsSystemUserID})
override, err := client.UpsertChatUsageLimitGroupOverride(ctx, group.ID, codersdk.UpsertChatUsageLimitGroupOverrideRequest{
SpendLimitMicros: 7_000_000,
})
require.NoError(t, err)
require.Equal(t, group.ID, override.GroupID)
require.EqualValues(t, 1, override.MemberCount)
require.NotNil(t, override.SpendLimitMicros)
require.EqualValues(t, 7_000_000, *override.SpendLimitMicros)
config, err := client.GetChatUsageLimitConfig(ctx)
require.NoError(t, err)
var listed *codersdk.ChatUsageLimitGroupOverride
for i := range config.GroupOverrides {
if config.GroupOverrides[i].GroupID == group.ID {
listed = &config.GroupOverrides[i]
break
}
}
require.NotNil(t, listed)
require.EqualValues(t, 1, listed.MemberCount)
})
t.Run("UpsertGroupOverrideMissingGroup", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client)
_, err := client.UpsertChatUsageLimitGroupOverride(ctx, uuid.New(), codersdk.UpsertChatUsageLimitGroupOverrideRequest{
SpendLimitMicros: 7_000_000,
})
sdkErr := requireSDKError(t, err, http.StatusNotFound)
require.Equal(t, "Group not found.", sdkErr.Message)
})
t.Run("DeleteGroupOverrideMissingOverride", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, client)
group := dbgen.Group(t, db, database.Group{OrganizationID: firstUser.OrganizationID})
err := client.DeleteChatUsageLimitGroupOverride(ctx, group.ID)
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Chat usage limit group override not found.", sdkErr.Message)
})
}
func TestPostChatFile(t *testing.T) {
t.Parallel()
@@ -4512,7 +4124,7 @@ func TestChatSystemPrompt(t *testing.T) {
t.Run("AdminCanSet", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
SystemPrompt: "You are a helpful coding assistant.",
})
require.NoError(t, err)
@@ -4526,7 +4138,7 @@ func TestChatSystemPrompt(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
// Unset by sending an empty string.
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
SystemPrompt: "",
})
require.NoError(t, err)
@@ -4539,7 +4151,7 @@ func TestChatSystemPrompt(t *testing.T) {
t.Run("NonAdminFails", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
err := memberClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
err := memberClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
SystemPrompt: "This should fail.",
})
requireSDKError(t, err, http.StatusNotFound)
@@ -4560,7 +4172,7 @@ func TestChatSystemPrompt(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
tooLong := strings.Repeat("a", 131073)
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
SystemPrompt: tooLong,
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
@@ -4568,108 +4180,6 @@ func TestChatSystemPrompt(t *testing.T) {
})
}
func TestChatDesktopEnabled(t *testing.T) {
t.Parallel()
t.Run("ReturnsFalseWhenUnset", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
coderdtest.CreateFirstUser(t, adminClient)
resp, err := adminClient.GetChatDesktopEnabled(ctx)
require.NoError(t, err)
require.False(t, resp.EnableDesktop)
})
t.Run("AdminCanSetTrue", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
coderdtest.CreateFirstUser(t, adminClient)
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
EnableDesktop: true,
})
require.NoError(t, err)
resp, err := adminClient.GetChatDesktopEnabled(ctx)
require.NoError(t, err)
require.True(t, resp.EnableDesktop)
})
t.Run("AdminCanSetFalse", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
coderdtest.CreateFirstUser(t, adminClient)
// Set true first, then set false.
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
EnableDesktop: true,
})
require.NoError(t, err)
err = adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
EnableDesktop: false,
})
require.NoError(t, err)
resp, err := adminClient.GetChatDesktopEnabled(ctx)
require.NoError(t, err)
require.False(t, resp.EnableDesktop)
})
t.Run("NonAdminCanRead", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient)
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
EnableDesktop: true,
})
require.NoError(t, err)
resp, err := memberClient.GetChatDesktopEnabled(ctx)
require.NoError(t, err)
require.True(t, resp.EnableDesktop)
})
t.Run("NonAdminWriteFails", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient)
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
err := memberClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
EnableDesktop: true,
})
requireSDKError(t, err, http.StatusForbidden)
})
t.Run("UnauthenticatedFails", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
coderdtest.CreateFirstUser(t, adminClient)
anonClient := codersdk.New(adminClient.URL)
_, err := anonClient.GetChatDesktopEnabled(ctx)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
})
}
func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error {
t.Helper()
+2 -16
View File
@@ -1157,8 +1157,6 @@ func New(options *Options) *API {
r.Route("/config", func(r chi.Router) {
r.Get("/system-prompt", api.getChatSystemPrompt)
r.Put("/system-prompt", api.putChatSystemPrompt)
r.Get("/desktop-enabled", api.getChatDesktopEnabled)
r.Put("/desktop-enabled", api.putChatDesktopEnabled)
r.Get("/user-prompt", api.getUserChatCustomPrompt)
r.Put("/user-prompt", api.putUserChatCustomPrompt)
})
@@ -1180,25 +1178,13 @@ func New(options *Options) *API {
r.Delete("/", api.deleteChatModelConfig)
})
})
r.Route("/usage-limits", func(r chi.Router) {
r.Get("/", api.getChatUsageLimitConfig)
r.Put("/", api.updateChatUsageLimitConfig)
r.Get("/status", api.getMyChatUsageLimitStatus)
r.Route("/overrides/{user}", func(r chi.Router) {
r.Put("/", api.upsertChatUsageLimitOverride)
r.Delete("/", api.deleteChatUsageLimitOverride)
})
r.Route("/group-overrides/{group}", func(r chi.Router) {
r.Put("/", api.upsertChatUsageLimitGroupOverride)
r.Delete("/", api.deleteChatUsageLimitGroupOverride)
})
})
r.Route("/{chat}", func(r chi.Router) {
r.Use(httpmw.ExtractChatParam(options.Database))
r.Get("/", api.getChat)
r.Get("/git/watch", api.watchChatGit)
r.Get("/desktop", api.watchChatDesktop)
r.Patch("/", api.patchChat)
r.Post("/archive", api.archiveChat)
r.Post("/unarchive", api.unarchiveChat)
r.Get("/messages", api.getChatMessages)
r.Post("/messages", api.postChatMessages)
r.Patch("/messages/{message}", api.patchChatMessage)
-9
View File
@@ -879,15 +879,6 @@ func createAnotherUserRetry(t testing.TB, client *codersdk.Client, organizationI
m(&req)
}
// Service accounts cannot have a password or email and must
// use login_type=none. Enforce this after mutators so callers
// only need to set ServiceAccount=true.
if req.ServiceAccount {
req.Password = ""
req.Email = ""
req.UserLoginType = codersdk.LoginTypeNone
}
user, err := client.CreateUserWithOrgs(context.Background(), req)
var apiError *codersdk.Error
// If the user already exists by username or email conflict, try again up to "retries" times.
+18 -23
View File
@@ -6,27 +6,22 @@ type CheckConstraint string
// CheckConstraint enums.
const (
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config
CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config
CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
CheckGroupsChatSpendLimitMicrosCheck CheckConstraint = "groups_chat_spend_limit_micros_check" // groups
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
CheckUsersChatSpendLimitMicrosCheck CheckConstraint = "users_chat_spend_limit_micros_check" // users
CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users
CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users
CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
)
+7 -8
View File
@@ -195,14 +195,13 @@ func MinimalUserFromVisibleUser(user database.VisibleUser) codersdk.MinimalUser
func ReducedUser(user database.User) codersdk.ReducedUser {
return codersdk.ReducedUser{
MinimalUser: MinimalUser(user),
Email: user.Email,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
LastSeenAt: user.LastSeenAt,
Status: codersdk.UserStatus(user.Status),
LoginType: codersdk.LoginType(user.LoginType),
IsServiceAccount: user.IsServiceAccount,
MinimalUser: MinimalUser(user),
Email: user.Email,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
LastSeenAt: user.LastSeenAt,
Status: codersdk.UserStatus(user.Status),
LoginType: codersdk.LoginType(user.LoginType),
}
}
+11 -136
View File
@@ -1264,7 +1264,7 @@ func (q *querier) canAssignRoles(ctx context.Context, orgID uuid.UUID, added, re
// System roles are stored in the database but have a fixed, code-defined
// meaning. Do not rewrite the name for them so the static "who can assign
// what" mapping applies.
if !rolestore.IsSystemRoleName(roleName.Name) {
if !rbac.SystemRoleName(roleName.Name) {
// To support a dynamic mapping of what roles can assign what, we need
// to store this in the database. For now, just use a static role so
// owners and org admins can assign roles.
@@ -1726,13 +1726,6 @@ func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountCon
return q.db.CountAuthorizedConnectionLogs(ctx, arg, prep)
}
func (q *querier) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return 0, err
}
return q.db.CountEnabledModelsWithoutPricing(ctx)
}
func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil {
return nil, err
@@ -1861,20 +1854,6 @@ func (q *querier) DeleteChatQueuedMessage(ctx context.Context, arg database.Dele
return q.db.DeleteChatQueuedMessage(ctx, arg)
}
func (q *querier) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteChatUsageLimitGroupOverride(ctx, groupID)
}
func (q *querier) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteChatUsageLimitUserOverride(ctx, userID)
}
func (q *querier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceCryptoKey); err != nil {
return database.CryptoKey{}, err
@@ -2145,12 +2124,12 @@ func (q *querier) DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) erro
return fetchAndExec(q.log, q.auth, policy.ActionShare, fetch, q.db.DeleteWorkspaceACLByID)(ctx, id)
}
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, params database.DeleteWorkspaceACLsByOrganizationParams) error {
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
// This is a system-only function.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.DeleteWorkspaceACLsByOrganization(ctx, params)
return q.db.DeleteWorkspaceACLsByOrganization(ctx, organizationID)
}
func (q *querier) DeleteWorkspaceAgentPortShare(ctx context.Context, arg database.DeleteWorkspaceAgentPortShareParams) error {
@@ -2482,17 +2461,6 @@ func (q *querier) GetChatCostSummary(ctx context.Context, arg database.GetChatCo
return q.db.GetChatCostSummary(ctx, arg)
}
func (q *querier) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
// The desktop-enabled flag is a deployment-wide setting read by any
// authenticated chat user and by chatd when deciding whether to expose
// computer-use tooling. We only require that an explicit actor is present
// in the context so unauthenticated calls fail closed.
if _, ok := ActorFromContext(ctx); !ok {
return false, ErrNoActor
}
return q.db.GetChatDesktopEnabled(ctx)
}
func (q *querier) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, chatID)
@@ -2643,33 +2611,8 @@ func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) {
return q.db.GetChatSystemPrompt(ctx)
}
func (q *querier) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatUsageLimitConfig{}, err
}
return q.db.GetChatUsageLimitConfig(ctx)
}
func (q *querier) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.GetChatUsageLimitGroupOverrideRow{}, err
}
return q.db.GetChatUsageLimitGroupOverride(ctx, groupID)
}
func (q *querier) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.GetChatUsageLimitUserOverrideRow{}, err
}
return q.db.GetChatUsageLimitUserOverride(ctx, userID)
}
func (q *querier) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceChat.Type)
if err != nil {
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
}
return q.db.GetAuthorizedChats(ctx, arg, prep)
func (q *querier) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByOwnerID)(ctx, ownerID)
}
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
@@ -3337,6 +3280,12 @@ func (q *querier) GetProvisionerJobTimingsByJobID(ctx context.Context, jobID uui
return q.db.GetProvisionerJobTimingsByJobID(ctx, jobID)
}
func (q *querier) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) {
// TODO: Remove this once we have a proper rbac check for provisioner jobs.
// Details in https://github.com/coder/coder/issues/16160
return q.db.GetProvisionerJobsByIDs(ctx, ids)
}
func (q *querier) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
// TODO: Remove this once we have a proper rbac check for provisioner jobs.
// Details in https://github.com/coder/coder/issues/16160
@@ -3822,13 +3771,6 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID)
return q.db.GetUserChatCustomPrompt(ctx, userID)
}
func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil {
return 0, err
}
return q.db.GetUserChatSpendInPeriod(ctx, arg)
}
func (q *querier) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return 0, err
@@ -3836,13 +3778,6 @@ func (q *querier) GetUserCount(ctx context.Context, includeSystem bool) (int64,
return q.db.GetUserCount(ctx, includeSystem)
}
func (q *querier) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(userID.String())); err != nil {
return 0, err
}
return q.db.GetUserGroupSpendLimit(ctx, userID)
}
func (q *querier) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
// Used by insights endpoints. Need to check both for auditors and for regular users with template acl perms.
if err := q.authorizeContext(ctx, policy.ActionViewInsights, rbac.ResourceTemplate); err != nil {
@@ -4512,13 +4447,6 @@ func (q *querier) InsertAIBridgeInterception(ctx context.Context, arg database.I
return insert(q.log, q.auth, rbac.ResourceAibridgeInterception.WithOwner(arg.InitiatorID.String()), q.db.InsertAIBridgeInterception)(ctx, arg)
}
func (q *querier) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
return database.AIBridgeModelThought{}, err
}
return q.db.InsertAIBridgeModelThought(ctx, arg)
}
func (q *querier) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
// All aibridge_token_usages records belong to the initiator of their associated interception.
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
@@ -5208,20 +5136,6 @@ func (q *querier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context,
return q.db.ListAIBridgeUserPromptsByInterceptionIDs(ctx, interceptionIDs)
}
func (q *querier) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.ListChatUsageLimitGroupOverrides(ctx)
}
func (q *querier) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.ListChatUsageLimitOverrides(ctx)
}
func (q *querier) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListProvisionerKeysByOrganization)(ctx, organizationID)
}
@@ -5341,13 +5255,6 @@ func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveU
return q.db.RemoveUserFromGroups(ctx, arg)
}
func (q *querier) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(userID.String())); err != nil {
return 0, err
}
return q.db.ResolveUserChatSpendLimit(ctx, userID)
}
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
@@ -6559,13 +6466,6 @@ func (q *querier) UpsertBoundaryUsageStats(ctx context.Context, arg database.Ups
return q.db.UpsertBoundaryUsageStats(ctx, arg)
}
func (q *querier) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.UpsertChatDesktopEnabled(ctx, enableDesktop)
}
func (q *querier) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
// Authorize update on the parent chat.
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
@@ -6597,27 +6497,6 @@ func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) erro
return q.db.UpsertChatSystemPrompt(ctx, value)
}
func (q *querier) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatUsageLimitConfig{}, err
}
return q.db.UpsertChatUsageLimitConfig(ctx, arg)
}
func (q *querier) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.UpsertChatUsageLimitGroupOverrideRow{}, err
}
return q.db.UpsertChatUsageLimitGroupOverride(ctx, arg)
}
func (q *querier) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.UpsertChatUsageLimitUserOverrideRow{}, err
}
return q.db.UpsertChatUsageLimitUserOverride(ctx, arg)
}
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return database.ConnectionLog{}, err
@@ -6900,7 +6779,3 @@ func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database
// database.Store interface, so dbauthz needs to implement it.
return q.ListAIBridgeModels(ctx, arg)
}
func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.Chat, error) {
return q.GetChats(ctx, arg)
}
+10 -177
View File
@@ -513,10 +513,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatCostSummary(gomock.Any(), arg).Return(row, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(row)
}))
s.Run("CountEnabledModelsWithoutPricing", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().CountEnabledModelsWithoutPricing(gomock.Any()).Return(int64(3), nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(int64(3))
}))
s.Run("GetChatDiffStatusByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
@@ -618,17 +614,12 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
}))
s.Run("GetChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
params := database.GetChatsParams{}
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
// No asserts here because SQLFilter.
check.Args(params).Asserts()
}))
s.Run("GetAuthorizedChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
params := database.GetChatsParams{}
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
// No asserts here because it re-routes through GetChats which uses SQLFilter.
check.Args(params, emptyPreparedAuthorized{}).Asserts()
s.Run("GetChatsByOwnerID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
c1 := testutil.Fake(s.T(), faker, database.Chat{})
c2 := testutil.Fake(s.T(), faker, database.Chat{})
params := database.GetChatsByOwnerIDParams{OwnerID: c1.OwnerID}
dbm.EXPECT().GetChatsByOwnerID(gomock.Any(), params).Return([]database.Chat{c1, c2}, nil).AnyTimes()
check.Args(params).Asserts(c1, policy.ActionRead, c2, policy.ActionRead).Returns([]database.Chat{c1, c2})
}))
s.Run("GetChatQueuedMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
@@ -641,10 +632,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatSystemPrompt(gomock.Any()).Return("prompt", nil).AnyTimes()
check.Args().Asserts()
}))
s.Run("GetChatDesktopEnabled", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetChatDesktopEnabled(gomock.Any()).Return(false, nil).AnyTimes()
check.Args().Asserts()
}))
s.Run("GetEnabledChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
@@ -854,146 +841,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UpsertChatSystemPrompt(gomock.Any(), "").Return(nil).AnyTimes()
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("UpsertChatDesktopEnabled", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertChatDesktopEnabled(gomock.Any(), false).Return(nil).AnyTimes()
check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("GetUserChatSpendInPeriod", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetUserChatSpendInPeriodParams{
UserID: uuid.New(),
StartTime: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
EndTime: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
}
spend := int64(123)
dbm.EXPECT().GetUserChatSpendInPeriod(gomock.Any(), arg).Return(spend, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.UserID.String()), policy.ActionRead).Returns(spend)
}))
s.Run("GetUserGroupSpendLimit", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
userID := uuid.New()
limit := int64(456)
dbm.EXPECT().GetUserGroupSpendLimit(gomock.Any(), userID).Return(limit, nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceChat.WithOwner(userID.String()), policy.ActionRead).Returns(limit)
}))
s.Run("ResolveUserChatSpendLimit", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
userID := uuid.New()
limit := int64(789)
dbm.EXPECT().ResolveUserChatSpendLimit(gomock.Any(), userID).Return(limit, nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceChat.WithOwner(userID.String()), policy.ActionRead).Returns(limit)
}))
s.Run("GetChatUsageLimitConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
now := dbtime.Now()
config := database.ChatUsageLimitConfig{
ID: 1,
Singleton: true,
Enabled: true,
DefaultLimitMicros: 1_000_000,
Period: "monthly",
CreatedAt: now,
UpdatedAt: now,
}
dbm.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(config, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
groupID := uuid.New()
override := database.GetChatUsageLimitGroupOverrideRow{
GroupID: groupID,
SpendLimitMicros: sql.NullInt64{Int64: 2_000_000, Valid: true},
}
dbm.EXPECT().GetChatUsageLimitGroupOverride(gomock.Any(), groupID).Return(override, nil).AnyTimes()
check.Args(groupID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(override)
}))
s.Run("GetChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
userID := uuid.New()
override := database.GetChatUsageLimitUserOverrideRow{
UserID: userID,
SpendLimitMicros: sql.NullInt64{Int64: 3_000_000, Valid: true},
}
dbm.EXPECT().GetChatUsageLimitUserOverride(gomock.Any(), userID).Return(override, nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(override)
}))
s.Run("ListChatUsageLimitGroupOverrides", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
overrides := []database.ListChatUsageLimitGroupOverridesRow{{
GroupID: uuid.New(),
GroupName: "group-name",
GroupDisplayName: "Group Name",
GroupAvatarUrl: "https://example.com/group.png",
SpendLimitMicros: sql.NullInt64{Int64: 4_000_000, Valid: true},
MemberCount: 5,
}}
dbm.EXPECT().ListChatUsageLimitGroupOverrides(gomock.Any()).Return(overrides, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(overrides)
}))
s.Run("ListChatUsageLimitOverrides", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
overrides := []database.ListChatUsageLimitOverridesRow{{
UserID: uuid.New(),
Username: "usage-limit-user",
Name: "Usage Limit User",
AvatarURL: "https://example.com/avatar.png",
SpendLimitMicros: sql.NullInt64{Int64: 5_000_000, Valid: true},
}}
dbm.EXPECT().ListChatUsageLimitOverrides(gomock.Any()).Return(overrides, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(overrides)
}))
s.Run("UpsertChatUsageLimitConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
now := dbtime.Now()
arg := database.UpsertChatUsageLimitConfigParams{
Enabled: true,
DefaultLimitMicros: 6_000_000,
Period: "monthly",
}
config := database.ChatUsageLimitConfig{
ID: 1,
Singleton: true,
Enabled: arg.Enabled,
DefaultLimitMicros: arg.DefaultLimitMicros,
Period: arg.Period,
CreatedAt: now,
UpdatedAt: now,
}
dbm.EXPECT().UpsertChatUsageLimitConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
}))
s.Run("UpsertChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.UpsertChatUsageLimitGroupOverrideParams{
SpendLimitMicros: 7_000_000,
GroupID: uuid.New(),
}
override := database.UpsertChatUsageLimitGroupOverrideRow{
GroupID: arg.GroupID,
Name: "group",
DisplayName: "Group",
AvatarURL: "",
SpendLimitMicros: sql.NullInt64{Int64: arg.SpendLimitMicros, Valid: true},
}
dbm.EXPECT().UpsertChatUsageLimitGroupOverride(gomock.Any(), arg).Return(override, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(override)
}))
s.Run("UpsertChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.UpsertChatUsageLimitUserOverrideParams{
SpendLimitMicros: 8_000_000,
UserID: uuid.New(),
}
override := database.UpsertChatUsageLimitUserOverrideRow{
UserID: arg.UserID,
Username: "user",
Name: "User",
AvatarURL: "",
SpendLimitMicros: sql.NullInt64{Int64: arg.SpendLimitMicros, Valid: true},
}
dbm.EXPECT().UpsertChatUsageLimitUserOverride(gomock.Any(), arg).Return(override, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(override)
}))
s.Run("DeleteChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
groupID := uuid.New()
dbm.EXPECT().DeleteChatUsageLimitGroupOverride(gomock.Any(), groupID).Return(nil).AnyTimes()
check.Args(groupID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("DeleteChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
userID := uuid.New()
dbm.EXPECT().DeleteChatUsageLimitUserOverride(gomock.Any(), userID).Return(nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
}
func (s *MethodTestSuite) TestFile() {
@@ -1493,7 +1340,7 @@ func (s *MethodTestSuite) TestOrganization() {
org := testutil.Fake(s.T(), faker, database.Organization{})
arg := database.UpdateOrganizationWorkspaceSharingSettingsParams{
ID: org.ID,
ShareableWorkspaceOwners: database.ShareableWorkspaceOwnersNone,
WorkspaceSharingDisabled: true,
}
dbm.EXPECT().GetOrganizationByID(gomock.Any(), org.ID).Return(org, nil).AnyTimes()
dbm.EXPECT().UpdateOrganizationWorkspaceSharingSettings(gomock.Any(), arg).Return(org, nil).AnyTimes()
@@ -2412,12 +2259,9 @@ func (s *MethodTestSuite) TestWorkspace() {
check.Args(w.ID).Asserts(w, policy.ActionShare)
}))
s.Run("DeleteWorkspaceACLsByOrganization", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.DeleteWorkspaceACLsByOrganizationParams{
OrganizationID: uuid.New(),
ExcludeServiceAccounts: false,
}
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
orgID := uuid.New()
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), orgID).Return(nil).AnyTimes()
check.Args(orgID).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
}))
s.Run("GetLatestWorkspaceBuildByWorkspaceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
w := testutil.Fake(s.T(), faker, database.Workspace{})
@@ -5183,17 +5027,6 @@ func (s *MethodTestSuite) TestAIBridge() {
check.Args(params).Asserts(intc, policy.ActionCreate)
}))
s.Run("InsertAIBridgeModelThought", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intID := uuid.UUID{2}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intID).Return(intc, nil).AnyTimes() // Validation.
params := database.InsertAIBridgeModelThoughtParams{InterceptionID: intc.ID}
expected := testutil.Fake(s.T(), faker, database.AIBridgeModelThought{InterceptionID: intc.ID})
db.EXPECT().InsertAIBridgeModelThought(gomock.Any(), params).Return(expected, nil).AnyTimes()
check.Args(params).Asserts(intc, policy.ActionUpdate)
}))
s.Run("InsertAIBridgeTokenUsage", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intID := uuid.UUID{2}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
+1 -2
View File
@@ -29,7 +29,6 @@ import (
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/rbac/regosql"
"github.com/coder/coder/v2/coderd/rbac/rolestore"
"github.com/coder/coder/v2/coderd/util/slice"
)
@@ -144,7 +143,7 @@ func (s *MethodTestSuite) Mocked(testCaseF func(dmb *dbmock.MockStore, faker *go
UUID: pair.OrganizationID,
Valid: pair.OrganizationID != uuid.Nil,
},
IsSystem: rolestore.IsSystemRoleName(pair.Name),
IsSystem: rbac.SystemRoleName(pair.Name),
ID: uuid.New(),
})
}
+25 -17
View File
@@ -650,26 +650,34 @@ func Organization(t testing.TB, db database.Store, orig database.Organization) d
})
require.NoError(t, err, "insert organization")
// Populate the placeholder system roles (created by DB
// trigger/migration) so org members have expected permissions.
//nolint:gocritic // ReconcileSystemRole needs the system:update
// Populate the placeholder organization-member system role (created by
// DB trigger/migration) so org members have expected permissions.
//nolint:gocritic // ReconcileOrgMemberRole needs the system:update
// permission that `genCtx` does not have.
sysCtx := dbauthz.AsSystemRestricted(genCtx)
for roleName := range rolestore.SystemRoleNames {
role := database.CustomRole{
Name: roleName,
OrganizationID: uuid.NullUUID{UUID: org.ID, Valid: true},
}
_, _, err = rolestore.ReconcileSystemRole(sysCtx, db, role, org)
if errors.Is(err, sql.ErrNoRows) {
// The trigger that creates the placeholder role didn't run (e.g.,
// triggers were disabled in the test). Create the role manually.
err = rolestore.CreateSystemRole(sysCtx, db, org, roleName)
require.NoError(t, err, "create role "+roleName)
_, _, err = rolestore.ReconcileSystemRole(sysCtx, db, role, org)
}
require.NoError(t, err, "reconcile role "+roleName)
_, _, err = rolestore.ReconcileOrgMemberRole(sysCtx, db, database.CustomRole{
Name: rbac.RoleOrgMember(),
OrganizationID: uuid.NullUUID{
UUID: org.ID,
Valid: true,
},
}, org.WorkspaceSharingDisabled)
if errors.Is(err, sql.ErrNoRows) {
// The trigger that creates the placeholder role didn't run (e.g.,
// triggers were disabled in the test). Create the role manually.
err = rolestore.CreateOrgMemberRole(sysCtx, db, org)
require.NoError(t, err, "create organization-member role")
_, _, err = rolestore.ReconcileOrgMemberRole(sysCtx, db, database.CustomRole{
Name: rbac.RoleOrgMember(),
OrganizationID: uuid.NullUUID{
UUID: org.ID,
Valid: true,
},
}, org.WorkspaceSharingDisabled)
}
require.NoError(t, err, "reconcile organization-member role")
return org
}
+14 -152
View File
@@ -288,14 +288,6 @@ func (m queryMetricsStore) CountConnectionLogs(ctx context.Context, arg database
return r0, r1
}
func (m queryMetricsStore) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
start := time.Now()
r0, r1 := m.s.CountEnabledModelsWithoutPricing(ctx)
m.queryLatencies.WithLabelValues("CountEnabledModelsWithoutPricing").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountEnabledModelsWithoutPricing").Inc()
return r0, r1
}
func (m queryMetricsStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
start := time.Now()
r0, r1 := m.s.CountInProgressPrebuilds(ctx)
@@ -416,22 +408,6 @@ func (m queryMetricsStore) DeleteChatQueuedMessage(ctx context.Context, arg data
return r0
}
func (m queryMetricsStore) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatUsageLimitGroupOverride(ctx, groupID)
m.queryLatencies.WithLabelValues("DeleteChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatUsageLimitGroupOverride").Inc()
return r0
}
func (m queryMetricsStore) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatUsageLimitUserOverride(ctx, userID)
m.queryLatencies.WithLabelValues("DeleteChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatUsageLimitUserOverride").Inc()
return r0
}
func (m queryMetricsStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
start := time.Now()
r0, r1 := m.s.DeleteCryptoKey(ctx, arg)
@@ -696,11 +672,10 @@ func (m queryMetricsStore) DeleteWorkspaceACLByID(ctx context.Context, id uuid.U
return r0
}
func (m queryMetricsStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, arg database.DeleteWorkspaceACLsByOrganizationParams) error {
func (m queryMetricsStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteWorkspaceACLsByOrganization(ctx, arg)
r0 := m.s.DeleteWorkspaceACLsByOrganization(ctx, organizationID)
m.queryLatencies.WithLabelValues("DeleteWorkspaceACLsByOrganization").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteWorkspaceACLsByOrganization").Inc()
return r0
}
@@ -1048,14 +1023,6 @@ func (m queryMetricsStore) GetChatCostSummary(ctx context.Context, arg database.
return r0, r1
}
func (m queryMetricsStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
start := time.Now()
r0, r1 := m.s.GetChatDesktopEnabled(ctx)
m.queryLatencies.WithLabelValues("GetChatDesktopEnabled").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDesktopEnabled").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
start := time.Now()
r0, r1 := m.s.GetChatDiffStatusByChatID(ctx, chatID)
@@ -1176,35 +1143,11 @@ func (m queryMetricsStore) GetChatSystemPrompt(ctx context.Context) (string, err
return r0, r1
}
func (m queryMetricsStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
func (m queryMetricsStore) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetChatUsageLimitConfig(ctx)
m.queryLatencies.WithLabelValues("GetChatUsageLimitConfig").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitConfig").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatUsageLimitGroupOverride(ctx, groupID)
m.queryLatencies.WithLabelValues("GetChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitGroupOverride").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatUsageLimitUserOverride(ctx, userID)
m.queryLatencies.WithLabelValues("GetChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitUserOverride").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetChats(ctx, arg)
m.queryLatencies.WithLabelValues("GetChats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChats").Inc()
r0, r1 := m.s.GetChatsByOwnerID(ctx, ownerID)
m.queryLatencies.WithLabelValues("GetChatsByOwnerID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsByOwnerID").Inc()
return r0, r1
}
@@ -1888,6 +1831,14 @@ func (m queryMetricsStore) GetProvisionerJobTimingsByJobID(ctx context.Context,
return r0, r1
}
func (m queryMetricsStore) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) {
start := time.Now()
r0, r1 := m.s.GetProvisionerJobsByIDs(ctx, ids)
m.queryLatencies.WithLabelValues("GetProvisionerJobsByIDs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetProvisionerJobsByIDs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
start := time.Now()
r0, r1 := m.s.GetProvisionerJobsByIDsWithQueuePosition(ctx, arg)
@@ -2328,14 +2279,6 @@ func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID u
return r0, r1
}
func (m queryMetricsStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.GetUserChatSpendInPeriod(ctx, arg)
m.queryLatencies.WithLabelValues("GetUserChatSpendInPeriod").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatSpendInPeriod").Inc()
return r0, r1
}
func (m queryMetricsStore) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
start := time.Now()
r0, r1 := m.s.GetUserCount(ctx, includeSystem)
@@ -2344,14 +2287,6 @@ func (m queryMetricsStore) GetUserCount(ctx context.Context, includeSystem bool)
return r0, r1
}
func (m queryMetricsStore) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
start := time.Now()
r0, r1 := m.s.GetUserGroupSpendLimit(ctx, userID)
m.queryLatencies.WithLabelValues("GetUserGroupSpendLimit").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserGroupSpendLimit").Inc()
return r0, r1
}
func (m queryMetricsStore) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
start := time.Now()
r0, r1 := m.s.GetUserLatencyInsights(ctx, arg)
@@ -2960,14 +2895,6 @@ func (m queryMetricsStore) InsertAIBridgeInterception(ctx context.Context, arg d
return r0, r1
}
func (m queryMetricsStore) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
start := time.Now()
r0, r1 := m.s.InsertAIBridgeModelThought(ctx, arg)
m.queryLatencies.WithLabelValues("InsertAIBridgeModelThought").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertAIBridgeModelThought").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
start := time.Now()
r0, r1 := m.s.InsertAIBridgeTokenUsage(ctx, arg)
@@ -3592,22 +3519,6 @@ func (m queryMetricsStore) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.
return r0, r1
}
func (m queryMetricsStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
start := time.Now()
r0, r1 := m.s.ListChatUsageLimitGroupOverrides(ctx)
m.queryLatencies.WithLabelValues("ListChatUsageLimitGroupOverrides").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatUsageLimitGroupOverrides").Inc()
return r0, r1
}
func (m queryMetricsStore) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
start := time.Now()
r0, r1 := m.s.ListChatUsageLimitOverrides(ctx)
m.queryLatencies.WithLabelValues("ListChatUsageLimitOverrides").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatUsageLimitOverrides").Inc()
return r0, r1
}
func (m queryMetricsStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
start := time.Now()
r0, r1 := m.s.ListProvisionerKeysByOrganization(ctx, organizationID)
@@ -3720,14 +3631,6 @@ func (m queryMetricsStore) RemoveUserFromGroups(ctx context.Context, arg databas
return r0, r1
}
func (m queryMetricsStore) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
start := time.Now()
r0, r1 := m.s.ResolveUserChatSpendLimit(ctx, userID)
m.queryLatencies.WithLabelValues("ResolveUserChatSpendLimit").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ResolveUserChatSpendLimit").Inc()
return r0, r1
}
func (m queryMetricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
start := time.Now()
r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest)
@@ -3980,7 +3883,6 @@ func (m queryMetricsStore) UpdateOrganizationWorkspaceSharingSettings(ctx contex
start := time.Now()
r0, r1 := m.s.UpdateOrganizationWorkspaceSharingSettings(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateOrganizationWorkspaceSharingSettings").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateOrganizationWorkspaceSharingSettings").Inc()
return r0, r1
}
@@ -4560,14 +4462,6 @@ func (m queryMetricsStore) UpsertBoundaryUsageStats(ctx context.Context, arg dat
return r0, r1
}
func (m queryMetricsStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
start := time.Now()
r0 := m.s.UpsertChatDesktopEnabled(ctx, enableDesktop)
m.queryLatencies.WithLabelValues("UpsertChatDesktopEnabled").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDesktopEnabled").Inc()
return r0
}
func (m queryMetricsStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
start := time.Now()
r0, r1 := m.s.UpsertChatDiffStatus(ctx, arg)
@@ -4592,30 +4486,6 @@ func (m queryMetricsStore) UpsertChatSystemPrompt(ctx context.Context, value str
return r0
}
func (m queryMetricsStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
start := time.Now()
r0, r1 := m.s.UpsertChatUsageLimitConfig(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitConfig").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitConfig").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
start := time.Now()
r0, r1 := m.s.UpsertChatUsageLimitGroupOverride(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitGroupOverride").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
start := time.Now()
r0, r1 := m.s.UpsertChatUsageLimitUserOverride(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitUserOverride").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
start := time.Now()
r0, r1 := m.s.UpsertConnectionLog(ctx, arg)
@@ -4911,11 +4781,3 @@ func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeModels").Inc()
return r0, r1
}
func (m queryMetricsStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetAuthorizedChats(ctx, arg, prepared)
m.queryLatencies.WithLabelValues("GetAuthorizedChats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAuthorizedChats").Inc()
return r0, r1
}
+25 -277
View File
@@ -424,21 +424,6 @@ func (mr *MockStoreMockRecorder) CountConnectionLogs(ctx, arg any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountConnectionLogs", reflect.TypeOf((*MockStore)(nil).CountConnectionLogs), ctx, arg)
}
// CountEnabledModelsWithoutPricing mocks base method.
func (m *MockStore) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountEnabledModelsWithoutPricing", ctx)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountEnabledModelsWithoutPricing indicates an expected call of CountEnabledModelsWithoutPricing.
func (mr *MockStoreMockRecorder) CountEnabledModelsWithoutPricing(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEnabledModelsWithoutPricing", reflect.TypeOf((*MockStore)(nil).CountEnabledModelsWithoutPricing), ctx)
}
// CountInProgressPrebuilds mocks base method.
func (m *MockStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
m.ctrl.T.Helper()
@@ -654,34 +639,6 @@ func (mr *MockStoreMockRecorder) DeleteChatQueuedMessage(ctx, arg any) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatQueuedMessage", reflect.TypeOf((*MockStore)(nil).DeleteChatQueuedMessage), ctx, arg)
}
// DeleteChatUsageLimitGroupOverride mocks base method.
func (m *MockStore) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatUsageLimitGroupOverride", ctx, groupID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteChatUsageLimitGroupOverride indicates an expected call of DeleteChatUsageLimitGroupOverride.
func (mr *MockStoreMockRecorder) DeleteChatUsageLimitGroupOverride(ctx, groupID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).DeleteChatUsageLimitGroupOverride), ctx, groupID)
}
// DeleteChatUsageLimitUserOverride mocks base method.
func (m *MockStore) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatUsageLimitUserOverride", ctx, userID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteChatUsageLimitUserOverride indicates an expected call of DeleteChatUsageLimitUserOverride.
func (mr *MockStoreMockRecorder) DeleteChatUsageLimitUserOverride(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).DeleteChatUsageLimitUserOverride), ctx, userID)
}
// DeleteCryptoKey mocks base method.
func (m *MockStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
m.ctrl.T.Helper()
@@ -1155,17 +1112,17 @@ func (mr *MockStoreMockRecorder) DeleteWorkspaceACLByID(ctx, id any) *gomock.Cal
}
// DeleteWorkspaceACLsByOrganization mocks base method.
func (m *MockStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, arg database.DeleteWorkspaceACLsByOrganizationParams) error {
func (m *MockStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteWorkspaceACLsByOrganization", ctx, arg)
ret := m.ctrl.Call(m, "DeleteWorkspaceACLsByOrganization", ctx, organizationID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteWorkspaceACLsByOrganization indicates an expected call of DeleteWorkspaceACLsByOrganization.
func (mr *MockStoreMockRecorder) DeleteWorkspaceACLsByOrganization(ctx, arg any) *gomock.Call {
func (mr *MockStoreMockRecorder) DeleteWorkspaceACLsByOrganization(ctx, organizationID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLsByOrganization", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLsByOrganization), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLsByOrganization", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLsByOrganization), ctx, organizationID)
}
// DeleteWorkspaceAgentPortShare mocks base method.
@@ -1731,21 +1688,6 @@ func (mr *MockStoreMockRecorder) GetAuthorizedAuditLogsOffset(ctx, arg, prepared
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedAuditLogsOffset", reflect.TypeOf((*MockStore)(nil).GetAuthorizedAuditLogsOffset), ctx, arg, prepared)
}
// GetAuthorizedChats mocks base method.
func (m *MockStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthorizedChats", ctx, arg, prepared)
ret0, _ := ret[0].([]database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAuthorizedChats indicates an expected call of GetAuthorizedChats.
func (mr *MockStoreMockRecorder) GetAuthorizedChats(ctx, arg, prepared any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedChats", reflect.TypeOf((*MockStore)(nil).GetAuthorizedChats), ctx, arg, prepared)
}
// GetAuthorizedConnectionLogsOffset mocks base method.
func (m *MockStore) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) {
m.ctrl.T.Helper()
@@ -1911,21 +1853,6 @@ func (mr *MockStoreMockRecorder) GetChatCostSummary(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostSummary", reflect.TypeOf((*MockStore)(nil).GetChatCostSummary), ctx, arg)
}
// GetChatDesktopEnabled mocks base method.
func (m *MockStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatDesktopEnabled", ctx)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatDesktopEnabled indicates an expected call of GetChatDesktopEnabled.
func (mr *MockStoreMockRecorder) GetChatDesktopEnabled(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDesktopEnabled", reflect.TypeOf((*MockStore)(nil).GetChatDesktopEnabled), ctx)
}
// GetChatDiffStatusByChatID mocks base method.
func (m *MockStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
m.ctrl.T.Helper()
@@ -2151,64 +2078,19 @@ func (mr *MockStoreMockRecorder) GetChatSystemPrompt(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).GetChatSystemPrompt), ctx)
}
// GetChatUsageLimitConfig mocks base method.
func (m *MockStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
// GetChatsByOwnerID mocks base method.
func (m *MockStore) GetChatsByOwnerID(ctx context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatUsageLimitConfig", ctx)
ret0, _ := ret[0].(database.ChatUsageLimitConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatUsageLimitConfig indicates an expected call of GetChatUsageLimitConfig.
func (mr *MockStoreMockRecorder) GetChatUsageLimitConfig(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitConfig", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitConfig), ctx)
}
// GetChatUsageLimitGroupOverride mocks base method.
func (m *MockStore) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatUsageLimitGroupOverride", ctx, groupID)
ret0, _ := ret[0].(database.GetChatUsageLimitGroupOverrideRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatUsageLimitGroupOverride indicates an expected call of GetChatUsageLimitGroupOverride.
func (mr *MockStoreMockRecorder) GetChatUsageLimitGroupOverride(ctx, groupID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitGroupOverride), ctx, groupID)
}
// GetChatUsageLimitUserOverride mocks base method.
func (m *MockStore) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatUsageLimitUserOverride", ctx, userID)
ret0, _ := ret[0].(database.GetChatUsageLimitUserOverrideRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatUsageLimitUserOverride indicates an expected call of GetChatUsageLimitUserOverride.
func (mr *MockStoreMockRecorder) GetChatUsageLimitUserOverride(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitUserOverride), ctx, userID)
}
// GetChats mocks base method.
func (m *MockStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChats", ctx, arg)
ret := m.ctrl.Call(m, "GetChatsByOwnerID", ctx, arg)
ret0, _ := ret[0].([]database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChats indicates an expected call of GetChats.
func (mr *MockStoreMockRecorder) GetChats(ctx, arg any) *gomock.Call {
// GetChatsByOwnerID indicates an expected call of GetChatsByOwnerID.
func (mr *MockStoreMockRecorder) GetChatsByOwnerID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChats", reflect.TypeOf((*MockStore)(nil).GetChats), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByOwnerID", reflect.TypeOf((*MockStore)(nil).GetChatsByOwnerID), ctx, arg)
}
// GetConnectionLogsOffset mocks base method.
@@ -3486,6 +3368,21 @@ func (mr *MockStoreMockRecorder) GetProvisionerJobTimingsByJobID(ctx, jobID any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerJobTimingsByJobID", reflect.TypeOf((*MockStore)(nil).GetProvisionerJobTimingsByJobID), ctx, jobID)
}
// GetProvisionerJobsByIDs mocks base method.
func (m *MockStore) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProvisionerJobsByIDs", ctx, ids)
ret0, _ := ret[0].([]database.ProvisionerJob)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetProvisionerJobsByIDs indicates an expected call of GetProvisionerJobsByIDs.
func (mr *MockStoreMockRecorder) GetProvisionerJobsByIDs(ctx, ids any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerJobsByIDs", reflect.TypeOf((*MockStore)(nil).GetProvisionerJobsByIDs), ctx, ids)
}
// GetProvisionerJobsByIDsWithQueuePosition mocks base method.
func (m *MockStore) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
m.ctrl.T.Helper()
@@ -4341,21 +4238,6 @@ func (mr *MockStoreMockRecorder) GetUserChatCustomPrompt(ctx, userID any) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).GetUserChatCustomPrompt), ctx, userID)
}
// GetUserChatSpendInPeriod mocks base method.
func (m *MockStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserChatSpendInPeriod", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserChatSpendInPeriod indicates an expected call of GetUserChatSpendInPeriod.
func (mr *MockStoreMockRecorder) GetUserChatSpendInPeriod(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatSpendInPeriod", reflect.TypeOf((*MockStore)(nil).GetUserChatSpendInPeriod), ctx, arg)
}
// GetUserCount mocks base method.
func (m *MockStore) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
m.ctrl.T.Helper()
@@ -4371,21 +4253,6 @@ func (mr *MockStoreMockRecorder) GetUserCount(ctx, includeSystem any) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserCount", reflect.TypeOf((*MockStore)(nil).GetUserCount), ctx, includeSystem)
}
// GetUserGroupSpendLimit mocks base method.
func (m *MockStore) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserGroupSpendLimit", ctx, userID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserGroupSpendLimit indicates an expected call of GetUserGroupSpendLimit.
func (mr *MockStoreMockRecorder) GetUserGroupSpendLimit(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserGroupSpendLimit", reflect.TypeOf((*MockStore)(nil).GetUserGroupSpendLimit), ctx, userID)
}
// GetUserLatencyInsights mocks base method.
func (m *MockStore) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
m.ctrl.T.Helper()
@@ -5540,21 +5407,6 @@ func (mr *MockStoreMockRecorder) InsertAIBridgeInterception(ctx, arg any) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeInterception", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeInterception), ctx, arg)
}
// InsertAIBridgeModelThought mocks base method.
func (m *MockStore) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertAIBridgeModelThought", ctx, arg)
ret0, _ := ret[0].(database.AIBridgeModelThought)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertAIBridgeModelThought indicates an expected call of InsertAIBridgeModelThought.
func (mr *MockStoreMockRecorder) InsertAIBridgeModelThought(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeModelThought", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeModelThought), ctx, arg)
}
// InsertAIBridgeTokenUsage mocks base method.
func (m *MockStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
m.ctrl.T.Helper()
@@ -6740,36 +6592,6 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared)
}
// ListChatUsageLimitGroupOverrides mocks base method.
func (m *MockStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListChatUsageLimitGroupOverrides", ctx)
ret0, _ := ret[0].([]database.ListChatUsageLimitGroupOverridesRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListChatUsageLimitGroupOverrides indicates an expected call of ListChatUsageLimitGroupOverrides.
func (mr *MockStoreMockRecorder) ListChatUsageLimitGroupOverrides(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChatUsageLimitGroupOverrides", reflect.TypeOf((*MockStore)(nil).ListChatUsageLimitGroupOverrides), ctx)
}
// ListChatUsageLimitOverrides mocks base method.
func (m *MockStore) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListChatUsageLimitOverrides", ctx)
ret0, _ := ret[0].([]database.ListChatUsageLimitOverridesRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListChatUsageLimitOverrides indicates an expected call of ListChatUsageLimitOverrides.
func (mr *MockStoreMockRecorder) ListChatUsageLimitOverrides(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChatUsageLimitOverrides", reflect.TypeOf((*MockStore)(nil).ListChatUsageLimitOverrides), ctx)
}
// ListProvisionerKeysByOrganization mocks base method.
func (m *MockStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
m.ctrl.T.Helper()
@@ -7008,21 +6830,6 @@ func (mr *MockStoreMockRecorder) RemoveUserFromGroups(ctx, arg any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromGroups), ctx, arg)
}
// ResolveUserChatSpendLimit mocks base method.
func (m *MockStore) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ResolveUserChatSpendLimit", ctx, userID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ResolveUserChatSpendLimit indicates an expected call of ResolveUserChatSpendLimit.
func (mr *MockStoreMockRecorder) ResolveUserChatSpendLimit(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolveUserChatSpendLimit", reflect.TypeOf((*MockStore)(nil).ResolveUserChatSpendLimit), ctx, userID)
}
// RevokeDBCryptKey mocks base method.
func (m *MockStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
m.ctrl.T.Helper()
@@ -8525,20 +8332,6 @@ func (mr *MockStoreMockRecorder) UpsertBoundaryUsageStats(ctx, arg any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertBoundaryUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertBoundaryUsageStats), ctx, arg)
}
// UpsertChatDesktopEnabled mocks base method.
func (m *MockStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatDesktopEnabled", ctx, enableDesktop)
ret0, _ := ret[0].(error)
return ret0
}
// UpsertChatDesktopEnabled indicates an expected call of UpsertChatDesktopEnabled.
func (mr *MockStoreMockRecorder) UpsertChatDesktopEnabled(ctx, enableDesktop any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDesktopEnabled", reflect.TypeOf((*MockStore)(nil).UpsertChatDesktopEnabled), ctx, enableDesktop)
}
// UpsertChatDiffStatus mocks base method.
func (m *MockStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
m.ctrl.T.Helper()
@@ -8583,51 +8376,6 @@ func (mr *MockStoreMockRecorder) UpsertChatSystemPrompt(ctx, value any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatSystemPrompt), ctx, value)
}
// UpsertChatUsageLimitConfig mocks base method.
func (m *MockStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatUsageLimitConfig", ctx, arg)
ret0, _ := ret[0].(database.ChatUsageLimitConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertChatUsageLimitConfig indicates an expected call of UpsertChatUsageLimitConfig.
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitConfig(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitConfig", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitConfig), ctx, arg)
}
// UpsertChatUsageLimitGroupOverride mocks base method.
func (m *MockStore) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatUsageLimitGroupOverride", ctx, arg)
ret0, _ := ret[0].(database.UpsertChatUsageLimitGroupOverrideRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertChatUsageLimitGroupOverride indicates an expected call of UpsertChatUsageLimitGroupOverride.
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitGroupOverride(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitGroupOverride), ctx, arg)
}
// UpsertChatUsageLimitUserOverride mocks base method.
func (m *MockStore) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatUsageLimitUserOverride", ctx, arg)
ret0, _ := ret[0].(database.UpsertChatUsageLimitUserOverrideRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertChatUsageLimitUserOverride indicates an expected call of UpsertChatUsageLimitUserOverride.
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitUserOverride(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitUserOverride), ctx, arg)
}
// UpsertConnectionLog mocks base method.
func (m *MockStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
m.ctrl.T.Helper()
+5 -73
View File
@@ -512,12 +512,6 @@ CREATE TYPE resource_type AS ENUM (
'ai_seat'
);
CREATE TYPE shareable_workspace_owners AS ENUM (
'none',
'everyone',
'service_accounts'
);
CREATE TYPE startup_script_behavior AS ENUM (
'blocking',
'non-blocking'
@@ -798,7 +792,7 @@ BEGIN
END;
$$;
CREATE FUNCTION insert_organization_system_roles() RETURNS trigger
CREATE FUNCTION insert_org_member_system_role() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
@@ -813,8 +807,7 @@ BEGIN
is_system,
created_at,
updated_at
) VALUES
(
) VALUES (
'organization-member',
'',
NEW.id,
@@ -825,18 +818,6 @@ BEGIN
true,
NOW(),
NOW()
),
(
'organization-service-account',
'',
NEW.id,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
true,
NOW(),
NOW()
);
RETURN NEW;
END;
@@ -1105,15 +1086,6 @@ COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interceptio
COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID supplied by the client (optional and not universally supported).';
CREATE TABLE aibridge_model_thoughts (
interception_id uuid NOT NULL,
content text NOT NULL,
metadata jsonb,
created_at timestamp with time zone NOT NULL
);
COMMENT ON TABLE aibridge_model_thoughts IS 'Audit log of model thinking in intercepted requests in AI Bridge';
CREATE TABLE aibridge_token_usages (
id uuid NOT NULL,
interception_id uuid NOT NULL,
@@ -1346,28 +1318,6 @@ CREATE SEQUENCE chat_queued_messages_id_seq
ALTER SEQUENCE chat_queued_messages_id_seq OWNED BY chat_queued_messages.id;
CREATE TABLE chat_usage_limit_config (
id bigint NOT NULL,
singleton boolean DEFAULT true NOT NULL,
enabled boolean DEFAULT false NOT NULL,
default_limit_micros bigint DEFAULT 0 NOT NULL,
period text DEFAULT 'month'::text NOT NULL,
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
CONSTRAINT chat_usage_limit_config_default_limit_micros_check CHECK ((default_limit_micros >= 0)),
CONSTRAINT chat_usage_limit_config_period_check CHECK ((period = ANY (ARRAY['day'::text, 'week'::text, 'month'::text]))),
CONSTRAINT chat_usage_limit_config_singleton_check CHECK (singleton)
);
CREATE SEQUENCE chat_usage_limit_config_id_seq
START WITH 1
INCREMENT BY 1
NO MINVALUE
NO MAXVALUE
CACHE 1;
ALTER SEQUENCE chat_usage_limit_config_id_seq OWNED BY chat_usage_limit_config.id;
CREATE TABLE chats (
id uuid DEFAULT gen_random_uuid() NOT NULL,
owner_id uuid NOT NULL,
@@ -1524,9 +1474,7 @@ CREATE TABLE groups (
avatar_url text DEFAULT ''::text NOT NULL,
quota_allowance integer DEFAULT 0 NOT NULL,
display_name text DEFAULT ''::text NOT NULL,
source group_source DEFAULT 'user'::group_source NOT NULL,
chat_spend_limit_micros bigint,
CONSTRAINT groups_chat_spend_limit_micros_check CHECK (((chat_spend_limit_micros IS NULL) OR (chat_spend_limit_micros > 0)))
source group_source DEFAULT 'user'::group_source NOT NULL
);
COMMENT ON COLUMN groups.display_name IS 'Display name is a custom, human-friendly group name that user can set. This is not required to be unique and can be the empty string.';
@@ -1561,9 +1509,7 @@ CREATE TABLE users (
one_time_passcode_expires_at timestamp with time zone,
is_system boolean DEFAULT false NOT NULL,
is_service_account boolean DEFAULT false NOT NULL,
chat_spend_limit_micros bigint,
CONSTRAINT one_time_passcode_set CHECK ((((hashed_one_time_passcode IS NULL) AND (one_time_passcode_expires_at IS NULL)) OR ((hashed_one_time_passcode IS NOT NULL) AND (one_time_passcode_expires_at IS NOT NULL)))),
CONSTRAINT users_chat_spend_limit_micros_check CHECK (((chat_spend_limit_micros IS NULL) OR (chat_spend_limit_micros > 0))),
CONSTRAINT users_email_not_empty CHECK (((is_service_account = true) = (email = ''::text))),
CONSTRAINT users_service_account_login_type CHECK (((is_service_account = false) OR (login_type = 'none'::login_type))),
CONSTRAINT users_username_min_length CHECK ((length(username) >= 1))
@@ -1851,11 +1797,9 @@ CREATE TABLE organizations (
display_name text NOT NULL,
icon text DEFAULT ''::text NOT NULL,
deleted boolean DEFAULT false NOT NULL,
shareable_workspace_owners shareable_workspace_owners DEFAULT 'everyone'::shareable_workspace_owners NOT NULL
workspace_sharing_disabled boolean DEFAULT false NOT NULL
);
COMMENT ON COLUMN organizations.shareable_workspace_owners IS 'Controls whose workspaces can be shared: none, everyone, or service_accounts.';
CREATE TABLE parameter_schemas (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
@@ -3212,8 +3156,6 @@ ALTER TABLE ONLY chat_messages ALTER COLUMN id SET DEFAULT nextval('chat_message
ALTER TABLE ONLY chat_queued_messages ALTER COLUMN id SET DEFAULT nextval('chat_queued_messages_id_seq'::regclass);
ALTER TABLE ONLY chat_usage_limit_config ALTER COLUMN id SET DEFAULT nextval('chat_usage_limit_config_id_seq'::regclass);
ALTER TABLE ONLY licenses ALTER COLUMN id SET DEFAULT nextval('licenses_id_seq'::regclass);
ALTER TABLE ONLY provisioner_job_logs ALTER COLUMN id SET DEFAULT nextval('provisioner_job_logs_id_seq'::regclass);
@@ -3274,12 +3216,6 @@ ALTER TABLE ONLY chat_providers
ALTER TABLE ONLY chat_queued_messages
ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
ALTER TABLE ONLY chat_usage_limit_config
ADD CONSTRAINT chat_usage_limit_config_pkey PRIMARY KEY (id);
ALTER TABLE ONLY chat_usage_limit_config
ADD CONSTRAINT chat_usage_limit_config_singleton_key UNIQUE (singleton);
ALTER TABLE ONLY chats
ADD CONSTRAINT chats_pkey PRIMARY KEY (id);
@@ -3592,8 +3528,6 @@ CREATE INDEX idx_aibridge_interceptions_thread_parent_id ON aibridge_interceptio
CREATE INDEX idx_aibridge_interceptions_thread_root_id ON aibridge_interceptions USING btree (thread_root_id);
CREATE INDEX idx_aibridge_model_thoughts_interception_id ON aibridge_model_thoughts USING btree (interception_id);
CREATE INDEX idx_aibridge_token_usages_interception_id ON aibridge_token_usages USING btree (interception_id);
CREATE INDEX idx_aibridge_token_usages_provider_response_id ON aibridge_token_usages USING btree (provider_response_id);
@@ -3634,8 +3568,6 @@ CREATE INDEX idx_chat_messages_compressed_summary_boundary ON chat_messages USIN
CREATE INDEX idx_chat_messages_created_at ON chat_messages USING btree (created_at);
CREATE INDEX idx_chat_messages_owner_spend ON chat_messages USING btree (chat_id, created_at) WHERE (total_cost_micros IS NOT NULL);
CREATE INDEX idx_chat_model_configs_enabled ON chat_model_configs USING btree (enabled);
CREATE INDEX idx_chat_model_configs_provider ON chat_model_configs USING btree (provider);
@@ -3884,7 +3816,7 @@ CREATE TRIGGER trigger_delete_oauth2_provider_app_token AFTER DELETE ON oauth2_p
CREATE TRIGGER trigger_insert_apikeys BEFORE INSERT ON api_keys FOR EACH ROW EXECUTE FUNCTION insert_apikey_fail_if_user_deleted();
CREATE TRIGGER trigger_insert_organization_system_roles AFTER INSERT ON organizations FOR EACH ROW EXECUTE FUNCTION insert_organization_system_roles();
CREATE TRIGGER trigger_insert_org_member_system_role AFTER INSERT ON organizations FOR EACH ROW EXECUTE FUNCTION insert_org_member_system_role();
CREATE TRIGGER trigger_nullify_next_start_at_on_workspace_autostart_modificati AFTER UPDATE ON workspaces FOR EACH ROW EXECUTE FUNCTION nullify_next_start_at_on_workspace_autostart_modification();
@@ -26,7 +26,6 @@ func TestCustomQueriesSyncedRowScan(t *testing.T) {
"GetTemplatesWithFilter": "GetAuthorizedTemplates",
"GetWorkspaces": "GetAuthorizedWorkspaces",
"GetUsers": "GetAuthorizedUsers",
"GetChats": "GetAuthorizedChats",
}
// Scan custom
@@ -1,4 +0,0 @@
DROP INDEX IF EXISTS idx_chat_messages_owner_spend;
ALTER TABLE groups DROP COLUMN IF EXISTS chat_spend_limit_micros;
ALTER TABLE users DROP COLUMN IF EXISTS chat_spend_limit_micros;
DROP TABLE IF EXISTS chat_usage_limit_config;
@@ -1,32 +0,0 @@
-- 1. Singleton config table
CREATE TABLE chat_usage_limit_config (
id BIGSERIAL PRIMARY KEY,
-- Only one row allowed (enforced by CHECK).
singleton BOOLEAN NOT NULL DEFAULT TRUE CHECK (singleton),
UNIQUE (singleton),
enabled BOOLEAN NOT NULL DEFAULT FALSE,
-- Limit per user per period, in micro-dollars (1 USD = 1,000,000).
default_limit_micros BIGINT NOT NULL DEFAULT 0
CHECK (default_limit_micros >= 0),
-- Period length: 'day', 'week', or 'month'.
period TEXT NOT NULL DEFAULT 'month'
CHECK (period IN ('day', 'week', 'month')),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- Seed a single disabled row so reads never return empty.
INSERT INTO chat_usage_limit_config (singleton) VALUES (TRUE);
-- 2. Per-user overrides (inline on users table).
ALTER TABLE users ADD COLUMN chat_spend_limit_micros BIGINT DEFAULT NULL
CHECK (chat_spend_limit_micros IS NULL OR chat_spend_limit_micros > 0);
-- 3. Per-group overrides (inline on groups table).
ALTER TABLE groups ADD COLUMN chat_spend_limit_micros BIGINT DEFAULT NULL
CHECK (chat_spend_limit_micros IS NULL OR chat_spend_limit_micros > 0);
-- Speed up per-user spend aggregation in the usage-limit hot path.
CREATE INDEX idx_chat_messages_owner_spend
ON chat_messages (chat_id, created_at)
WHERE total_cost_micros IS NOT NULL;
@@ -1,3 +0,0 @@
DROP INDEX idx_aibridge_model_thoughts_interception_id;
DROP TABLE aibridge_model_thoughts;
@@ -1,10 +0,0 @@
CREATE TABLE aibridge_model_thoughts (
interception_id UUID NOT NULL,
content TEXT NOT NULL,
metadata jsonb,
created_at TIMESTAMPTZ NOT NULL
);
COMMENT ON TABLE aibridge_model_thoughts IS 'Audit log of model thinking in intercepted requests in AI Bridge';
CREATE INDEX idx_aibridge_model_thoughts_interception_id ON aibridge_model_thoughts(interception_id);
@@ -1,52 +0,0 @@
DELETE FROM custom_roles
WHERE name = 'organization-service-account' AND is_system = true;
ALTER TABLE organizations
ADD COLUMN workspace_sharing_disabled boolean NOT NULL DEFAULT false;
-- Migrate back: 'none' -> disabled, everything else -> enabled.
UPDATE organizations
SET workspace_sharing_disabled = true
WHERE shareable_workspace_owners = 'none';
ALTER TABLE organizations DROP COLUMN shareable_workspace_owners;
DROP TYPE shareable_workspace_owners;
-- Restore the original single-role trigger from migration 408.
DROP TRIGGER IF EXISTS trigger_insert_organization_system_roles ON organizations;
DROP FUNCTION IF EXISTS insert_organization_system_roles;
CREATE OR REPLACE FUNCTION insert_org_member_system_role() RETURNS trigger AS $$
BEGIN
INSERT INTO custom_roles (
name,
display_name,
organization_id,
site_permissions,
org_permissions,
user_permissions,
member_permissions,
is_system,
created_at,
updated_at
) VALUES (
'organization-member',
'',
NEW.id,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
true,
NOW(),
NOW()
);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_insert_org_member_system_role
AFTER INSERT ON organizations
FOR EACH ROW
EXECUTE FUNCTION insert_org_member_system_role();
@@ -1,101 +0,0 @@
CREATE TYPE shareable_workspace_owners AS ENUM ('none', 'everyone', 'service_accounts');
ALTER TABLE organizations
ADD COLUMN shareable_workspace_owners shareable_workspace_owners NOT NULL DEFAULT 'everyone';
COMMENT ON COLUMN organizations.shareable_workspace_owners IS 'Controls whose workspaces can be shared: none, everyone, or service_accounts.';
-- Migrate existing data from the boolean column.
UPDATE organizations
SET shareable_workspace_owners = 'none'
WHERE workspace_sharing_disabled = true;
ALTER TABLE organizations DROP COLUMN workspace_sharing_disabled;
-- Defensively rename any existing 'organization-service-account' roles
-- so they don't collide with the new system role.
UPDATE custom_roles
SET name = name || '-' || id::text
-- lower(name) is part of the existing unique index
WHERE lower(name) = 'organization-service-account';
-- Create skeleton organization-service-account system roles for all
-- existing organizations, mirroring what migration 408 did for
-- organization-member.
INSERT INTO custom_roles (
name,
display_name,
organization_id,
site_permissions,
org_permissions,
user_permissions,
member_permissions,
is_system,
created_at,
updated_at
)
SELECT
'organization-service-account',
'',
id,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
true,
NOW(),
NOW()
FROM
organizations;
-- Replace the single-role trigger with one that creates both system
-- roles when a new organization is inserted.
DROP TRIGGER IF EXISTS trigger_insert_org_member_system_role ON organizations;
DROP FUNCTION IF EXISTS insert_org_member_system_role;
CREATE OR REPLACE FUNCTION insert_organization_system_roles() RETURNS trigger AS $$
BEGIN
INSERT INTO custom_roles (
name,
display_name,
organization_id,
site_permissions,
org_permissions,
user_permissions,
member_permissions,
is_system,
created_at,
updated_at
) VALUES
(
'organization-member',
'',
NEW.id,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
true,
NOW(),
NOW()
),
(
'organization-service-account',
'',
NEW.id,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
true,
NOW(),
NOW()
);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_insert_organization_system_roles
AFTER INSERT ON organizations
FOR EACH ROW
EXECUTE FUNCTION insert_organization_system_roles();
@@ -1,28 +0,0 @@
-- Fixture for migration 000443_three_options_for_allowed_workspace_sharing.
-- Inserts a custom role named 'Organization-Service-Account' (mixed case)
-- to ensure the migration's case-insensitive rename catches it.
INSERT INTO custom_roles (
name,
display_name,
organization_id,
site_permissions,
org_permissions,
user_permissions,
member_permissions,
is_system,
created_at,
updated_at
)
VALUES (
'Organization-Service-Account',
'User-created role',
'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1',
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
false,
NOW(),
NOW()
)
ON CONFLICT DO NOTHING;
@@ -1,5 +0,0 @@
UPDATE users SET chat_spend_limit_micros = 5000000
WHERE id = 'fc1511ef-4fcf-4a3b-98a1-8df64160e35a';
UPDATE groups SET chat_spend_limit_micros = 10000000
WHERE id = 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1';
@@ -1,13 +0,0 @@
INSERT INTO
aibridge_model_thoughts (
interception_id,
content,
metadata,
created_at
)
VALUES (
'be003e1e-b38f-43bf-847d-928074dd0aa8', -- from 000370_aibridge.up.sql
'The user is asking about their workspaces. I should use the coder_list_workspaces tool to retrieve this information.',
'{"source": "commentary"}',
'2025-09-15 12:45:19.123456+00'
);
-64
View File
@@ -52,7 +52,6 @@ type customQuerier interface {
auditLogQuerier
connectionLogQuerier
aibridgeQuerier
chatQuerier
}
type templateQuerier interface {
@@ -452,7 +451,6 @@ func (q *sqlQuerier) GetAuthorizedUsers(ctx context.Context, arg GetUsersParams,
&i.OneTimePasscodeExpiresAt,
&i.IsSystem,
&i.IsServiceAccount,
&i.ChatSpendLimitMicros,
&i.Count,
); err != nil {
return nil, err
@@ -739,68 +737,6 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
return count, nil
}
type chatQuerier interface {
GetAuthorizedChats(ctx context.Context, arg GetChatsParams, prepared rbac.PreparedAuthorized) ([]Chat, error)
}
func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams, prepared rbac.PreparedAuthorized) ([]Chat, error) {
authorizedFilter, err := prepared.CompileToSQL(ctx, rbac.ConfigChats())
if err != nil {
return nil, xerrors.Errorf("compile authorized filter: %w", err)
}
filtered, err := insertAuthorizedFilter(getChats, fmt.Sprintf(" AND %s", authorizedFilter))
if err != nil {
return nil, xerrors.Errorf("insert authorized filter: %w", err)
}
// The name comment is for metric tracking
query := fmt.Sprintf("-- name: GetAuthorizedChats :many\n%s", filtered)
rows, err := q.db.QueryContext(ctx, query,
arg.OwnerID,
arg.Archived,
arg.AfterID,
arg.OffsetOpt,
arg.LimitOpt,
)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Chat
for rows.Next() {
var i Chat
if err := rows.Scan(
&i.ID,
&i.OwnerID,
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
&i.CreatedAt,
&i.UpdatedAt,
&i.ParentChatID,
&i.RootChatID,
&i.LastModelConfigID,
&i.Archived,
&i.LastError,
&i.Mode,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
type aibridgeQuerier interface {
ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error)
CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error)
+12 -94
View File
@@ -3131,67 +3131,6 @@ func AllResourceTypeValues() []ResourceType {
}
}
type ShareableWorkspaceOwners string
const (
ShareableWorkspaceOwnersNone ShareableWorkspaceOwners = "none"
ShareableWorkspaceOwnersEveryone ShareableWorkspaceOwners = "everyone"
ShareableWorkspaceOwnersServiceAccounts ShareableWorkspaceOwners = "service_accounts"
)
func (e *ShareableWorkspaceOwners) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = ShareableWorkspaceOwners(s)
case string:
*e = ShareableWorkspaceOwners(s)
default:
return fmt.Errorf("unsupported scan type for ShareableWorkspaceOwners: %T", src)
}
return nil
}
type NullShareableWorkspaceOwners struct {
ShareableWorkspaceOwners ShareableWorkspaceOwners `json:"shareable_workspace_owners"`
Valid bool `json:"valid"` // Valid is true if ShareableWorkspaceOwners is not NULL
}
// Scan implements the Scanner interface.
func (ns *NullShareableWorkspaceOwners) Scan(value interface{}) error {
if value == nil {
ns.ShareableWorkspaceOwners, ns.Valid = "", false
return nil
}
ns.Valid = true
return ns.ShareableWorkspaceOwners.Scan(value)
}
// Value implements the driver Valuer interface.
func (ns NullShareableWorkspaceOwners) Value() (driver.Value, error) {
if !ns.Valid {
return nil, nil
}
return string(ns.ShareableWorkspaceOwners), nil
}
func (e ShareableWorkspaceOwners) Valid() bool {
switch e {
case ShareableWorkspaceOwnersNone,
ShareableWorkspaceOwnersEveryone,
ShareableWorkspaceOwnersServiceAccounts:
return true
}
return false
}
func AllShareableWorkspaceOwnersValues() []ShareableWorkspaceOwners {
return []ShareableWorkspaceOwners{
ShareableWorkspaceOwnersNone,
ShareableWorkspaceOwnersEveryone,
ShareableWorkspaceOwnersServiceAccounts,
}
}
type StartupScriptBehavior string
const (
@@ -4038,14 +3977,6 @@ type AIBridgeInterception struct {
ClientSessionID sql.NullString `db:"client_session_id" json:"client_session_id"`
}
// Audit log of model thinking in intercepted requests in AI Bridge
type AIBridgeModelThought struct {
InterceptionID uuid.UUID `db:"interception_id" json:"interception_id"`
Content string `db:"content" json:"content"`
Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
}
// Audit log of tokens used by intercepted requests in AI Bridge
type AIBridgeTokenUsage struct {
ID uuid.UUID `db:"id" json:"id"`
@@ -4265,16 +4196,6 @@ type ChatQueuedMessage struct {
CreatedAt time.Time `db:"created_at" json:"created_at"`
}
type ChatUsageLimitConfig struct {
ID int64 `db:"id" json:"id"`
Singleton bool `db:"singleton" json:"singleton"`
Enabled bool `db:"enabled" json:"enabled"`
DefaultLimitMicros int64 `db:"default_limit_micros" json:"default_limit_micros"`
Period string `db:"period" json:"period"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
type ConnectionLog struct {
ID uuid.UUID `db:"id" json:"id"`
ConnectTime time.Time `db:"connect_time" json:"connect_time"`
@@ -4387,8 +4308,7 @@ type Group struct {
// Display name is a custom, human-friendly group name that user can set. This is not required to be unique and can be the empty string.
DisplayName string `db:"display_name" json:"display_name"`
// Source indicates how the group was created. It can be created by a user manually, or through some system process like OIDC group sync.
Source GroupSource `db:"source" json:"source"`
ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"`
Source GroupSource `db:"source" json:"source"`
}
// Joins group members with user information, organization ID, group name. Includes both regular group members and organization members (as part of the "Everyone" group).
@@ -4596,17 +4516,16 @@ type OAuth2ProviderAppToken struct {
}
type Organization struct {
ID uuid.UUID `db:"id" json:"id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
IsDefault bool `db:"is_default" json:"is_default"`
DisplayName string `db:"display_name" json:"display_name"`
Icon string `db:"icon" json:"icon"`
Deleted bool `db:"deleted" json:"deleted"`
// Controls whose workspaces can be shared: none, everyone, or service_accounts.
ShareableWorkspaceOwners ShareableWorkspaceOwners `db:"shareable_workspace_owners" json:"shareable_workspace_owners"`
ID uuid.UUID `db:"id" json:"id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
IsDefault bool `db:"is_default" json:"is_default"`
DisplayName string `db:"display_name" json:"display_name"`
Icon string `db:"icon" json:"icon"`
Deleted bool `db:"deleted" json:"deleted"`
WorkspaceSharingDisabled bool `db:"workspace_sharing_disabled" json:"workspace_sharing_disabled"`
}
type OrganizationMember struct {
@@ -5159,8 +5078,7 @@ type User struct {
// Determines if a user is a system user, and therefore cannot login or perform normal actions
IsSystem bool `db:"is_system" json:"is_system"`
// Determines if a user is an admin-managed account that cannot login
IsServiceAccount bool `db:"is_service_account" json:"is_service_account"`
ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"`
IsServiceAccount bool `db:"is_service_account" json:"is_service_account"`
}
type UserConfig struct {
+3 -28
View File
@@ -77,9 +77,6 @@ type sqlcQuerier interface {
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error)
// Counts enabled, non-deleted model configs that lack both input and
// output pricing in their JSONB options.cost configuration.
CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error)
// CountInProgressPrebuilds returns the number of in-progress prebuilds, grouped by preset ID and transition.
// Prebuild considered in-progress if it's in the "pending", "starting", "stopping", or "deleting" state.
CountInProgressPrebuilds(ctx context.Context) ([]CountInProgressPrebuildsRow, error)
@@ -102,8 +99,6 @@ type sqlcQuerier interface {
DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error
DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error
DeleteChatQueuedMessage(ctx context.Context, arg DeleteChatQueuedMessageParams) error
DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error
DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error
DeleteCryptoKey(ctx context.Context, arg DeleteCryptoKeyParams) (CryptoKey, error)
DeleteCustomRole(ctx context.Context, arg DeleteCustomRoleParams) error
DeleteExpiredAPIKeys(ctx context.Context, arg DeleteExpiredAPIKeysParams) (int64, error)
@@ -150,7 +145,7 @@ type sqlcQuerier interface {
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
DeleteWorkspaceACLsByOrganization(ctx context.Context, arg DeleteWorkspaceACLsByOrganizationParams) error
DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error
DeleteWorkspaceAgentPortShare(ctx context.Context, arg DeleteWorkspaceAgentPortShareParams) error
DeleteWorkspaceAgentPortSharesByTemplate(ctx context.Context, templateID uuid.UUID) error
DeleteWorkspaceSubAgentByID(ctx context.Context, id uuid.UUID) error
@@ -234,7 +229,6 @@ type sqlcQuerier interface {
// Aggregate cost summary for a single user within a date range.
// Only counts assistant-role messages.
GetChatCostSummary(ctx context.Context, arg GetChatCostSummaryParams) (GetChatCostSummaryRow, error)
GetChatDesktopEnabled(ctx context.Context) (bool, error)
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error)
@@ -250,10 +244,7 @@ type sqlcQuerier interface {
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
GetChatSystemPrompt(ctx context.Context) (string, error)
GetChatUsageLimitConfig(ctx context.Context) (ChatUsageLimitConfig, error)
GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (GetChatUsageLimitGroupOverrideRow, error)
GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (GetChatUsageLimitUserOverrideRow, error)
GetChats(ctx context.Context, arg GetChatsParams) ([]Chat, error)
GetChatsByOwnerID(ctx context.Context, arg GetChatsByOwnerIDParams) ([]Chat, error)
GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error)
GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error)
GetCryptoKeys(ctx context.Context) ([]CryptoKey, error)
@@ -387,6 +378,7 @@ type sqlcQuerier interface {
// Blocks until the row is available for update.
GetProvisionerJobByIDWithLock(ctx context.Context, id uuid.UUID) (ProvisionerJob, error)
GetProvisionerJobTimingsByJobID(ctx context.Context, jobID uuid.UUID) ([]ProvisionerJobTiming, error)
GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error)
GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg GetProvisionerJobsByIDsWithQueuePositionParams) ([]GetProvisionerJobsByIDsWithQueuePositionRow, error)
GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisioner(ctx context.Context, arg GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerParams) ([]GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow, error)
GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]ProvisionerJob, error)
@@ -505,11 +497,7 @@ type sqlcQuerier interface {
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
GetUserChatSpendInPeriod(ctx context.Context, arg GetUserChatSpendInPeriodParams) (int64, error)
GetUserCount(ctx context.Context, includeSystem bool) (int64, error)
// Returns the minimum (most restrictive) group limit for a user.
// Returns -1 if the user has no group limits applied.
GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error)
// GetUserLatencyInsights returns the median and 95th percentile connection
// latency that users have experienced. The result can be filtered on
// template_ids, meaning only user data from workspaces based on those templates
@@ -613,7 +601,6 @@ type sqlcQuerier interface {
GetWorkspacesEligibleForTransition(ctx context.Context, now time.Time) ([]GetWorkspacesEligibleForTransitionRow, error)
GetWorkspacesForWorkspaceMetrics(ctx context.Context) ([]GetWorkspacesForWorkspaceMetricsRow, error)
InsertAIBridgeInterception(ctx context.Context, arg InsertAIBridgeInterceptionParams) (AIBridgeInterception, error)
InsertAIBridgeModelThought(ctx context.Context, arg InsertAIBridgeModelThoughtParams) (AIBridgeModelThought, error)
InsertAIBridgeTokenUsage(ctx context.Context, arg InsertAIBridgeTokenUsageParams) (AIBridgeTokenUsage, error)
InsertAIBridgeToolUsage(ctx context.Context, arg InsertAIBridgeToolUsageParams) (AIBridgeToolUsage, error)
InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIBridgeUserPromptParams) (AIBridgeUserPrompt, error)
@@ -710,8 +697,6 @@ type sqlcQuerier interface {
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
ListChatUsageLimitGroupOverrides(ctx context.Context) ([]ListChatUsageLimitGroupOverridesRow, error)
ListChatUsageLimitOverrides(ctx context.Context) ([]ListChatUsageLimitOverridesRow, error)
ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
@@ -732,12 +717,6 @@ type sqlcQuerier interface {
ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error)
// Resolves the effective spend limit for a user using the hierarchy:
// 1. Individual user override (highest priority)
// 2. Minimum group limit across all user's groups
// 3. Global default from config
// Returns -1 if limits are not enabled.
ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error)
RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error
// Note that this selects from the CTE, not the original table. The CTE is named
// the same as the original table to trick sqlc into reusing the existing struct
@@ -866,13 +845,9 @@ type sqlcQuerier interface {
// cumulative values for unique counts (accurate period totals). Request counts
// are always deltas, accumulated in DB. Returns true if insert, false if update.
UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBoundaryUsageStatsParams) (bool, error)
UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
UpsertChatSystemPrompt(ctx context.Context, value string) error
UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error)
UpsertChatUsageLimitGroupOverride(ctx context.Context, arg UpsertChatUsageLimitGroupOverrideParams) (UpsertChatUsageLimitGroupOverrideRow, error)
UpsertChatUsageLimitUserOverride(ctx context.Context, arg UpsertChatUsageLimitUserOverrideParams) (UpsertChatUsageLimitUserOverrideRow, error)
UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error)
// The default proxy is implied and not actually stored in the database.
// So we need to store it's configuration here for display purposes.
+58 -391
View File
@@ -1235,230 +1235,6 @@ func TestGetAuthorizedWorkspacesAndAgentsByOwnerID(t *testing.T) {
})
}
func TestGetAuthorizedChats(t *testing.T) {
t.Parallel()
if testing.Short() {
t.SkipNow()
}
sqlDB := testSQLDB(t)
err := migrations.Up(sqlDB)
require.NoError(t, err)
db := database.New(sqlDB)
authorizer := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
// Create users with different roles.
owner := dbgen.User(t, db, database.User{
RBACRoles: []string{rbac.RoleOwner().String()},
})
member := dbgen.User(t, db, database.User{})
secondMember := dbgen.User(t, db, database.User{})
// Create FK dependencies: a chat provider and model config.
ctx := testutil.Context(t, testutil.WaitMedium)
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
})
require.NoError(t, err)
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "test-model",
DisplayName: "Test Model",
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 80,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
// Create 3 chats owned by owner.
for i := range 3 {
_, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: owner.ID,
LastModelConfigID: modelCfg.ID,
Title: fmt.Sprintf("owner chat %d", i+1),
})
require.NoError(t, err)
}
// Create 2 chats owned by member.
for i := range 2 {
_, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: member.ID,
LastModelConfigID: modelCfg.ID,
Title: fmt.Sprintf("member chat %d", i+1),
})
require.NoError(t, err)
}
t.Run("sqlQuerier", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
// Member should only see their own 2 chats.
memberSubject, _, err := httpmw.UserRBACSubject(ctx, db, member.ID, rbac.ExpandableScope(rbac.ScopeAll))
require.NoError(t, err)
preparedMember, err := authorizer.Prepare(ctx, memberSubject, policy.ActionRead, rbac.ResourceChat.Type)
require.NoError(t, err)
memberRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedMember)
require.NoError(t, err)
require.Len(t, memberRows, 2)
for _, row := range memberRows {
require.Equal(t, member.ID, row.OwnerID, "member should only see own chats")
}
// Owner should see at least the 5 pre-created chats (site-wide
// access). Parallel subtests may add more, so use GreaterOrEqual.
ownerSubject, _, err := httpmw.UserRBACSubject(ctx, db, owner.ID, rbac.ExpandableScope(rbac.ScopeAll))
require.NoError(t, err)
preparedOwner, err := authorizer.Prepare(ctx, ownerSubject, policy.ActionRead, rbac.ResourceChat.Type)
require.NoError(t, err)
ownerRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedOwner)
require.NoError(t, err)
require.GreaterOrEqual(t, len(ownerRows), 5)
// secondMember has no chats and should see 0.
secondSubject, _, err := httpmw.UserRBACSubject(ctx, db, secondMember.ID, rbac.ExpandableScope(rbac.ScopeAll))
require.NoError(t, err)
preparedSecond, err := authorizer.Prepare(ctx, secondSubject, policy.ActionRead, rbac.ResourceChat.Type)
require.NoError(t, err)
secondRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedSecond)
require.NoError(t, err)
require.Len(t, secondRows, 0)
// Org admin should NOT see other users' chats — chats are
// not org-scoped resources.
orgs, err := db.GetOrganizations(ctx, database.GetOrganizationsParams{})
require.NoError(t, err)
require.NotEmpty(t, orgs)
orgAdmin := dbgen.User(t, db, database.User{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: orgAdmin.ID,
OrganizationID: orgs[0].ID,
Roles: []string{rbac.RoleOrgAdmin()},
})
orgAdminSubject, _, err := httpmw.UserRBACSubject(ctx, db, orgAdmin.ID, rbac.ExpandableScope(rbac.ScopeAll))
require.NoError(t, err)
preparedOrgAdmin, err := authorizer.Prepare(ctx, orgAdminSubject, policy.ActionRead, rbac.ResourceChat.Type)
require.NoError(t, err)
orgAdminRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedOrgAdmin)
require.NoError(t, err)
require.Len(t, orgAdminRows, 0, "org admin with no chats should see 0 chats")
// OwnerID filter: member queries their own chats.
memberFilterSelf, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
OwnerID: member.ID,
}, preparedMember)
require.NoError(t, err)
require.Len(t, memberFilterSelf, 2)
// OwnerID filter: member queries owner's chats → sees 0.
memberFilterOwner, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
OwnerID: owner.ID,
}, preparedMember)
require.NoError(t, err)
require.Len(t, memberFilterOwner, 0)
})
t.Run("dbauthz", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
authzdb := dbauthz.New(db, authorizer, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
// As member: should see only own 2 chats.
memberSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, member.ID, rbac.ExpandableScope(rbac.ScopeAll))
require.NoError(t, err)
memberCtx := dbauthz.As(ctx, memberSubject)
memberRows, err := authzdb.GetChats(memberCtx, database.GetChatsParams{})
require.NoError(t, err)
require.Len(t, memberRows, 2)
for _, row := range memberRows {
require.Equal(t, member.ID, row.OwnerID, "member should only see own chats")
}
// As owner: should see at least the 5 pre-created chats.
ownerSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, owner.ID, rbac.ExpandableScope(rbac.ScopeAll))
require.NoError(t, err)
ownerCtx := dbauthz.As(ctx, ownerSubject)
ownerRows, err := authzdb.GetChats(ownerCtx, database.GetChatsParams{})
require.NoError(t, err)
require.GreaterOrEqual(t, len(ownerRows), 5)
// As secondMember: should see 0 chats.
secondSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, secondMember.ID, rbac.ExpandableScope(rbac.ScopeAll))
require.NoError(t, err)
secondCtx := dbauthz.As(ctx, secondSubject)
secondRows, err := authzdb.GetChats(secondCtx, database.GetChatsParams{})
require.NoError(t, err)
require.Len(t, secondRows, 0)
})
t.Run("pagination", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
// Use a dedicated user for pagination to avoid interference
// with the other parallel subtests.
paginationUser := dbgen.User(t, db, database.User{})
for i := range 7 {
_, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: paginationUser.ID,
LastModelConfigID: modelCfg.ID,
Title: fmt.Sprintf("pagination chat %d", i+1),
})
require.NoError(t, err)
}
pagUserSubject, _, err := httpmw.UserRBACSubject(ctx, db, paginationUser.ID, rbac.ExpandableScope(rbac.ScopeAll))
require.NoError(t, err)
preparedMember, err := authorizer.Prepare(ctx, pagUserSubject, policy.ActionRead, rbac.ResourceChat.Type)
require.NoError(t, err)
// Fetch first page with limit=2.
page1, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
LimitOpt: 2,
}, preparedMember)
require.NoError(t, err)
require.Len(t, page1, 2)
for _, row := range page1 {
require.Equal(t, paginationUser.ID, row.OwnerID, "paginated results must belong to pagination user")
}
// Fetch remaining pages and collect all chat IDs.
allIDs := make(map[uuid.UUID]struct{})
for _, row := range page1 {
allIDs[row.ID] = struct{}{}
}
offset := int32(2)
for {
page, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
LimitOpt: 2,
OffsetOpt: offset,
}, preparedMember)
require.NoError(t, err)
for _, row := range page {
require.Equal(t, paginationUser.ID, row.OwnerID, "paginated results must belong to pagination user")
allIDs[row.ID] = struct{}{}
}
if len(page) < 2 {
break
}
offset += int32(len(page)) //nolint:gosec // Test code, pagination values are small.
}
// All 7 member chats should be accounted for with no leakage.
require.Len(t, allIDs, 7, "pagination should return all member chats exactly once")
})
}
func TestInsertWorkspaceAgentLogs(t *testing.T) {
t.Parallel()
if testing.Short() {
@@ -2655,42 +2431,6 @@ func TestDeleteCustomRoleDoesNotDeleteSystemRole(t *testing.T) {
require.True(t, roles[0].IsSystem)
}
func TestGetAuthorizationUserRolesImpliedOrgRole(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
org := dbgen.Organization(t, db, database.Organization{})
regularUser := dbgen.User(t, db, database.User{})
saUser := dbgen.User(t, db, database.User{IsServiceAccount: true})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org.ID,
UserID: regularUser.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org.ID,
UserID: saUser.ID,
})
ctx := testutil.Context(t, testutil.WaitShort)
wantMember := rbac.RoleOrgMember() + ":" + org.ID.String()
wantSA := rbac.RoleOrgServiceAccount() + ":" + org.ID.String()
// Regular users get the implied organization-member role.
regularRoles, err := db.GetAuthorizationUserRoles(ctx, regularUser.ID)
require.NoError(t, err)
require.Contains(t, regularRoles.Roles, wantMember)
require.NotContains(t, regularRoles.Roles, wantSA)
// Service accounts get the implied organization-service-account role.
saRoles, err := db.GetAuthorizationUserRoles(ctx, saUser.ID)
require.NoError(t, err)
require.Contains(t, saRoles.Roles, wantSA)
require.NotContains(t, saRoles.Roles, wantMember)
}
func TestUpdateOrganizationWorkspaceSharingSettings(t *testing.T) {
t.Parallel()
@@ -2701,155 +2441,82 @@ func TestUpdateOrganizationWorkspaceSharingSettings(t *testing.T) {
updated, err := db.UpdateOrganizationWorkspaceSharingSettings(ctx, database.UpdateOrganizationWorkspaceSharingSettingsParams{
ID: org.ID,
ShareableWorkspaceOwners: database.ShareableWorkspaceOwnersNone,
WorkspaceSharingDisabled: true,
UpdatedAt: dbtime.Now(),
})
require.NoError(t, err)
require.Equal(t, database.ShareableWorkspaceOwnersNone, updated.ShareableWorkspaceOwners)
require.True(t, updated.WorkspaceSharingDisabled)
got, err := db.GetOrganizationByID(ctx, org.ID)
require.NoError(t, err)
require.Equal(t, database.ShareableWorkspaceOwnersNone, got.ShareableWorkspaceOwners)
require.True(t, got.WorkspaceSharingDisabled)
}
func TestDeleteWorkspaceACLsByOrganization(t *testing.T) {
t.Parallel()
t.Run("DeletesAll", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
org1 := dbgen.Organization(t, db, database.Organization{})
org2 := dbgen.Organization(t, db, database.Organization{})
db, _ := dbtestutil.NewDB(t)
org1 := dbgen.Organization(t, db, database.Organization{})
org2 := dbgen.Organization(t, db, database.Organization{})
owner1 := dbgen.User(t, db, database.User{})
owner2 := dbgen.User(t, db, database.User{})
sharedUser := dbgen.User(t, db, database.User{})
sharedGroup := dbgen.Group(t, db, database.Group{
OrganizationID: org1.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org1.ID,
UserID: owner1.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org2.ID,
UserID: owner2.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org1.ID,
UserID: sharedUser.ID,
})
ws1 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: owner1.ID,
OrganizationID: org1.ID,
UserACL: database.WorkspaceACL{
sharedUser.ID.String(): {
Permissions: []policy.Action{policy.ActionRead},
},
},
GroupACL: database.WorkspaceACL{
sharedGroup.ID.String(): {
Permissions: []policy.Action{policy.ActionRead},
},
},
}).Do().Workspace
ws2 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: owner2.ID,
OrganizationID: org2.ID,
UserACL: database.WorkspaceACL{
uuid.NewString(): {
Permissions: []policy.Action{policy.ActionRead},
},
},
}).Do().Workspace
ctx := testutil.Context(t, testutil.WaitShort)
err := db.DeleteWorkspaceACLsByOrganization(ctx, database.DeleteWorkspaceACLsByOrganizationParams{
OrganizationID: org1.ID,
ExcludeServiceAccounts: false,
})
require.NoError(t, err)
got1, err := db.GetWorkspaceByID(ctx, ws1.ID)
require.NoError(t, err)
require.Empty(t, got1.UserACL)
require.Empty(t, got1.GroupACL)
got2, err := db.GetWorkspaceByID(ctx, ws2.ID)
require.NoError(t, err)
require.NotEmpty(t, got2.UserACL)
owner1 := dbgen.User(t, db, database.User{})
owner2 := dbgen.User(t, db, database.User{})
sharedUser := dbgen.User(t, db, database.User{})
sharedGroup := dbgen.Group(t, db, database.Group{
OrganizationID: org1.ID,
})
t.Run("ExcludesServiceAccounts", func(t *testing.T) {
t.Parallel()
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org1.ID,
UserID: owner1.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org2.ID,
UserID: owner2.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org1.ID,
UserID: sharedUser.ID,
})
db, _ := dbtestutil.NewDB(t)
org := dbgen.Organization(t, db, database.Organization{})
regularUser := dbgen.User(t, db, database.User{})
saUser := dbgen.User(t, db, database.User{IsServiceAccount: true})
sharedUser := dbgen.User(t, db, database.User{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org.ID,
UserID: regularUser.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org.ID,
UserID: saUser.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org.ID,
UserID: sharedUser.ID,
})
regularWS := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: regularUser.ID,
OrganizationID: org.ID,
UserACL: database.WorkspaceACL{
sharedUser.ID.String(): {
Permissions: []policy.Action{policy.ActionRead},
},
},
}).Do().Workspace
saWS := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: saUser.ID,
OrganizationID: org.ID,
UserACL: database.WorkspaceACL{
sharedUser.ID.String(): {
Permissions: []policy.Action{policy.ActionRead},
},
},
}).Do().Workspace
ctx := testutil.Context(t, testutil.WaitShort)
err := db.DeleteWorkspaceACLsByOrganization(ctx, database.DeleteWorkspaceACLsByOrganizationParams{
OrganizationID: org.ID,
ExcludeServiceAccounts: true,
})
require.NoError(t, err)
// Regular user workspace ACLs should be cleared.
gotRegular, err := db.GetWorkspaceByID(ctx, regularWS.ID)
require.NoError(t, err)
require.Empty(t, gotRegular.UserACL)
// Service account workspace ACLs should be preserved.
gotSA, err := db.GetWorkspaceByID(ctx, saWS.ID)
require.NoError(t, err)
require.Equal(t, database.WorkspaceACL{
ws1 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: owner1.ID,
OrganizationID: org1.ID,
UserACL: database.WorkspaceACL{
sharedUser.ID.String(): {
Permissions: []policy.Action{policy.ActionRead},
},
}, gotSA.UserACL)
})
},
GroupACL: database.WorkspaceACL{
sharedGroup.ID.String(): {
Permissions: []policy.Action{policy.ActionRead},
},
},
}).Do().Workspace
ws2 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: owner2.ID,
OrganizationID: org2.ID,
UserACL: database.WorkspaceACL{
uuid.NewString(): {
Permissions: []policy.Action{policy.ActionRead},
},
},
}).Do().Workspace
ctx := testutil.Context(t, testutil.WaitShort)
err := db.DeleteWorkspaceACLsByOrganization(ctx, org1.ID)
require.NoError(t, err)
got1, err := db.GetWorkspaceByID(ctx, ws1.ID)
require.NoError(t, err)
require.Empty(t, got1.UserACL)
require.Empty(t, got1.GroupACL)
got2, err := db.GetWorkspaceByID(ctx, ws2.ID)
require.NoError(t, err)
require.NotEmpty(t, got2.UserACL)
}
func TestAuthorizedAuditLogs(t *testing.T) {
File diff suppressed because it is too large Load Diff
-14
View File
@@ -53,14 +53,6 @@ INSERT INTO aibridge_tool_usages (
)
RETURNING *;
-- name: InsertAIBridgeModelThought :one
INSERT INTO aibridge_model_thoughts (
interception_id, content, metadata, created_at
) VALUES (
@interception_id, @content, COALESCE(@metadata::jsonb, '{}'::jsonb), @created_at
)
RETURNING *;
-- name: GetAIBridgeInterceptionByID :one
SELECT
*
@@ -370,11 +362,6 @@ WITH
WHERE started_at < @before_time::timestamp with time zone
),
-- CTEs are executed in order.
model_thoughts AS (
DELETE FROM aibridge_model_thoughts
WHERE interception_id IN (SELECT id FROM to_delete)
RETURNING 1
),
tool_usages AS (
DELETE FROM aibridge_tool_usages
WHERE interception_id IN (SELECT id FROM to_delete)
@@ -397,7 +384,6 @@ WITH
)
-- Cumulative count.
SELECT (
(SELECT COUNT(*) FROM model_thoughts) +
(SELECT COUNT(*) FROM tool_usages) +
(SELECT COUNT(*) FROM token_usages) +
(SELECT COUNT(*) FROM user_prompts) +
+2 -132
View File
@@ -113,16 +113,13 @@ ORDER BY
created_at ASC,
id ASC;
-- name: GetChats :many
-- name: GetChatsByOwnerID :many
SELECT
*
FROM
chats
WHERE
CASE
WHEN @owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN chats.owner_id = @owner_id
ELSE true
END
owner_id = @owner_id::uuid
AND CASE
WHEN sqlc.narg('archived') :: boolean IS NULL THEN true
ELSE chats.archived = sqlc.narg('archived') :: boolean
@@ -146,8 +143,6 @@ WHERE
)
ELSE true
END
-- Authorize Filter clause will be injected below in GetAuthorizedChats
-- @authorize_filter
ORDER BY
-- Deterministic and consistent ordering of all rows, even if they share
-- a timestamp. This is to ensure consistent pagination.
@@ -705,128 +700,3 @@ LIMIT
sqlc.arg('page_limit')::int
OFFSET
sqlc.arg('page_offset')::int;
-- name: GetChatUsageLimitConfig :one
SELECT * FROM chat_usage_limit_config WHERE singleton = TRUE LIMIT 1;
-- name: UpsertChatUsageLimitConfig :one
INSERT INTO chat_usage_limit_config (singleton, enabled, default_limit_micros, period, updated_at)
VALUES (TRUE, @enabled::boolean, @default_limit_micros::bigint, @period::text, NOW())
ON CONFLICT (singleton) DO UPDATE SET
enabled = EXCLUDED.enabled,
default_limit_micros = EXCLUDED.default_limit_micros,
period = EXCLUDED.period,
updated_at = NOW()
RETURNING *;
-- name: ListChatUsageLimitOverrides :many
SELECT u.id AS user_id, u.username, u.name, u.avatar_url,
u.chat_spend_limit_micros AS spend_limit_micros
FROM users u
WHERE u.chat_spend_limit_micros IS NOT NULL
ORDER BY u.username ASC;
-- name: UpsertChatUsageLimitUserOverride :one
UPDATE users
SET chat_spend_limit_micros = @spend_limit_micros::bigint
WHERE id = @user_id::uuid
RETURNING id AS user_id, username, name, avatar_url, chat_spend_limit_micros AS spend_limit_micros;
-- name: DeleteChatUsageLimitUserOverride :exec
UPDATE users SET chat_spend_limit_micros = NULL WHERE id = @user_id::uuid;
-- name: GetChatUsageLimitUserOverride :one
SELECT id AS user_id, chat_spend_limit_micros AS spend_limit_micros
FROM users
WHERE id = @user_id::uuid AND chat_spend_limit_micros IS NOT NULL;
-- name: GetUserChatSpendInPeriod :one
SELECT COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_spend_micros
FROM chat_messages cm
JOIN chats c ON c.id = cm.chat_id
WHERE c.owner_id = @user_id::uuid
AND cm.created_at >= @start_time::timestamptz
AND cm.created_at < @end_time::timestamptz
AND cm.total_cost_micros IS NOT NULL;
-- name: CountEnabledModelsWithoutPricing :one
-- Counts enabled, non-deleted model configs that lack both input and
-- output pricing in their JSONB options.cost configuration.
SELECT COUNT(*)::bigint AS count
FROM chat_model_configs
WHERE enabled = TRUE
AND deleted = FALSE
AND (
options->'cost' IS NULL
OR options->'cost' = 'null'::jsonb
OR (
(options->'cost'->>'input_price_per_million_tokens' IS NULL)
AND (options->'cost'->>'output_price_per_million_tokens' IS NULL)
)
);
-- name: ListChatUsageLimitGroupOverrides :many
SELECT
g.id AS group_id,
g.name AS group_name,
g.display_name AS group_display_name,
g.avatar_url AS group_avatar_url,
g.chat_spend_limit_micros AS spend_limit_micros,
(SELECT COUNT(*)
FROM group_members_expanded gme
WHERE gme.group_id = g.id
AND gme.user_is_system = FALSE) AS member_count
FROM groups g
WHERE g.chat_spend_limit_micros IS NOT NULL
ORDER BY g.name ASC;
-- name: UpsertChatUsageLimitGroupOverride :one
UPDATE groups
SET chat_spend_limit_micros = @spend_limit_micros::bigint
WHERE id = @group_id::uuid
RETURNING id AS group_id, name, display_name, avatar_url, chat_spend_limit_micros AS spend_limit_micros;
-- name: DeleteChatUsageLimitGroupOverride :exec
UPDATE groups SET chat_spend_limit_micros = NULL WHERE id = @group_id::uuid;
-- name: GetChatUsageLimitGroupOverride :one
SELECT id AS group_id, chat_spend_limit_micros AS spend_limit_micros
FROM groups
WHERE id = @group_id::uuid AND chat_spend_limit_micros IS NOT NULL;
-- name: GetUserGroupSpendLimit :one
-- Returns the minimum (most restrictive) group limit for a user.
-- Returns -1 if the user has no group limits applied.
SELECT COALESCE(MIN(g.chat_spend_limit_micros), -1)::bigint AS limit_micros
FROM groups g
JOIN group_members_expanded gme ON gme.group_id = g.id
WHERE gme.user_id = @user_id::uuid
AND g.chat_spend_limit_micros IS NOT NULL;
-- name: ResolveUserChatSpendLimit :one
-- Resolves the effective spend limit for a user using the hierarchy:
-- 1. Individual user override (highest priority)
-- 2. Minimum group limit across all user's groups
-- 3. Global default from config
-- Returns -1 if limits are not enabled.
SELECT CASE
-- If limits are disabled, return -1.
WHEN NOT cfg.enabled THEN -1
-- Individual override takes priority.
WHEN u.chat_spend_limit_micros IS NOT NULL THEN u.chat_spend_limit_micros
-- Group limit (minimum across all user's groups) is next.
WHEN gl.limit_micros IS NOT NULL THEN gl.limit_micros
-- Fall back to global default.
ELSE cfg.default_limit_micros
END::bigint AS effective_limit_micros
FROM chat_usage_limit_config cfg
CROSS JOIN users u
LEFT JOIN LATERAL (
SELECT MIN(g.chat_spend_limit_micros) AS limit_micros
FROM groups g
JOIN group_members_expanded gme ON gme.group_id = g.id
WHERE gme.user_id = @user_id::uuid
AND g.chat_spend_limit_micros IS NOT NULL
) gl ON TRUE
WHERE u.id = @user_id::uuid
LIMIT 1;
+1 -1
View File
@@ -147,7 +147,7 @@ WHERE
UPDATE
organizations
SET
shareable_workspace_owners = @shareable_workspace_owners,
workspace_sharing_disabled = @workspace_sharing_disabled,
updated_at = @updated_at
WHERE
id = @id
@@ -67,6 +67,14 @@ WHERE
id = $1
FOR UPDATE;
-- name: GetProvisionerJobsByIDs :many
SELECT
*
FROM
provisioner_jobs
WHERE
id = ANY(@ids :: uuid [ ]);
-- name: GetProvisionerJobsByIDsWithQueuePosition :many
WITH filtered_provisioner_jobs AS (
-- Step 1: Filter provisioner_jobs
-20
View File
@@ -140,23 +140,3 @@ SELECT
-- name: UpsertChatSystemPrompt :exec
INSERT INTO site_configs (key, value) VALUES ('agents_chat_system_prompt', $1)
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_system_prompt';
-- name: GetChatDesktopEnabled :one
SELECT
COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_desktop_enabled'), false) :: boolean AS enable_desktop;
-- name: UpsertChatDesktopEnabled :exec
INSERT INTO site_configs (key, value)
VALUES (
'agents_desktop_enabled',
CASE
WHEN sqlc.arg(enable_desktop)::bool THEN 'true'
ELSE 'false'
END
)
ON CONFLICT (key) DO UPDATE
SET value = CASE
WHEN sqlc.arg(enable_desktop)::bool THEN 'true'
ELSE 'false'
END
WHERE site_configs.key = 'agents_desktop_enabled';
+2 -14
View File
@@ -391,21 +391,9 @@ SELECT
array_agg(org_roles || ':' || organization_members.organization_id::text)
FROM
organization_members,
-- All org members get an implied role for their orgs. Most members
-- get organization-member, but service accounts will get
-- organization-service-account instead. They're largely the same,
-- but having them be distinct means we can allow configuring
-- service-accounts to have slightly broader permissionssuch as
-- for workspace sharing.
-- All org_members get the organization-member role for their orgs
unnest(
array_append(
roles,
CASE WHEN users.is_service_account THEN
'organization-service-account'
ELSE
'organization-member'
END
)
array_append(roles, 'organization-member')
) AS org_roles
WHERE
user_id = users.id
+1 -7
View File
@@ -955,13 +955,7 @@ SET
group_acl = '{}'::jsonb,
user_acl = '{}'::jsonb
WHERE
organization_id = @organization_id
AND (
NOT @exclude_service_accounts::boolean
OR owner_id NOT IN (
SELECT id FROM users WHERE is_service_account = true
)
);
organization_id = @organization_id;
-- name: GetRegularWorkspaceCreateMetrics :many
-- Count regular workspaces: only those whose first successful 'start' build
-1
View File
@@ -235,7 +235,6 @@ sql:
aibridge_tool_usage: AIBridgeToolUsage
aibridge_token_usage: AIBridgeTokenUsage
aibridge_user_prompt: AIBridgeUserPrompt
aibridge_model_thought: AIBridgeModelThought
rules:
- name: do-not-use-public-schema-in-queries
message: "do not use public schema in queries"
+15 -15
View File
@@ -20,28 +20,28 @@ import (
// AuditOAuthConvertState is never stored in the database. It is stored in a cookie
// clientside as a JWT. This type is provided for audit logging purposes.
type AuditOAuthConvertState struct {
CreatedAt time.Time `json:"created_at" db:"created_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
// The time at which the state string expires, a merge request times out if the user does not perform it quick enough.
ExpiresAt time.Time `json:"expires_at" db:"expires_at"`
FromLoginType LoginType `json:"from_login_type" db:"from_login_type"`
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
FromLoginType LoginType `db:"from_login_type" json:"from_login_type"`
// The login type the user is converting to. Should be github or oidc.
ToLoginType LoginType `json:"to_login_type" db:"to_login_type"`
UserID uuid.UUID `json:"user_id" db:"user_id"`
ToLoginType LoginType `db:"to_login_type" json:"to_login_type"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
}
type HealthSettings struct {
ID uuid.UUID `json:"id" db:"id"`
DismissedHealthchecks []string `json:"dismissed_healthchecks" db:"dismissed_healthchecks"`
ID uuid.UUID `db:"id" json:"id"`
DismissedHealthchecks []string `db:"dismissed_healthchecks" json:"dismissed_healthchecks"`
}
type NotificationsSettings struct {
ID uuid.UUID `json:"id" db:"id"`
NotifierPaused bool `json:"notifier_paused" db:"notifier_paused"`
ID uuid.UUID `db:"id" json:"id"`
NotifierPaused bool `db:"notifier_paused" json:"notifier_paused"`
}
type PrebuildsSettings struct {
ID uuid.UUID `json:"id" db:"id"`
ReconciliationPaused bool `json:"reconciliation_paused" db:"reconciliation_paused"`
ID uuid.UUID `db:"id" json:"id"`
ReconciliationPaused bool `db:"reconciliation_paused" json:"reconciliation_paused"`
}
type Actions []policy.Action
@@ -237,9 +237,9 @@ func (a CustomRolePermission) String() string {
// NameOrganizationPair is used as a lookup tuple for custom role rows.
type NameOrganizationPair struct {
Name string `json:"name" db:"name"`
Name string `db:"name" json:"name"`
// OrganizationID if unset will assume a null column value
OrganizationID uuid.UUID `json:"organization_id" db:"organization_id"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
}
func (*NameOrganizationPair) Scan(_ interface{}) error {
@@ -264,8 +264,8 @@ func (a NameOrganizationPair) Value() (driver.Value, error) {
// AgentIDNamePair is used as a result tuple for workspace and agent rows.
type AgentIDNamePair struct {
ID uuid.UUID `json:"id" db:"id"`
Name string `json:"name" db:"name"`
ID uuid.UUID `db:"id" json:"id"`
Name string `db:"name" json:"name"`
}
func (p *AgentIDNamePair) Scan(src interface{}) error {
-2
View File
@@ -22,8 +22,6 @@ const (
UniqueChatProvidersPkey UniqueConstraint = "chat_providers_pkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id);
UniqueChatProvidersProviderKey UniqueConstraint = "chat_providers_provider_key" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider);
UniqueChatQueuedMessagesPkey UniqueConstraint = "chat_queued_messages_pkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
UniqueChatUsageLimitConfigPkey UniqueConstraint = "chat_usage_limit_config_pkey" // ALTER TABLE ONLY chat_usage_limit_config ADD CONSTRAINT chat_usage_limit_config_pkey PRIMARY KEY (id);
UniqueChatUsageLimitConfigSingletonKey UniqueConstraint = "chat_usage_limit_config_singleton_key" // ALTER TABLE ONLY chat_usage_limit_config ADD CONSTRAINT chat_usage_limit_config_singleton_key UNIQUE (singleton);
UniqueChatsPkey UniqueConstraint = "chats_pkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_pkey PRIMARY KEY (id);
UniqueConnectionLogsPkey UniqueConstraint = "connection_logs_pkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_pkey PRIMARY KEY (id);
UniqueCryptoKeysPkey UniqueConstraint = "crypto_keys_pkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_pkey PRIMARY KEY (feature, sequence);
+3 -3
View File
@@ -48,8 +48,8 @@ type Store interface {
UpsertChatDiffStatusReference(
ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams,
) (database.ChatDiffStatus, error)
GetChats(
ctx context.Context, arg database.GetChatsParams,
GetChatsByOwnerID(
ctx context.Context, arg database.GetChatsByOwnerIDParams,
) ([]database.Chat, error)
}
@@ -250,7 +250,7 @@ func (w *Worker) MarkStale(
return
}
chats, err := w.store.GetChats(ctx, database.GetChatsParams{
chats, err := w.store.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
OwnerID: ownerID,
})
if err != nil {
+12 -11
View File
@@ -469,8 +469,8 @@ func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) {
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
require.Equal(t, ownerID, arg.OwnerID)
return []database.Chat{
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
@@ -478,12 +478,13 @@ func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) {
{ID: chatOther, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
}, nil
})
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
mu.Lock()
upsertRefCalls = append(upsertRefCalls, arg)
mu.Unlock()
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
}).Times(2)
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
mu.Lock()
upsertRefCalls = append(upsertRefCalls, arg)
mu.Unlock()
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
}).Times(2)
pub := func(_ context.Context, chatID uuid.UUID) error {
mu.Lock()
@@ -526,7 +527,7 @@ func TestWorker_MarkStale_NoMatchingChats(t *testing.T) {
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
Return([]database.Chat{
{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
@@ -554,7 +555,7 @@ func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) {
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
Return([]database.Chat{
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
{ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
@@ -589,7 +590,7 @@ func TestWorker_MarkStale_GetChatsFails(t *testing.T) {
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
Return(nil, fmt.Errorf("db error"))
mClock := quartz.NewMock(t)
-9
View File
@@ -688,15 +688,6 @@ func ConfigWithoutACL() regosql.ConvertConfig {
}
}
// ConfigChats is the configuration for converting rego to SQL when
// the target table is "chats", which has no organization_id or ACL
// columns.
func ConfigChats() regosql.ConvertConfig {
return regosql.ConvertConfig{
VariableConverter: regosql.ChatConverter(),
}
}
func ConfigWorkspaces() regosql.ConvertConfig {
return regosql.ConvertConfig{
VariableConverter: regosql.WorkspaceConverter(),
+4 -4
View File
@@ -1404,8 +1404,8 @@ func testAuthorize(t *testing.T, name string, subject Subject, sets ...[]authTes
// RoleByName won't resolve it here. Assume the default behavior: workspace
// sharing enabled.
func orgMemberRole(orgID uuid.UUID) Role {
settings := OrgSettings{ShareableWorkspaceOwners: ShareableWorkspaceOwnersEveryone}
perms := OrgMemberPermissions(settings)
workspaceSharingDisabled := false
orgPerms, memberPerms := OrgMemberPermissions(workspaceSharingDisabled)
return Role{
Identifier: ScopedRoleOrgMember(orgID),
DisplayName: "",
@@ -1413,8 +1413,8 @@ func orgMemberRole(orgID uuid.UUID) Role {
User: []Permission{},
ByOrgID: map[string]OrgPermissions{
orgID.String(): {
Org: perms.Org,
Member: perms.Member,
Org: orgPerms,
Member: memberPerms,
},
},
}
+2 -2
View File
@@ -38,8 +38,8 @@ type Object struct {
// Type is "workspace", "project", "app", etc
Type string `json:"type"`
ACLUserList map[string][]policy.Action `json:"acl_user_list"`
ACLGroupList map[string][]policy.Action `json:"acl_group_list"`
ACLUserList map[string][]policy.Action ` json:"acl_user_list"`
ACLGroupList map[string][]policy.Action ` json:"acl_group_list"`
}
// String is not perfect, but decent enough for human display
-16
View File
@@ -282,22 +282,6 @@ neq(input.object.owner, "");
p("'10d03e62-7703-4df5-a358-4f76577d4e2f' = id :: text") + " AND " + p("id :: text != ''") + " AND " + p("'' = ''"),
),
},
{
Name: "ChatOwnerMe",
Queries: []string{
`"me" = input.object.owner; input.object.owner != ""; input.object.org_owner = ""`,
},
ExpectedSQL: p(p("'me' = owner_id :: text") + " AND " + p("owner_id :: text != ''") + " AND " + p("'' = ''")),
VariableConverter: regosql.ChatConverter(),
},
{
Name: "ChatOrgScopedNeverMatches",
Queries: []string{
`input.object.org_owner = "org-id"`,
},
ExpectedSQL: p("'' = 'org-id'"),
VariableConverter: regosql.ChatConverter(),
},
}
for _, tc := range testCases {
-24
View File
@@ -126,30 +126,6 @@ func NoACLConverter() *sqltypes.VariableConverter {
return matcher
}
// ChatConverter should be used for the chats table, which has no
// organization_id, group_acl, or user_acl columns.
func ChatConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
// The chats table has no organization_id column. Map org_owner
// to a literal empty string so that:
// - User-level ownership checks (org_owner = '') activate correctly.
// - Org-scoped permissions never match (org_owner will never equal
// a real org UUID), which is intentional since chats are not
// org-scoped resources.
// Note: custom org roles that include "chat" permissions will
// silently have no effect because of this mapping.
sqltypes.StringVarMatcher("''", []string{"input", "object", "org_owner"}),
userOwnerMatcher(),
)
matcher.RegisterMatcher(
sqltypes.AlwaysFalse(groupACLMatcher(matcher)),
sqltypes.AlwaysFalse(userACLMatcher(matcher)),
)
return matcher
}
func DefaultVariableConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
+79 -150
View File
@@ -29,7 +29,6 @@ const (
orgAdmin string = "organization-admin"
orgMember string = "organization-member"
orgServiceAccount string = "organization-service-account"
orgAuditor string = "organization-auditor"
orgUserAdmin string = "organization-user-admin"
orgTemplateAdmin string = "organization-template-admin"
@@ -151,10 +150,6 @@ func RoleOrgMember() string {
return orgMember
}
func RoleOrgServiceAccount() string {
return orgServiceAccount
}
func RoleOrgAuditor() string {
return orgAuditor
}
@@ -234,16 +229,31 @@ func allPermsExcept(excepts ...Objecter) []Permission {
// https://github.com/coder/coder/issues/1194
var builtInRoles map[string]func(orgID uuid.UUID) Role
// systemRoles are roles that have migrated from builtInRoles to
// database storage. This migration is partial - permissions are still
// generated at runtime and reconciled to the database, rather than
// the database being the source of truth.
var systemRoles = map[string]struct{}{
RoleOrgMember(): {},
}
func SystemRoleName(name string) bool {
_, ok := systemRoles[name]
return ok
}
type RoleOptions struct {
NoOwnerWorkspaceExec bool
NoWorkspaceSharing bool
}
// ReservedRoleName exists because the database should only allow unique role
// names, but some roles are built in. So these names are reserved
// names, but some roles are built in or generated at runtime. So these names
// are reserved
func ReservedRoleName(name string) bool {
_, ok := builtInRoles[name]
return ok
_, isBuiltIn := builtInRoles[name]
_, isSystem := systemRoles[name]
return isBuiltIn || isSystem
}
// ReloadBuiltinRoles loads the static roles into the builtInRoles map.
@@ -928,32 +938,21 @@ func PermissionsEqual(a, b []Permission) bool {
return len(setA) == len(setB)
}
// OrgSettings carries organization-level settings that affect system
// role permissions. It lives in the rbac package to avoid a cyclic
// dependency with the database package. Callers in rolestore map
// database.Organization fields onto this struct.
type OrgSettings struct {
ShareableWorkspaceOwners ShareableWorkspaceOwners
}
type ShareableWorkspaceOwners string
const (
ShareableWorkspaceOwnersNone ShareableWorkspaceOwners = "none"
ShareableWorkspaceOwnersEveryone ShareableWorkspaceOwners = "everyone"
ShareableWorkspaceOwnersServiceAccounts ShareableWorkspaceOwners = "service_accounts"
)
// OrgRolePermissions holds the two permission sets that make up a
// system role: org-wide permissions and member-scoped permissions.
type OrgRolePermissions struct {
Org []Permission
Member []Permission
}
// OrgMemberPermissions returns the permissions for the organization-member
// system role, which can vary based on the organization's workspace sharing
// settings.
func OrgMemberPermissions(org OrgSettings) OrgRolePermissions {
// system role. The results are then stored in the database and can vary per
// organization based on the workspace_sharing_disabled setting.
// This is the source of truth for org-member permissions, used by:
// - the startup reconciliation routine, to keep permissions current with
// RBAC resources
// - the organization workspace sharing setting endpoint, when updating
// the setting
// - the org creation endpoint, when populating the organization-member
// system role created by the DB trigger
//
//nolint:revive // workspaceSharingDisabled is an org setting
func OrgMemberPermissions(workspaceSharingDisabled bool) (
orgPerms, memberPerms []Permission,
) {
// Organization-level permissions that all org members get.
orgPermMap := map[string][]policy.Action{
// All users can see provisioner daemons for workspace creation.
@@ -964,25 +963,58 @@ func OrgMemberPermissions(org OrgSettings) OrgRolePermissions {
ResourceAssignOrgRole.Type: {policy.ActionRead},
}
// In all modes of workspace sharing but `none`, members need to
// see other org members (including service accounts) to either
// share with them or get access to their shared workspaces,
// resolved through GET /users/{user}/workspace/{workspace}
if org.ShareableWorkspaceOwners != ShareableWorkspaceOwnersNone {
// When workspace sharing is enabled, members need to see other org members
// and groups to share workspaces with them.
if !workspaceSharingDisabled {
orgPermMap[ResourceOrganizationMember.Type] = []policy.Action{policy.ActionRead}
}
// When workspace sharing is open to members, they also need to
// see org groups to share with them.
if org.ShareableWorkspaceOwners == ShareableWorkspaceOwnersEveryone {
orgPermMap[ResourceGroup.Type] = []policy.Action{policy.ActionRead}
}
orgPerms := Permissions(orgPermMap)
orgPerms = Permissions(orgPermMap)
if org.ShareableWorkspaceOwners == ShareableWorkspaceOwnersNone {
// Member-scoped permissions (resources owned by the member).
// Uses allPermsExcept to automatically include permissions for new resources.
memberPerms = append(
allPermsExcept(
ResourceWorkspaceDormant,
ResourcePrebuiltWorkspace,
ResourceUser,
ResourceOrganizationMember,
),
Permissions(map[string][]policy.Action{
// Reduced permission set on dormant workspaces. No build,
// ssh, or exec.
ResourceWorkspaceDormant.Type: {
policy.ActionRead,
policy.ActionDelete,
policy.ActionCreate,
policy.ActionUpdate,
policy.ActionWorkspaceStop,
policy.ActionCreateAgent,
policy.ActionDeleteAgent,
policy.ActionUpdateAgent,
},
// Can read their own organization member record.
ResourceOrganizationMember.Type: {
policy.ActionRead,
},
// Users can create provisioner daemons scoped to themselves.
//
// TODO(geokat): copied from the original built-in role
// verbatim, but seems to be a no-op (not excepted above;
// plus no owner is set for the ProvisionerDaemon RBAC
// object).
ResourceProvisionerDaemon.Type: {
policy.ActionRead,
policy.ActionCreate,
policy.ActionUpdate,
},
})...,
)
if workspaceSharingDisabled {
// Org-level negation blocks sharing on ANY workspace in the
// org. This overrides any positive permission from other
// org. This overrides any positive permission from other
// roles, including org-admin.
orgPerms = append(orgPerms, Permission{
Negate: true,
@@ -991,108 +1023,5 @@ func OrgMemberPermissions(org OrgSettings) OrgRolePermissions {
})
}
// Uses allPermsExcept to automatically include permissions for new resources.
memberPerms := append(
allPermsExcept(
ResourceWorkspaceDormant,
ResourcePrebuiltWorkspace,
ResourceUser,
ResourceOrganizationMember,
),
Permissions(map[string][]policy.Action{
// Reduced permission set on dormant workspaces. No build,
// ssh, or exec.
ResourceWorkspaceDormant.Type: {
policy.ActionRead,
policy.ActionDelete,
policy.ActionCreate,
policy.ActionUpdate,
policy.ActionWorkspaceStop,
policy.ActionCreateAgent,
policy.ActionDeleteAgent,
policy.ActionUpdateAgent,
},
// Can read their own organization member record.
ResourceOrganizationMember.Type: {
policy.ActionRead,
},
})...,
)
if org.ShareableWorkspaceOwners != ShareableWorkspaceOwnersEveryone {
memberPerms = append(memberPerms, Permission{
Negate: true,
ResourceType: ResourceWorkspace.Type,
Action: policy.ActionShare,
})
}
return OrgRolePermissions{Org: orgPerms, Member: memberPerms}
}
// OrgServiceAccountPermissions returns the permissions for the
// organization-service-account system role, which can vary based on
// the organization's workspace sharing settings.
func OrgServiceAccountPermissions(org OrgSettings) OrgRolePermissions {
// Organization-level permissions that all org service accounts get.
orgPermMap := map[string][]policy.Action{
// All users can see provisioner daemons for workspace creation.
ResourceProvisionerDaemon.Type: {policy.ActionRead},
// All org members can read the organization.
ResourceOrganization.Type: {policy.ActionRead},
// Can read available roles.
ResourceAssignOrgRole.Type: {policy.ActionRead},
}
// When workspace sharing is enabled, service accounts need to see
// other org members and groups to share workspaces with them.
if org.ShareableWorkspaceOwners != ShareableWorkspaceOwnersNone {
orgPermMap[ResourceOrganizationMember.Type] = []policy.Action{policy.ActionRead}
orgPermMap[ResourceGroup.Type] = []policy.Action{policy.ActionRead}
}
orgPerms := Permissions(orgPermMap)
if org.ShareableWorkspaceOwners == ShareableWorkspaceOwnersNone {
// Org-level negation blocks sharing on ANY workspace in the
// org. If a service account has any other roles assigned,
// this negation will override any positive perms in them, too.
orgPerms = append(orgPerms, Permission{
Negate: true,
ResourceType: ResourceWorkspace.Type,
Action: policy.ActionShare,
})
}
// service account-scoped permissions (resources owned by the
// service account). Uses allPermsExcept to automatically include
// permissions for new resources.
memberPerms := append(
allPermsExcept(
ResourceWorkspaceDormant,
ResourcePrebuiltWorkspace,
ResourceUser,
ResourceOrganizationMember,
),
Permissions(map[string][]policy.Action{
// Reduced permission set on dormant workspaces. No build,
// ssh, or exec.
ResourceWorkspaceDormant.Type: {
policy.ActionRead,
policy.ActionDelete,
policy.ActionCreate,
policy.ActionUpdate,
policy.ActionWorkspaceStop,
policy.ActionCreateAgent,
policy.ActionDeleteAgent,
policy.ActionUpdateAgent,
},
// Can read their own organization member record.
ResourceOrganizationMember.Type: {
policy.ActionRead,
},
})...,
)
return OrgRolePermissions{Org: orgPerms, Member: memberPerms}
return orgPerms, memberPerms
}
+40 -54
View File
@@ -51,68 +51,54 @@ func TestBuiltInRoles(t *testing.T) {
}
}
// permissionGranted checks whether a permission list contains a
// matching entry for the target, accounting for wildcard actions.
// It does not evaluate negations that may override a positive grant.
func permissionGranted(perms []rbac.Permission, target rbac.Permission) bool {
return slices.ContainsFunc(perms, func(p rbac.Permission) bool {
return p.Negate == target.Negate &&
p.ResourceType == target.ResourceType &&
(p.Action == target.Action || p.Action == policy.WildcardSymbol)
})
}
func TestOrgSharingPermissions(t *testing.T) {
func TestSystemRolesAreReservedRoleNames(t *testing.T) {
t.Parallel()
tests := []struct {
name string
permsFunc func(rbac.OrgSettings) rbac.OrgRolePermissions
mode rbac.ShareableWorkspaceOwners
orgReadMembers bool
orgReadGroups bool
orgNegateShare bool
memberNegateShare bool
}{
{"Member/Everyone", rbac.OrgMemberPermissions, rbac.ShareableWorkspaceOwnersEveryone, true, true, false, false},
{"Member/None", rbac.OrgMemberPermissions, rbac.ShareableWorkspaceOwnersNone, false, false, true, true},
{"Member/ServiceAccounts", rbac.OrgMemberPermissions, rbac.ShareableWorkspaceOwnersServiceAccounts, true, false, false, true},
{"ServiceAccount/Everyone", rbac.OrgServiceAccountPermissions, rbac.ShareableWorkspaceOwnersEveryone, true, true, false, false},
{"ServiceAccount/None", rbac.OrgServiceAccountPermissions, rbac.ShareableWorkspaceOwnersNone, false, false, true, false},
{"ServiceAccount/ServiceAccounts", rbac.OrgServiceAccountPermissions, rbac.ShareableWorkspaceOwnersServiceAccounts, true, true, false, false},
}
require.True(t, rbac.ReservedRoleName(rbac.RoleOrgMember()))
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
func TestOrgMemberPermissions(t *testing.T) {
t.Parallel()
perms := tt.permsFunc(rbac.OrgSettings{
ShareableWorkspaceOwners: tt.mode,
})
t.Run("WorkspaceSharingEnabled", func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.orgReadMembers, permissionGranted(perms.Org, rbac.Permission{
ResourceType: rbac.ResourceOrganizationMember.Type,
Action: policy.ActionRead,
}), "org read members")
orgPerms, _ := rbac.OrgMemberPermissions(false)
assert.Equal(t, tt.orgReadGroups, permissionGranted(perms.Org, rbac.Permission{
ResourceType: rbac.ResourceGroup.Type,
Action: policy.ActionRead,
}), "org read groups")
require.True(t, slices.Contains(orgPerms, rbac.Permission{
ResourceType: rbac.ResourceOrganizationMember.Type,
Action: policy.ActionRead,
}))
require.True(t, slices.Contains(orgPerms, rbac.Permission{
ResourceType: rbac.ResourceGroup.Type,
Action: policy.ActionRead,
}))
require.False(t, slices.Contains(orgPerms, rbac.Permission{
Negate: true,
ResourceType: rbac.ResourceWorkspace.Type,
Action: policy.ActionShare,
}))
})
assert.Equal(t, tt.orgNegateShare, permissionGranted(perms.Org, rbac.Permission{
Negate: true,
ResourceType: rbac.ResourceWorkspace.Type,
Action: policy.ActionShare,
}), "org negate share")
t.Run("WorkspaceSharingDisabled", func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.memberNegateShare, permissionGranted(perms.Member, rbac.Permission{
Negate: true,
ResourceType: rbac.ResourceWorkspace.Type,
Action: policy.ActionShare,
}), "member negate share")
})
}
orgPerms, _ := rbac.OrgMemberPermissions(true)
require.False(t, slices.Contains(orgPerms, rbac.Permission{
ResourceType: rbac.ResourceOrganizationMember.Type,
Action: policy.ActionRead,
}))
require.False(t, slices.Contains(orgPerms, rbac.Permission{
ResourceType: rbac.ResourceGroup.Type,
Action: policy.ActionRead,
}))
require.True(t, slices.Contains(orgPerms, rbac.Permission{
Negate: true,
ResourceType: rbac.ResourceWorkspace.Type,
Action: policy.ActionShare,
}))
})
}
//nolint:tparallel,paralleltest
+59 -105
View File
@@ -2,7 +2,6 @@ package rolestore
import (
"context"
"maps"
"net/http"
"github.com/google/uuid"
@@ -162,28 +161,13 @@ func ConvertDBRole(dbRole database.CustomRole) (rbac.Role, error) {
return role, nil
}
// System roles are defined in code but stored in the database,
// allowing their permissions to be adjusted per-organization at
// runtime based on org settings (e.g. workspace sharing).
var systemRoles = map[string]permissionsFunc{
rbac.RoleOrgMember(): rbac.OrgMemberPermissions,
rbac.RoleOrgServiceAccount(): rbac.OrgServiceAccountPermissions,
}
// permissionsFunc produces the desired permissions for a system role
// given organization settings.
type permissionsFunc func(rbac.OrgSettings) rbac.OrgRolePermissions
func IsSystemRoleName(name string) bool {
_, ok := systemRoles[name]
return ok
}
var SystemRoleNames = maps.Keys(systemRoles)
// ReconcileSystemRoles ensures that every organization's system roles
// in the DB are up-to-date with the current RBAC definitions and
// organization settings.
// ReconcileSystemRoles ensures that every organization's org-member
// system role in the DB is up-to-date with permissions reflecting
// current RBAC resources and the organization's
// workspace_sharing_disabled setting. Uses PostgreSQL advisory lock
// (LockIDReconcileSystemRoles) to safely handle multi-instance
// deployments. Uses set-based comparison to avoid unnecessary
// database writes when permissions haven't changed.
func ReconcileSystemRoles(ctx context.Context, log slog.Logger, db database.Store) error {
return db.InTx(func(tx database.Store) error {
// Acquire advisory lock to prevent concurrent updates from
@@ -209,45 +193,36 @@ func ReconcileSystemRoles(ctx context.Context, log slog.Logger, db database.Stor
return xerrors.Errorf("fetch custom roles: %w", err)
}
// Index system roles by (org ID, role name) for quick lookup.
type orgRoleKey struct {
OrgID uuid.UUID
RoleName string
}
roleIndex := make(map[orgRoleKey]database.CustomRole)
// Find org-member roles and index by organization ID for quick lookup.
rolesByOrg := make(map[uuid.UUID]database.CustomRole)
for _, role := range customRoles {
if role.IsSystem && IsSystemRoleName(role.Name) && role.OrganizationID.Valid {
roleIndex[orgRoleKey{role.OrganizationID.UUID, role.Name}] = role
if role.IsSystem && role.Name == rbac.RoleOrgMember() && role.OrganizationID.Valid {
rolesByOrg[role.OrganizationID.UUID] = role
}
}
for _, org := range orgs {
for roleName := range systemRoles {
role, exists := roleIndex[orgRoleKey{org.ID, roleName}]
if !exists {
// Something is very wrong: the role should have been
// created by the db trigger or migration. Log loudly and
// try creating it as a last-ditch effort before giving up.
log.Critical(ctx, "missing system role; trying to re-create",
slog.F("organization_id", org.ID),
slog.F("role_name", roleName))
role, exists := rolesByOrg[org.ID]
if !exists {
// Something is very wrong: the role should have been created by the
// database trigger or migration. Log loudly and try creating it as
// a last-ditch effort before giving up.
log.Critical(ctx, "missing organization-member system role; trying to re-create",
slog.F("organization_id", org.ID))
err := CreateSystemRole(ctx, tx, org, roleName)
if err != nil {
return xerrors.Errorf("create missing %s system role for organization %s: %w",
roleName, org.ID, err)
}
// Nothing more to do; the new role's permissions are
// up-to-date.
continue
if err := CreateOrgMemberRole(ctx, tx, org); err != nil {
return xerrors.Errorf("create missing organization-member role for organization %s: %w",
org.ID, err)
}
_, _, err := ReconcileSystemRole(ctx, tx, role, org)
if err != nil {
return xerrors.Errorf("reconcile %s system role for organization %s: %w",
roleName, org.ID, err)
}
// Nothing more to do; the new role's permissions are up-to-date.
continue
}
_, _, err := ReconcileOrgMemberRole(ctx, tx, role, org.WorkspaceSharingDisabled)
if err != nil {
return xerrors.Errorf("reconcile organization-member role for organization %s: %w",
org.ID, err)
}
}
@@ -255,30 +230,28 @@ func ReconcileSystemRoles(ctx context.Context, log slog.Logger, db database.Stor
}, nil)
}
// ReconcileSystemRole compares the given role's permissions against
// the desired permissions produced by the permissions function based
// on the organization's settings. If they differ, the DB row is
// updated. Uses set-based comparison so permission ordering doesn't
// matter. Returns the correct role and a boolean indicating whether
// the reconciliation was necessary.
//
// IMPORTANT: Callers must hold database.LockIDReconcileSystemRoles
// for the duration of the enclosing transaction.
func ReconcileSystemRole(
// ReconcileOrgMemberRole ensures passed-in org-member role's perms
// are correct (current) and stored in the DB. Uses set-based
// comparison to avoid unnecessary database writes when permissions
// haven't changed. Returns the correct role and a boolean indicating
// whether the reconciliation was necessary.
// NOTE: Callers must acquire `database.LockIDReconcileSystemRoles` at
// the start of the transaction and hold it for the transactions
// duration. This prevents concurrent org-member reconciliation from
// racing and producing inconsistent writes.
func ReconcileOrgMemberRole(
ctx context.Context,
tx database.Store,
in database.CustomRole,
org database.Organization,
) (database.CustomRole, bool, error) {
permsFunc, ok := systemRoles[in.Name]
if !ok {
panic("dev error: no permissions function exists for role " + in.Name)
}
workspaceSharingDisabled bool,
) (
database.CustomRole, bool, error,
) {
// All fields except OrgPermissions and MemberPermissions will be the same.
out := in
// Paranoia check: we don't use these in custom roles yet.
// TODO(geokat): Have these as check constraints in DB for now?
out.SitePermissions = database.CustomRolePermissions{}
out.UserPermissions = database.CustomRolePermissions{}
out.DisplayName = ""
@@ -286,14 +259,15 @@ func ReconcileSystemRole(
inOrgPerms := ConvertDBPermissions(in.OrgPermissions)
inMemberPerms := ConvertDBPermissions(in.MemberPermissions)
outPerms := permsFunc(orgSettings(org))
outOrgPerms, outMemberPerms := rbac.OrgMemberPermissions(workspaceSharingDisabled)
match := rbac.PermissionsEqual(inOrgPerms, outPerms.Org) &&
rbac.PermissionsEqual(inMemberPerms, outPerms.Member)
// Compare using set-based comparison (order doesn't matter).
match := rbac.PermissionsEqual(inOrgPerms, outOrgPerms) &&
rbac.PermissionsEqual(inMemberPerms, outMemberPerms)
if !match {
out.OrgPermissions = ConvertPermissionsToDB(outPerms.Org)
out.MemberPermissions = ConvertPermissionsToDB(outPerms.Member)
out.OrgPermissions = ConvertPermissionsToDB(outOrgPerms)
out.MemberPermissions = ConvertPermissionsToDB(outMemberPerms)
_, err := tx.UpdateCustomRole(ctx, database.UpdateCustomRoleParams{
Name: out.Name,
@@ -305,50 +279,30 @@ func ReconcileSystemRole(
MemberPermissions: out.MemberPermissions,
})
if err != nil {
return out, !match, xerrors.Errorf("update %s system role for organization %s: %w",
in.Name, in.OrganizationID.UUID, err)
return out, !match, xerrors.Errorf("update organization-member custom role for organization %s: %w",
in.OrganizationID.UUID, err)
}
}
return out, !match, nil
}
// orgSettings maps database.Organization fields to the
// rbac.OrgSettings struct, bridging the database and rbac packages
// without introducing a circular dependency.
func orgSettings(org database.Organization) rbac.OrgSettings {
return rbac.OrgSettings{
ShareableWorkspaceOwners: rbac.ShareableWorkspaceOwners(org.ShareableWorkspaceOwners),
}
}
// CreateSystemRole inserts a new system role into the database with
// permissions produced by permsFunc based on the organization's current
// settings.
func CreateSystemRole(
ctx context.Context,
tx database.Store,
org database.Organization,
roleName string,
) error {
permsFunc, ok := systemRoles[roleName]
if !ok {
panic("dev error: no permissions function exists for role " + roleName)
}
perms := permsFunc(orgSettings(org))
// CreateOrgMemberRole creates an org-member system role for an organization.
func CreateOrgMemberRole(ctx context.Context, tx database.Store, org database.Organization) error {
orgPerms, memberPerms := rbac.OrgMemberPermissions(org.WorkspaceSharingDisabled)
_, err := tx.InsertCustomRole(ctx, database.InsertCustomRoleParams{
Name: roleName,
Name: rbac.RoleOrgMember(),
DisplayName: "",
OrganizationID: uuid.NullUUID{UUID: org.ID, Valid: true},
SitePermissions: database.CustomRolePermissions{},
OrgPermissions: ConvertPermissionsToDB(perms.Org),
OrgPermissions: ConvertPermissionsToDB(orgPerms),
UserPermissions: database.CustomRolePermissions{},
MemberPermissions: ConvertPermissionsToDB(perms.Member),
MemberPermissions: ConvertPermissionsToDB(memberPerms),
IsSystem: true,
})
if err != nil {
return xerrors.Errorf("insert %s role: %w", roleName, err)
return xerrors.Errorf("insert org-member role: %w", err)
}
return nil
+55 -71
View File
@@ -42,84 +42,68 @@ func TestExpandCustomRoleRoles(t *testing.T) {
require.Len(t, roles, 1, "role found")
}
func TestReconcileSystemRole(t *testing.T) {
func TestReconcileOrgMemberRole(t *testing.T) {
t.Parallel()
tests := []struct {
name string
roleName string
permsFunc func(rbac.OrgSettings) rbac.OrgRolePermissions
}{
{"OrgMember", rbac.RoleOrgMember(), rbac.OrgMemberPermissions},
{"ServiceAccount", rbac.RoleOrgServiceAccount(), rbac.OrgServiceAccountPermissions},
}
db, _ := dbtestutil.NewDB(t)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
org := dbgen.Organization(t, db, database.Organization{})
db, _ := dbtestutil.NewDB(t)
org := dbgen.Organization(t, db, database.Organization{})
ctx := testutil.Context(t, testutil.WaitShort)
ctx := testutil.Context(t, testutil.WaitShort)
existing, err := database.ExpectOne(db.CustomRoles(ctx, database.CustomRolesParams{
LookupRoles: []database.NameOrganizationPair{
{
Name: tt.roleName,
OrganizationID: org.ID,
},
},
IncludeSystemRoles: true,
}))
require.NoError(t, err)
existing, err := database.ExpectOne(db.CustomRoles(ctx, database.CustomRolesParams{
LookupRoles: []database.NameOrganizationPair{
{
Name: rbac.RoleOrgMember(),
OrganizationID: org.ID,
},
},
IncludeSystemRoles: true,
}))
require.NoError(t, err)
// Zero out permissions to simulate stale state.
_, err = db.UpdateCustomRole(ctx, database.UpdateCustomRoleParams{
Name: existing.Name,
OrganizationID: uuid.NullUUID{
UUID: org.ID,
Valid: true,
},
DisplayName: "",
SitePermissions: database.CustomRolePermissions{},
UserPermissions: database.CustomRolePermissions{},
OrgPermissions: database.CustomRolePermissions{},
MemberPermissions: database.CustomRolePermissions{},
})
require.NoError(t, err)
_, err = db.UpdateCustomRole(ctx, database.UpdateCustomRoleParams{
Name: existing.Name,
OrganizationID: uuid.NullUUID{
UUID: org.ID,
Valid: true,
},
DisplayName: "",
SitePermissions: database.CustomRolePermissions{},
UserPermissions: database.CustomRolePermissions{},
OrgPermissions: database.CustomRolePermissions{},
MemberPermissions: database.CustomRolePermissions{},
})
require.NoError(t, err)
stale := existing
stale.OrgPermissions = database.CustomRolePermissions{}
stale.MemberPermissions = database.CustomRolePermissions{}
stale := existing
stale.OrgPermissions = database.CustomRolePermissions{}
stale.MemberPermissions = database.CustomRolePermissions{}
reconciled, didUpdate, err := rolestore.ReconcileSystemRole(ctx, db, stale, org)
require.NoError(t, err)
require.True(t, didUpdate, "expected reconciliation to update stale permissions")
reconciled, didUpdate, err := rolestore.ReconcileOrgMemberRole(ctx, db, stale, org.WorkspaceSharingDisabled)
require.NoError(t, err)
require.True(t, didUpdate, "expected reconciliation to update stale permissions")
dbstored, err := database.ExpectOne(db.CustomRoles(ctx, database.CustomRolesParams{
LookupRoles: []database.NameOrganizationPair{
{
Name: tt.roleName,
OrganizationID: org.ID,
},
},
IncludeSystemRoles: true,
}))
require.NoError(t, err)
got, err := database.ExpectOne(db.CustomRoles(ctx, database.CustomRolesParams{
LookupRoles: []database.NameOrganizationPair{
{
Name: rbac.RoleOrgMember(),
OrganizationID: org.ID,
},
},
IncludeSystemRoles: true,
}))
require.NoError(t, err)
want := tt.permsFunc(rbac.OrgSettings{
ShareableWorkspaceOwners: rbac.ShareableWorkspaceOwners(org.ShareableWorkspaceOwners),
})
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(dbstored.OrgPermissions), want.Org))
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(dbstored.MemberPermissions), want.Member))
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(reconciled.OrgPermissions), want.Org))
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(reconciled.MemberPermissions), want.Member))
wantOrg, wantMember := rbac.OrgMemberPermissions(org.WorkspaceSharingDisabled)
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(got.OrgPermissions), wantOrg))
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(got.MemberPermissions), wantMember))
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(reconciled.OrgPermissions), wantOrg))
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(reconciled.MemberPermissions), wantMember))
_, didUpdate, err = rolestore.ReconcileSystemRole(ctx, db, reconciled, org)
require.NoError(t, err)
require.False(t, didUpdate, "expected no-op reconciliation when permissions are already current")
})
}
_, didUpdate, err = rolestore.ReconcileOrgMemberRole(ctx, db, reconciled, org.WorkspaceSharingDisabled)
require.NoError(t, err)
require.False(t, didUpdate, "expected no-op reconciliation when permissions are already current")
}
func TestReconcileSystemRoles(t *testing.T) {
@@ -134,7 +118,7 @@ func TestReconcileSystemRoles(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort)
_, err := sqlDB.ExecContext(ctx, "UPDATE organizations SET shareable_workspace_owners = 'none' WHERE id = $1", org2.ID)
_, err := sqlDB.ExecContext(ctx, "UPDATE organizations SET workspace_sharing_disabled = true WHERE id = $1", org2.ID)
require.NoError(t, err)
// Simulate a missing system role by bypassing the application's
@@ -179,9 +163,9 @@ func TestReconcileSystemRoles(t *testing.T) {
require.NoError(t, err)
require.True(t, got.IsSystem)
want := rbac.OrgMemberPermissions(rbac.OrgSettings{ShareableWorkspaceOwners: rbac.ShareableWorkspaceOwners(org.ShareableWorkspaceOwners)})
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(got.OrgPermissions), want.Org))
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(got.MemberPermissions), want.Member))
wantOrg, wantMember := rbac.OrgMemberPermissions(org.WorkspaceSharingDisabled)
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(got.OrgPermissions), wantOrg))
require.True(t, rbac.PermissionsEqual(rolestore.ConvertDBPermissions(got.MemberPermissions), wantMember))
}
assertOrgMemberRole(t, org1.ID)
+2 -2
View File
@@ -471,8 +471,8 @@ func Tasks(ctx context.Context, db database.Store, query string, actorID uuid.UU
//
// Supported query parameters:
// - archived: boolean (default: false, excludes archived chats unless explicitly set)
func Chats(query string) (database.GetChatsParams, []codersdk.ValidationError) {
filter := database.GetChatsParams{
func Chats(query string) (database.GetChatsByOwnerIDParams, []codersdk.ValidationError) {
filter := database.GetChatsByOwnerIDParams{
// Default to hiding archived chats.
Archived: sql.NullBool{Bool: false, Valid: true},
}
+4 -4
View File
@@ -1222,27 +1222,27 @@ func TestSearchChats(t *testing.T) {
testCases := []struct {
Name string
Query string
Expected database.GetChatsParams
Expected database.GetChatsByOwnerIDParams
ExpectedErrorContains string
}{
{
Name: "Empty",
Query: "",
Expected: database.GetChatsParams{
Expected: database.GetChatsByOwnerIDParams{
Archived: sql.NullBool{Bool: false, Valid: true},
},
},
{
Name: "ArchivedTrue",
Query: "archived:true",
Expected: database.GetChatsParams{
Expected: database.GetChatsByOwnerIDParams{
Archived: sql.NullBool{Bool: true, Valid: true},
},
},
{
Name: "ArchivedFalse",
Query: "archived:false",
Expected: database.GetChatsParams{
Expected: database.GetChatsByOwnerIDParams{
Archived: sql.NullBool{Bool: false, Valid: true},
},
},
-7
View File
@@ -12,7 +12,6 @@ import (
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/webpush"
"github.com/coder/coder/v2/codersdk"
)
@@ -55,9 +54,6 @@ func (api *API) postUserWebpushSubscription(rw http.ResponseWriter, r *http.Requ
})
return
}
if invalidator, ok := api.WebpushDispatcher.(webpush.SubscriptionCacheInvalidator); ok {
invalidator.InvalidateUser(user.ID)
}
rw.WriteHeader(http.StatusNoContent)
}
@@ -115,9 +111,6 @@ func (api *API) deleteUserWebpushSubscription(rw http.ResponseWriter, r *http.Re
})
return
}
if invalidator, ok := api.WebpushDispatcher.(webpush.SubscriptionCacheInvalidator); ok {
invalidator.InvalidateUser(user.ID)
}
rw.WriteHeader(http.StatusNoContent)
}
+7 -190
View File
@@ -9,23 +9,18 @@ import (
"net/http"
"slices"
"sync"
"time"
"github.com/SherClockHolmes/webpush-go"
"github.com/google/uuid"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"tailscale.com/util/singleflight"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
const defaultSubscriptionCacheTTL = 3 * time.Minute
// Dispatcher is an interface that can be used to dispatch
// web push notifications to clients such as browsers.
type Dispatcher interface {
@@ -38,36 +33,6 @@ type Dispatcher interface {
PublicKey() string
}
// SubscriptionCacheInvalidator is an optional interface that lets local
// subscription mutation handlers invalidate cached subscriptions.
type SubscriptionCacheInvalidator interface {
InvalidateUser(userID uuid.UUID)
}
type options struct {
clock quartz.Clock
subscriptionCacheTTL time.Duration
}
// Option configures optional behavior for a Webpusher.
type Option func(*options)
// WithClock sets the clock used by the subscription cache. Defaults to a real
// clock when not provided.
func WithClock(clock quartz.Clock) Option {
return func(o *options) {
o.clock = clock
}
}
// WithSubscriptionCacheTTL sets the in-memory subscription cache TTL. Defaults
// to three minutes when not provided or when given a non-positive duration.
func WithSubscriptionCacheTTL(ttl time.Duration) Option {
return func(o *options) {
o.subscriptionCacheTTL = ttl
}
}
// New creates a new Dispatcher to dispatch web push notifications.
//
// This is *not* integrated into the enqueue system unfortunately.
@@ -76,21 +41,7 @@ func WithSubscriptionCacheTTL(ttl time.Duration) Option {
// for updates inside of a workspace, which we want to be immediate.
//
// See: https://github.com/coder/internal/issues/528
func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub string, opts ...Option) (Dispatcher, error) {
cfg := options{
clock: quartz.NewReal(),
subscriptionCacheTTL: defaultSubscriptionCacheTTL,
}
for _, opt := range opts {
opt(&cfg)
}
if cfg.clock == nil {
cfg.clock = quartz.NewReal()
}
if cfg.subscriptionCacheTTL <= 0 {
cfg.subscriptionCacheTTL = defaultSubscriptionCacheTTL
}
func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub string) (Dispatcher, error) {
keys, err := db.GetWebpushVAPIDKeys(ctx)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
@@ -112,23 +63,14 @@ func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub stri
}
return &Webpusher{
vapidSub: vapidSub,
store: db,
log: log,
VAPIDPublicKey: keys.VapidPublicKey,
VAPIDPrivateKey: keys.VapidPrivateKey,
clock: cfg.clock,
subscriptionCacheTTL: cfg.subscriptionCacheTTL,
subscriptionCache: make(map[uuid.UUID]cachedSubscriptions),
subscriptionGenerations: make(map[uuid.UUID]uint64),
vapidSub: vapidSub,
store: db,
log: log,
VAPIDPublicKey: keys.VapidPublicKey,
VAPIDPrivateKey: keys.VapidPrivateKey,
}, nil
}
type cachedSubscriptions struct {
subscriptions []database.WebpushSubscription
expiresAt time.Time
}
type Webpusher struct {
store database.Store
log *slog.Logger
@@ -141,18 +83,10 @@ type Webpusher struct {
// the message payload.
VAPIDPublicKey string
VAPIDPrivateKey string
clock quartz.Clock
cacheMu sync.RWMutex
subscriptionCache map[uuid.UUID]cachedSubscriptions
subscriptionGenerations map[uuid.UUID]uint64
subscriptionCacheTTL time.Duration
subscriptionFetches singleflight.Group[string, []database.WebpushSubscription]
}
func (n *Webpusher) Dispatch(ctx context.Context, userID uuid.UUID, msg codersdk.WebpushMessage) error {
subscriptions, err := n.subscriptionsForUser(ctx, userID)
subscriptions, err := n.store.GetWebpushSubscriptionsByUserID(ctx, userID)
if err != nil {
return xerrors.Errorf("get web push subscriptions by user ID: %w", err)
}
@@ -208,129 +142,12 @@ func (n *Webpusher) Dispatch(ctx context.Context, userID uuid.UUID, msg codersdk
err = n.store.DeleteWebpushSubscriptions(dbauthz.AsNotifier(ctx), cleanupSubscriptions)
if err != nil {
n.log.Error(ctx, "failed to delete stale push subscriptions", slog.Error(err))
} else {
n.pruneSubscriptions(userID, cleanupSubscriptions)
}
}
return nil
}
func (n *Webpusher) subscriptionsForUser(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
if subscriptions, ok := n.cachedSubscriptions(userID); ok {
return subscriptions, nil
}
subscriptions, err, _ := n.subscriptionFetches.Do(userID.String(), func() ([]database.WebpushSubscription, error) {
if cached, ok := n.cachedSubscriptions(userID); ok {
return cached, nil
}
generation := n.subscriptionGeneration(userID)
fetched, err := n.store.GetWebpushSubscriptionsByUserID(ctx, userID)
if err != nil {
return nil, err
}
n.storeSubscriptions(userID, generation, fetched)
return slices.Clone(fetched), nil
})
if err != nil {
return nil, err
}
return slices.Clone(subscriptions), nil
}
func (n *Webpusher) cachedSubscriptions(userID uuid.UUID) ([]database.WebpushSubscription, bool) {
n.cacheMu.RLock()
entry, ok := n.subscriptionCache[userID]
n.cacheMu.RUnlock()
if !ok {
return nil, false
}
if n.clock.Now().Before(entry.expiresAt) {
return slices.Clone(entry.subscriptions), true
}
n.cacheMu.Lock()
if current, ok := n.subscriptionCache[userID]; ok && !n.clock.Now().Before(current.expiresAt) {
delete(n.subscriptionCache, userID)
}
n.cacheMu.Unlock()
return nil, false
}
func (n *Webpusher) subscriptionGeneration(userID uuid.UUID) uint64 {
n.cacheMu.RLock()
generation := n.subscriptionGenerations[userID]
n.cacheMu.RUnlock()
return generation
}
func (n *Webpusher) storeSubscriptions(userID uuid.UUID, generation uint64, subscriptions []database.WebpushSubscription) {
n.cacheMu.Lock()
defer n.cacheMu.Unlock()
if n.subscriptionGenerations[userID] != generation {
return
}
n.subscriptionCache[userID] = cachedSubscriptions{
subscriptions: slices.Clone(subscriptions),
expiresAt: n.clock.Now().Add(n.subscriptionCacheTTL),
}
}
func (n *Webpusher) pruneSubscriptions(userID uuid.UUID, staleIDs []uuid.UUID) {
if len(staleIDs) == 0 {
return
}
stale := make(map[uuid.UUID]struct{}, len(staleIDs))
for _, id := range staleIDs {
stale[id] = struct{}{}
}
n.cacheMu.Lock()
defer n.cacheMu.Unlock()
entry, ok := n.subscriptionCache[userID]
if !ok {
return
}
if !n.clock.Now().Before(entry.expiresAt) {
delete(n.subscriptionCache, userID)
return
}
filtered := make([]database.WebpushSubscription, 0, len(entry.subscriptions))
for _, subscription := range entry.subscriptions {
if _, shouldDelete := stale[subscription.ID]; shouldDelete {
continue
}
filtered = append(filtered, subscription)
}
if len(filtered) == 0 {
delete(n.subscriptionCache, userID)
return
}
entry.subscriptions = filtered
n.subscriptionCache[userID] = entry
}
// InvalidateUser clears the cached subscriptions for a user and advances
// its invalidation generation. Local subscribe and unsubscribe handlers call
// this after mutating subscriptions in the same process.
func (n *Webpusher) InvalidateUser(userID uuid.UUID) {
n.cacheMu.Lock()
delete(n.subscriptionCache, userID)
n.subscriptionGenerations[userID]++
n.cacheMu.Unlock()
n.subscriptionFetches.Forget(userID.String())
}
func (n *Webpusher) webpushSend(ctx context.Context, msg []byte, endpoint string, keys webpush.Keys) (int, []byte, error) {
// Copy the message to avoid modifying the original.
cpy := slices.Clone(msg)
+3 -150
View File
@@ -6,9 +6,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -23,7 +21,6 @@ import (
"github.com/coder/coder/v2/coderd/webpush"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
const (
@@ -31,20 +28,6 @@ const (
validEndpointP256dhKey = "BNNL5ZaTfK81qhXOx23+wewhigUeFb632jN6LvRWCFH1ubQr77FE/9qV1FuojuRmHP42zmf34rXgW80OvUVDgTk="
)
type countingWebpushStore struct {
database.Store
getSubscriptionsCalls atomic.Int32
}
func (s *countingWebpushStore) GetWebpushSubscriptionsByUserID(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
s.getSubscriptionsCalls.Add(1)
return s.Store.GetWebpushSubscriptionsByUserID(ctx, userID)
}
func (s *countingWebpushStore) getCallCount() int32 {
return s.getSubscriptionsCalls.Load()
}
func TestPush(t *testing.T) {
t.Parallel()
@@ -233,131 +216,6 @@ func TestPush(t *testing.T) {
require.NoError(t, err)
assert.Empty(t, subscriptions, "No subscriptions should be returned")
})
t.Run("CachesSubscriptionsWithinTTL", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
rawStore, _ := dbtestutil.NewDB(t)
store := &countingWebpushStore{Store: rawStore}
var delivered atomic.Int32
manager, _, serverURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
delivered.Add(1)
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusOK)
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
user := dbgen.User(t, rawStore, database.User{})
_, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: serverURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
msg := randomWebpushMessage(t)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
require.Equal(t, int32(1), store.getCallCount(), "subscriptions should be read once within the TTL")
require.Equal(t, int32(2), delivered.Load(), "both dispatches should send a notification")
})
t.Run("RefreshesSubscriptionsAfterTTLExpires", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
rawStore, _ := dbtestutil.NewDB(t)
store := &countingWebpushStore{Store: rawStore}
var delivered atomic.Int32
manager, _, serverURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
delivered.Add(1)
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusOK)
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
user := dbgen.User(t, rawStore, database.User{})
_, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: serverURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
msg := randomWebpushMessage(t)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
clock.Advance(time.Minute)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
require.Equal(t, int32(2), store.getCallCount(), "dispatch should refresh subscriptions after the TTL expires")
require.Equal(t, int32(2), delivered.Load(), "both dispatches should send a notification")
})
t.Run("PrunesStaleSubscriptionsFromCache", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
rawStore, _ := dbtestutil.NewDB(t)
store := &countingWebpushStore{Store: rawStore}
var okCalls atomic.Int32
var goneCalls atomic.Int32
manager, _, okServerURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
okCalls.Add(1)
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusOK)
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
goneServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
goneCalls.Add(1)
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusGone)
}))
defer goneServer.Close()
user := dbgen.User(t, rawStore, database.User{})
okSubscription, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: okServerURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
_, err = rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: goneServer.URL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
msg := randomWebpushMessage(t)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
require.Equal(t, int32(1), store.getCallCount(), "stale subscription cleanup should not force a second DB read within the TTL")
require.Equal(t, int32(2), okCalls.Load(), "the healthy endpoint should receive both dispatches")
require.Equal(t, int32(1), goneCalls.Load(), "the stale endpoint should be pruned from the cache after the first dispatch")
subscriptions, err := rawStore.GetWebpushSubscriptionsByUserID(ctx, user.ID)
require.NoError(t, err)
require.Len(t, subscriptions, 1, "only the healthy subscription should remain")
require.Equal(t, okSubscription.ID, subscriptions[0].ID)
})
}
func randomWebpushMessage(t testing.TB) codersdk.WebpushMessage {
@@ -386,21 +244,16 @@ func assertWebpushPayload(t testing.TB, r *http.Request) {
assert.Error(t, json.NewDecoder(r.Body).Decode(io.Discard))
}
// setupPushTest creates a common test setup for webpush notification tests.
// setupPushTest creates a common test setup for webpush notification tests
func setupPushTest(ctx context.Context, t *testing.T, handlerFunc func(w http.ResponseWriter, r *http.Request)) (webpush.Dispatcher, database.Store, string) {
t.Helper()
db, _ := dbtestutil.NewDB(t)
return setupPushTestWithOptions(ctx, t, db, handlerFunc)
}
func setupPushTestWithOptions(ctx context.Context, t *testing.T, db database.Store, handlerFunc func(w http.ResponseWriter, r *http.Request), opts ...webpush.Option) (webpush.Dispatcher, database.Store, string) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
db, _ := dbtestutil.NewDB(t)
server := httptest.NewServer(http.HandlerFunc(handlerFunc))
t.Cleanup(server.Close)
manager, err := webpush.New(ctx, &logger, db, "http://example.com", opts...)
manager, err := webpush.New(ctx, &logger, db, "http://example.com")
require.NoError(t, err, "Failed to create webpush manager")
return manager, db, server.URL
+7 -18
View File
@@ -35,42 +35,31 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
_, anotherMember := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
var handlerCalls atomic.Int32
handlerCalled := make(chan bool, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusCreated)
handlerCalls.Add(1)
handlerCalled <- true
}))
defer server.Close()
// Seed the dispatcher cache with an empty subscription set. Creating the
// subscription should invalidate that entry so the next dispatch sees the new
// subscription immediately.
err := memberClient.PostTestWebpushMessage(ctx)
require.NoError(t, err, "test webpush message without a subscription")
require.Zero(t, handlerCalls.Load(), "a user without subscriptions should not receive a push")
err = memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
err := memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
Endpoint: server.URL,
AuthKey: validEndpointAuthKey,
P256DHKey: validEndpointP256dhKey,
})
require.NoError(t, err, "create webpush subscription")
require.Equal(t, int32(1), handlerCalls.Load(), "subscription validation should hit the endpoint once")
require.True(t, <-handlerCalled, "handler should have been called")
err = memberClient.PostTestWebpushMessage(ctx)
require.NoError(t, err, "test webpush message after subscribing")
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate empty cache entries after subscribing")
require.NoError(t, err, "test webpush message")
require.True(t, <-handlerCalled, "handler should have been called again")
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
Endpoint: server.URL,
})
require.NoError(t, err, "delete webpush subscription")
err = memberClient.PostTestWebpushMessage(ctx)
require.NoError(t, err, "test webpush message after unsubscribing")
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate cached subscriptions after unsubscribing")
// Deleting the subscription for a non-existent endpoint should return a 404.
// Deleting the subscription for a non-existent endpoint should return a 404
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
Endpoint: server.URL,
})
+126 -49
View File
@@ -989,48 +989,108 @@ type workspaceBuildsData struct {
func (api *API) workspaceBuildsData(ctx context.Context, workspaceBuilds []database.WorkspaceBuild) (workspaceBuildsData, error) {
jobIDs := make([]uuid.UUID, 0, len(workspaceBuilds))
for _, build := range workspaceBuilds {
jobIDs = append(jobIDs, build.JobID)
}
jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
IDs: jobIDs,
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
})
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return workspaceBuildsData{}, xerrors.Errorf("get provisioner jobs: %w", err)
}
pendingJobIDs := []uuid.UUID{}
for _, job := range jobs {
if job.ProvisionerJob.JobStatus == database.ProvisionerJobStatusPending {
pendingJobIDs = append(pendingJobIDs, job.ProvisionerJob.ID)
}
}
pendingJobProvisioners, err := api.Database.GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx, pendingJobIDs)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return workspaceBuildsData{}, xerrors.Errorf("get provisioner daemons: %w", err)
}
templateVersionIDs := make([]uuid.UUID, 0, len(workspaceBuilds))
for _, build := range workspaceBuilds {
jobIDs = append(jobIDs, build.JobID)
templateVersionIDs = append(templateVersionIDs, build.TemplateVersionID)
}
// nolint:gocritic // Getting template versions by ID is a system function.
templateVersions, err := api.Database.GetTemplateVersionsByIDs(dbauthz.AsSystemRestricted(ctx), templateVersionIDs)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return workspaceBuildsData{}, xerrors.Errorf("get template versions: %w", err)
// Phase A: Fetch jobs, template versions, and resources in parallel.
// These three queries depend only on the build list and can run
// concurrently.
var (
jobs []database.ProvisionerJob
templateVersions []database.TemplateVersion
resources []database.WorkspaceResource
)
var eg errgroup.Group
eg.Go(func() error {
var err error
jobs, err = api.Database.GetProvisionerJobsByIDs(ctx, jobIDs)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("get provisioner jobs: %w", err)
}
return nil
})
eg.Go(func() error {
var err error
// nolint:gocritic // Getting template versions by ID is a system function.
templateVersions, err = api.Database.GetTemplateVersionsByIDs(dbauthz.AsSystemRestricted(ctx), templateVersionIDs)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("get template versions: %w", err)
}
return nil
})
eg.Go(func() error {
var err error
// nolint:gocritic // Getting workspace resources by job ID is a system function.
resources, err = api.Database.GetWorkspaceResourcesByJobIDs(dbauthz.AsSystemRestricted(ctx), jobIDs)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("get workspace resources by job: %w", err)
}
return nil
})
if err := eg.Wait(); err != nil {
return workspaceBuildsData{}, err
}
// nolint:gocritic // Getting workspace resources by job ID is a system function.
resources, err := api.Database.GetWorkspaceResourcesByJobIDs(dbauthz.AsSystemRestricted(ctx), jobIDs)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return workspaceBuildsData{}, xerrors.Errorf("get workspace resources by job: %w", err)
// Phase B: Get queue position and eligible daemons for pending
// jobs only. The queue position query is expensive (cross-join
// with all pending jobs and active daemons), so we only run it
// for the small number of actually-pending jobs.
var pendingJobIDs []uuid.UUID
for _, job := range jobs {
if job.JobStatus == database.ProvisionerJobStatusPending {
pendingJobIDs = append(pendingJobIDs, job.ID)
}
}
var queuePositionRows []database.GetProvisionerJobsByIDsWithQueuePositionRow
if len(pendingJobIDs) > 0 {
var err error
queuePositionRows, err = api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
IDs: pendingJobIDs,
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
})
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return workspaceBuildsData{}, xerrors.Errorf("get provisioner jobs queue position: %w", err)
}
}
var pendingJobProvisioners []database.GetEligibleProvisionerDaemonsByProvisionerJobIDsRow
if len(pendingJobIDs) > 0 {
var err error
pendingJobProvisioners, err = api.Database.GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx, pendingJobIDs)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return workspaceBuildsData{}, xerrors.Errorf("get provisioner daemons: %w", err)
}
}
// Merge job rows with queue position information so downstream
// consumers see the same type they expect.
queuePositionByID := make(map[uuid.UUID]database.GetProvisionerJobsByIDsWithQueuePositionRow, len(queuePositionRows))
for _, qpRow := range queuePositionRows {
queuePositionByID[qpRow.ID] = qpRow
}
jobsWithQueuePosition := make([]database.GetProvisionerJobsByIDsWithQueuePositionRow, 0, len(jobs))
for _, job := range jobs {
row := database.GetProvisionerJobsByIDsWithQueuePositionRow{
ID: job.ID,
CreatedAt: job.CreatedAt,
ProvisionerJob: job,
QueuePosition: 0,
QueueSize: 0,
}
if qpRow, ok := queuePositionByID[job.ID]; ok {
row.QueuePosition = qpRow.QueuePosition
row.QueueSize = qpRow.QueueSize
}
jobsWithQueuePosition = append(jobsWithQueuePosition, row)
}
if len(resources) == 0 {
return workspaceBuildsData{
jobs: jobs,
jobs: jobsWithQueuePosition,
templateVersions: templateVersions,
provisionerDaemons: pendingJobProvisioners,
}, nil
@@ -1041,21 +1101,38 @@ func (api *API) workspaceBuildsData(ctx context.Context, workspaceBuilds []datab
resourceIDs = append(resourceIDs, resource.ID)
}
// nolint:gocritic // Getting workspace resource metadata by resource ID is a system function.
metadata, err := api.Database.GetWorkspaceResourceMetadataByResourceIDs(dbauthz.AsSystemRestricted(ctx), resourceIDs)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return workspaceBuildsData{}, xerrors.Errorf("fetching resource metadata: %w", err)
// Phase C: Fetch metadata and agents in parallel (both depend
// on resourceIDs which we just computed).
var (
metadata []database.WorkspaceResourceMetadatum
agents []database.WorkspaceAgent
)
var eg2 errgroup.Group
eg2.Go(func() error {
var err error
// nolint:gocritic // Getting workspace resource metadata by resource ID is a system function.
metadata, err = api.Database.GetWorkspaceResourceMetadataByResourceIDs(dbauthz.AsSystemRestricted(ctx), resourceIDs)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("fetching resource metadata: %w", err)
}
return nil
})
eg2.Go(func() error {
var err error
// nolint:gocritic // Getting workspace agents by resource IDs is a system function.
agents, err = api.Database.GetWorkspaceAgentsByResourceIDs(dbauthz.AsSystemRestricted(ctx), resourceIDs)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("get workspace agents: %w", err)
}
return nil
})
if err := eg2.Wait(); err != nil {
return workspaceBuildsData{}, err
}
// nolint:gocritic // Getting workspace agents by resource IDs is a system function.
agents, err := api.Database.GetWorkspaceAgentsByResourceIDs(dbauthz.AsSystemRestricted(ctx), resourceIDs)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return workspaceBuildsData{}, xerrors.Errorf("get workspace agents: %w", err)
}
if len(resources) == 0 {
if len(agents) == 0 {
return workspaceBuildsData{
jobs: jobs,
jobs: jobsWithQueuePosition,
templateVersions: templateVersions,
resources: resources,
metadata: metadata,
@@ -1074,23 +1151,23 @@ func (api *API) workspaceBuildsData(ctx context.Context, workspaceBuilds []datab
logSources []database.WorkspaceAgentLogSource
)
var eg errgroup.Group
eg.Go(func() (err error) {
var eg3 errgroup.Group
eg3.Go(func() (err error) {
// nolint:gocritic // Getting workspace apps by agent IDs is a system function.
apps, err = api.Database.GetWorkspaceAppsByAgentIDs(dbauthz.AsSystemRestricted(ctx), agentIDs)
return err
})
eg.Go(func() (err error) {
eg3.Go(func() (err error) {
// nolint:gocritic // Getting workspace scripts by agent IDs is a system function.
scripts, err = api.Database.GetWorkspaceAgentScriptsByAgentIDs(dbauthz.AsSystemRestricted(ctx), agentIDs)
return err
})
eg.Go(func() error {
eg3.Go(func() (err error) {
// nolint:gocritic // Getting workspace agent log sources by agent IDs is a system function.
logSources, err = api.Database.GetWorkspaceAgentLogSourcesByAgentIDs(dbauthz.AsSystemRestricted(ctx), agentIDs)
return err
})
err = eg.Wait()
err := eg3.Wait()
if err != nil {
return workspaceBuildsData{}, err
}
@@ -1107,7 +1184,7 @@ func (api *API) workspaceBuildsData(ctx context.Context, workspaceBuilds []datab
}
return workspaceBuildsData{
jobs: jobs,
jobs: jobsWithQueuePosition,
templateVersions: templateVersions,
resources: resources,
metadata: metadata,
+2 -6
View File
@@ -2487,11 +2487,7 @@ func (api *API) patchWorkspaceACL(rw http.ResponseWriter, r *http.Request) {
return nil
}, nil)
if err != nil {
if dbauthz.IsNotAuthorizedError(err) {
httpapi.Forbidden(rw)
} else {
httpapi.InternalServerError(rw, err)
}
httpapi.InternalServerError(rw, err)
return
}
@@ -2570,7 +2566,7 @@ func (api *API) allowWorkspaceSharing(ctx context.Context, rw http.ResponseWrite
httpapi.InternalServerError(rw, err)
return false
}
if org.ShareableWorkspaceOwners == database.ShareableWorkspaceOwnersNone {
if org.WorkspaceSharingDisabled {
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
Message: "Workspace sharing is disabled for this organization.",
})
+4 -4
View File
@@ -517,14 +517,14 @@ const (
)
type AgentMetric struct {
Name string `json:"name" validate:"required"`
Type AgentMetricType `json:"type" enums:"counter,gauge" validate:"required"`
Value float64 `json:"value" validate:"required"`
Name string `json:"name" validate:"required"`
Type AgentMetricType `json:"type" validate:"required" enums:"counter,gauge"`
Value float64 `json:"value" validate:"required"`
Labels []AgentMetricLabel `json:"labels,omitempty"`
}
type AgentMetricLabel struct {
Name string `json:"name" validate:"required"`
Name string `json:"name" validate:"required"`
Value string `json:"value" validate:"required"`
}
+1 -1
View File
@@ -13,7 +13,7 @@ import (
type AWSInstanceIdentityToken struct {
Signature string `json:"signature" validate:"required"`
Document string `json:"document" validate:"required"`
Document string `json:"document" validate:"required"`
}
// AWSSessionTokenExchanger exchanges AWS instance metadata for a Coder session token.
+1 -1
View File
@@ -10,7 +10,7 @@ import (
type AzureInstanceIdentityToken struct {
Signature string `json:"signature" validate:"required"`
Encoding string `json:"encoding" validate:"required"`
Encoding string `json:"encoding" validate:"required"`
}
// AzureSessionTokenExchanger exchanges Azure attested metadata for a Coder session token.
+13 -13
View File
@@ -12,42 +12,42 @@ import (
)
type AIBridgeInterception struct {
ID uuid.UUID `json:"id" format:"uuid"`
ID uuid.UUID `json:"id" format:"uuid"`
APIKeyID *string `json:"api_key_id"`
Initiator MinimalUser `json:"initiator"`
Provider string `json:"provider"`
Model string `json:"model"`
Client *string `json:"client"`
Metadata map[string]any `json:"metadata"`
StartedAt time.Time `json:"started_at" format:"date-time"`
EndedAt *time.Time `json:"ended_at" format:"date-time"`
StartedAt time.Time `json:"started_at" format:"date-time"`
EndedAt *time.Time `json:"ended_at" format:"date-time"`
TokenUsages []AIBridgeTokenUsage `json:"token_usages"`
UserPrompts []AIBridgeUserPrompt `json:"user_prompts"`
ToolUsages []AIBridgeToolUsage `json:"tool_usages"`
}
type AIBridgeTokenUsage struct {
ID uuid.UUID `json:"id" format:"uuid"`
InterceptionID uuid.UUID `json:"interception_id" format:"uuid"`
ID uuid.UUID `json:"id" format:"uuid"`
InterceptionID uuid.UUID `json:"interception_id" format:"uuid"`
ProviderResponseID string `json:"provider_response_id"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
Metadata map[string]any `json:"metadata"`
CreatedAt time.Time `json:"created_at" format:"date-time"`
CreatedAt time.Time `json:"created_at" format:"date-time"`
}
type AIBridgeUserPrompt struct {
ID uuid.UUID `json:"id" format:"uuid"`
InterceptionID uuid.UUID `json:"interception_id" format:"uuid"`
ID uuid.UUID `json:"id" format:"uuid"`
InterceptionID uuid.UUID `json:"interception_id" format:"uuid"`
ProviderResponseID string `json:"provider_response_id"`
Prompt string `json:"prompt"`
Metadata map[string]any `json:"metadata"`
CreatedAt time.Time `json:"created_at" format:"date-time"`
CreatedAt time.Time `json:"created_at" format:"date-time"`
}
type AIBridgeToolUsage struct {
ID uuid.UUID `json:"id" format:"uuid"`
InterceptionID uuid.UUID `json:"interception_id" format:"uuid"`
ID uuid.UUID `json:"id" format:"uuid"`
InterceptionID uuid.UUID `json:"interception_id" format:"uuid"`
ProviderResponseID string `json:"provider_response_id"`
ServerURL string `json:"server_url"`
Tool string `json:"tool"`
@@ -55,7 +55,7 @@ type AIBridgeToolUsage struct {
Injected bool `json:"injected"`
InvocationError string `json:"invocation_error"`
Metadata map[string]any `json:"metadata"`
CreatedAt time.Time `json:"created_at" format:"date-time"`
CreatedAt time.Time `json:"created_at" format:"date-time"`
}
type AIBridgeListInterceptionsResponse struct {
@@ -73,7 +73,7 @@ type AIBridgeListInterceptionsFilter struct {
// Initiator is a user ID, username, or "me".
Initiator string `json:"initiator,omitempty"`
StartedBefore time.Time `json:"started_before,omitempty" format:"date-time"`
StartedAfter time.Time `json:"started_after,omitempty" format:"date-time"`
StartedAfter time.Time `json:"started_after,omitempty" format:"date-time"`
Provider string `json:"provider,omitempty"`
Model string `json:"model,omitempty"`
Client string `json:"client,omitempty"`
+32 -32
View File
@@ -14,7 +14,7 @@ import (
// CreateTaskRequest represents the request to create a new task.
type CreateTaskRequest struct {
TemplateVersionID uuid.UUID `json:"template_version_id" format:"uuid"`
TemplateVersionID uuid.UUID `json:"template_version_id" format:"uuid"`
TemplateVersionPresetID uuid.UUID `json:"template_version_preset_id,omitempty" format:"uuid"`
Input string `json:"input"`
Name string `json:"name,omitempty"`
@@ -98,39 +98,39 @@ const (
// Task represents a task.
type Task struct {
ID uuid.UUID `json:"id" table:"id" format:"uuid"`
OrganizationID uuid.UUID `json:"organization_id" table:"organization id" format:"uuid"`
OwnerID uuid.UUID `json:"owner_id" table:"owner id" format:"uuid"`
OwnerName string `json:"owner_name" table:"owner name"`
OwnerAvatarURL string `json:"owner_avatar_url,omitempty" table:"owner avatar url"`
Name string `json:"name" table:"name,default_sort"`
DisplayName string `json:"display_name" table:"display_name"`
TemplateID uuid.UUID `json:"template_id" table:"template id" format:"uuid"`
TemplateVersionID uuid.UUID `json:"template_version_id" table:"template version id" format:"uuid"`
TemplateName string `json:"template_name" table:"template name"`
TemplateDisplayName string `json:"template_display_name" table:"template display name"`
TemplateIcon string `json:"template_icon" table:"template icon"`
WorkspaceID uuid.NullUUID `json:"workspace_id" table:"workspace id" format:"uuid"`
WorkspaceName string `json:"workspace_name" table:"workspace name"`
WorkspaceStatus WorkspaceStatus `json:"workspace_status,omitempty" table:"workspace status" enums:"pending,starting,running,stopping,stopped,failed,canceling,canceled,deleting,deleted"`
ID uuid.UUID `json:"id" format:"uuid" table:"id"`
OrganizationID uuid.UUID `json:"organization_id" format:"uuid" table:"organization id"`
OwnerID uuid.UUID `json:"owner_id" format:"uuid" table:"owner id"`
OwnerName string `json:"owner_name" table:"owner name"`
OwnerAvatarURL string `json:"owner_avatar_url,omitempty" table:"owner avatar url"`
Name string `json:"name" table:"name,default_sort"`
DisplayName string `json:"display_name" table:"display_name"`
TemplateID uuid.UUID `json:"template_id" format:"uuid" table:"template id"`
TemplateVersionID uuid.UUID `json:"template_version_id" format:"uuid" table:"template version id"`
TemplateName string `json:"template_name" table:"template name"`
TemplateDisplayName string `json:"template_display_name" table:"template display name"`
TemplateIcon string `json:"template_icon" table:"template icon"`
WorkspaceID uuid.NullUUID `json:"workspace_id" format:"uuid" table:"workspace id"`
WorkspaceName string `json:"workspace_name" table:"workspace name"`
WorkspaceStatus WorkspaceStatus `json:"workspace_status,omitempty" enums:"pending,starting,running,stopping,stopped,failed,canceling,canceled,deleting,deleted" table:"workspace status"`
WorkspaceBuildNumber int32 `json:"workspace_build_number,omitempty" table:"workspace build number"`
WorkspaceAgentID uuid.NullUUID `json:"workspace_agent_id" table:"workspace agent id" format:"uuid"`
WorkspaceAgentLifecycle *WorkspaceAgentLifecycle `json:"workspace_agent_lifecycle" table:"workspace agent lifecycle"`
WorkspaceAgentHealth *WorkspaceAgentHealth `json:"workspace_agent_health" table:"workspace agent health"`
WorkspaceAppID uuid.NullUUID `json:"workspace_app_id" table:"workspace app id" format:"uuid"`
InitialPrompt string `json:"initial_prompt" table:"initial prompt"`
Status TaskStatus `json:"status" table:"status" enums:"pending,initializing,active,paused,unknown,error"`
CurrentState *TaskStateEntry `json:"current_state" table:"cs,recursive_inline,empty_nil"`
CreatedAt time.Time `json:"created_at" table:"created at" format:"date-time"`
UpdatedAt time.Time `json:"updated_at" table:"updated at" format:"date-time"`
WorkspaceAgentID uuid.NullUUID `json:"workspace_agent_id" format:"uuid" table:"workspace agent id"`
WorkspaceAgentLifecycle *WorkspaceAgentLifecycle `json:"workspace_agent_lifecycle" table:"workspace agent lifecycle"`
WorkspaceAgentHealth *WorkspaceAgentHealth `json:"workspace_agent_health" table:"workspace agent health"`
WorkspaceAppID uuid.NullUUID `json:"workspace_app_id" format:"uuid" table:"workspace app id"`
InitialPrompt string `json:"initial_prompt" table:"initial prompt"`
Status TaskStatus `json:"status" enums:"pending,initializing,active,paused,unknown,error" table:"status"`
CurrentState *TaskStateEntry `json:"current_state" table:"cs,recursive_inline,empty_nil"`
CreatedAt time.Time `json:"created_at" format:"date-time" table:"created at"`
UpdatedAt time.Time `json:"updated_at" format:"date-time" table:"updated at"`
}
// TaskStateEntry represents a single entry in the task's state history.
type TaskStateEntry struct {
Timestamp time.Time `json:"timestamp" table:"-" format:"date-time"`
State TaskState `json:"state" table:"state" enum:"working,idle,completed,failed"`
Message string `json:"message" table:"message"`
URI string `json:"uri" table:"-"`
Timestamp time.Time `json:"timestamp" format:"date-time" table:"-"`
State TaskState `json:"state" enum:"working,idle,completed,failed" table:"state"`
Message string `json:"message" table:"message"`
URI string `json:"uri" table:"-"`
}
// TasksFilter filters the list of tasks.
@@ -387,10 +387,10 @@ const (
// TaskLogEntry represents a single log entry for a task.
type TaskLogEntry struct {
ID int `json:"id" table:"id"`
ID int `json:"id" table:"id"`
Content string `json:"content" table:"content"`
Type TaskLogType `json:"type" table:"type" enum:"input,output"`
Time time.Time `json:"time" table:"time,default_sort" format:"date-time"`
Type TaskLogType `json:"type" enum:"input,output" table:"type"`
Time time.Time `json:"time" format:"date-time" table:"time,default_sort"`
}
// TaskLogsResponse contains task logs and metadata. When snapshot is false,
+10 -10
View File
@@ -12,17 +12,17 @@ import (
// APIKey: do not ever return the HashedSecret
type APIKey struct {
ID string `json:"id" validate:"required"`
UserID uuid.UUID `json:"user_id" format:"uuid" validate:"required"`
LastUsed time.Time `json:"last_used" format:"date-time" validate:"required"`
ExpiresAt time.Time `json:"expires_at" format:"date-time" validate:"required"`
CreatedAt time.Time `json:"created_at" format:"date-time" validate:"required"`
UpdatedAt time.Time `json:"updated_at" format:"date-time" validate:"required"`
LoginType LoginType `json:"login_type" enums:"password,github,oidc,token" validate:"required"`
Scope APIKeyScope `json:"scope" enums:"all,application_connect"` // Deprecated: use Scopes instead.
ID string `json:"id" validate:"required"`
UserID uuid.UUID `json:"user_id" validate:"required" format:"uuid"`
LastUsed time.Time `json:"last_used" validate:"required" format:"date-time"`
ExpiresAt time.Time `json:"expires_at" validate:"required" format:"date-time"`
CreatedAt time.Time `json:"created_at" validate:"required" format:"date-time"`
UpdatedAt time.Time `json:"updated_at" validate:"required" format:"date-time"`
LoginType LoginType `json:"login_type" validate:"required" enums:"password,github,oidc,token"`
Scope APIKeyScope `json:"scope" enums:"all,application_connect"` // Deprecated: use Scopes instead.
Scopes []APIKeyScope `json:"scopes"`
TokenName string `json:"token_name" validate:"required"`
LifetimeSeconds int64 `json:"lifetime_seconds" validate:"required"`
TokenName string `json:"token_name" validate:"required"`
LifetimeSeconds int64 `json:"lifetime_seconds" validate:"required"`
AllowList []APIAllowListTarget `json:"allow_list"`
}
+11 -11
View File
@@ -178,13 +178,13 @@ type AuditDiffField struct {
}
type AuditLog struct {
ID uuid.UUID `json:"id" format:"uuid"`
RequestID uuid.UUID `json:"request_id" format:"uuid"`
Time time.Time `json:"time" format:"date-time"`
ID uuid.UUID `json:"id" format:"uuid"`
RequestID uuid.UUID `json:"request_id" format:"uuid"`
Time time.Time `json:"time" format:"date-time"`
IP netip.Addr `json:"ip"`
UserAgent string `json:"user_agent"`
ResourceType ResourceType `json:"resource_type"`
ResourceID uuid.UUID `json:"resource_id" format:"uuid"`
ResourceID uuid.UUID `json:"resource_id" format:"uuid"`
// ResourceTarget is the name of the resource.
ResourceTarget string `json:"resource_target"`
ResourceIcon string `json:"resource_icon"`
@@ -215,14 +215,14 @@ type AuditLogResponse struct {
}
type CreateTestAuditLogRequest struct {
Action AuditAction `json:"action,omitempty" enums:"create,write,delete,start,stop"`
ResourceType ResourceType `json:"resource_type,omitempty" enums:"template,template_version,user,workspace,workspace_build,git_ssh_key,auditable_group"`
ResourceID uuid.UUID `json:"resource_id,omitempty" format:"uuid"`
Action AuditAction `json:"action,omitempty" enums:"create,write,delete,start,stop"`
ResourceType ResourceType `json:"resource_type,omitempty" enums:"template,template_version,user,workspace,workspace_build,git_ssh_key,auditable_group"`
ResourceID uuid.UUID `json:"resource_id,omitempty" format:"uuid"`
AdditionalFields json.RawMessage `json:"additional_fields,omitempty"`
Time time.Time `json:"time,omitempty" format:"date-time"`
BuildReason BuildReason `json:"build_reason,omitempty" enums:"autostart,autostop,initiator"`
OrganizationID uuid.UUID `json:"organization_id,omitempty" format:"uuid"`
RequestID uuid.UUID `json:"request_id,omitempty" format:"uuid"`
Time time.Time `json:"time,omitempty" format:"date-time"`
BuildReason BuildReason `json:"build_reason,omitempty" enums:"autostart,autostop,initiator"`
OrganizationID uuid.UUID `json:"organization_id,omitempty" format:"uuid"`
RequestID uuid.UUID `json:"request_id,omitempty" format:"uuid"`
}
// AuditLogs retrieves audit logs from the given page.
+182 -544
View File
File diff suppressed because it is too large Load Diff
-161
View File
@@ -1,15 +1,8 @@
package codersdk_test
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/shopspring/decimal"
@@ -62,81 +55,6 @@ func TestChatModelProviderOptions_UnmarshalJSON_ParsesPlainProviderPayloads(t *t
)
}
func TestChatUsageLimitExceededFrom(t *testing.T) {
t.Parallel()
t.Run("ExtractsTyped409", func(t *testing.T) {
t.Parallel()
want := codersdk.ChatUsageLimitExceededResponse{
Response: codersdk.Response{Message: "Chat usage limit exceeded."},
SpentMicros: 123,
LimitMicros: 456,
ResetsAt: time.Date(2026, time.March, 16, 12, 0, 0, 0, time.UTC),
}
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.Equal(t, "/api/experimental/chats", r.URL.Path)
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusConflict)
require.NoError(t, json.NewEncoder(rw).Encode(want))
}))
defer srv.Close()
serverURL, err := url.Parse(srv.URL)
require.NoError(t, err)
client := codersdk.New(serverURL)
_, err = client.CreateChat(context.Background(), codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "hello",
}},
})
require.Error(t, err)
sdkErr, ok := codersdk.AsError(err)
require.True(t, ok)
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
require.Equal(t, want.Message, sdkErr.Message)
limitErr := codersdk.ChatUsageLimitExceededFrom(err)
require.NotNil(t, limitErr)
require.Equal(t, want, *limitErr)
})
t.Run("ReturnsNilForNonLimitErrors", func(t *testing.T) {
t.Parallel()
require.Nil(t, codersdk.ChatUsageLimitExceededFrom(codersdk.NewError(http.StatusConflict, codersdk.Response{Message: "plain conflict"})))
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusBadRequest)
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.Response{Message: "Invalid request."}))
}))
defer srv.Close()
serverURL, err := url.Parse(srv.URL)
require.NoError(t, err)
client := codersdk.New(serverURL)
_, err = client.CreateChat(context.Background(), codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "hello",
}},
})
require.Error(t, err)
sdkErr, ok := codersdk.AsError(err)
require.True(t, ok)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
require.Nil(t, codersdk.ChatUsageLimitExceededFrom(err))
})
}
func TestChatMessagePart_StripInternal(t *testing.T) {
t.Parallel()
@@ -193,85 +111,6 @@ func TestChatMessagePart_StripInternal(t *testing.T) {
})
}
// TestChatMessagePartVariantTags validates the `variants` struct tags
// on ChatMessagePart fields. Every field must either declare variant
// membership or be explicitly excluded, and every known part type
// must appear in at least one tag.
//
// If this test fails, edit the variants struct tags on ChatMessagePart
// in codersdk/chats.go.
func TestChatMessagePartVariantTags(t *testing.T) {
t.Parallel()
const editHint = "edit the variants struct tags on ChatMessagePart in codersdk/chats.go"
// Fields intentionally excluded from all generated variants.
// If you add a new field to ChatMessagePart, either add a
// variants tag or add it here with a comment explaining why.
excludedFields := map[string]string{
"type": "discriminant, added automatically by codegen",
"signature": "added in #22290, never populated by any code path",
"result_delta": "added in #22290, never populated by any code path",
"provider_metadata": "internal only, stripped by db2sdk before API responses",
}
knownTypes := make(map[codersdk.ChatMessagePartType]bool)
for _, pt := range codersdk.AllChatMessagePartTypes() {
knownTypes[pt] = true
}
// Parse all variants tags from the struct and validate them.
typ := reflect.TypeOf(codersdk.ChatMessagePart{})
coveredTypes := make(map[codersdk.ChatMessagePartType]bool)
hasRequired := make(map[codersdk.ChatMessagePartType]bool)
for i := range typ.NumField() {
f := typ.Field(i)
jsonTag := f.Tag.Get("json")
if jsonTag == "" || jsonTag == "-" {
continue
}
jsonName, _, _ := strings.Cut(jsonTag, ",")
varTag := f.Tag.Get("variants")
if varTag == "" {
assert.Contains(t, excludedFields, jsonName,
"field %s (json:%q) has no variants tag and is not in excludedFields; %s",
f.Name, jsonName, editHint)
continue
}
assert.NotEqual(t, "type", jsonName,
"the discriminant field must not have a variants tag; %s", editHint)
for _, entry := range strings.Split(varTag, ",") {
isOptional := strings.HasSuffix(entry, "?")
typeLit := codersdk.ChatMessagePartType(strings.TrimSuffix(entry, "?"))
assert.True(t, knownTypes[typeLit],
"field %s variants tag references unknown type %q; %s",
f.Name, typeLit, editHint)
coveredTypes[typeLit] = true
if !isOptional {
hasRequired[typeLit] = true
}
}
}
// Every known type must appear in at least one variants tag.
for pt := range knownTypes {
assert.True(t, coveredTypes[pt],
"ChatMessagePartType %q is not referenced by any variants tag; %s", pt, editHint)
}
// Every variant must have at least one required field.
for pt := range coveredTypes {
assert.True(t, hasRequired[pt],
"variant %q has no required fields (all have ? suffix); %s", pt, editHint)
}
}
func TestModelCostConfig_LegacyNumericJSON(t *testing.T) {
t.Parallel()
+1 -1
View File
@@ -587,7 +587,7 @@ type Response struct {
// ValidationError represents a scoped error to a user input.
type ValidationError struct {
Field string `json:"field" validate:"required"`
Field string `json:"field" validate:"required"`
Detail string `json:"detail" validate:"required"`
}
+4 -4
View File
@@ -12,12 +12,12 @@ import (
)
type ConnectionLog struct {
ID uuid.UUID `json:"id" format:"uuid"`
ConnectTime time.Time `json:"connect_time" format:"date-time"`
ID uuid.UUID `json:"id" format:"uuid"`
ConnectTime time.Time `json:"connect_time" format:"date-time"`
Organization MinimalOrganization `json:"organization"`
WorkspaceOwnerID uuid.UUID `json:"workspace_owner_id" format:"uuid"`
WorkspaceOwnerID uuid.UUID `json:"workspace_owner_id" format:"uuid"`
WorkspaceOwnerUsername string `json:"workspace_owner_username"`
WorkspaceID uuid.UUID `json:"workspace_id" format:"uuid"`
WorkspaceID uuid.UUID `json:"workspace_id" format:"uuid"`
WorkspaceName string `json:"workspace_name"`
AgentName string `json:"agent_name"`
IP *netip.Addr `json:"ip,omitempty"`
+185 -185
View File
@@ -386,8 +386,8 @@ type Feature struct {
type UsagePeriod struct {
IssuedAt time.Time `json:"issued_at" format:"date-time"`
Start time.Time `json:"start" format:"date-time"`
End time.Time `json:"end" format:"date-time"`
Start time.Time `json:"start" format:"date-time"`
End time.Time `json:"end" format:"date-time"`
}
// Compare compares two features and returns an integer representing
@@ -499,7 +499,7 @@ type Entitlements struct {
HasLicense bool `json:"has_license"`
Trial bool `json:"trial"`
RequireTelemetry bool `json:"require_telemetry"`
RefreshedAt time.Time `json:"refreshed_at" format:"date-time"`
RefreshedAt time.Time `json:"refreshed_at" format:"date-time"`
}
// AddFeature will add the feature to the entitlements iff it expands
@@ -579,71 +579,71 @@ type DeploymentValues struct {
DocsURL serpent.URL `json:"docs_url,omitempty"`
RedirectToAccessURL serpent.Bool `json:"redirect_to_access_url,omitempty"`
// HTTPAddress is a string because it may be set to zero to disable.
HTTPAddress serpent.String `json:"http_address,omitempty" typescript:",notnull"`
HTTPAddress serpent.String `json:"http_address,omitempty" typescript:",notnull"`
AutobuildPollInterval serpent.Duration `json:"autobuild_poll_interval,omitempty"`
JobReaperDetectorInterval serpent.Duration `json:"job_hang_detector_interval,omitempty"`
DERP DERP `json:"derp,omitempty" typescript:",notnull"`
Prometheus PrometheusConfig `json:"prometheus,omitempty" typescript:",notnull"`
Pprof PprofConfig `json:"pprof,omitempty" typescript:",notnull"`
ProxyTrustedHeaders serpent.StringArray `json:"proxy_trusted_headers,omitempty" typescript:",notnull"`
ProxyTrustedOrigins serpent.StringArray `json:"proxy_trusted_origins,omitempty" typescript:",notnull"`
CacheDir serpent.String `json:"cache_directory,omitempty" typescript:",notnull"`
EphemeralDeployment serpent.Bool `json:"ephemeral_deployment,omitempty" typescript:",notnull"`
PostgresURL serpent.String `json:"pg_connection_url,omitempty" typescript:",notnull"`
PostgresAuth string `json:"pg_auth,omitempty" typescript:",notnull"`
PostgresConnMaxOpen serpent.Int64 `json:"pg_conn_max_open,omitempty" typescript:",notnull"`
PostgresConnMaxIdle serpent.String `json:"pg_conn_max_idle,omitempty" typescript:",notnull"`
OAuth2 OAuth2Config `json:"oauth2,omitempty" typescript:",notnull"`
OIDC OIDCConfig `json:"oidc,omitempty" typescript:",notnull"`
Telemetry TelemetryConfig `json:"telemetry,omitempty" typescript:",notnull"`
TLS TLSConfig `json:"tls,omitempty" typescript:",notnull"`
Trace TraceConfig `json:"trace,omitempty" typescript:",notnull"`
HTTPCookies HTTPCookieConfig `json:"http_cookies,omitempty" typescript:",notnull"`
StrictTransportSecurity serpent.Int64 `json:"strict_transport_security,omitempty" typescript:",notnull"`
StrictTransportSecurityOptions serpent.StringArray `json:"strict_transport_security_options,omitempty" typescript:",notnull"`
SSHKeygenAlgorithm serpent.String `json:"ssh_keygen_algorithm,omitempty" typescript:",notnull"`
MetricsCacheRefreshInterval serpent.Duration `json:"metrics_cache_refresh_interval,omitempty" typescript:",notnull"`
AgentStatRefreshInterval serpent.Duration `json:"agent_stat_refresh_interval,omitempty" typescript:",notnull"`
AgentFallbackTroubleshootingURL serpent.URL `json:"agent_fallback_troubleshooting_url,omitempty" typescript:",notnull"`
BrowserOnly serpent.Bool `json:"browser_only,omitempty" typescript:",notnull"`
SCIMAPIKey serpent.String `json:"scim_api_key,omitempty" typescript:",notnull"`
ExternalTokenEncryptionKeys serpent.StringArray `json:"external_token_encryption_keys,omitempty" typescript:",notnull"`
Provisioner ProvisionerConfig `json:"provisioner,omitempty" typescript:",notnull"`
RateLimit RateLimitConfig `json:"rate_limit,omitempty" typescript:",notnull"`
Experiments serpent.StringArray `json:"experiments,omitempty" typescript:",notnull"`
UpdateCheck serpent.Bool `json:"update_check,omitempty" typescript:",notnull"`
Swagger SwaggerConfig `json:"swagger,omitempty" typescript:",notnull"`
Logging LoggingConfig `json:"logging,omitempty" typescript:",notnull"`
Dangerous DangerousConfig `json:"dangerous,omitempty" typescript:",notnull"`
DisablePathApps serpent.Bool `json:"disable_path_apps,omitempty" typescript:",notnull"`
Sessions SessionLifetime `json:"session_lifetime,omitempty" typescript:",notnull"`
DisablePasswordAuth serpent.Bool `json:"disable_password_auth,omitempty" typescript:",notnull"`
Support SupportConfig `json:"support,omitempty" typescript:",notnull"`
EnableAuthzRecording serpent.Bool `json:"enable_authz_recording,omitempty" typescript:",notnull"`
ExternalAuthConfigs serpent.Struct[[]ExternalAuthConfig] `json:"external_auth,omitempty" typescript:",notnull"`
DERP DERP `json:"derp,omitempty" typescript:",notnull"`
Prometheus PrometheusConfig `json:"prometheus,omitempty" typescript:",notnull"`
Pprof PprofConfig `json:"pprof,omitempty" typescript:",notnull"`
ProxyTrustedHeaders serpent.StringArray `json:"proxy_trusted_headers,omitempty" typescript:",notnull"`
ProxyTrustedOrigins serpent.StringArray `json:"proxy_trusted_origins,omitempty" typescript:",notnull"`
CacheDir serpent.String `json:"cache_directory,omitempty" typescript:",notnull"`
EphemeralDeployment serpent.Bool `json:"ephemeral_deployment,omitempty" typescript:",notnull"`
PostgresURL serpent.String `json:"pg_connection_url,omitempty" typescript:",notnull"`
PostgresAuth string `json:"pg_auth,omitempty" typescript:",notnull"`
PostgresConnMaxOpen serpent.Int64 `json:"pg_conn_max_open,omitempty" typescript:",notnull"`
PostgresConnMaxIdle serpent.String `json:"pg_conn_max_idle,omitempty" typescript:",notnull"`
OAuth2 OAuth2Config `json:"oauth2,omitempty" typescript:",notnull"`
OIDC OIDCConfig `json:"oidc,omitempty" typescript:",notnull"`
Telemetry TelemetryConfig `json:"telemetry,omitempty" typescript:",notnull"`
TLS TLSConfig `json:"tls,omitempty" typescript:",notnull"`
Trace TraceConfig `json:"trace,omitempty" typescript:",notnull"`
HTTPCookies HTTPCookieConfig `json:"http_cookies,omitempty" typescript:",notnull"`
StrictTransportSecurity serpent.Int64 `json:"strict_transport_security,omitempty" typescript:",notnull"`
StrictTransportSecurityOptions serpent.StringArray `json:"strict_transport_security_options,omitempty" typescript:",notnull"`
SSHKeygenAlgorithm serpent.String `json:"ssh_keygen_algorithm,omitempty" typescript:",notnull"`
MetricsCacheRefreshInterval serpent.Duration `json:"metrics_cache_refresh_interval,omitempty" typescript:",notnull"`
AgentStatRefreshInterval serpent.Duration `json:"agent_stat_refresh_interval,omitempty" typescript:",notnull"`
AgentFallbackTroubleshootingURL serpent.URL `json:"agent_fallback_troubleshooting_url,omitempty" typescript:",notnull"`
BrowserOnly serpent.Bool `json:"browser_only,omitempty" typescript:",notnull"`
SCIMAPIKey serpent.String `json:"scim_api_key,omitempty" typescript:",notnull"`
ExternalTokenEncryptionKeys serpent.StringArray `json:"external_token_encryption_keys,omitempty" typescript:",notnull"`
Provisioner ProvisionerConfig `json:"provisioner,omitempty" typescript:",notnull"`
RateLimit RateLimitConfig `json:"rate_limit,omitempty" typescript:",notnull"`
Experiments serpent.StringArray `json:"experiments,omitempty" typescript:",notnull"`
UpdateCheck serpent.Bool `json:"update_check,omitempty" typescript:",notnull"`
Swagger SwaggerConfig `json:"swagger,omitempty" typescript:",notnull"`
Logging LoggingConfig `json:"logging,omitempty" typescript:",notnull"`
Dangerous DangerousConfig `json:"dangerous,omitempty" typescript:",notnull"`
DisablePathApps serpent.Bool `json:"disable_path_apps,omitempty" typescript:",notnull"`
Sessions SessionLifetime `json:"session_lifetime,omitempty" typescript:",notnull"`
DisablePasswordAuth serpent.Bool `json:"disable_password_auth,omitempty" typescript:",notnull"`
Support SupportConfig `json:"support,omitempty" typescript:",notnull"`
EnableAuthzRecording serpent.Bool `json:"enable_authz_recording,omitempty" typescript:",notnull"`
ExternalAuthConfigs serpent.Struct[[]ExternalAuthConfig] `json:"external_auth,omitempty" typescript:",notnull"`
ExternalAuthGithubDefaultProviderEnable serpent.Bool `json:"external_auth_github_default_provider_enable,omitempty" typescript:",notnull"`
SSHConfig SSHConfig `json:"config_ssh,omitempty" typescript:",notnull"`
WgtunnelHost serpent.String `json:"wgtunnel_host,omitempty" typescript:",notnull"`
DisableOwnerWorkspaceExec serpent.Bool `json:"disable_owner_workspace_exec,omitempty" typescript:",notnull"`
DisableWorkspaceSharing serpent.Bool `json:"disable_workspace_sharing,omitempty" typescript:",notnull"`
ProxyHealthStatusInterval serpent.Duration `json:"proxy_health_status_interval,omitempty" typescript:",notnull"`
EnableTerraformDebugMode serpent.Bool `json:"enable_terraform_debug_mode,omitempty" typescript:",notnull"`
UserQuietHoursSchedule UserQuietHoursScheduleConfig `json:"user_quiet_hours_schedule,omitempty" typescript:",notnull"`
WebTerminalRenderer serpent.String `json:"web_terminal_renderer,omitempty" typescript:",notnull"`
AllowWorkspaceRenames serpent.Bool `json:"allow_workspace_renames,omitempty" typescript:",notnull"`
Healthcheck HealthcheckConfig `json:"healthcheck,omitempty" typescript:",notnull"`
Retention RetentionConfig `json:"retention,omitempty" typescript:",notnull"`
CLIUpgradeMessage serpent.String `json:"cli_upgrade_message,omitempty" typescript:",notnull"`
TermsOfServiceURL serpent.String `json:"terms_of_service_url,omitempty" typescript:",notnull"`
Notifications NotificationsConfig `json:"notifications,omitempty" typescript:",notnull"`
AdditionalCSPPolicy serpent.StringArray `json:"additional_csp_policy,omitempty" typescript:",notnull"`
WorkspaceHostnameSuffix serpent.String `json:"workspace_hostname_suffix,omitempty" typescript:",notnull"`
Prebuilds PrebuildsConfig `json:"workspace_prebuilds,omitempty" typescript:",notnull"`
HideAITasks serpent.Bool `json:"hide_ai_tasks,omitempty" typescript:",notnull"`
SSHConfig SSHConfig `json:"config_ssh,omitempty" typescript:",notnull"`
WgtunnelHost serpent.String `json:"wgtunnel_host,omitempty" typescript:",notnull"`
DisableOwnerWorkspaceExec serpent.Bool `json:"disable_owner_workspace_exec,omitempty" typescript:",notnull"`
DisableWorkspaceSharing serpent.Bool `json:"disable_workspace_sharing,omitempty" typescript:",notnull"`
ProxyHealthStatusInterval serpent.Duration `json:"proxy_health_status_interval,omitempty" typescript:",notnull"`
EnableTerraformDebugMode serpent.Bool `json:"enable_terraform_debug_mode,omitempty" typescript:",notnull"`
UserQuietHoursSchedule UserQuietHoursScheduleConfig `json:"user_quiet_hours_schedule,omitempty" typescript:",notnull"`
WebTerminalRenderer serpent.String `json:"web_terminal_renderer,omitempty" typescript:",notnull"`
AllowWorkspaceRenames serpent.Bool `json:"allow_workspace_renames,omitempty" typescript:",notnull"`
Healthcheck HealthcheckConfig `json:"healthcheck,omitempty" typescript:",notnull"`
Retention RetentionConfig `json:"retention,omitempty" typescript:",notnull"`
CLIUpgradeMessage serpent.String `json:"cli_upgrade_message,omitempty" typescript:",notnull"`
TermsOfServiceURL serpent.String `json:"terms_of_service_url,omitempty" typescript:",notnull"`
Notifications NotificationsConfig `json:"notifications,omitempty" typescript:",notnull"`
AdditionalCSPPolicy serpent.StringArray `json:"additional_csp_policy,omitempty" typescript:",notnull"`
WorkspaceHostnameSuffix serpent.String `json:"workspace_hostname_suffix,omitempty" typescript:",notnull"`
Prebuilds PrebuildsConfig `json:"workspace_prebuilds,omitempty" typescript:",notnull"`
HideAITasks serpent.Bool `json:"hide_ai_tasks,omitempty" typescript:",notnull"`
AI AIConfig `json:"ai,omitempty"`
StatsCollection StatsCollectionConfig `json:"stats_collection,omitempty" typescript:",notnull"`
StatsCollection StatsCollectionConfig `json:"stats_collection,omitempty" typescript:",notnull"`
Config serpent.YAMLConfigPath `json:"config,omitempty" typescript:",notnull"`
Config serpent.YAMLConfigPath `json:"config,omitempty" typescript:",notnull"`
WriteConfig serpent.Bool `json:"write_config,omitempty" typescript:",notnull"`
// Deprecated: Use HTTPAddress or TLS.Address instead.
@@ -726,19 +726,19 @@ type DERP struct {
}
type DERPServerConfig struct {
Enable serpent.Bool `json:"enable" typescript:",notnull"`
RegionID serpent.Int64 `json:"region_id" typescript:",notnull"`
RegionCode serpent.String `json:"region_code" typescript:",notnull"`
RegionName serpent.String `json:"region_name" typescript:",notnull"`
Enable serpent.Bool `json:"enable" typescript:",notnull"`
RegionID serpent.Int64 `json:"region_id" typescript:",notnull"`
RegionCode serpent.String `json:"region_code" typescript:",notnull"`
RegionName serpent.String `json:"region_name" typescript:",notnull"`
STUNAddresses serpent.StringArray `json:"stun_addresses" typescript:",notnull"`
RelayURL serpent.URL `json:"relay_url" typescript:",notnull"`
RelayURL serpent.URL `json:"relay_url" typescript:",notnull"`
}
type DERPConfig struct {
BlockDirect serpent.Bool `json:"block_direct" typescript:",notnull"`
BlockDirect serpent.Bool `json:"block_direct" typescript:",notnull"`
ForceWebSockets serpent.Bool `json:"force_websockets" typescript:",notnull"`
URL serpent.String `json:"url" typescript:",notnull"`
Path serpent.String `json:"path" typescript:",notnull"`
URL serpent.String `json:"url" typescript:",notnull"`
Path serpent.String `json:"path" typescript:",notnull"`
}
type UsageStatsConfig struct {
@@ -750,15 +750,15 @@ type StatsCollectionConfig struct {
}
type PrometheusConfig struct {
Enable serpent.Bool `json:"enable" typescript:",notnull"`
Address serpent.HostPort `json:"address" typescript:",notnull"`
CollectAgentStats serpent.Bool `json:"collect_agent_stats" typescript:",notnull"`
CollectDBMetrics serpent.Bool `json:"collect_db_metrics" typescript:",notnull"`
Enable serpent.Bool `json:"enable" typescript:",notnull"`
Address serpent.HostPort `json:"address" typescript:",notnull"`
CollectAgentStats serpent.Bool `json:"collect_agent_stats" typescript:",notnull"`
CollectDBMetrics serpent.Bool `json:"collect_db_metrics" typescript:",notnull"`
AggregateAgentStatsBy serpent.StringArray `json:"aggregate_agent_stats_by" typescript:",notnull"`
}
type PprofConfig struct {
Enable serpent.Bool `json:"enable" typescript:",notnull"`
Enable serpent.Bool `json:"enable" typescript:",notnull"`
Address serpent.HostPort `json:"address" typescript:",notnull"`
}
@@ -767,32 +767,32 @@ type OAuth2Config struct {
}
type OAuth2GithubConfig struct {
ClientID serpent.String `json:"client_id" typescript:",notnull"`
ClientSecret serpent.String `json:"client_secret" typescript:",notnull"`
DeviceFlow serpent.Bool `json:"device_flow" typescript:",notnull"`
ClientID serpent.String `json:"client_id" typescript:",notnull"`
ClientSecret serpent.String `json:"client_secret" typescript:",notnull"`
DeviceFlow serpent.Bool `json:"device_flow" typescript:",notnull"`
DefaultProviderEnable serpent.Bool `json:"default_provider_enable" typescript:",notnull"`
AllowedOrgs serpent.StringArray `json:"allowed_orgs" typescript:",notnull"`
AllowedTeams serpent.StringArray `json:"allowed_teams" typescript:",notnull"`
AllowSignups serpent.Bool `json:"allow_signups" typescript:",notnull"`
AllowEveryone serpent.Bool `json:"allow_everyone" typescript:",notnull"`
EnterpriseBaseURL serpent.String `json:"enterprise_base_url" typescript:",notnull"`
AllowedOrgs serpent.StringArray `json:"allowed_orgs" typescript:",notnull"`
AllowedTeams serpent.StringArray `json:"allowed_teams" typescript:",notnull"`
AllowSignups serpent.Bool `json:"allow_signups" typescript:",notnull"`
AllowEveryone serpent.Bool `json:"allow_everyone" typescript:",notnull"`
EnterpriseBaseURL serpent.String `json:"enterprise_base_url" typescript:",notnull"`
}
type OIDCConfig struct {
AllowSignups serpent.Bool `json:"allow_signups" typescript:",notnull"`
ClientID serpent.String `json:"client_id" typescript:",notnull"`
ClientID serpent.String `json:"client_id" typescript:",notnull"`
ClientSecret serpent.String `json:"client_secret" typescript:",notnull"`
// ClientKeyFile & ClientCertFile are used in place of ClientSecret for PKI auth.
ClientKeyFile serpent.String `json:"client_key_file" typescript:",notnull"`
ClientCertFile serpent.String `json:"client_cert_file" typescript:",notnull"`
EmailDomain serpent.StringArray `json:"email_domain" typescript:",notnull"`
IssuerURL serpent.String `json:"issuer_url" typescript:",notnull"`
Scopes serpent.StringArray `json:"scopes" typescript:",notnull"`
ClientKeyFile serpent.String `json:"client_key_file" typescript:",notnull"`
ClientCertFile serpent.String `json:"client_cert_file" typescript:",notnull"`
EmailDomain serpent.StringArray `json:"email_domain" typescript:",notnull"`
IssuerURL serpent.String `json:"issuer_url" typescript:",notnull"`
Scopes serpent.StringArray `json:"scopes" typescript:",notnull"`
IgnoreEmailVerified serpent.Bool `json:"ignore_email_verified" typescript:",notnull"`
UsernameField serpent.String `json:"username_field" typescript:",notnull"`
NameField serpent.String `json:"name_field" typescript:",notnull"`
EmailField serpent.String `json:"email_field" typescript:",notnull"`
AuthURLParams serpent.Struct[map[string]string] `json:"auth_url_params" typescript:",notnull"`
UsernameField serpent.String `json:"username_field" typescript:",notnull"`
NameField serpent.String `json:"name_field" typescript:",notnull"`
EmailField serpent.String `json:"email_field" typescript:",notnull"`
AuthURLParams serpent.Struct[map[string]string] `json:"auth_url_params" typescript:",notnull"`
// IgnoreUserInfo & UserInfoFromAccessToken are mutually exclusive. Only 1
// can be set to true. Ideally this would be an enum with 3 states, ['none',
// 'userinfo', 'access_token']. However, for backward compatibility,
@@ -804,21 +804,21 @@ type OIDCConfig struct {
// endpoint. This assumes the access token is a valid JWT with a set of claims to
// be merged with the id_token.
UserInfoFromAccessToken serpent.Bool `json:"source_user_info_from_access_token" typescript:",notnull"`
OrganizationField serpent.String `json:"organization_field" typescript:",notnull"`
OrganizationMapping serpent.Struct[map[string][]uuid.UUID] `json:"organization_mapping" typescript:",notnull"`
OrganizationAssignDefault serpent.Bool `json:"organization_assign_default" typescript:",notnull"`
GroupAutoCreate serpent.Bool `json:"group_auto_create" typescript:",notnull"`
GroupRegexFilter serpent.Regexp `json:"group_regex_filter" typescript:",notnull"`
GroupAllowList serpent.StringArray `json:"group_allow_list" typescript:",notnull"`
GroupField serpent.String `json:"groups_field" typescript:",notnull"`
GroupMapping serpent.Struct[map[string]string] `json:"group_mapping" typescript:",notnull"`
UserRoleField serpent.String `json:"user_role_field" typescript:",notnull"`
UserRoleMapping serpent.Struct[map[string][]string] `json:"user_role_mapping" typescript:",notnull"`
UserRolesDefault serpent.StringArray `json:"user_roles_default" typescript:",notnull"`
SignInText serpent.String `json:"sign_in_text" typescript:",notnull"`
IconURL serpent.URL `json:"icon_url" typescript:",notnull"`
SignupsDisabledText serpent.String `json:"signups_disabled_text" typescript:",notnull"`
SkipIssuerChecks serpent.Bool `json:"skip_issuer_checks" typescript:",notnull"`
OrganizationField serpent.String `json:"organization_field" typescript:",notnull"`
OrganizationMapping serpent.Struct[map[string][]uuid.UUID] `json:"organization_mapping" typescript:",notnull"`
OrganizationAssignDefault serpent.Bool `json:"organization_assign_default" typescript:",notnull"`
GroupAutoCreate serpent.Bool `json:"group_auto_create" typescript:",notnull"`
GroupRegexFilter serpent.Regexp `json:"group_regex_filter" typescript:",notnull"`
GroupAllowList serpent.StringArray `json:"group_allow_list" typescript:",notnull"`
GroupField serpent.String `json:"groups_field" typescript:",notnull"`
GroupMapping serpent.Struct[map[string]string] `json:"group_mapping" typescript:",notnull"`
UserRoleField serpent.String `json:"user_role_field" typescript:",notnull"`
UserRoleMapping serpent.Struct[map[string][]string] `json:"user_role_mapping" typescript:",notnull"`
UserRolesDefault serpent.StringArray `json:"user_roles_default" typescript:",notnull"`
SignInText serpent.String `json:"sign_in_text" typescript:",notnull"`
IconURL serpent.URL `json:"icon_url" typescript:",notnull"`
SignupsDisabledText serpent.String `json:"signups_disabled_text" typescript:",notnull"`
SkipIssuerChecks serpent.Bool `json:"skip_issuer_checks" typescript:",notnull"`
// RedirectURL is optional, defaulting to 'ACCESS_URL'. Only useful in niche
// situations where the OIDC callback domain is different from the ACCESS_URL
@@ -828,38 +828,38 @@ type OIDCConfig struct {
type TelemetryConfig struct {
Enable serpent.Bool `json:"enable" typescript:",notnull"`
Trace serpent.Bool `json:"trace" typescript:",notnull"`
URL serpent.URL `json:"url" typescript:",notnull"`
Trace serpent.Bool `json:"trace" typescript:",notnull"`
URL serpent.URL `json:"url" typescript:",notnull"`
}
type TLSConfig struct {
Enable serpent.Bool `json:"enable" typescript:",notnull"`
Address serpent.HostPort `json:"address" typescript:",notnull"`
RedirectHTTP serpent.Bool `json:"redirect_http" typescript:",notnull"`
CertFiles serpent.StringArray `json:"cert_file" typescript:",notnull"`
ClientAuth serpent.String `json:"client_auth" typescript:",notnull"`
ClientCAFile serpent.String `json:"client_ca_file" typescript:",notnull"`
KeyFiles serpent.StringArray `json:"key_file" typescript:",notnull"`
MinVersion serpent.String `json:"min_version" typescript:",notnull"`
ClientCertFile serpent.String `json:"client_cert_file" typescript:",notnull"`
ClientKeyFile serpent.String `json:"client_key_file" typescript:",notnull"`
SupportedCiphers serpent.StringArray `json:"supported_ciphers" typescript:",notnull"`
Enable serpent.Bool `json:"enable" typescript:",notnull"`
Address serpent.HostPort `json:"address" typescript:",notnull"`
RedirectHTTP serpent.Bool `json:"redirect_http" typescript:",notnull"`
CertFiles serpent.StringArray `json:"cert_file" typescript:",notnull"`
ClientAuth serpent.String `json:"client_auth" typescript:",notnull"`
ClientCAFile serpent.String `json:"client_ca_file" typescript:",notnull"`
KeyFiles serpent.StringArray `json:"key_file" typescript:",notnull"`
MinVersion serpent.String `json:"min_version" typescript:",notnull"`
ClientCertFile serpent.String `json:"client_cert_file" typescript:",notnull"`
ClientKeyFile serpent.String `json:"client_key_file" typescript:",notnull"`
SupportedCiphers serpent.StringArray `json:"supported_ciphers" typescript:",notnull"`
AllowInsecureCiphers serpent.Bool `json:"allow_insecure_ciphers" typescript:",notnull"`
}
type TraceConfig struct {
Enable serpent.Bool `json:"enable" typescript:",notnull"`
Enable serpent.Bool `json:"enable" typescript:",notnull"`
HoneycombAPIKey serpent.String `json:"honeycomb_api_key" typescript:",notnull"`
CaptureLogs serpent.Bool `json:"capture_logs" typescript:",notnull"`
DataDog serpent.Bool `json:"data_dog" typescript:",notnull"`
CaptureLogs serpent.Bool `json:"capture_logs" typescript:",notnull"`
DataDog serpent.Bool `json:"data_dog" typescript:",notnull"`
}
const cookieHostPrefix = "__Host-"
type HTTPCookieConfig struct {
Secure serpent.Bool `json:"secure_auth_cookie,omitempty" typescript:",notnull"`
SameSite string `json:"same_site,omitempty" typescript:",notnull"`
EnableHostPrefix bool `json:"host_prefix,omitempty" typescript:",notnull"`
SameSite string `json:"same_site,omitempty" typescript:",notnull"`
EnableHostPrefix bool `json:"host_prefix,omitempty" typescript:",notnull"`
}
// cookiesToPrefix is the set of cookies that should be prefixed with the host prefix if EnableHostPrefix is true.
@@ -953,23 +953,23 @@ func (cfg HTTPCookieConfig) HTTPSameSite() http.SameSite {
type ExternalAuthConfig struct {
// Type is the type of external auth config.
Type string `json:"type" yaml:"type"`
Type string `json:"type" yaml:"type"`
ClientID string `json:"client_id" yaml:"client_id"`
ClientSecret string `json:"-" yaml:"client_secret"`
ClientSecret string `json:"-" yaml:"client_secret"`
// ID is a unique identifier for the auth config.
// It defaults to `type` when not provided.
ID string `json:"id" yaml:"id"`
AuthURL string `json:"auth_url" yaml:"auth_url"`
TokenURL string `json:"token_url" yaml:"token_url"`
ValidateURL string `json:"validate_url" yaml:"validate_url"`
RevokeURL string `json:"revoke_url" yaml:"revoke_url"`
AppInstallURL string `json:"app_install_url" yaml:"app_install_url"`
ID string `json:"id" yaml:"id"`
AuthURL string `json:"auth_url" yaml:"auth_url"`
TokenURL string `json:"token_url" yaml:"token_url"`
ValidateURL string `json:"validate_url" yaml:"validate_url"`
RevokeURL string `json:"revoke_url" yaml:"revoke_url"`
AppInstallURL string `json:"app_install_url" yaml:"app_install_url"`
AppInstallationsURL string `json:"app_installations_url" yaml:"app_installations_url"`
NoRefresh bool `json:"no_refresh" yaml:"no_refresh"`
Scopes []string `json:"scopes" yaml:"scopes"`
ExtraTokenKeys []string `json:"-" yaml:"extra_token_keys"`
DeviceFlow bool `json:"device_flow" yaml:"device_flow"`
DeviceCodeURL string `json:"device_code_url" yaml:"device_code_url"`
NoRefresh bool `json:"no_refresh" yaml:"no_refresh"`
Scopes []string `json:"scopes" yaml:"scopes"`
ExtraTokenKeys []string `json:"-" yaml:"extra_token_keys"`
DeviceFlow bool `json:"device_flow" yaml:"device_flow"`
DeviceCodeURL string `json:"device_code_url" yaml:"device_code_url"`
// Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.
MCPURL string `json:"mcp_url" yaml:"mcp_url"`
// Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.
@@ -998,17 +998,17 @@ type ExternalAuthConfig struct {
type ProvisionerConfig struct {
// Daemons is the number of built-in terraform provisioners.
Daemons serpent.Int64 `json:"daemons" typescript:",notnull"`
DaemonTypes serpent.StringArray `json:"daemon_types" typescript:",notnull"`
DaemonPollInterval serpent.Duration `json:"daemon_poll_interval" typescript:",notnull"`
DaemonPollJitter serpent.Duration `json:"daemon_poll_jitter" typescript:",notnull"`
Daemons serpent.Int64 `json:"daemons" typescript:",notnull"`
DaemonTypes serpent.StringArray `json:"daemon_types" typescript:",notnull"`
DaemonPollInterval serpent.Duration `json:"daemon_poll_interval" typescript:",notnull"`
DaemonPollJitter serpent.Duration `json:"daemon_poll_jitter" typescript:",notnull"`
ForceCancelInterval serpent.Duration `json:"force_cancel_interval" typescript:",notnull"`
DaemonPSK serpent.String `json:"daemon_psk" typescript:",notnull"`
DaemonPSK serpent.String `json:"daemon_psk" typescript:",notnull"`
}
type RateLimitConfig struct {
DisableAll serpent.Bool `json:"disable_all" typescript:",notnull"`
API serpent.Int64 `json:"api" typescript:",notnull"`
API serpent.Int64 `json:"api" typescript:",notnull"`
}
type SwaggerConfig struct {
@@ -1016,20 +1016,20 @@ type SwaggerConfig struct {
}
type LoggingConfig struct {
Filter serpent.StringArray `json:"log_filter" typescript:",notnull"`
Human serpent.String `json:"human" typescript:",notnull"`
JSON serpent.String `json:"json" typescript:",notnull"`
Filter serpent.StringArray `json:"log_filter" typescript:",notnull"`
Human serpent.String `json:"human" typescript:",notnull"`
JSON serpent.String `json:"json" typescript:",notnull"`
Stackdriver serpent.String `json:"stackdriver" typescript:",notnull"`
}
type DangerousConfig struct {
AllowPathAppSharing serpent.Bool `json:"allow_path_app_sharing" typescript:",notnull"`
AllowPathAppSharing serpent.Bool `json:"allow_path_app_sharing" typescript:",notnull"`
AllowPathAppSiteOwnerAccess serpent.Bool `json:"allow_path_app_site_owner_access" typescript:",notnull"`
AllowAllCors serpent.Bool `json:"allow_all_cors" typescript:",notnull"`
AllowAllCors serpent.Bool `json:"allow_all_cors" typescript:",notnull"`
}
type UserQuietHoursScheduleConfig struct {
DefaultSchedule serpent.String `json:"default_schedule" typescript:",notnull"`
DefaultSchedule serpent.String `json:"default_schedule" typescript:",notnull"`
AllowUserCustom serpent.Bool `json:"allow_user_custom" typescript:",notnull"`
// TODO: add WindowDuration and the ability to postpone max_deadline by this
// amount
@@ -1038,7 +1038,7 @@ type UserQuietHoursScheduleConfig struct {
// HealthcheckConfig contains configuration for healthchecks.
type HealthcheckConfig struct {
Refresh serpent.Duration `json:"refresh" typescript:",notnull"`
Refresh serpent.Duration `json:"refresh" typescript:",notnull"`
ThresholdDatabase serpent.Duration `json:"threshold_database" typescript:",notnull"`
}
@@ -4001,54 +4001,54 @@ Write out the current server config as YAML to stdout.`,
}
type AIBridgeConfig struct {
Enabled serpent.Bool `json:"enabled" typescript:",notnull"`
OpenAI AIBridgeOpenAIConfig `json:"openai" typescript:",notnull"`
Enabled serpent.Bool `json:"enabled" typescript:",notnull"`
OpenAI AIBridgeOpenAIConfig `json:"openai" typescript:",notnull"`
Anthropic AIBridgeAnthropicConfig `json:"anthropic" typescript:",notnull"`
Bedrock AIBridgeBedrockConfig `json:"bedrock" typescript:",notnull"`
Bedrock AIBridgeBedrockConfig `json:"bedrock" typescript:",notnull"`
// Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.
InjectCoderMCPTools serpent.Bool `json:"inject_coder_mcp_tools" typescript:",notnull"`
Retention serpent.Duration `json:"retention" typescript:",notnull"`
MaxConcurrency serpent.Int64 `json:"max_concurrency" typescript:",notnull"`
RateLimit serpent.Int64 `json:"rate_limit" typescript:",notnull"`
StructuredLogging serpent.Bool `json:"structured_logging" typescript:",notnull"`
SendActorHeaders serpent.Bool `json:"send_actor_headers" typescript:",notnull"`
Retention serpent.Duration `json:"retention" typescript:",notnull"`
MaxConcurrency serpent.Int64 `json:"max_concurrency" typescript:",notnull"`
RateLimit serpent.Int64 `json:"rate_limit" typescript:",notnull"`
StructuredLogging serpent.Bool `json:"structured_logging" typescript:",notnull"`
SendActorHeaders serpent.Bool `json:"send_actor_headers" typescript:",notnull"`
// Circuit breaker protects against cascading failures from upstream AI
// provider rate limits (429, 503, 529 overloaded).
CircuitBreakerEnabled serpent.Bool `json:"circuit_breaker_enabled" typescript:",notnull"`
CircuitBreakerEnabled serpent.Bool `json:"circuit_breaker_enabled" typescript:",notnull"`
CircuitBreakerFailureThreshold serpent.Int64 `json:"circuit_breaker_failure_threshold" typescript:",notnull"`
CircuitBreakerInterval serpent.Duration `json:"circuit_breaker_interval" typescript:",notnull"`
CircuitBreakerTimeout serpent.Duration `json:"circuit_breaker_timeout" typescript:",notnull"`
CircuitBreakerMaxRequests serpent.Int64 `json:"circuit_breaker_max_requests" typescript:",notnull"`
CircuitBreakerInterval serpent.Duration `json:"circuit_breaker_interval" typescript:",notnull"`
CircuitBreakerTimeout serpent.Duration `json:"circuit_breaker_timeout" typescript:",notnull"`
CircuitBreakerMaxRequests serpent.Int64 `json:"circuit_breaker_max_requests" typescript:",notnull"`
}
type AIBridgeOpenAIConfig struct {
BaseURL serpent.String `json:"base_url" typescript:",notnull"`
Key serpent.String `json:"key" typescript:",notnull"`
Key serpent.String `json:"key" typescript:",notnull"`
}
type AIBridgeAnthropicConfig struct {
BaseURL serpent.String `json:"base_url" typescript:",notnull"`
Key serpent.String `json:"key" typescript:",notnull"`
Key serpent.String `json:"key" typescript:",notnull"`
}
type AIBridgeBedrockConfig struct {
BaseURL serpent.String `json:"base_url" typescript:",notnull"`
Region serpent.String `json:"region" typescript:",notnull"`
AccessKey serpent.String `json:"access_key" typescript:",notnull"`
BaseURL serpent.String `json:"base_url" typescript:",notnull"`
Region serpent.String `json:"region" typescript:",notnull"`
AccessKey serpent.String `json:"access_key" typescript:",notnull"`
AccessKeySecret serpent.String `json:"access_key_secret" typescript:",notnull"`
Model serpent.String `json:"model" typescript:",notnull"`
SmallFastModel serpent.String `json:"small_fast_model" typescript:",notnull"`
Model serpent.String `json:"model" typescript:",notnull"`
SmallFastModel serpent.String `json:"small_fast_model" typescript:",notnull"`
}
type AIBridgeProxyConfig struct {
Enabled serpent.Bool `json:"enabled" typescript:",notnull"`
ListenAddr serpent.String `json:"listen_addr" typescript:",notnull"`
TLSCertFile serpent.String `json:"tls_cert_file" typescript:",notnull"`
TLSKeyFile serpent.String `json:"tls_key_file" typescript:",notnull"`
MITMCertFile serpent.String `json:"cert_file" typescript:",notnull"`
MITMKeyFile serpent.String `json:"key_file" typescript:",notnull"`
DomainAllowlist serpent.StringArray `json:"domain_allowlist" typescript:",notnull"`
UpstreamProxy serpent.String `json:"upstream_proxy" typescript:",notnull"`
Enabled serpent.Bool `json:"enabled" typescript:",notnull"`
ListenAddr serpent.String `json:"listen_addr" typescript:",notnull"`
TLSCertFile serpent.String `json:"tls_cert_file" typescript:",notnull"`
TLSKeyFile serpent.String `json:"tls_key_file" typescript:",notnull"`
MITMCertFile serpent.String `json:"cert_file" typescript:",notnull"`
MITMKeyFile serpent.String `json:"key_file" typescript:",notnull"`
DomainAllowlist serpent.StringArray `json:"domain_allowlist" typescript:",notnull"`
UpstreamProxy serpent.String `json:"upstream_proxy" typescript:",notnull"`
UpstreamProxyCA serpent.String `json:"upstream_proxy_ca" typescript:",notnull"`
}
@@ -4062,11 +4062,11 @@ type SupportConfig struct {
}
type LinkConfig struct {
Name string `json:"name" yaml:"name"`
Target string `json:"target" yaml:"target"`
Icon string `json:"icon" enums:"bug,chat,docs,star" yaml:"icon"`
Name string `json:"name" yaml:"name"`
Target string `json:"target" yaml:"target"`
Icon string `json:"icon" yaml:"icon" enums:"bug,chat,docs,star"`
Location string `json:"location,omitempty" enums:"navbar,dropdown" yaml:"location,omitempty"`
Location string `json:"location,omitempty" yaml:"location,omitempty" enums:"navbar,dropdown"`
}
// Validate checks cross-field constraints for deployment values.
@@ -4574,7 +4574,7 @@ type CryptoKey struct {
Secret string `json:"secret"`
DeletesAt time.Time `json:"deletes_at" format:"date-time"`
Sequence int32 `json:"sequence"`
StartsAt time.Time `json:"starts_at" format:"date-time"`
StartsAt time.Time `json:"starts_at" format:"date-time"`
}
func (c CryptoKey) CanSign(now time.Time) bool {

Some files were not shown because too many files have changed in this diff Show More