Compare commits
94 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 986d6a856d | |||
| 351ab6c7c7 | |||
| d3d0ea3622 | |||
| c0b71a1161 | |||
| 828c9e23f5 | |||
| a130a7dc97 | |||
| d6fef96d72 | |||
| 4dd8531f37 | |||
| 3bcb7de7c0 | |||
| 1e07ec49a6 | |||
| 84de391f26 | |||
| b83b93ea5c | |||
| 014e5b4f57 | |||
| fc3508dc60 | |||
| 8b4d35798a | |||
| d69dcf18de | |||
| fe82d0aeb9 | |||
| 81dba9da14 | |||
| 20ac96e68d | |||
| 677f90b78a | |||
| d697213373 | |||
| 62144d230f | |||
| 0d0c6c956d | |||
| 488ceb6e58 | |||
| 481c132135 | |||
| d42008e93d | |||
| aa3cee6410 | |||
| 4f566f92b5 | |||
| bd5b62c976 | |||
| 66f809388e | |||
| 563c00fb2c | |||
| 817fb4e67a | |||
| 2cf47ec384 | |||
| 11481d7bed | |||
| f3bf5baba0 | |||
| 9df7fda5f6 | |||
| 665db7bdeb | |||
| 903cfb183f | |||
| 49e5547c22 | |||
| f9c265ca6e | |||
| a65a31a5a3 | |||
| 22a4a33886 | |||
| d3c9469e13 | |||
| 91ec0f1484 | |||
| 6b76e30321 | |||
| 6fc9f195f1 | |||
| c2243addce | |||
| cd163d404b | |||
| 41d12b8aa3 | |||
| 497e1e6589 | |||
| b779c9ee33 | |||
| 144b32a4b6 | |||
| a40716b6fe | |||
| 635c5d52a8 | |||
| 075dfecd12 | |||
| fdb1205bdf | |||
| 33a47fced3 | |||
| ca5158f94a | |||
| b7e0f42591 | |||
| 41bd7acf66 | |||
| 87d4a29371 | |||
| a797a494ef | |||
| a33605df58 | |||
| 3c430a67fa | |||
| abee77ac2f | |||
| 7946dc6645 | |||
| eb828a6a86 | |||
| 4e2d7ffaa7 | |||
| 524bca4c87 | |||
| 365de3e367 | |||
| 5d0eb772da | |||
| 04fca84872 | |||
| 7cca2b6176 | |||
| 1031da9738 | |||
| b69631cb35 | |||
| 7b0aa31b55 | |||
| 93b9d70a9b | |||
| 6972d073a2 | |||
| 89bb5bb945 | |||
| b7eab35734 | |||
| 3f76f312e4 | |||
| abf59ee7a6 | |||
| cabb611fd9 | |||
| b2d8b67ff7 | |||
| c1884148f0 | |||
| 741af057dc | |||
| 32a894d4a7 | |||
| 4fdd48b3f5 | |||
| e94de0bdab | |||
| fa8693605f | |||
| af1be592cf | |||
| 6f97539122 | |||
| 530872873e | |||
| 115011bd70 |
@@ -1198,7 +1198,7 @@ jobs:
|
||||
make -j \
|
||||
build/coder_linux_{amd64,arm64,armv7} \
|
||||
build/coder_"$version"_windows_amd64.zip \
|
||||
build/coder_"$version"_linux_amd64.{tar.gz,deb}
|
||||
build/coder_"$version"_linux_{amd64,arm64,armv7}.{tar.gz,deb}
|
||||
env:
|
||||
# The Windows and Darwin slim binaries must be signed for Coder
|
||||
# Desktop to accept them.
|
||||
@@ -1216,11 +1216,28 @@ jobs:
|
||||
GCLOUD_ACCESS_TOKEN: ${{ steps.gcloud_auth.outputs.access_token }}
|
||||
JSIGN_PATH: /tmp/jsign-6.0.jar
|
||||
|
||||
# Free up disk space before building Docker images. The preceding
|
||||
# Build step produces ~2 GB of binaries and packages, the Go build
|
||||
# cache is ~1.3 GB, and node_modules is ~500 MB. Docker image
|
||||
# builds, pushes, and SBOM generation need headroom that isn't
|
||||
# available without reclaiming some of that space.
|
||||
- name: Clean up build cache
|
||||
run: |
|
||||
set -euxo pipefail
|
||||
# Go caches are no longer needed — binaries are already compiled.
|
||||
go clean -cache -modcache
|
||||
# Remove .apk and .rpm packages that are not uploaded as
|
||||
# artifacts and were only built as make prerequisites.
|
||||
rm -f ./build/*.apk ./build/*.rpm
|
||||
|
||||
- name: Build Linux Docker images
|
||||
id: build-docker
|
||||
env:
|
||||
CODER_IMAGE_BASE: ghcr.io/coder/coder-preview
|
||||
DOCKER_CLI_EXPERIMENTAL: "enabled"
|
||||
# Skip building .deb/.rpm/.apk/.tar.gz as prerequisites for
|
||||
# the Docker image targets — they were already built above.
|
||||
DOCKER_IMAGE_NO_PREREQUISITES: "true"
|
||||
run: |
|
||||
set -euxo pipefail
|
||||
|
||||
|
||||
@@ -23,6 +23,44 @@ permissions:
|
||||
concurrency: pr-${{ github.ref }}
|
||||
|
||||
jobs:
|
||||
community-label:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
if: >-
|
||||
${{
|
||||
github.event_name == 'pull_request_target' &&
|
||||
github.event.action == 'opened' &&
|
||||
github.event.pull_request.author_association != 'MEMBER' &&
|
||||
github.event.pull_request.author_association != 'COLLABORATOR' &&
|
||||
github.event.pull_request.author_association != 'OWNER'
|
||||
}}
|
||||
steps:
|
||||
- name: Add community label
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
script: |
|
||||
const params = {
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
}
|
||||
|
||||
const labels = context.payload.pull_request.labels.map((label) => label.name)
|
||||
if (labels.includes("community")) {
|
||||
console.log('PR already has "community" label.')
|
||||
return
|
||||
}
|
||||
|
||||
console.log(
|
||||
'Adding "community" label for author association "%s".',
|
||||
context.payload.pull_request.author_association,
|
||||
)
|
||||
await github.rest.issues.addLabels({
|
||||
...params,
|
||||
labels: ["community"],
|
||||
})
|
||||
|
||||
cla:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
|
||||
@@ -136,18 +136,10 @@ endif
|
||||
# the search path so that these exclusions match.
|
||||
FIND_EXCLUSIONS= \
|
||||
-not \( \( -path '*/.git/*' -o -path './build/*' -o -path './vendor/*' -o -path './.coderv2/*' -o -path '*/node_modules/*' -o -path '*/out/*' -o -path './coderd/apidoc/*' -o -path '*/.next/*' -o -path '*/.terraform/*' -o -path './_gen/*' \) -prune \)
|
||||
|
||||
# Source files used for make targets, evaluated on use.
|
||||
GO_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.go' -not -name '*_test.go')
|
||||
# Same as GO_SRC_FILES but excluding certain files that have problematic
|
||||
# Makefile dependencies (e.g. pnpm).
|
||||
MOST_GO_SRC_FILES := $(shell \
|
||||
find . \
|
||||
$(FIND_EXCLUSIONS) \
|
||||
-type f \
|
||||
-name '*.go' \
|
||||
-not -name '*_test.go' \
|
||||
-not -wholename './agent/agentcontainers/dcspec/dcspec_gen.go' \
|
||||
)
|
||||
|
||||
# All the shell files in the repo, excluding ignored files.
|
||||
SHELL_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.sh')
|
||||
|
||||
@@ -514,7 +506,10 @@ install: build/coder_$(VERSION)_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
|
||||
cp "$<" "$$output_file"
|
||||
.PHONY: install
|
||||
|
||||
build/.bin/develop: go.mod go.sum $(GO_SRC_FILES)
|
||||
# Only wildcard the go files in the develop directory to avoid rebuilds
|
||||
# when project files are changd. Technically changes to some imports may
|
||||
# not be detected, but it's unlikely to cause any issues.
|
||||
build/.bin/develop: go.mod go.sum $(wildcard scripts/develop/*.go)
|
||||
CGO_ENABLED=0 go build -o $@ ./scripts/develop
|
||||
|
||||
BOLD := $(shell tput bold 2>/dev/null)
|
||||
|
||||
+1
-1
@@ -389,7 +389,7 @@ func (a *agent) init() {
|
||||
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
|
||||
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
|
||||
desktop := agentdesktop.NewPortableDesktop(
|
||||
a.logger.Named("desktop"), a.execer, a.scriptDataDir,
|
||||
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(),
|
||||
)
|
||||
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
|
||||
a.reconnectingPTYServer = reconnectingpty.NewServer(
|
||||
|
||||
@@ -2,13 +2,9 @@ package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -24,28 +20,6 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
const (
|
||||
portableDesktopVersion = "v0.0.4"
|
||||
downloadRetries = 3
|
||||
downloadRetryDelay = time.Second
|
||||
)
|
||||
|
||||
// platformBinaries maps GOARCH to download URL and expected SHA-256
|
||||
// digest for each supported platform.
|
||||
var platformBinaries = map[string]struct {
|
||||
URL string
|
||||
SHA256 string
|
||||
}{
|
||||
"amd64": {
|
||||
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-x64",
|
||||
SHA256: "a04e05e6c7d6f2e6b3acbf1729a7b21271276300b4fee321f4ffee6136538317",
|
||||
},
|
||||
"arm64": {
|
||||
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-arm64",
|
||||
SHA256: "b8cb9142dc32d46a608f25229cbe8168ff2a3aadc54253c74ff54cd347e16ca6",
|
||||
},
|
||||
}
|
||||
|
||||
// portableDesktopOutput is the JSON output from
|
||||
// `portabledesktop up --json`.
|
||||
type portableDesktopOutput struct {
|
||||
@@ -78,43 +52,31 @@ type screenshotOutput struct {
|
||||
// portableDesktop implements Desktop by shelling out to the
|
||||
// portabledesktop CLI via agentexec.Execer.
|
||||
type portableDesktop struct {
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
dataDir string // agent's ScriptDataDir, used for binary caching
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
scriptBinDir string // coder script bin directory
|
||||
|
||||
mu sync.Mutex
|
||||
session *desktopSession // nil until started
|
||||
binPath string // resolved path to binary, cached
|
||||
closed bool
|
||||
|
||||
// httpClient is used for downloading the binary. If nil,
|
||||
// http.DefaultClient is used.
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewPortableDesktop creates a Desktop backed by the portabledesktop
|
||||
// CLI binary, using execer to spawn child processes. dataDir is used
|
||||
// to cache the downloaded binary.
|
||||
// CLI binary, using execer to spawn child processes. scriptBinDir is
|
||||
// the coder script bin directory checked for the binary.
|
||||
func NewPortableDesktop(
|
||||
logger slog.Logger,
|
||||
execer agentexec.Execer,
|
||||
dataDir string,
|
||||
scriptBinDir string,
|
||||
) Desktop {
|
||||
return &portableDesktop{
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
dataDir: dataDir,
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
scriptBinDir: scriptBinDir,
|
||||
}
|
||||
}
|
||||
|
||||
// httpDo returns the HTTP client to use for downloads.
|
||||
func (p *portableDesktop) httpDo() *http.Client {
|
||||
if p.httpClient != nil {
|
||||
return p.httpClient
|
||||
}
|
||||
return http.DefaultClient
|
||||
}
|
||||
|
||||
// Start launches the desktop session (idempotent).
|
||||
func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
|
||||
p.mu.Lock()
|
||||
@@ -399,8 +361,8 @@ func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, e
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
// ensureBinary resolves or downloads the portabledesktop binary. It
|
||||
// must be called while p.mu is held.
|
||||
// ensureBinary resolves the portabledesktop binary from PATH or the
|
||||
// coder script bin directory. It must be called while p.mu is held.
|
||||
func (p *portableDesktop) ensureBinary(ctx context.Context) error {
|
||||
if p.binPath != "" {
|
||||
return nil
|
||||
@@ -415,130 +377,23 @@ func (p *portableDesktop) ensureBinary(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 2. Platform checks.
|
||||
if runtime.GOOS != "linux" {
|
||||
return xerrors.New("portabledesktop is only supported on Linux")
|
||||
}
|
||||
bin, ok := platformBinaries[runtime.GOARCH]
|
||||
if !ok {
|
||||
return xerrors.Errorf("unsupported architecture for portabledesktop: %s", runtime.GOARCH)
|
||||
}
|
||||
|
||||
// 3. Check cache.
|
||||
cacheDir := filepath.Join(p.dataDir, "portabledesktop", bin.SHA256)
|
||||
cachedPath := filepath.Join(cacheDir, "portabledesktop")
|
||||
|
||||
if info, err := os.Stat(cachedPath); err == nil && !info.IsDir() {
|
||||
// Verify it is executable.
|
||||
if info.Mode()&0o100 != 0 {
|
||||
p.logger.Info(ctx, "using cached portabledesktop binary",
|
||||
slog.F("path", cachedPath),
|
||||
// 2. Check the coder script bin directory.
|
||||
scriptBinPath := filepath.Join(p.scriptBinDir, "portabledesktop")
|
||||
if info, err := os.Stat(scriptBinPath); err == nil && !info.IsDir() {
|
||||
// On Windows, permission bits don't indicate executability,
|
||||
// so accept any regular file.
|
||||
if runtime.GOOS == "windows" || info.Mode()&0o111 != 0 {
|
||||
p.logger.Info(ctx, "found portabledesktop in script bin directory",
|
||||
slog.F("path", scriptBinPath),
|
||||
)
|
||||
p.binPath = cachedPath
|
||||
p.binPath = scriptBinPath
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Download with retry.
|
||||
p.logger.Info(ctx, "downloading portabledesktop binary",
|
||||
slog.F("url", bin.URL),
|
||||
slog.F("version", portableDesktopVersion),
|
||||
slog.F("arch", runtime.GOARCH),
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
for attempt := range downloadRetries {
|
||||
if err := downloadBinary(ctx, p.httpDo(), bin.URL, bin.SHA256, cachedPath); err != nil {
|
||||
lastErr = err
|
||||
p.logger.Warn(ctx, "download attempt failed",
|
||||
slog.F("attempt", attempt+1),
|
||||
slog.F("max_attempts", downloadRetries),
|
||||
slog.Error(err),
|
||||
)
|
||||
if attempt < downloadRetries-1 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(downloadRetryDelay):
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
p.binPath = cachedPath
|
||||
p.logger.Info(ctx, "downloaded portabledesktop binary",
|
||||
slog.F("path", cachedPath),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
return xerrors.Errorf("download portabledesktop after %d attempts: %w", downloadRetries, lastErr)
|
||||
}
|
||||
|
||||
// downloadBinary fetches a binary from url, verifies its SHA-256
|
||||
// digest matches expectedSHA256, and atomically writes it to destPath.
|
||||
func downloadBinary(ctx context.Context, client *http.Client, url, expectedSHA256, destPath string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(destPath), 0o700); err != nil {
|
||||
return xerrors.Errorf("create cache directory: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create HTTP request: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("HTTP GET %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return xerrors.Errorf("HTTP GET %s: status %d", url, resp.StatusCode)
|
||||
}
|
||||
|
||||
// Write to a temp file in the same directory so the final rename
|
||||
// is atomic on the same filesystem.
|
||||
tmpFile, err := os.CreateTemp(filepath.Dir(destPath), "portabledesktop-download-*")
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
// Clean up the temp file on any error path.
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
}()
|
||||
|
||||
// Stream the response body while computing SHA-256.
|
||||
hasher := sha256.New()
|
||||
if _, err := io.Copy(tmpFile, io.TeeReader(resp.Body, hasher)); err != nil {
|
||||
return xerrors.Errorf("download body: %w", err)
|
||||
}
|
||||
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return xerrors.Errorf("close temp file: %w", err)
|
||||
}
|
||||
|
||||
// Verify digest.
|
||||
actualSHA256 := hex.EncodeToString(hasher.Sum(nil))
|
||||
if actualSHA256 != expectedSHA256 {
|
||||
return xerrors.Errorf(
|
||||
"SHA-256 mismatch: expected %s, got %s",
|
||||
expectedSHA256, actualSHA256,
|
||||
p.logger.Warn(ctx, "portabledesktop found in script bin directory but not executable",
|
||||
slog.F("path", scriptBinPath),
|
||||
slog.F("mode", info.Mode().String()),
|
||||
)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tmpPath, 0o700); err != nil {
|
||||
return xerrors.Errorf("chmod: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, destPath); err != nil {
|
||||
return xerrors.Errorf("rename to final path: %w", err)
|
||||
}
|
||||
|
||||
success = true
|
||||
return nil
|
||||
return xerrors.New("portabledesktop binary not found in PATH or script bin directory")
|
||||
}
|
||||
|
||||
@@ -2,11 +2,6 @@ package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -77,7 +72,6 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
// The "up" script prints the JSON line then sleeps until
|
||||
// the context is canceled (simulating a long-running process).
|
||||
@@ -88,13 +82,13 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
cfg, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -111,7 +105,6 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
@@ -120,13 +113,13 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
cfg1, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -154,7 +147,6 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
@@ -163,13 +155,13 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
result, err := pd.Screenshot(ctx, ScreenshotOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -180,7 +172,6 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
@@ -189,13 +180,13 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
_, err := pd.Screenshot(ctx, ScreenshotOptions{
|
||||
TargetWidth: 800,
|
||||
TargetHeight: 600,
|
||||
@@ -287,13 +278,13 @@ func TestPortableDesktop_MouseMethods(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
err := tt.invoke(context.Background(), pd)
|
||||
err := tt.invoke(t.Context(), pd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmds := rec.allCommands()
|
||||
@@ -372,13 +363,13 @@ func TestPortableDesktop_KeyboardMethods(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
err := tt.invoke(context.Background(), pd)
|
||||
err := tt.invoke(t.Context(), pd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmds := rec.allCommands()
|
||||
@@ -404,13 +395,13 @@ func TestPortableDesktop_CursorPosition(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
x, y, err := pd.CursorPosition(context.Background())
|
||||
x, y, err := pd.CursorPosition(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, x)
|
||||
assert.Equal(t, 200, y)
|
||||
@@ -428,13 +419,13 @@ func TestPortableDesktop_Close(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := t.Context()
|
||||
_, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -457,81 +448,6 @@ func TestPortableDesktop_Close(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "desktop is closed")
|
||||
}
|
||||
|
||||
// --- downloadBinary tests ---
|
||||
|
||||
func TestDownloadBinary_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
binaryContent := []byte("#!/bin/sh\necho portable\n")
|
||||
hash := sha256.Sum256(binaryContent)
|
||||
expectedSHA := hex.EncodeToString(hash[:])
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(binaryContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
err := downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the file exists and has correct content.
|
||||
got, err := os.ReadFile(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binaryContent, got)
|
||||
|
||||
// Verify executable permissions.
|
||||
info, err := os.Stat(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, info.Mode()&0o700, "binary should be executable")
|
||||
}
|
||||
|
||||
func TestDownloadBinary_ChecksumMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("real binary content"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
wrongSHA := "0000000000000000000000000000000000000000000000000000000000000000"
|
||||
err := downloadBinary(context.Background(), srv.Client(), srv.URL, wrongSHA, destPath)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "SHA-256 mismatch")
|
||||
|
||||
// The destination file should not exist (temp file cleaned up).
|
||||
_, statErr := os.Stat(destPath)
|
||||
assert.True(t, os.IsNotExist(statErr), "dest file should not exist after checksum failure")
|
||||
|
||||
// No leftover temp files in the directory.
|
||||
entries, err := os.ReadDir(destDir)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, entries, "no leftover temp files should remain")
|
||||
}
|
||||
|
||||
func TestDownloadBinary_HTTPError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
err := downloadBinary(context.Background(), srv.Client(), srv.URL, "irrelevant", destPath)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "status 404")
|
||||
}
|
||||
|
||||
// --- ensureBinary tests ---
|
||||
|
||||
func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
|
||||
@@ -541,173 +457,89 @@ func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
|
||||
// immediately without doing any work.
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "/already/set",
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "/already/set",
|
||||
}
|
||||
|
||||
err := pd.ensureBinary(context.Background())
|
||||
err := pd.ensureBinary(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/already/set", pd.binPath)
|
||||
}
|
||||
|
||||
func TestEnsureBinary_UsesCachedBinary(t *testing.T) {
|
||||
func TestEnsureBinary_UsesScriptBinDir(t *testing.T) {
|
||||
// Cannot use t.Parallel because t.Setenv modifies the process
|
||||
// environment.
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("portabledesktop is only supported on Linux")
|
||||
}
|
||||
|
||||
bin, ok := platformBinaries[runtime.GOARCH]
|
||||
if !ok {
|
||||
t.Skipf("no platformBinary entry for %s", runtime.GOARCH)
|
||||
}
|
||||
|
||||
dataDir := t.TempDir()
|
||||
cacheDir := filepath.Join(dataDir, "portabledesktop", bin.SHA256)
|
||||
require.NoError(t, os.MkdirAll(cacheDir, 0o700))
|
||||
|
||||
cachedPath := filepath.Join(cacheDir, "portabledesktop")
|
||||
require.NoError(t, os.WriteFile(cachedPath, []byte("#!/bin/sh\n"), 0o600))
|
||||
scriptBinDir := t.TempDir()
|
||||
binPath := filepath.Join(scriptBinDir, "portabledesktop")
|
||||
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
|
||||
require.NoError(t, os.Chmod(binPath, 0o755))
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
dataDir: dataDir,
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
scriptBinDir: scriptBinDir,
|
||||
}
|
||||
|
||||
// Clear PATH so LookPath won't find a real binary.
|
||||
t.Setenv("PATH", "")
|
||||
|
||||
err := pd.ensureBinary(context.Background())
|
||||
err := pd.ensureBinary(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, cachedPath, pd.binPath)
|
||||
assert.Equal(t, binPath, pd.binPath)
|
||||
}
|
||||
|
||||
func TestEnsureBinary_Downloads(t *testing.T) {
|
||||
func TestEnsureBinary_ScriptBinDirNotExecutable(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Windows does not support Unix permission bits")
|
||||
}
|
||||
// Cannot use t.Parallel because t.Setenv modifies the process
|
||||
// environment and we override the package-level platformBinaries.
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("portabledesktop is only supported on Linux")
|
||||
}
|
||||
// environment.
|
||||
|
||||
binaryContent := []byte("#!/bin/sh\necho downloaded\n")
|
||||
hash := sha256.Sum256(binaryContent)
|
||||
expectedSHA := hex.EncodeToString(hash[:])
|
||||
scriptBinDir := t.TempDir()
|
||||
binPath := filepath.Join(scriptBinDir, "portabledesktop")
|
||||
// Write without execute permission.
|
||||
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
|
||||
_ = binPath
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(binaryContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Save and restore platformBinaries for this test.
|
||||
origBinaries := platformBinaries
|
||||
platformBinaries = map[string]struct {
|
||||
URL string
|
||||
SHA256 string
|
||||
}{
|
||||
runtime.GOARCH: {
|
||||
URL: srv.URL + "/portabledesktop",
|
||||
SHA256: expectedSHA,
|
||||
},
|
||||
}
|
||||
t.Cleanup(func() { platformBinaries = origBinaries })
|
||||
|
||||
dataDir := t.TempDir()
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
dataDir: dataDir,
|
||||
httpClient: srv.Client(),
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
scriptBinDir: scriptBinDir,
|
||||
}
|
||||
|
||||
// Ensure PATH doesn't contain a real portabledesktop binary.
|
||||
// Clear PATH so LookPath won't find a real binary.
|
||||
t.Setenv("PATH", "")
|
||||
|
||||
err := pd.ensureBinary(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPath := filepath.Join(dataDir, "portabledesktop", expectedSHA, "portabledesktop")
|
||||
assert.Equal(t, expectedPath, pd.binPath)
|
||||
|
||||
// Verify the downloaded file has correct content.
|
||||
got, err := os.ReadFile(expectedPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binaryContent, got)
|
||||
err := pd.ensureBinary(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
}
|
||||
|
||||
func TestEnsureBinary_RetriesOnFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
func TestEnsureBinary_NotFound(t *testing.T) {
|
||||
// Cannot use t.Parallel because t.Setenv modifies the process
|
||||
// environment.
|
||||
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("portabledesktop is only supported on Linux")
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
scriptBinDir: t.TempDir(), // empty directory
|
||||
}
|
||||
|
||||
binaryContent := []byte("#!/bin/sh\necho retried\n")
|
||||
hash := sha256.Sum256(binaryContent)
|
||||
expectedSHA := hex.EncodeToString(hash[:])
|
||||
// Clear PATH so LookPath won't find a real binary.
|
||||
t.Setenv("PATH", "")
|
||||
|
||||
var mu sync.Mutex
|
||||
attempt := 0
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
mu.Lock()
|
||||
current := attempt
|
||||
attempt++
|
||||
mu.Unlock()
|
||||
|
||||
// Fail the first 2 attempts, succeed on the third.
|
||||
if current < 2 {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(binaryContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Test downloadBinary directly to avoid time.Sleep in
|
||||
// ensureBinary's retry loop. We call it 3 times to simulate
|
||||
// what ensureBinary would do.
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
var lastErr error
|
||||
for i := range 3 {
|
||||
lastErr = downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
|
||||
if lastErr == nil {
|
||||
break
|
||||
}
|
||||
if i < 2 {
|
||||
// In the real code, ensureBinary sleeps here.
|
||||
// We skip the sleep in tests.
|
||||
continue
|
||||
}
|
||||
}
|
||||
require.NoError(t, lastErr, "download should succeed on the third attempt")
|
||||
|
||||
got, err := os.ReadFile(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binaryContent, got)
|
||||
|
||||
mu.Lock()
|
||||
assert.Equal(t, 3, attempt, "server should have been hit 3 times")
|
||||
mu.Unlock()
|
||||
err := pd.ensureBinary(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
}
|
||||
|
||||
// Ensure that portableDesktop satisfies the Desktop interface at
|
||||
// compile time. This uses the unexported type so it lives in the
|
||||
// internal test package.
|
||||
var _ Desktop = (*portableDesktop)(nil)
|
||||
|
||||
// Silence the linter about unused imports — agentexec.DefaultExecer
|
||||
// is used in TestEnsureBinary_UsesCachedBinPath and others, and
|
||||
// fmt.Sscanf is used indirectly via the implementation.
|
||||
var (
|
||||
_ = agentexec.DefaultExecer
|
||||
_ = fmt.Sprintf
|
||||
)
|
||||
|
||||
+89
-38
@@ -447,13 +447,10 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
|
||||
content := string(data)
|
||||
|
||||
for _, edit := range edits {
|
||||
var ok bool
|
||||
content, ok = fuzzyReplace(content, edit.Search, edit.Replace)
|
||||
if !ok {
|
||||
api.logger.Warn(ctx, "edit search string not found, skipping",
|
||||
slog.F("path", path),
|
||||
slog.F("search_preview", truncate(edit.Search, 64)),
|
||||
)
|
||||
var err error
|
||||
content, err = fuzzyReplace(content, edit)
|
||||
if err != nil {
|
||||
return http.StatusBadRequest, xerrors.Errorf("edit %s: %w", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -480,51 +477,92 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// fuzzyReplace attempts to find `search` inside `content` and replace its first
|
||||
// occurrence with `replace`. It uses a cascading match strategy inspired by
|
||||
// fuzzyReplace attempts to find `search` inside `content` and replace it
|
||||
// with `replace`. It uses a cascading match strategy inspired by
|
||||
// openai/codex's apply_patch:
|
||||
//
|
||||
// 1. Exact substring match (byte-for-byte).
|
||||
// 2. Line-by-line match ignoring trailing whitespace on each line.
|
||||
// 3. Line-by-line match ignoring all leading/trailing whitespace (indentation-tolerant).
|
||||
// 3. Line-by-line match ignoring all leading/trailing whitespace
|
||||
// (indentation-tolerant).
|
||||
//
|
||||
// When a fuzzy match is found (passes 2 or 3), the replacement is still applied
|
||||
// at the byte offsets of the original content so that surrounding text (including
|
||||
// indentation of untouched lines) is preserved.
|
||||
// When edit.ReplaceAll is false (the default), the search string must
|
||||
// match exactly one location. If multiple matches are found, an error
|
||||
// is returned asking the caller to include more context or set
|
||||
// replace_all.
|
||||
//
|
||||
// Returns the (possibly modified) content and a bool indicating whether a match
|
||||
// was found.
|
||||
func fuzzyReplace(content, search, replace string) (string, bool) {
|
||||
// Pass 1 – exact substring (replace all occurrences).
|
||||
// When a fuzzy match is found (passes 2 or 3), the replacement is still
|
||||
// applied at the byte offsets of the original content so that surrounding
|
||||
// text (including indentation of untouched lines) is preserved.
|
||||
func fuzzyReplace(content string, edit workspacesdk.FileEdit) (string, error) {
|
||||
search := edit.Search
|
||||
replace := edit.Replace
|
||||
|
||||
// Pass 1 – exact substring match.
|
||||
if strings.Contains(content, search) {
|
||||
return strings.ReplaceAll(content, search, replace), true
|
||||
if edit.ReplaceAll {
|
||||
return strings.ReplaceAll(content, search, replace), nil
|
||||
}
|
||||
count := strings.Count(content, search)
|
||||
if count > 1 {
|
||||
return "", xerrors.Errorf("search string matches %d occurrences "+
|
||||
"(expected exactly 1). Include more surrounding "+
|
||||
"context to make the match unique, or set "+
|
||||
"replace_all to true", count)
|
||||
}
|
||||
// Exactly one match.
|
||||
return strings.Replace(content, search, replace, 1), nil
|
||||
}
|
||||
|
||||
// For line-level fuzzy matching we split both content and search into lines.
|
||||
// For line-level fuzzy matching we split both content and search
|
||||
// into lines.
|
||||
contentLines := strings.SplitAfter(content, "\n")
|
||||
searchLines := strings.SplitAfter(search, "\n")
|
||||
|
||||
// A trailing newline in the search produces an empty final element from
|
||||
// SplitAfter. Drop it so it doesn't interfere with line matching.
|
||||
// A trailing newline in the search produces an empty final element
|
||||
// from SplitAfter. Drop it so it doesn't interfere with line
|
||||
// matching.
|
||||
if len(searchLines) > 0 && searchLines[len(searchLines)-1] == "" {
|
||||
searchLines = searchLines[:len(searchLines)-1]
|
||||
}
|
||||
|
||||
// Pass 2 – trim trailing whitespace on each line.
|
||||
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
|
||||
trimRight := func(a, b string) bool {
|
||||
return strings.TrimRight(a, " \t\r\n") == strings.TrimRight(b, " \t\r\n")
|
||||
}); ok {
|
||||
return spliceLines(contentLines, start, end, replace), true
|
||||
}
|
||||
|
||||
// Pass 3 – trim all leading and trailing whitespace (indentation-tolerant).
|
||||
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
|
||||
trimAll := func(a, b string) bool {
|
||||
return strings.TrimSpace(a) == strings.TrimSpace(b)
|
||||
}); ok {
|
||||
return spliceLines(contentLines, start, end, replace), true
|
||||
}
|
||||
|
||||
return content, false
|
||||
// Pass 2 – trim trailing whitespace on each line.
|
||||
if start, end, ok := seekLines(contentLines, searchLines, trimRight); ok {
|
||||
if !edit.ReplaceAll {
|
||||
if count := countLineMatches(contentLines, searchLines, trimRight); count > 1 {
|
||||
return "", xerrors.Errorf("search string matches %d occurrences "+
|
||||
"(expected exactly 1). Include more surrounding "+
|
||||
"context to make the match unique, or set "+
|
||||
"replace_all to true", count)
|
||||
}
|
||||
}
|
||||
return spliceLines(contentLines, start, end, replace), nil
|
||||
}
|
||||
|
||||
// Pass 3 – trim all leading and trailing whitespace
|
||||
// (indentation-tolerant).
|
||||
if start, end, ok := seekLines(contentLines, searchLines, trimAll); ok {
|
||||
if !edit.ReplaceAll {
|
||||
if count := countLineMatches(contentLines, searchLines, trimAll); count > 1 {
|
||||
return "", xerrors.Errorf("search string matches %d occurrences "+
|
||||
"(expected exactly 1). Include more surrounding "+
|
||||
"context to make the match unique, or set "+
|
||||
"replace_all to true", count)
|
||||
}
|
||||
}
|
||||
return spliceLines(contentLines, start, end, replace), nil
|
||||
}
|
||||
|
||||
return "", xerrors.New("search string not found in file. Verify the search " +
|
||||
"string matches the file content exactly, including whitespace " +
|
||||
"and indentation")
|
||||
}
|
||||
|
||||
// seekLines scans contentLines looking for a contiguous subsequence that matches
|
||||
@@ -549,6 +587,26 @@ outer:
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
// countLineMatches counts how many non-overlapping contiguous
|
||||
// subsequences of contentLines match searchLines according to eq.
|
||||
func countLineMatches(contentLines, searchLines []string, eq func(a, b string) bool) int {
|
||||
count := 0
|
||||
if len(searchLines) == 0 || len(searchLines) > len(contentLines) {
|
||||
return count
|
||||
}
|
||||
outer:
|
||||
for i := 0; i <= len(contentLines)-len(searchLines); i++ {
|
||||
for j, sLine := range searchLines {
|
||||
if !eq(contentLines[i+j], sLine) {
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
count++
|
||||
i += len(searchLines) - 1 // skip past this match
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// spliceLines replaces contentLines[start:end] with replacement text, returning
|
||||
// the full content as a single string.
|
||||
func spliceLines(contentLines []string, start, end int, replacement string) string {
|
||||
@@ -562,10 +620,3 @@ func spliceLines(contentLines []string, start, end int, replacement string) stri
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n] + "..."
|
||||
}
|
||||
|
||||
@@ -576,7 +576,9 @@ func TestEditFiles(t *testing.T) {
|
||||
expected: map[string]string{filepath.Join(tmpdir, "edit1"): "bar bar"},
|
||||
},
|
||||
{
|
||||
name: "EditEdit", // Edits affect previous edits.
|
||||
// When the second edit creates ambiguity (two "bar"
|
||||
// occurrences), it should fail.
|
||||
name: "EditEditAmbiguous",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "edit-edit"): "foo bar"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
@@ -593,7 +595,33 @@ func TestEditFiles(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "edit-edit"): "qux qux"},
|
||||
errCode: http.StatusBadRequest,
|
||||
errors: []string{"matches 2 occurrences"},
|
||||
// File should not be modified on error.
|
||||
expected: map[string]string{filepath.Join(tmpdir, "edit-edit"): "foo bar"},
|
||||
},
|
||||
{
|
||||
// With replace_all the cascading edit replaces
|
||||
// both occurrences.
|
||||
name: "EditEditReplaceAll",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "edit-edit-ra"): "foo bar"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "edit-edit-ra"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "foo",
|
||||
Replace: "bar",
|
||||
},
|
||||
{
|
||||
Search: "bar",
|
||||
Replace: "qux",
|
||||
ReplaceAll: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "edit-edit-ra"): "qux qux"},
|
||||
},
|
||||
{
|
||||
name: "Multiline",
|
||||
@@ -720,7 +748,7 @@ func TestEditFiles(t *testing.T) {
|
||||
expected: map[string]string{filepath.Join(tmpdir, "exact-preferred"): "goodbye world"},
|
||||
},
|
||||
{
|
||||
name: "NoMatchStillSucceeds",
|
||||
name: "NoMatchErrors",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
@@ -733,9 +761,46 @@ func TestEditFiles(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
errCode: http.StatusBadRequest,
|
||||
errors: []string{"search string not found in file"},
|
||||
// File should remain unchanged.
|
||||
expected: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
|
||||
},
|
||||
{
|
||||
name: "AmbiguousExactMatch",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "ambig-exact"): "foo bar foo baz foo"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "ambig-exact"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "foo",
|
||||
Replace: "qux",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
errCode: http.StatusBadRequest,
|
||||
errors: []string{"matches 3 occurrences"},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "ambig-exact"): "foo bar foo baz foo"},
|
||||
},
|
||||
{
|
||||
name: "ReplaceAllExact",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "ra-exact"): "foo bar foo baz foo"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "ra-exact"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "foo",
|
||||
Replace: "qux",
|
||||
ReplaceAll: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "ra-exact"): "qux bar qux baz qux"},
|
||||
},
|
||||
{
|
||||
name: "MixedWhitespaceMultiline",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "mixed-ws"): "func main() {\n\tresult := compute()\n\tfmt.Println(result)\n}"},
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
//go:build !windows
|
||||
|
||||
package agentproc
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// procSysProcAttr returns the SysProcAttr to use when spawning
|
||||
// processes. On Unix, Setpgid creates a new process group so
|
||||
// that signals can be delivered to the entire group (the shell
|
||||
// and all its children).
|
||||
func procSysProcAttr() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
Setpgid: true,
|
||||
}
|
||||
}
|
||||
|
||||
// signalProcess sends a signal to the process group rooted at p.
|
||||
// Using the negative PID sends the signal to every process in the
|
||||
// group, ensuring child processes (e.g. from shell pipelines) are
|
||||
// also signaled.
|
||||
func signalProcess(p *os.Process, sig syscall.Signal) error {
|
||||
return syscall.Kill(-p.Pid, sig)
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package agentproc
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// procSysProcAttr returns the SysProcAttr to use when spawning
|
||||
// processes. On Windows, process groups are not supported in the
|
||||
// same way as Unix, so this returns an empty struct.
|
||||
func procSysProcAttr() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{}
|
||||
}
|
||||
|
||||
// signalProcess sends a signal directly to the process. Windows
|
||||
// does not support process group signaling, so we fall back to
|
||||
// sending the signal to the process itself.
|
||||
func signalProcess(p *os.Process, _ syscall.Signal) error {
|
||||
return p.Kill()
|
||||
}
|
||||
@@ -113,6 +113,7 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
|
||||
cmd.Dir = req.WorkDir
|
||||
}
|
||||
cmd.Stdin = nil
|
||||
cmd.SysProcAttr = procSysProcAttr()
|
||||
|
||||
// WaitDelay ensures cmd.Wait returns promptly after
|
||||
// the process is killed, even if child processes are
|
||||
@@ -272,13 +273,15 @@ func (m *manager) signal(id string, sig string) error {
|
||||
|
||||
switch sig {
|
||||
case "kill":
|
||||
if err := proc.cmd.Process.Kill(); err != nil {
|
||||
// Use process group kill to ensure child processes
|
||||
// (e.g. from shell pipelines) are also killed.
|
||||
if err := signalProcess(proc.cmd.Process, syscall.SIGKILL); err != nil {
|
||||
return xerrors.Errorf("kill process: %w", err)
|
||||
}
|
||||
case "terminate":
|
||||
//nolint:revive // syscall.SIGTERM is portable enough
|
||||
// for our supported platforms.
|
||||
if err := proc.cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
||||
// Use process group signal to ensure child processes
|
||||
// are also terminated.
|
||||
if err := signalProcess(proc.cmd.Process, syscall.SIGTERM); err != nil {
|
||||
return xerrors.Errorf("terminate process: %w", err)
|
||||
}
|
||||
default:
|
||||
|
||||
@@ -46,6 +46,7 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
|
||||
autoUpdates string
|
||||
copyParametersFrom string
|
||||
useParameterDefaults bool
|
||||
noWait bool
|
||||
// Organization context is only required if more than 1 template
|
||||
// shares the same name across multiple organizations.
|
||||
orgContext = NewOrganizationContext()
|
||||
@@ -372,6 +373,14 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
|
||||
|
||||
cliutil.WarnMatchedProvisioners(inv.Stderr, workspace.LatestBuild.MatchedProvisioners, workspace.LatestBuild.Job)
|
||||
|
||||
if noWait {
|
||||
_, _ = fmt.Fprintf(inv.Stdout,
|
||||
"\nThe %s workspace has been created and is building in the background.\n",
|
||||
cliui.Keyword(workspace.Name),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, workspace.LatestBuild.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("watch build: %w", err)
|
||||
@@ -445,6 +454,12 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
|
||||
Description: "Automatically accept parameter defaults when no value is provided.",
|
||||
Value: serpent.BoolOf(&useParameterDefaults),
|
||||
},
|
||||
serpent.Option{
|
||||
Flag: "no-wait",
|
||||
Env: "CODER_CREATE_NO_WAIT",
|
||||
Description: "Return immediately after creating the workspace. The build will run in the background.",
|
||||
Value: serpent.BoolOf(&noWait),
|
||||
},
|
||||
cliui.SkipPromptOption(),
|
||||
)
|
||||
cmd.Options = append(cmd.Options, parameterFlags.cliParameters()...)
|
||||
|
||||
@@ -603,6 +603,81 @@ func TestCreate(t *testing.T) {
|
||||
assert.Nil(t, ws.AutostartSchedule, "expected workspace autostart schedule to be nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoWait", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
inv, root := clitest.New(t, "create", "my-workspace",
|
||||
"--template", template.Name,
|
||||
"-y",
|
||||
"--no-wait",
|
||||
)
|
||||
clitest.SetupConfig(t, member, root)
|
||||
doneChan := make(chan struct{})
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
err := inv.Run()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
pty.ExpectMatchContext(ctx, "building in the background")
|
||||
_ = testutil.TryReceive(ctx, t, doneChan)
|
||||
|
||||
// Verify workspace was actually created.
|
||||
ws, err := member.WorkspaceByOwnerAndName(ctx, codersdk.Me, "my-workspace", codersdk.WorkspaceOptions{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ws.TemplateName, template.Name)
|
||||
})
|
||||
|
||||
t.Run("NoWaitWithParameterDefaults", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses([]*proto.RichParameter{
|
||||
{Name: "region", Type: "string", DefaultValue: "us-east-1"},
|
||||
{Name: "instance_type", Type: "string", DefaultValue: "t3.micro"},
|
||||
}))
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
inv, root := clitest.New(t, "create", "my-workspace",
|
||||
"--template", template.Name,
|
||||
"-y",
|
||||
"--use-parameter-defaults",
|
||||
"--no-wait",
|
||||
)
|
||||
clitest.SetupConfig(t, member, root)
|
||||
doneChan := make(chan struct{})
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
err := inv.Run()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
pty.ExpectMatchContext(ctx, "building in the background")
|
||||
_ = testutil.TryReceive(ctx, t, doneChan)
|
||||
|
||||
// Verify workspace was created and parameters were applied.
|
||||
ws, err := member.WorkspaceByOwnerAndName(ctx, codersdk.Me, "my-workspace", codersdk.WorkspaceOptions{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ws.TemplateName, template.Name)
|
||||
|
||||
buildParams, err := member.WorkspaceBuildParameters(ctx, ws.LatestBuild.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, buildParams, codersdk.WorkspaceBuildParameter{Name: "region", Value: "us-east-1"})
|
||||
assert.Contains(t, buildParams, codersdk.WorkspaceBuildParameter{Name: "instance_type", Value: "t3.micro"})
|
||||
})
|
||||
}
|
||||
|
||||
func prepareEchoResponses(parameters []*proto.RichParameter, presets ...*proto.Preset) *echo.Responses {
|
||||
|
||||
@@ -1000,6 +1000,12 @@ func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool
|
||||
Properties: sdkTool.Schema.Properties,
|
||||
Required: sdkTool.Schema.Required,
|
||||
},
|
||||
Annotations: mcp.ToolAnnotation{
|
||||
ReadOnlyHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.ReadOnlyHint),
|
||||
DestructiveHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.DestructiveHint),
|
||||
IdempotentHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.IdempotentHint),
|
||||
OpenWorldHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.OpenWorldHint),
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
+16
-1
@@ -81,7 +81,13 @@ func TestExpMcpServer(t *testing.T) {
|
||||
var toolsResponse struct {
|
||||
Result struct {
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Name string `json:"name"`
|
||||
Annotations struct {
|
||||
ReadOnlyHint *bool `json:"readOnlyHint"`
|
||||
DestructiveHint *bool `json:"destructiveHint"`
|
||||
IdempotentHint *bool `json:"idempotentHint"`
|
||||
OpenWorldHint *bool `json:"openWorldHint"`
|
||||
} `json:"annotations"`
|
||||
} `json:"tools"`
|
||||
} `json:"result"`
|
||||
}
|
||||
@@ -94,6 +100,15 @@ func TestExpMcpServer(t *testing.T) {
|
||||
}
|
||||
slices.Sort(foundTools)
|
||||
require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools)
|
||||
annotations := toolsResponse.Result.Tools[0].Annotations
|
||||
require.NotNil(t, annotations.ReadOnlyHint)
|
||||
require.NotNil(t, annotations.DestructiveHint)
|
||||
require.NotNil(t, annotations.IdempotentHint)
|
||||
require.NotNil(t, annotations.OpenWorldHint)
|
||||
assert.True(t, *annotations.ReadOnlyHint)
|
||||
assert.False(t, *annotations.DestructiveHint)
|
||||
assert.True(t, *annotations.IdempotentHint)
|
||||
assert.False(t, *annotations.OpenWorldHint)
|
||||
|
||||
// Call the tool and ensure it works.
|
||||
toolPayload := `{"jsonrpc":"2.0","id":3,"method":"tools/call", "params": {"name": "coder_get_authenticated_user", "arguments": {}}}`
|
||||
|
||||
@@ -214,7 +214,7 @@ func (r *RootCmd) createOrganizationRole(orgContext *OrganizationContext) *serpe
|
||||
} else {
|
||||
updated, err = client.CreateOrganizationRole(ctx, customRole)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("patch role: %w", err)
|
||||
return xerrors.Errorf("create role: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+26
-17
@@ -113,6 +113,20 @@ func (r *RootCmd) supportBundle() *serpent.Command {
|
||||
)
|
||||
cliLog.Debug(inv.Context(), "invocation", slog.F("args", strings.Join(os.Args, " ")))
|
||||
|
||||
// Bypass rate limiting for support bundle collection since it makes many API calls.
|
||||
// Note: this can only be done by the owner user.
|
||||
if ok, err := support.CanGenerateFull(inv.Context(), client); err == nil && ok {
|
||||
cliLog.Debug(inv.Context(), "running as owner")
|
||||
client.HTTPClient.Transport = &codersdk.HeaderTransport{
|
||||
Transport: client.HTTPClient.Transport,
|
||||
Header: http.Header{codersdk.BypassRatelimitHeader: {"true"}},
|
||||
}
|
||||
} else if !ok {
|
||||
cliLog.Warn(inv.Context(), "not running as owner, not all information available")
|
||||
} else {
|
||||
cliLog.Error(inv.Context(), "failed to look up current user", slog.Error(err))
|
||||
}
|
||||
|
||||
// Check if we're running inside a workspace
|
||||
if val, found := os.LookupEnv("CODER"); found && val == "true" {
|
||||
cliui.Warn(inv.Stderr, "Running inside Coder workspace; this can affect results!")
|
||||
@@ -200,12 +214,6 @@ func (r *RootCmd) supportBundle() *serpent.Command {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "pprof data collection will take approximately 30 seconds...")
|
||||
}
|
||||
|
||||
// Bypass rate limiting for support bundle collection since it makes many API calls.
|
||||
client.HTTPClient.Transport = &codersdk.HeaderTransport{
|
||||
Transport: client.HTTPClient.Transport,
|
||||
Header: http.Header{codersdk.BypassRatelimitHeader: {"true"}},
|
||||
}
|
||||
|
||||
deps := support.Deps{
|
||||
Client: client,
|
||||
// Support adds a sink so we don't need to supply one ourselves.
|
||||
@@ -354,19 +362,20 @@ func summarizeBundle(inv *serpent.Invocation, bun *support.Bundle) {
|
||||
return
|
||||
}
|
||||
|
||||
if bun.Deployment.Config == nil {
|
||||
cliui.Error(inv.Stdout, "No deployment configuration available!")
|
||||
return
|
||||
var docsURL string
|
||||
if bun.Deployment.Config != nil {
|
||||
docsURL = bun.Deployment.Config.Values.DocsURL.String()
|
||||
} else {
|
||||
cliui.Warn(inv.Stdout, "No deployment configuration available. This may require the Owner role.")
|
||||
}
|
||||
|
||||
docsURL := bun.Deployment.Config.Values.DocsURL.String()
|
||||
if bun.Deployment.HealthReport == nil {
|
||||
cliui.Error(inv.Stdout, "No deployment health report available!")
|
||||
return
|
||||
}
|
||||
deployHealthSummary := bun.Deployment.HealthReport.Summarize(docsURL)
|
||||
if len(deployHealthSummary) > 0 {
|
||||
cliui.Warn(inv.Stdout, "Deployment health issues detected:", deployHealthSummary...)
|
||||
if bun.Deployment.HealthReport != nil {
|
||||
deployHealthSummary := bun.Deployment.HealthReport.Summarize(docsURL)
|
||||
if len(deployHealthSummary) > 0 {
|
||||
cliui.Warn(inv.Stdout, "Deployment health issues detected:", deployHealthSummary...)
|
||||
}
|
||||
} else {
|
||||
cliui.Warn(inv.Stdout, "No deployment health report available.")
|
||||
}
|
||||
|
||||
if bun.Network.Netcheck == nil {
|
||||
|
||||
+30
-3
@@ -132,12 +132,35 @@ func TestSupportBundle(t *testing.T) {
|
||||
assertBundleContents(t, path, true, false, []string{secretValue})
|
||||
})
|
||||
|
||||
t.Run("NoPrivilege", func(t *testing.T) {
|
||||
t.Run("MemberCanGenerateBundle", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
inv, root := clitest.New(t, "support", "bundle", memberWorkspace.Workspace.Name, "--yes")
|
||||
|
||||
d := t.TempDir()
|
||||
path := filepath.Join(d, "bundle.zip")
|
||||
inv, root := clitest.New(t, "support", "bundle", memberWorkspace.Workspace.Name, "--output-file", path, "--yes")
|
||||
clitest.SetupConfig(t, memberClient, root)
|
||||
err := inv.Run()
|
||||
require.ErrorContains(t, err, "failed authorization check")
|
||||
require.NoError(t, err)
|
||||
r, err := zip.OpenReader(path)
|
||||
require.NoError(t, err, "open zip file")
|
||||
defer r.Close()
|
||||
fileNames := make(map[string]struct{}, len(r.File))
|
||||
for _, f := range r.File {
|
||||
fileNames[f.Name] = struct{}{}
|
||||
}
|
||||
// These should always be present in the zip structure, even if
|
||||
// the content is null/empty for non-admin users.
|
||||
for _, name := range []string{
|
||||
"deployment/buildinfo.json",
|
||||
"deployment/config.json",
|
||||
"workspace/workspace.json",
|
||||
"logs.txt",
|
||||
"cli_logs.txt",
|
||||
"network/netcheck.json",
|
||||
"network/interfaces.json",
|
||||
} {
|
||||
require.Contains(t, fileNames, name)
|
||||
}
|
||||
})
|
||||
|
||||
// This ensures that the CLI does not panic when trying to generate a support bundle
|
||||
@@ -159,6 +182,10 @@ func TestSupportBundle(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Logf("received request: %s %s", r.Method, r.URL)
|
||||
switch r.URL.Path {
|
||||
case "/api/v2/users/me":
|
||||
resp := codersdk.User{}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
assert.NoError(t, json.NewEncoder(w).Encode(resp))
|
||||
case "/api/v2/authcheck":
|
||||
// Fake auth check
|
||||
resp := codersdk.AuthorizationResponse{
|
||||
|
||||
+4
@@ -20,6 +20,10 @@ OPTIONS:
|
||||
--copy-parameters-from string, $CODER_WORKSPACE_COPY_PARAMETERS_FROM
|
||||
Specify the source workspace name to copy parameters from.
|
||||
|
||||
--no-wait bool, $CODER_CREATE_NO_WAIT
|
||||
Return immediately after creating the workspace. The build will run in
|
||||
the background.
|
||||
|
||||
--parameter string-array, $CODER_RICH_PARAMETER
|
||||
Rich parameter value in the format "name=value".
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"last_seen_at": "====[timestamp]=====",
|
||||
"name": "test-daemon",
|
||||
"version": "v0.0.0-devel",
|
||||
"api_version": "1.15",
|
||||
"api_version": "1.16",
|
||||
"provisioners": [
|
||||
"echo"
|
||||
],
|
||||
|
||||
@@ -24,6 +24,10 @@ OPTIONS:
|
||||
-p, --password string
|
||||
Specifies a password for the new user.
|
||||
|
||||
--service-account bool
|
||||
Create a user account intended to be used by a service or as an
|
||||
intermediary rather than by a human.
|
||||
|
||||
-u, --username string
|
||||
Specifies a username for the new user.
|
||||
|
||||
|
||||
+5
@@ -752,6 +752,11 @@ workspace_prebuilds:
|
||||
# limit; disabled when set to zero.
|
||||
# (default: 3, type: int)
|
||||
failure_hard_limit: 3
|
||||
# Configure the background chat processing daemon.
|
||||
chat:
|
||||
# How many pending chats a worker should acquire per polling cycle.
|
||||
# (default: 10, type: int)
|
||||
acquireBatchSize: 10
|
||||
aibridge:
|
||||
# Whether to start an in-memory aibridged instance.
|
||||
# (default: false, type: bool)
|
||||
|
||||
+37
-12
@@ -17,13 +17,14 @@ import (
|
||||
|
||||
func (r *RootCmd) userCreate() *serpent.Command {
|
||||
var (
|
||||
email string
|
||||
username string
|
||||
name string
|
||||
password string
|
||||
disableLogin bool
|
||||
loginType string
|
||||
orgContext = NewOrganizationContext()
|
||||
email string
|
||||
username string
|
||||
name string
|
||||
password string
|
||||
disableLogin bool
|
||||
loginType string
|
||||
serviceAccount bool
|
||||
orgContext = NewOrganizationContext()
|
||||
)
|
||||
cmd := &serpent.Command{
|
||||
Use: "create",
|
||||
@@ -32,6 +33,23 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
serpent.RequireNArgs(0),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
if serviceAccount {
|
||||
switch {
|
||||
case loginType != "":
|
||||
return xerrors.New("You cannot use --login-type with --service-account")
|
||||
case password != "":
|
||||
return xerrors.New("You cannot use --password with --service-account")
|
||||
case email != "":
|
||||
return xerrors.New("You cannot use --email with --service-account")
|
||||
case disableLogin:
|
||||
return xerrors.New("You cannot use --disable-login with --service-account")
|
||||
}
|
||||
}
|
||||
|
||||
if disableLogin && loginType != "" {
|
||||
return xerrors.New("You cannot specify both --disable-login and --login-type")
|
||||
}
|
||||
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -59,7 +77,7 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if email == "" {
|
||||
if email == "" && !serviceAccount {
|
||||
email, err = cliui.Prompt(inv, cliui.PromptOptions{
|
||||
Text: "Email:",
|
||||
Validate: func(s string) error {
|
||||
@@ -87,10 +105,7 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
}
|
||||
}
|
||||
userLoginType := codersdk.LoginTypePassword
|
||||
if disableLogin && loginType != "" {
|
||||
return xerrors.New("You cannot specify both --disable-login and --login-type")
|
||||
}
|
||||
if disableLogin {
|
||||
if disableLogin || serviceAccount {
|
||||
userLoginType = codersdk.LoginTypeNone
|
||||
} else if loginType != "" {
|
||||
userLoginType = codersdk.LoginType(loginType)
|
||||
@@ -111,6 +126,7 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
Password: password,
|
||||
OrganizationIDs: []uuid.UUID{organization.ID},
|
||||
UserLoginType: userLoginType,
|
||||
ServiceAccount: serviceAccount,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -127,6 +143,10 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
case codersdk.LoginTypeOIDC:
|
||||
authenticationMethod = `Login is authenticated through the configured OIDC provider.`
|
||||
}
|
||||
if serviceAccount {
|
||||
email = "n/a"
|
||||
authenticationMethod = "Service accounts must authenticate with a token and cannot log in."
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, `A new user has been created!
|
||||
Share the instructions below to get them started.
|
||||
@@ -194,6 +214,11 @@ Create a workspace `+pretty.Sprint(cliui.DefaultStyles.Code, "coder create")+`!
|
||||
)),
|
||||
Value: serpent.StringOf(&loginType),
|
||||
},
|
||||
{
|
||||
Flag: "service-account",
|
||||
Description: "Create a user account intended to be used by a service or as an intermediary rather than by a human.",
|
||||
Value: serpent.BoolOf(&serviceAccount),
|
||||
},
|
||||
}
|
||||
|
||||
orgContext.AttachOptions(cmd)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
@@ -124,4 +125,56 @@ func TestUserCreate(t *testing.T) {
|
||||
assert.Equal(t, args[5], created.Username)
|
||||
assert.Empty(t, created.Name)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
name: "ServiceAccount",
|
||||
args: []string{"--service-account", "-u", "dean"},
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountLoginType",
|
||||
args: []string{"--service-account", "-u", "dean", "--login-type", "none"},
|
||||
err: "You cannot use --login-type with --service-account",
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountDisableLogin",
|
||||
args: []string{"--service-account", "-u", "dean", "--disable-login"},
|
||||
err: "You cannot use --disable-login with --service-account",
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountEmail",
|
||||
args: []string{"--service-account", "-u", "dean", "--email", "dean@coder.com"},
|
||||
err: "You cannot use --email with --service-account",
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountPassword",
|
||||
args: []string{"--service-account", "-u", "dean", "--password", "1n5ecureP4ssw0rd!"},
|
||||
err: "You cannot use --password with --service-account",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
inv, root := clitest.New(t, append([]string{"users", "create"}, tt.args...)...)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
err := inv.Run()
|
||||
if tt.err == "" {
|
||||
require.NoError(t, err)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created, err := client.User(ctx, "dean")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, codersdk.LoginTypeNone, created.LoginType)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, tt.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
// Package aiseats is the AGPL version the package.
|
||||
// The actual implementation is in `enterprise/aiseats`.
|
||||
package aiseats
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
type Reason struct {
|
||||
EventType database.AiSeatUsageReason
|
||||
Description string
|
||||
}
|
||||
|
||||
// ReasonAIBridge constructs a reason for usage originating from AI Bridge.
|
||||
func ReasonAIBridge(description string) Reason {
|
||||
return Reason{EventType: database.AiSeatUsageReasonAibridge, Description: description}
|
||||
}
|
||||
|
||||
// ReasonTask constructs a reason for usage originating from tasks.
|
||||
func ReasonTask(description string) Reason {
|
||||
return Reason{EventType: database.AiSeatUsageReasonTask, Description: description}
|
||||
}
|
||||
|
||||
// SeatTracker records AI seat consumption state.
|
||||
type SeatTracker interface {
|
||||
// RecordUsage does not return an error to prevent blocking the user from using
|
||||
// AI features. This method is used to record usage, not enforce it.
|
||||
RecordUsage(ctx context.Context, userID uuid.UUID, reason Reason)
|
||||
}
|
||||
|
||||
// Noop is an AGPL seat tracker that does nothing.
|
||||
type Noop struct{}
|
||||
|
||||
func (Noop) RecordUsage(context.Context, uuid.UUID, Reason) {}
|
||||
Generated
+276
-27
@@ -481,46 +481,47 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/archive": {
|
||||
"post": {
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Archive a chat",
|
||||
"operationId": "archive-chat",
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/desktop": {
|
||||
"/chats/insights/pull-requests": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Watch chat desktop",
|
||||
"operationId": "watch-chat-desktop",
|
||||
"summary": "Get PR insights",
|
||||
"operationId": "get-pr-insights",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Chat ID",
|
||||
"name": "chat",
|
||||
"in": "path",
|
||||
"description": "Start date (RFC3339)",
|
||||
"name": "start_date",
|
||||
"in": "query",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "End date (RFC3339)",
|
||||
"name": "end_date",
|
||||
"in": "query",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"101": {
|
||||
"description": "Switching Protocols"
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.PRInsightsResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -4862,7 +4863,7 @@ const docTemplate = `{
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
"$ref": "#/definitions/codersdk.UpdateWorkspaceSharingSettingsRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
@@ -4870,7 +4871,7 @@ const docTemplate = `{
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateWorkspaceSharingSettingsRequest"
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -12757,6 +12758,9 @@ const docTemplate = `{
|
||||
},
|
||||
"bridge": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeConfig"
|
||||
},
|
||||
"chat": {
|
||||
"$ref": "#/definitions/codersdk.ChatConfig"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -13814,6 +13818,14 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"acquire_batch_size": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ConnectionLatency": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -17140,6 +17152,191 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.PRInsightsModelBreakdown": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cost_per_merged_pr_micros": {
|
||||
"type": "integer"
|
||||
},
|
||||
"display_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"merge_rate": {
|
||||
"type": "number"
|
||||
},
|
||||
"merged_prs": {
|
||||
"type": "integer"
|
||||
},
|
||||
"model_config_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"provider": {
|
||||
"type": "string"
|
||||
},
|
||||
"total_additions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_cost_micros": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_deletions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_prs": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.PRInsightsPullRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"additions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"approved": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"author_avatar_url": {
|
||||
"type": "string"
|
||||
},
|
||||
"author_login": {
|
||||
"type": "string"
|
||||
},
|
||||
"base_branch": {
|
||||
"type": "string"
|
||||
},
|
||||
"changed_files": {
|
||||
"type": "integer"
|
||||
},
|
||||
"changes_requested": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"commits": {
|
||||
"type": "integer"
|
||||
},
|
||||
"cost_micros": {
|
||||
"type": "integer"
|
||||
},
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"deletions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"draft": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"model_display_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"pr_number": {
|
||||
"type": "integer"
|
||||
},
|
||||
"pr_title": {
|
||||
"type": "string"
|
||||
},
|
||||
"pr_url": {
|
||||
"type": "string"
|
||||
},
|
||||
"reviewer_count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"state": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.PRInsightsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"by_model": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.PRInsightsModelBreakdown"
|
||||
}
|
||||
},
|
||||
"recent_prs": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.PRInsightsPullRequest"
|
||||
}
|
||||
},
|
||||
"summary": {
|
||||
"$ref": "#/definitions/codersdk.PRInsightsSummary"
|
||||
},
|
||||
"time_series": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.PRInsightsTimeSeriesEntry"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.PRInsightsSummary": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"approval_rate": {
|
||||
"type": "number"
|
||||
},
|
||||
"cost_per_merged_pr_micros": {
|
||||
"type": "integer"
|
||||
},
|
||||
"merge_rate": {
|
||||
"type": "number"
|
||||
},
|
||||
"prev_cost_per_merged_pr_micros": {
|
||||
"type": "integer"
|
||||
},
|
||||
"prev_merge_rate": {
|
||||
"type": "number"
|
||||
},
|
||||
"prev_total_prs_created": {
|
||||
"type": "integer"
|
||||
},
|
||||
"prev_total_prs_merged": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_additions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_cost_micros": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_deletions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_prs_created": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_prs_merged": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.PRInsightsTimeSeriesEntry": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"date": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"prs_closed": {
|
||||
"type": "integer"
|
||||
},
|
||||
"prs_created": {
|
||||
"type": "integer"
|
||||
},
|
||||
"prs_merged": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.PaginatedMembersResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -18353,6 +18550,9 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"is_service_account": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"last_seen_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -18521,7 +18721,8 @@ const docTemplate = `{
|
||||
"idp_sync_settings_role",
|
||||
"workspace_agent",
|
||||
"workspace_app",
|
||||
"task"
|
||||
"task",
|
||||
"ai_seat"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"ResourceTypeTemplate",
|
||||
@@ -18549,7 +18750,8 @@ const docTemplate = `{
|
||||
"ResourceTypeIdpSyncSettingsRole",
|
||||
"ResourceTypeWorkspaceAgent",
|
||||
"ResourceTypeWorkspaceApp",
|
||||
"ResourceTypeTask"
|
||||
"ResourceTypeTask",
|
||||
"ResourceTypeAISeat"
|
||||
]
|
||||
},
|
||||
"codersdk.Response": {
|
||||
@@ -18761,6 +18963,19 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ShareableWorkspaceOwners": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"none",
|
||||
"everyone",
|
||||
"service_accounts"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"ShareableWorkspaceOwnersNone",
|
||||
"ShareableWorkspaceOwnersEveryone",
|
||||
"ShareableWorkspaceOwnersServiceAccounts"
|
||||
]
|
||||
},
|
||||
"codersdk.SharedWorkspaceActor": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -19659,6 +19874,9 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"is_service_account": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"last_seen_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -20369,7 +20587,21 @@ const docTemplate = `{
|
||||
"codersdk.UpdateWorkspaceSharingSettingsRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"shareable_workspace_owners": {
|
||||
"description": "ShareableWorkspaceOwners controls whose workspaces can be shared\nwithin the organization.",
|
||||
"enum": [
|
||||
"none",
|
||||
"everyone",
|
||||
"service_accounts"
|
||||
],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.ShareableWorkspaceOwners"
|
||||
}
|
||||
]
|
||||
},
|
||||
"sharing_disabled": {
|
||||
"description": "SharingDisabled is deprecated and left for backward compatibility\npurposes.\nDeprecated: use ` + "`" + `ShareableWorkspaceOwners` + "`" + ` instead",
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
@@ -20491,6 +20723,9 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"is_service_account": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"last_seen_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -22210,7 +22445,21 @@ const docTemplate = `{
|
||||
"codersdk.WorkspaceSharingSettings": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"shareable_workspace_owners": {
|
||||
"description": "ShareableWorkspaceOwners controls whose workspaces can be shared\nwithin the organization.",
|
||||
"enum": [
|
||||
"none",
|
||||
"everyone",
|
||||
"service_accounts"
|
||||
],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.ShareableWorkspaceOwners"
|
||||
}
|
||||
]
|
||||
},
|
||||
"sharing_disabled": {
|
||||
"description": "SharingDisabled is deprecated and left for backward compatibility\npurposes.\nDeprecated: use ` + "`" + `ShareableWorkspaceOwners` + "`" + ` instead",
|
||||
"type": "boolean"
|
||||
},
|
||||
"sharing_globally_disabled": {
|
||||
|
||||
Generated
+262
-25
@@ -410,42 +410,43 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/archive": {
|
||||
"post": {
|
||||
"tags": ["Chats"],
|
||||
"summary": "Archive a chat",
|
||||
"operationId": "archive-chat",
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/desktop": {
|
||||
"/chats/insights/pull-requests": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Chats"],
|
||||
"summary": "Watch chat desktop",
|
||||
"operationId": "watch-chat-desktop",
|
||||
"summary": "Get PR insights",
|
||||
"operationId": "get-pr-insights",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Chat ID",
|
||||
"name": "chat",
|
||||
"in": "path",
|
||||
"description": "Start date (RFC3339)",
|
||||
"name": "start_date",
|
||||
"in": "query",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "End date (RFC3339)",
|
||||
"name": "end_date",
|
||||
"in": "query",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"101": {
|
||||
"description": "Switching Protocols"
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.PRInsightsResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -4301,7 +4302,7 @@
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
"$ref": "#/definitions/codersdk.UpdateWorkspaceSharingSettingsRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
@@ -4309,7 +4310,7 @@
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateWorkspaceSharingSettingsRequest"
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -11359,6 +11360,9 @@
|
||||
},
|
||||
"bridge": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeConfig"
|
||||
},
|
||||
"chat": {
|
||||
"$ref": "#/definitions/codersdk.ChatConfig"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -12381,6 +12385,14 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"acquire_batch_size": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ConnectionLatency": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -15581,6 +15593,191 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.PRInsightsModelBreakdown": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cost_per_merged_pr_micros": {
|
||||
"type": "integer"
|
||||
},
|
||||
"display_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"merge_rate": {
|
||||
"type": "number"
|
||||
},
|
||||
"merged_prs": {
|
||||
"type": "integer"
|
||||
},
|
||||
"model_config_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"provider": {
|
||||
"type": "string"
|
||||
},
|
||||
"total_additions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_cost_micros": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_deletions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_prs": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.PRInsightsPullRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"additions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"approved": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"author_avatar_url": {
|
||||
"type": "string"
|
||||
},
|
||||
"author_login": {
|
||||
"type": "string"
|
||||
},
|
||||
"base_branch": {
|
||||
"type": "string"
|
||||
},
|
||||
"changed_files": {
|
||||
"type": "integer"
|
||||
},
|
||||
"changes_requested": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"commits": {
|
||||
"type": "integer"
|
||||
},
|
||||
"cost_micros": {
|
||||
"type": "integer"
|
||||
},
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"deletions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"draft": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"model_display_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"pr_number": {
|
||||
"type": "integer"
|
||||
},
|
||||
"pr_title": {
|
||||
"type": "string"
|
||||
},
|
||||
"pr_url": {
|
||||
"type": "string"
|
||||
},
|
||||
"reviewer_count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"state": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.PRInsightsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"by_model": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.PRInsightsModelBreakdown"
|
||||
}
|
||||
},
|
||||
"recent_prs": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.PRInsightsPullRequest"
|
||||
}
|
||||
},
|
||||
"summary": {
|
||||
"$ref": "#/definitions/codersdk.PRInsightsSummary"
|
||||
},
|
||||
"time_series": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.PRInsightsTimeSeriesEntry"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.PRInsightsSummary": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"approval_rate": {
|
||||
"type": "number"
|
||||
},
|
||||
"cost_per_merged_pr_micros": {
|
||||
"type": "integer"
|
||||
},
|
||||
"merge_rate": {
|
||||
"type": "number"
|
||||
},
|
||||
"prev_cost_per_merged_pr_micros": {
|
||||
"type": "integer"
|
||||
},
|
||||
"prev_merge_rate": {
|
||||
"type": "number"
|
||||
},
|
||||
"prev_total_prs_created": {
|
||||
"type": "integer"
|
||||
},
|
||||
"prev_total_prs_merged": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_additions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_cost_micros": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_deletions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_prs_created": {
|
||||
"type": "integer"
|
||||
},
|
||||
"total_prs_merged": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.PRInsightsTimeSeriesEntry": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"date": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"prs_closed": {
|
||||
"type": "integer"
|
||||
},
|
||||
"prs_created": {
|
||||
"type": "integer"
|
||||
},
|
||||
"prs_merged": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.PaginatedMembersResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -16747,6 +16944,9 @@
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"is_service_account": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"last_seen_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -16910,7 +17110,8 @@
|
||||
"idp_sync_settings_role",
|
||||
"workspace_agent",
|
||||
"workspace_app",
|
||||
"task"
|
||||
"task",
|
||||
"ai_seat"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"ResourceTypeTemplate",
|
||||
@@ -16938,7 +17139,8 @@
|
||||
"ResourceTypeIdpSyncSettingsRole",
|
||||
"ResourceTypeWorkspaceAgent",
|
||||
"ResourceTypeWorkspaceApp",
|
||||
"ResourceTypeTask"
|
||||
"ResourceTypeTask",
|
||||
"ResourceTypeAISeat"
|
||||
]
|
||||
},
|
||||
"codersdk.Response": {
|
||||
@@ -17146,6 +17348,15 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ShareableWorkspaceOwners": {
|
||||
"type": "string",
|
||||
"enum": ["none", "everyone", "service_accounts"],
|
||||
"x-enum-varnames": [
|
||||
"ShareableWorkspaceOwnersNone",
|
||||
"ShareableWorkspaceOwnersEveryone",
|
||||
"ShareableWorkspaceOwnersServiceAccounts"
|
||||
]
|
||||
},
|
||||
"codersdk.SharedWorkspaceActor": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -18007,6 +18218,9 @@
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"is_service_account": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"last_seen_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -18682,7 +18896,17 @@
|
||||
"codersdk.UpdateWorkspaceSharingSettingsRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"shareable_workspace_owners": {
|
||||
"description": "ShareableWorkspaceOwners controls whose workspaces can be shared\nwithin the organization.",
|
||||
"enum": ["none", "everyone", "service_accounts"],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.ShareableWorkspaceOwners"
|
||||
}
|
||||
]
|
||||
},
|
||||
"sharing_disabled": {
|
||||
"description": "SharingDisabled is deprecated and left for backward compatibility\npurposes.\nDeprecated: use `ShareableWorkspaceOwners` instead",
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
@@ -18786,6 +19010,9 @@
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"is_service_account": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"last_seen_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -20421,7 +20648,17 @@
|
||||
"codersdk.WorkspaceSharingSettings": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"shareable_workspace_owners": {
|
||||
"description": "ShareableWorkspaceOwners controls whose workspaces can be shared\nwithin the organization.",
|
||||
"enum": ["none", "everyone", "service_accounts"],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.ShareableWorkspaceOwners"
|
||||
}
|
||||
]
|
||||
},
|
||||
"sharing_disabled": {
|
||||
"description": "SharingDisabled is deprecated and left for backward compatibility\npurposes.\nDeprecated: use `ShareableWorkspaceOwners` instead",
|
||||
"type": "boolean"
|
||||
},
|
||||
"sharing_globally_disabled": {
|
||||
|
||||
@@ -32,7 +32,8 @@ type Auditable interface {
|
||||
idpsync.OrganizationSyncSettings |
|
||||
idpsync.GroupSyncSettings |
|
||||
idpsync.RoleSyncSettings |
|
||||
database.TaskTable
|
||||
database.TaskTable |
|
||||
database.AiSeatState
|
||||
}
|
||||
|
||||
// Map is a map of changed fields in an audited resource. It maps field names to
|
||||
|
||||
@@ -132,6 +132,8 @@ func ResourceTarget[T Auditable](tgt T) string {
|
||||
return "Organization Role Sync"
|
||||
case database.TaskTable:
|
||||
return typed.Name
|
||||
case database.AiSeatState:
|
||||
return "AI Seat"
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown resource %T for ResourceTarget", tgt))
|
||||
}
|
||||
@@ -196,6 +198,8 @@ func ResourceID[T Auditable](tgt T) uuid.UUID {
|
||||
return noID // Org field on audit log has org id
|
||||
case database.TaskTable:
|
||||
return typed.ID
|
||||
case database.AiSeatState:
|
||||
return typed.UserID
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown resource %T for ResourceID", tgt))
|
||||
}
|
||||
@@ -251,6 +255,8 @@ func ResourceType[T Auditable](tgt T) database.ResourceType {
|
||||
return database.ResourceTypeIdpSyncSettingsGroup
|
||||
case database.TaskTable:
|
||||
return database.ResourceTypeTask
|
||||
case database.AiSeatState:
|
||||
return database.ResourceTypeAiSeat
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown resource %T for ResourceType", typed))
|
||||
}
|
||||
@@ -309,6 +315,8 @@ func ResourceRequiresOrgID[T Auditable]() bool {
|
||||
return true
|
||||
case database.TaskTable:
|
||||
return true
|
||||
case database.AiSeatState:
|
||||
return false
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown resource %T for ResourceRequiresOrgID", tgt))
|
||||
}
|
||||
|
||||
+592
-320
File diff suppressed because it is too large
Load Diff
@@ -2,13 +2,20 @@ package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
)
|
||||
|
||||
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
|
||||
@@ -84,3 +91,135 @@ func TestRefreshChatWorkspaceSnapshot_ReturnsReloadError(t *testing.T) {
|
||||
require.ErrorContains(t, err, loadErr.Error())
|
||||
require.Equal(t, chat, refreshed)
|
||||
}
|
||||
|
||||
func TestResolveInstructionsReusesTurnLocalWorkspaceAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
workspaceAgent := database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
OperatingSystem: "linux",
|
||||
Directory: "/home/coder/project",
|
||||
ExpandedDirectory: "/home/coder/project",
|
||||
}
|
||||
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
gomock.Any(),
|
||||
workspaceID,
|
||||
).Return([]database.WorkspaceAgent{workspaceAgent}, nil).Times(1)
|
||||
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
||||
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).Return(
|
||||
workspacesdk.LSResponse{},
|
||||
codersdk.NewTestError(404, "POST", "/api/v0/list-directory"),
|
||||
).Times(1)
|
||||
conn.EXPECT().ReadFile(
|
||||
gomock.Any(),
|
||||
"/home/coder/project/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1),
|
||||
).Return(
|
||||
nil,
|
||||
"",
|
||||
codersdk.NewTestError(404, "GET", "/api/v0/read-file"),
|
||||
).Times(1)
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
instructionCache: make(map[uuid.UUID]cachedInstruction),
|
||||
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
return conn, func() {}, nil
|
||||
},
|
||||
}
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: server,
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
t.Cleanup(workspaceCtx.close)
|
||||
|
||||
instruction := server.resolveInstructions(
|
||||
ctx,
|
||||
chat,
|
||||
workspaceCtx.getWorkspaceAgent,
|
||||
workspaceCtx.getWorkspaceConn,
|
||||
)
|
||||
require.Contains(t, instruction, "Operating System: linux")
|
||||
require.Contains(t, instruction, "Working Directory: /home/coder/project")
|
||||
}
|
||||
|
||||
func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
initialAgent := database.WorkspaceAgent{ID: uuid.New()}
|
||||
refreshedAgent := database.WorkspaceAgent{ID: uuid.New()}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
gomock.Any(),
|
||||
workspaceID,
|
||||
).Return([]database.WorkspaceAgent{initialAgent}, nil),
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
gomock.Any(),
|
||||
workspaceID,
|
||||
).Return([]database.WorkspaceAgent{refreshedAgent}, nil),
|
||||
)
|
||||
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
||||
|
||||
var dialed []uuid.UUID
|
||||
server := &Server{db: db}
|
||||
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
dialed = append(dialed, agentID)
|
||||
if agentID == initialAgent.ID {
|
||||
return nil, nil, xerrors.New("dial failed")
|
||||
}
|
||||
return conn, func() {}, nil
|
||||
}
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: server,
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
t.Cleanup(workspaceCtx.close)
|
||||
|
||||
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, conn, gotConn)
|
||||
require.Equal(t, []uuid.UUID{initialAgent.ID, refreshedAgent.ID}, dialed)
|
||||
}
|
||||
|
||||
+781
-612
File diff suppressed because it is too large
Load Diff
@@ -42,6 +42,11 @@ type PersistedStep struct {
|
||||
Content []fantasy.Content
|
||||
Usage fantasy.Usage
|
||||
ContextLimit sql.NullInt64
|
||||
// Runtime is the wall-clock duration of this step,
|
||||
// covering LLM streaming, tool execution, and retries.
|
||||
// Zero indicates the duration was not measured (e.g.
|
||||
// interrupted steps).
|
||||
Runtime time.Duration
|
||||
}
|
||||
|
||||
// RunOptions configures a single streaming chat loop run.
|
||||
@@ -260,6 +265,7 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
|
||||
for step := 0; totalSteps < opts.MaxSteps; step++ {
|
||||
totalSteps++
|
||||
stepStart := time.Now()
|
||||
// Copy messages so that provider-specific caching
|
||||
// mutations don't leak back to the caller's slice.
|
||||
// copy copies Message structs by value, so field
|
||||
@@ -365,6 +371,7 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
Content: result.content,
|
||||
Usage: result.usage,
|
||||
ContextLimit: contextLimit,
|
||||
Runtime: time.Since(stepStart),
|
||||
}); err != nil {
|
||||
if errors.Is(err, ErrInterrupted) {
|
||||
persistInterruptedStep(ctx, opts, &result)
|
||||
@@ -610,10 +617,12 @@ func processStepStream(
|
||||
result.providerMetadata = part.ProviderMetadata
|
||||
|
||||
case fantasy.StreamPartTypeError:
|
||||
// Detect interruption: context canceled with
|
||||
// ErrInterrupted as the cause.
|
||||
if errors.Is(part.Error, context.Canceled) &&
|
||||
errors.Is(context.Cause(ctx), ErrInterrupted) {
|
||||
// Detect interruption: the stream may surface the
|
||||
// cancel as context.Canceled or propagate the
|
||||
// ErrInterrupted cause directly, depending on
|
||||
// the provider implementation.
|
||||
if errors.Is(context.Cause(ctx), ErrInterrupted) &&
|
||||
(errors.Is(part.Error, context.Canceled) || errors.Is(part.Error, ErrInterrupted)) {
|
||||
// Flush in-progress content so that
|
||||
// persistInterruptedStep has access to partial
|
||||
// text, reasoning, and tool calls that were
|
||||
@@ -631,6 +640,23 @@ func processStepStream(
|
||||
}
|
||||
}
|
||||
|
||||
// The stream iterator may stop yielding parts without
|
||||
// producing a StreamPartTypeError when the context is
|
||||
// canceled (e.g. some providers close the response body
|
||||
// silently). Detect this case and flush partial content
|
||||
// so that persistInterruptedStep can save it.
|
||||
if ctx.Err() != nil &&
|
||||
errors.Is(context.Cause(ctx), ErrInterrupted) {
|
||||
flushActiveState(
|
||||
&result,
|
||||
activeTextContent,
|
||||
activeReasoningContent,
|
||||
activeToolCalls,
|
||||
toolNames,
|
||||
)
|
||||
return result, ErrInterrupted
|
||||
}
|
||||
|
||||
hasLocalToolCalls := false
|
||||
for _, tc := range result.toolCalls {
|
||||
if !tc.ProviderExecuted {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
@@ -64,6 +65,8 @@ func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
|
||||
require.Equal(t, 1, persistStepCalls)
|
||||
require.True(t, persistedStep.ContextLimit.Valid)
|
||||
require.Equal(t, int64(4096), persistedStep.ContextLimit.Int64)
|
||||
require.Greater(t, persistedStep.Runtime, time.Duration(0),
|
||||
"step runtime should be positive")
|
||||
|
||||
require.NotEmpty(t, capturedCall.Prompt)
|
||||
require.False(t, containsPromptSentinel(capturedCall.Prompt))
|
||||
|
||||
@@ -82,7 +82,7 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
|
||||
|
||||
options := &codersdk.ChatModelProviderOptions{
|
||||
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
|
||||
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
|
||||
Reasoning: &codersdk.ChatModelReasoningOptions{
|
||||
Enabled: boolPtr(true),
|
||||
},
|
||||
Provider: &codersdk.ChatModelOpenRouterProvider{
|
||||
@@ -92,7 +92,7 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
|
||||
}
|
||||
defaults := &codersdk.ChatModelProviderOptions{
|
||||
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
|
||||
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
|
||||
Reasoning: &codersdk.ChatModelReasoningOptions{
|
||||
Enabled: boolPtr(false),
|
||||
Exclude: boolPtr(true),
|
||||
MaxTokens: int64Ptr(123),
|
||||
|
||||
@@ -78,10 +78,10 @@ type ProcessToolOptions struct {
|
||||
|
||||
// ExecuteArgs are the parameters accepted by the execute tool.
|
||||
type ExecuteArgs struct {
|
||||
Command string `json:"command"`
|
||||
Timeout *string `json:"timeout,omitempty"`
|
||||
WorkDir *string `json:"workdir,omitempty"`
|
||||
RunInBackground *bool `json:"run_in_background,omitempty"`
|
||||
Command string `json:"command" description:"The shell command to execute."`
|
||||
Timeout *string `json:"timeout,omitempty" description:"Timeout duration (e.g. '30s', '5m'). Default is 10s. Only applies to foreground commands."`
|
||||
WorkDir *string `json:"workdir,omitempty" description:"Working directory for the command."`
|
||||
RunInBackground *bool `json:"run_in_background,omitempty" description:"Run this command in the background without blocking. Use for long-running processes like dev servers, file watchers, or builds that run longer than 5 seconds. Do NOT use shell & to background processes — it will not work correctly. Always use this parameter instead."`
|
||||
}
|
||||
|
||||
// Execute returns an AgentTool that runs a shell command in the
|
||||
@@ -89,7 +89,7 @@ type ExecuteArgs struct {
|
||||
func Execute(options ExecuteOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"execute",
|
||||
"Execute a shell command in the workspace.",
|
||||
"Execute a shell command in the workspace. Use run_in_background=true for long-running processes (dev servers, file watchers, builds). Never use shell '&' for backgrounding.",
|
||||
func(ctx context.Context, args ExecuteArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
@@ -122,6 +122,16 @@ func executeTool(
|
||||
|
||||
background := args.RunInBackground != nil && *args.RunInBackground
|
||||
|
||||
// Detect shell-style backgrounding (trailing &) and promote to
|
||||
// background mode. Models sometimes use "cmd &" instead of the
|
||||
// run_in_background parameter, which causes the shell to fork
|
||||
// and exit immediately, leaving an untracked orphan process.
|
||||
trimmed := strings.TrimSpace(args.Command)
|
||||
if !background && strings.HasSuffix(trimmed, "&") && !strings.HasSuffix(trimmed, "&&") {
|
||||
background = true
|
||||
args.Command = strings.TrimSpace(strings.TrimSuffix(trimmed, "&"))
|
||||
}
|
||||
|
||||
var workDir string
|
||||
if args.WorkDir != nil {
|
||||
workDir = *args.WorkDir
|
||||
|
||||
@@ -92,7 +92,7 @@ func TestAnthropicWebSearchRoundTrip(t *testing.T) {
|
||||
// Verify the chat completed and messages were persisted.
|
||||
chatData, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
chatMsgs, err := client.GetChatMessages(ctx, chat.ID)
|
||||
chatMsgs, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Chat status after step 1: %s, messages: %d",
|
||||
chatData.Status, len(chatMsgs.Messages))
|
||||
@@ -154,7 +154,7 @@ func TestAnthropicWebSearchRoundTrip(t *testing.T) {
|
||||
// Verify the follow-up completed and produced content.
|
||||
chatData2, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
chatMsgs2, err := client.GetChatMessages(ctx, chat.ID)
|
||||
chatMsgs2, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Chat status after step 2: %s, messages: %d",
|
||||
chatData2.Status, len(chatMsgs2.Messages))
|
||||
|
||||
@@ -62,6 +62,7 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
messages []database.ChatMessage,
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
generatedTitle *generatedChatTitle,
|
||||
logger slog.Logger,
|
||||
) {
|
||||
input, ok := titleInput(chat, messages)
|
||||
@@ -111,7 +112,8 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
return
|
||||
}
|
||||
chat.Title = title
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange)
|
||||
generatedTitle.Store(title)
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange, nil)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -84,6 +84,14 @@ func (p *Server) isAnthropicConfigured(ctx context.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Server) isDesktopEnabled(ctx context.Context) bool {
|
||||
enabled, err := p.db.GetChatDesktopEnabled(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
func (p *Server) subagentTools(ctx context.Context, currentChat func() database.Chat) []fantasy.AgentTool {
|
||||
tools := []fantasy.AgentTool{
|
||||
fantasy.NewAgentTool(
|
||||
@@ -253,9 +261,8 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
|
||||
}
|
||||
|
||||
// Only include the computer use tool when an Anthropic
|
||||
// provider is configured, since it requires an Anthropic
|
||||
// model.
|
||||
if p.isAnthropicConfigured(ctx) {
|
||||
// provider is configured and desktop is enabled.
|
||||
if p.isAnthropicConfigured(ctx) && p.isDesktopEnabled(ctx) {
|
||||
tools = append(tools, fantasy.NewAgentTool(
|
||||
"spawn_computer_use_agent",
|
||||
"Spawn a dedicated computer use agent that can see the desktop "+
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
@@ -144,14 +145,20 @@ func findToolByName(tools []fantasy.AgentTool, name string) fantasy.AgentTool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func chatdTestContext(t *testing.T) context.Context {
|
||||
t.Helper()
|
||||
return dbauthz.AsChatd(testutil.Context(t, testutil.WaitLong))
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_NoAnthropicProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
|
||||
// No Anthropic key in ProviderAPIKeys.
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
ctx := chatdTestContext(t)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
// Create a root parent chat.
|
||||
@@ -176,12 +183,13 @@ func TestSpawnComputerUseAgent_NotAvailableForChildChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
|
||||
// Provide an Anthropic key so the provider check passes.
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "test-anthropic-key",
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
ctx := chatdTestContext(t)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
// Create a root parent chat.
|
||||
@@ -232,16 +240,42 @@ func TestSpawnComputerUseAgent_NotAvailableForChildChats(t *testing.T) {
|
||||
assert.Contains(t, resp.Content, "delegated chats cannot create child subagents")
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_DesktopDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "test-anthropic-key",
|
||||
})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent-desktop-disabled",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
parentChat, err := db.GetChatByID(ctx, parent.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
|
||||
tool := findToolByName(tools, "spawn_computer_use_agent")
|
||||
assert.Nil(t, tool, "spawn_computer_use_agent tool must be omitted when desktop is disabled")
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
|
||||
// Provide an Anthropic key so the tool can proceed.
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "test-anthropic-key",
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
ctx := chatdTestContext(t)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
// The parent uses an OpenAI model.
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// ComputeUsagePeriodBounds returns the UTC-aligned start and end bounds for the
|
||||
// active usage-limit period containing now.
|
||||
func ComputeUsagePeriodBounds(now time.Time, period codersdk.ChatUsageLimitPeriod) (start, end time.Time) {
|
||||
utcNow := now.UTC()
|
||||
|
||||
switch period {
|
||||
case codersdk.ChatUsageLimitPeriodDay:
|
||||
start = time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
|
||||
end = start.AddDate(0, 0, 1)
|
||||
case codersdk.ChatUsageLimitPeriodWeek:
|
||||
// Walk backward to Monday of the current ISO week.
|
||||
// ISO 8601 weeks always start on Monday, so this never
|
||||
// crosses an ISO-week boundary.
|
||||
start = time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
|
||||
for start.Weekday() != time.Monday {
|
||||
start = start.AddDate(0, 0, -1)
|
||||
}
|
||||
end = start.AddDate(0, 0, 7)
|
||||
case codersdk.ChatUsageLimitPeriodMonth:
|
||||
start = time.Date(utcNow.Year(), utcNow.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
end = start.AddDate(0, 1, 0)
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown chat usage limit period: %q", period))
|
||||
}
|
||||
|
||||
return start, end
|
||||
}
|
||||
|
||||
// ResolveUsageLimitStatus resolves the current usage-limit status for userID.
|
||||
//
|
||||
// Note: There is a potential race condition where two concurrent messages
|
||||
// from the same user can both pass the limit check if processed in
|
||||
// parallel, allowing brief overage. This is acceptable because:
|
||||
// - Cost is only known after the LLM API returns.
|
||||
// - Overage is bounded by message cost × concurrency.
|
||||
// - Fail-open is the deliberate design choice for this feature.
|
||||
//
|
||||
// Architecture note: today this path enforces one period globally
|
||||
// (day/week/month) from config.
|
||||
// To support simultaneous periods, add nullable
|
||||
// daily/weekly/monthly_limit_micros columns on override tables, where NULL
|
||||
// means no limit for that period.
|
||||
// Then scan spend once over the widest active window with conditional SUMs
|
||||
// for each period and compare each spend/limit pair Go-side, blocking on
|
||||
// whichever period is tightest.
|
||||
func ResolveUsageLimitStatus(ctx context.Context, db database.Store, userID uuid.UUID, now time.Time) (*codersdk.ChatUsageLimitStatus, error) {
|
||||
//nolint:gocritic // AsChatd provides narrowly-scoped daemon access for
|
||||
// deployment config reads and cross-user chat spend aggregation.
|
||||
authCtx := dbauthz.AsChatd(ctx)
|
||||
|
||||
config, err := db.GetChatUsageLimitConfig(authCtx)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if !config.Enabled {
|
||||
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
|
||||
}
|
||||
|
||||
period, ok := mapDBPeriodToSDK(config.Period)
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("invalid chat usage limit period %q", config.Period)
|
||||
}
|
||||
|
||||
// Resolve effective limit in a single query:
|
||||
// individual override > group limit > global default.
|
||||
effectiveLimit, err := db.ResolveUserChatSpendLimit(authCtx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// -1 means limits are disabled (shouldn't happen since we checked above,
|
||||
// but handle gracefully).
|
||||
if effectiveLimit < 0 {
|
||||
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
|
||||
}
|
||||
|
||||
start, end := ComputeUsagePeriodBounds(now, period)
|
||||
|
||||
spendTotal, err := db.GetUserChatSpendInPeriod(authCtx, database.GetUserChatSpendInPeriodParams{
|
||||
UserID: userID,
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &codersdk.ChatUsageLimitStatus{
|
||||
IsLimited: true,
|
||||
Period: period,
|
||||
SpendLimitMicros: &effectiveLimit,
|
||||
CurrentSpend: spendTotal,
|
||||
PeriodStart: start,
|
||||
PeriodEnd: end,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func mapDBPeriodToSDK(dbPeriod string) (codersdk.ChatUsageLimitPeriod, bool) {
|
||||
switch dbPeriod {
|
||||
case string(codersdk.ChatUsageLimitPeriodDay):
|
||||
return codersdk.ChatUsageLimitPeriodDay, true
|
||||
case string(codersdk.ChatUsageLimitPeriodWeek):
|
||||
return codersdk.ChatUsageLimitPeriodWeek, true
|
||||
case string(codersdk.ChatUsageLimitPeriodMonth):
|
||||
return codersdk.ChatUsageLimitPeriodMonth, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package chatd //nolint:testpackage // Keeps chatd unit tests in the package.
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestComputeUsagePeriodBounds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
newYork, err := time.LoadLocation("America/New_York")
|
||||
if err != nil {
|
||||
t.Fatalf("load America/New_York: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
now time.Time
|
||||
period codersdk.ChatUsageLimitPeriod
|
||||
wantStart time.Time
|
||||
wantEnd time.Time
|
||||
}{
|
||||
{
|
||||
name: "day/mid_day",
|
||||
now: time.Date(2025, time.June, 15, 14, 30, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodDay,
|
||||
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "day/midnight_exactly",
|
||||
now: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodDay,
|
||||
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "day/end_of_day",
|
||||
now: time.Date(2025, time.June, 15, 23, 59, 59, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodDay,
|
||||
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "week/wednesday",
|
||||
now: time.Date(2025, time.June, 11, 10, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodWeek,
|
||||
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "week/monday",
|
||||
now: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodWeek,
|
||||
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "week/sunday",
|
||||
now: time.Date(2025, time.June, 15, 23, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodWeek,
|
||||
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "week/year_boundary",
|
||||
now: time.Date(2024, time.December, 31, 12, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodWeek,
|
||||
wantStart: time.Date(2024, time.December, 30, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.January, 6, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "month/mid_month",
|
||||
now: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodMonth,
|
||||
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "month/first_day",
|
||||
now: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodMonth,
|
||||
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "month/last_day",
|
||||
now: time.Date(2025, time.June, 30, 23, 59, 59, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodMonth,
|
||||
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "month/february",
|
||||
now: time.Date(2025, time.February, 15, 12, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodMonth,
|
||||
wantStart: time.Date(2025, time.February, 1, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.March, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "month/leap_year_february",
|
||||
now: time.Date(2024, time.February, 29, 12, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodMonth,
|
||||
wantStart: time.Date(2024, time.February, 1, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2024, time.March, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "day/non_utc_timezone",
|
||||
now: time.Date(2025, time.June, 15, 22, 0, 0, 0, newYork),
|
||||
period: codersdk.ChatUsageLimitPeriodDay,
|
||||
wantStart: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 17, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
start, end := ComputeUsagePeriodBounds(tc.now, tc.period)
|
||||
if !start.Equal(tc.wantStart) {
|
||||
t.Errorf("start: got %v, want %v", start, tc.wantStart)
|
||||
}
|
||||
if !end.Equal(tc.wantEnd) {
|
||||
t.Errorf("end: got %v, want %v", end, tc.wantEnd)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+887
-164
File diff suppressed because it is too large
Load Diff
+641
-29
@@ -2,6 +2,7 @@ package coderd_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -16,14 +17,19 @@ import (
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/websocket"
|
||||
@@ -54,6 +60,93 @@ func newChatClientWithDatabase(t testing.TB) (*codersdk.Client, database.Store)
|
||||
})
|
||||
}
|
||||
|
||||
func requireChatUsageLimitExceededError(
|
||||
t *testing.T,
|
||||
err error,
|
||||
wantSpentMicros int64,
|
||||
wantLimitMicros int64,
|
||||
wantResetsAt time.Time,
|
||||
) *codersdk.ChatUsageLimitExceededResponse {
|
||||
t.Helper()
|
||||
|
||||
sdkErr, ok := codersdk.AsError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
require.Equal(t, "Chat usage limit exceeded.", sdkErr.Message)
|
||||
|
||||
limitErr := codersdk.ChatUsageLimitExceededFrom(err)
|
||||
require.NotNil(t, limitErr)
|
||||
require.Equal(t, "Chat usage limit exceeded.", limitErr.Message)
|
||||
require.Equal(t, wantSpentMicros, limitErr.SpentMicros)
|
||||
require.Equal(t, wantLimitMicros, limitErr.LimitMicros)
|
||||
require.True(
|
||||
t,
|
||||
limitErr.ResetsAt.Equal(wantResetsAt),
|
||||
"expected resets_at %s, got %s",
|
||||
wantResetsAt.UTC().Format(time.RFC3339),
|
||||
limitErr.ResetsAt.UTC().Format(time.RFC3339),
|
||||
)
|
||||
|
||||
return limitErr
|
||||
}
|
||||
|
||||
func enableDailyChatUsageLimit(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
limitMicros int64,
|
||||
) time.Time {
|
||||
t.Helper()
|
||||
|
||||
_, err := db.UpsertChatUsageLimitConfig(
|
||||
dbauthz.AsSystemRestricted(ctx),
|
||||
database.UpsertChatUsageLimitConfigParams{
|
||||
Enabled: true,
|
||||
DefaultLimitMicros: limitMicros,
|
||||
Period: string(codersdk.ChatUsageLimitPeriodDay),
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, periodEnd := chatd.ComputeUsagePeriodBounds(time.Now(), codersdk.ChatUsageLimitPeriodDay)
|
||||
return periodEnd
|
||||
}
|
||||
|
||||
func insertAssistantCostMessage(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
modelConfigID uuid.UUID,
|
||||
totalCostMicros int64,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("assistant"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
|
||||
ChatID: chatID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
Content: assistantContent,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{},
|
||||
OutputTokens: sql.NullInt64{},
|
||||
TotalTokens: sql.NullInt64{},
|
||||
ReasoningTokens: sql.NullInt64{},
|
||||
CacheCreationTokens: sql.NullInt64{},
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
TotalCostMicros: sql.NullInt64{Int64: totalCostMicros, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPostChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -88,7 +181,7 @@ func TestPostChats(t *testing.T) {
|
||||
|
||||
chatResult, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID)
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chat.ID, chatResult.ID)
|
||||
|
||||
@@ -126,7 +219,7 @@ func TestPostChats(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID)
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
for _, message := range messagesResult.Messages {
|
||||
require.NotEqual(t, codersdk.ChatMessageRoleSystem, message.Role)
|
||||
@@ -324,6 +417,33 @@ func TestPostChats(t *testing.T) {
|
||||
require.Equal(t, "Invalid input part.", sdkErr.Message)
|
||||
require.Equal(t, `content[0].type "image" is not supported.`, sdkErr.Detail)
|
||||
})
|
||||
|
||||
t.Run("UsageLimitExceeded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100)
|
||||
|
||||
existingChat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "existing-limit-chat",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
insertAssistantCostMessage(ctx, t, db, existingChat.ID, modelConfig.ID, 100)
|
||||
|
||||
_, err = client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "over limit",
|
||||
}},
|
||||
})
|
||||
requireChatUsageLimitExceededError(t, err, 100, 100, wantResetsAt)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListChats(t *testing.T) {
|
||||
@@ -616,6 +736,127 @@ func TestWatchChats(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DiffStatusChangeIncludesDiffStatus", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||
DeploymentValues: chatDeploymentValues(t),
|
||||
})
|
||||
db := api.Database
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
// Insert a chat and a diff status row.
|
||||
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "diff status watch test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
refreshedAt := time.Now().UTC().Truncate(time.Second)
|
||||
staleAt := refreshedAt.Add(time.Hour)
|
||||
_, err = db.UpsertChatDiffStatusReference(
|
||||
dbauthz.AsSystemRestricted(ctx),
|
||||
database.UpsertChatDiffStatusReferenceParams{
|
||||
ChatID: chat.ID,
|
||||
Url: sql.NullString{String: "https://github.com/coder/coder/pull/99", Valid: true},
|
||||
GitBranch: "feature/test",
|
||||
GitRemoteOrigin: "git@github.com:coder/coder.git",
|
||||
StaleAt: staleAt,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
_, err = db.UpsertChatDiffStatus(
|
||||
dbauthz.AsSystemRestricted(ctx),
|
||||
database.UpsertChatDiffStatusParams{
|
||||
ChatID: chat.ID,
|
||||
Url: sql.NullString{String: "https://github.com/coder/coder/pull/99", Valid: true},
|
||||
PullRequestState: sql.NullString{String: "open", Valid: true},
|
||||
Additions: 42,
|
||||
Deletions: 7,
|
||||
ChangedFiles: 5,
|
||||
RefreshedAt: refreshedAt,
|
||||
StaleAt: staleAt,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Open the watch WebSocket.
|
||||
conn, err := client.Dial(ctx, "/api/experimental/chats/watch", nil)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Read the initial ping.
|
||||
var ping watchEvent
|
||||
err = wsjson.Read(ctx, conn, &ping)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
|
||||
|
||||
// Publish a diff_status_change event via pubsub,
|
||||
// mimicking what PublishDiffStatusChange does after
|
||||
// it reads the diff status from the DB.
|
||||
dbStatus, err := db.GetChatDiffStatusByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID)
|
||||
require.NoError(t, err)
|
||||
sdkDiffStatus := db2sdk.ChatDiffStatus(chat.ID, &dbStatus)
|
||||
event := coderdpubsub.ChatEvent{
|
||||
Kind: coderdpubsub.ChatEventKindDiffStatusChange,
|
||||
Chat: codersdk.Chat{
|
||||
ID: chat.ID,
|
||||
OwnerID: chat.OwnerID,
|
||||
Title: chat.Title,
|
||||
Status: codersdk.ChatStatus(chat.Status),
|
||||
CreatedAt: chat.CreatedAt,
|
||||
UpdatedAt: chat.UpdatedAt,
|
||||
DiffStatus: &sdkDiffStatus,
|
||||
},
|
||||
}
|
||||
payload, err := json.Marshal(event)
|
||||
require.NoError(t, err)
|
||||
err = api.Pubsub.Publish(coderdpubsub.ChatEventChannel(user.UserID), payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read events until we find the diff_status_change.
|
||||
for {
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
require.NoError(t, err)
|
||||
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var received coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &received)
|
||||
require.NoError(t, err)
|
||||
|
||||
if received.Kind != coderdpubsub.ChatEventKindDiffStatusChange ||
|
||||
received.Chat.ID != chat.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify the event carries the full DiffStatus.
|
||||
require.NotNil(t, received.Chat.DiffStatus, "diff_status_change event must include DiffStatus")
|
||||
ds := received.Chat.DiffStatus
|
||||
require.Equal(t, chat.ID, ds.ChatID)
|
||||
require.NotNil(t, ds.URL)
|
||||
require.Equal(t, "https://github.com/coder/coder/pull/99", *ds.URL)
|
||||
require.NotNil(t, ds.PullRequestState)
|
||||
require.Equal(t, "open", *ds.PullRequestState)
|
||||
require.EqualValues(t, 42, ds.Additions)
|
||||
require.EqualValues(t, 7, ds.Deletions)
|
||||
require.EqualValues(t, 5, ds.ChangedFiles)
|
||||
break
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Unauthenticated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1362,7 +1603,7 @@ func TestGetChat(t *testing.T) {
|
||||
|
||||
chatResult, err := client.GetChat(ctx, createdChat.ID)
|
||||
require.NoError(t, err)
|
||||
messagesResult, err := client.GetChatMessages(ctx, createdChat.ID)
|
||||
messagesResult, err := client.GetChatMessages(ctx, createdChat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, createdChat.ID, chatResult.ID)
|
||||
require.Equal(t, firstUser.UserID, chatResult.OwnerID)
|
||||
@@ -1447,7 +1688,7 @@ func TestArchiveChat(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, chatsBeforeArchive, 2)
|
||||
|
||||
err = client.ArchiveChat(ctx, chatToArchive.ID)
|
||||
err = client.UpdateChat(ctx, chatToArchive.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Default (no filter) returns only non-archived chats.
|
||||
@@ -1481,7 +1722,7 @@ func TestArchiveChat(t *testing.T) {
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
err := client.ArchiveChat(ctx, uuid.New())
|
||||
err := client.UpdateChat(ctx, uuid.New(), codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
|
||||
@@ -1524,7 +1765,7 @@ func TestArchiveChat(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive the parent via the API.
|
||||
err = client.ArchiveChat(ctx, parentChat.ID)
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
require.NoError(t, err)
|
||||
|
||||
// archived:false should exclude the entire archived family.
|
||||
@@ -1571,7 +1812,7 @@ func TestUnarchiveChat(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive the chat first.
|
||||
err = client.ArchiveChat(ctx, chat.ID)
|
||||
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's archived.
|
||||
@@ -1582,7 +1823,7 @@ func TestUnarchiveChat(t *testing.T) {
|
||||
require.Len(t, archivedChats, 1)
|
||||
require.True(t, archivedChats[0].Archived)
|
||||
// Unarchive the chat.
|
||||
err = client.UnarchiveChat(ctx, chat.ID)
|
||||
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's no longer archived.
|
||||
@@ -1621,10 +1862,9 @@ func TestUnarchiveChat(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Trying to unarchive a non-archived chat should fail.
|
||||
err = client.UnarchiveChat(ctx, chat.ID)
|
||||
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1632,7 +1872,7 @@ func TestUnarchiveChat(t *testing.T) {
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
err := client.UnarchiveChat(ctx, uuid.New())
|
||||
err := client.UpdateChat(ctx, uuid.New(), codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
}
|
||||
@@ -1686,7 +1926,7 @@ func TestPostChatMessages(t *testing.T) {
|
||||
require.True(t, hasTextPart(created.QueuedMessage.Content, messageText))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID)
|
||||
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
@@ -1714,7 +1954,7 @@ func TestPostChatMessages(t *testing.T) {
|
||||
require.True(t, hasTextPart(created.Message.Content, messageText))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID)
|
||||
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
@@ -1761,6 +2001,34 @@ func TestPostChatMessages(t *testing.T) {
|
||||
require.Equal(t, "content[0].text cannot be empty.", sdkErr.Detail)
|
||||
})
|
||||
|
||||
t.Run("UsageLimitExceeded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "initial message for usage-limit test",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100)
|
||||
insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100)
|
||||
|
||||
_, err = client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "over limit",
|
||||
}},
|
||||
})
|
||||
requireChatUsageLimitExceededError(t, err, 100, 100, wantResetsAt)
|
||||
})
|
||||
|
||||
t.Run("ChatNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1829,7 +2097,7 @@ func TestChatMessageWithFileReferences(t *testing.T) {
|
||||
|
||||
var found bool
|
||||
require.Eventually(t, func() bool {
|
||||
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID)
|
||||
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
@@ -1889,7 +2157,7 @@ func TestChatMessageWithFileReferences(t *testing.T) {
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID)
|
||||
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
@@ -1942,7 +2210,7 @@ func TestChatMessageWithFileReferences(t *testing.T) {
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID)
|
||||
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
@@ -1995,7 +2263,7 @@ func TestChatMessageWithFileReferences(t *testing.T) {
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID)
|
||||
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
@@ -2085,7 +2353,7 @@ func TestChatMessageWithFileReferences(t *testing.T) {
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID)
|
||||
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
@@ -2275,7 +2543,7 @@ func TestChatMessageWithFiles(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify file parts omit inline data in the API response.
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID)
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
for _, msg := range messagesResult.Messages {
|
||||
for _, part := range msg.Content {
|
||||
@@ -2371,7 +2639,7 @@ func TestPatchChatMessage(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID)
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
var userMessageID int64
|
||||
@@ -2403,7 +2671,7 @@ func TestPatchChatMessage(t *testing.T) {
|
||||
}
|
||||
require.True(t, foundEditedText)
|
||||
|
||||
messagesResult, err = client.GetChatMessages(ctx, chat.ID)
|
||||
messagesResult, err = client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
foundEditedInChat := false
|
||||
foundOriginalInChat := false
|
||||
@@ -2456,7 +2724,7 @@ func TestPatchChatMessage(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Find the user message ID.
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID)
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
var userMessageID int64
|
||||
@@ -2499,7 +2767,7 @@ func TestPatchChatMessage(t *testing.T) {
|
||||
require.True(t, foundFile, "edited message should preserve file_id")
|
||||
|
||||
// GET the chat messages and verify the file_id persists.
|
||||
messagesResult, err = client.GetChatMessages(ctx, chat.ID)
|
||||
messagesResult, err = client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
var foundTextInChat, foundFileInChat bool
|
||||
@@ -2521,6 +2789,46 @@ func TestPatchChatMessage(t *testing.T) {
|
||||
require.True(t, foundFileInChat, "chat should preserve file_id after edit")
|
||||
})
|
||||
|
||||
t.Run("UsageLimitExceeded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello before edit",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
var userMessageID int64
|
||||
for _, message := range messagesResult.Messages {
|
||||
if message.Role == codersdk.ChatMessageRoleUser {
|
||||
userMessageID = message.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotZero(t, userMessageID)
|
||||
|
||||
wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100)
|
||||
insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100)
|
||||
|
||||
_, err = client.EditChatMessage(ctx, chat.ID, userMessageID, codersdk.EditChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "edited over limit",
|
||||
}},
|
||||
})
|
||||
requireChatUsageLimitExceededError(t, err, 100, 100, wantResetsAt)
|
||||
})
|
||||
|
||||
t.Run("MessageNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -3114,7 +3422,7 @@ func TestDeleteChatQueuedMessage(t *testing.T) {
|
||||
res.Body.Close()
|
||||
require.Equal(t, http.StatusNoContent, res.StatusCode)
|
||||
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID)
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
for _, queued := range messagesResult.QueuedMessages {
|
||||
require.NotEqual(t, queuedMessage.ID, queued.ID)
|
||||
@@ -3217,7 +3525,7 @@ func TestPromoteChatQueuedMessage(t *testing.T) {
|
||||
}
|
||||
require.True(t, foundPromotedText)
|
||||
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID)
|
||||
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
for _, queued := range messagesResult.QueuedMessages {
|
||||
require.NotEqual(t, queuedMessage.ID, queued.ID)
|
||||
@@ -3230,6 +3538,81 @@ func TestPromoteChatQueuedMessage(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PromotesAlreadyQueuedMessageAfterLimitReached", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
enableDailyChatUsageLimit(ctx, t, db, 100)
|
||||
|
||||
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "promote queued usage limit",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
const queuedText = "queued message for promote route"
|
||||
queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText(queuedText),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
queuedMessage, err := db.InsertChatQueuedMessage(
|
||||
dbauthz.AsSystemRestricted(ctx),
|
||||
database.InsertChatQueuedMessageParams{
|
||||
ChatID: chat.ID,
|
||||
Content: queuedContent,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100)
|
||||
|
||||
_, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusWaiting,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
promoteRes, err := client.Request(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
fmt.Sprintf("/api/experimental/chats/%s/queue/%d/promote", chat.ID, queuedMessage.ID),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer promoteRes.Body.Close()
|
||||
require.Equal(t, http.StatusOK, promoteRes.StatusCode)
|
||||
|
||||
var promoted codersdk.ChatMessage
|
||||
err = json.NewDecoder(promoteRes.Body).Decode(&promoted)
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, promoted.ID)
|
||||
require.Equal(t, chat.ID, promoted.ChatID)
|
||||
require.Equal(t, codersdk.ChatMessageRoleUser, promoted.Role)
|
||||
|
||||
foundPromotedText := false
|
||||
for _, part := range promoted.Content {
|
||||
if part.Type == codersdk.ChatMessagePartTypeText && part.Text == queuedText {
|
||||
foundPromotedText = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundPromotedText)
|
||||
|
||||
queuedMessages, err := db.GetChatQueuedMessages(dbauthz.AsSystemRestricted(ctx), chat.ID)
|
||||
require.NoError(t, err)
|
||||
for _, queued := range queuedMessages {
|
||||
require.NotEqual(t, queuedMessage.ID, queued.ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InvalidQueuedMessageID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -3261,6 +3644,133 @@ func TestPromoteChatQueuedMessage(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatUsageLimitOverrideRoutes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("UpsertUserOverrideRequiresPositiveSpendLimit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, _ := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
_, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
|
||||
|
||||
res, err := client.Request(
|
||||
ctx,
|
||||
http.MethodPut,
|
||||
fmt.Sprintf("/api/experimental/chats/usage-limits/overrides/%s", member.ID),
|
||||
map[string]any{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
err = codersdk.ReadBodyAsError(res)
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid chat usage limit override.", sdkErr.Message)
|
||||
require.Equal(t, "Spend limit must be greater than 0.", sdkErr.Detail)
|
||||
})
|
||||
|
||||
t.Run("UpsertUserOverrideMissingUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_, err := client.UpsertChatUsageLimitOverride(ctx, uuid.New(), codersdk.UpsertChatUsageLimitOverrideRequest{
|
||||
SpendLimitMicros: 7_000_000,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusNotFound)
|
||||
require.Equal(t, "User not found.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("DeleteUserOverrideMissingUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
err := client.DeleteChatUsageLimitOverride(ctx, uuid.New())
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "User not found.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("DeleteUserOverrideMissingOverride", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
_, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
|
||||
|
||||
err := client.DeleteChatUsageLimitOverride(ctx, member.ID)
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Chat usage limit override not found.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("UpsertGroupOverrideIncludesMemberCount", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
_, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
|
||||
group := dbgen.Group(t, db, database.Group{OrganizationID: firstUser.OrganizationID})
|
||||
dbgen.GroupMember(t, db, database.GroupMemberTable{GroupID: group.ID, UserID: member.ID})
|
||||
dbgen.GroupMember(t, db, database.GroupMemberTable{GroupID: group.ID, UserID: database.PrebuildsSystemUserID})
|
||||
|
||||
override, err := client.UpsertChatUsageLimitGroupOverride(ctx, group.ID, codersdk.UpsertChatUsageLimitGroupOverrideRequest{
|
||||
SpendLimitMicros: 7_000_000,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, group.ID, override.GroupID)
|
||||
require.EqualValues(t, 1, override.MemberCount)
|
||||
require.NotNil(t, override.SpendLimitMicros)
|
||||
require.EqualValues(t, 7_000_000, *override.SpendLimitMicros)
|
||||
|
||||
config, err := client.GetChatUsageLimitConfig(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
var listed *codersdk.ChatUsageLimitGroupOverride
|
||||
for i := range config.GroupOverrides {
|
||||
if config.GroupOverrides[i].GroupID == group.ID {
|
||||
listed = &config.GroupOverrides[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, listed)
|
||||
require.EqualValues(t, 1, listed.MemberCount)
|
||||
})
|
||||
|
||||
t.Run("UpsertGroupOverrideMissingGroup", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_, err := client.UpsertChatUsageLimitGroupOverride(ctx, uuid.New(), codersdk.UpsertChatUsageLimitGroupOverrideRequest{
|
||||
SpendLimitMicros: 7_000_000,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusNotFound)
|
||||
require.Equal(t, "Group not found.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("DeleteGroupOverrideMissingOverride", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
group := dbgen.Group(t, db, database.Group{OrganizationID: firstUser.OrganizationID})
|
||||
|
||||
err := client.DeleteChatUsageLimitGroupOverride(ctx, group.ID)
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Chat usage limit group override not found.", sdkErr.Message)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostChatFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -4002,7 +4512,7 @@ func TestChatSystemPrompt(t *testing.T) {
|
||||
t.Run("AdminCanSet", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
|
||||
SystemPrompt: "You are a helpful coding assistant.",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -4016,7 +4526,7 @@ func TestChatSystemPrompt(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Unset by sending an empty string.
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
|
||||
SystemPrompt: "",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -4029,7 +4539,7 @@ func TestChatSystemPrompt(t *testing.T) {
|
||||
t.Run("NonAdminFails", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
err := memberClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
|
||||
err := memberClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
|
||||
SystemPrompt: "This should fail.",
|
||||
})
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
@@ -4050,7 +4560,7 @@ func TestChatSystemPrompt(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
tooLong := strings.Repeat("a", 131073)
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
|
||||
SystemPrompt: tooLong,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
@@ -4058,6 +4568,108 @@ func TestChatSystemPrompt(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatDesktopEnabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ReturnsFalseWhenUnset", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, adminClient)
|
||||
|
||||
resp, err := adminClient.GetChatDesktopEnabled(ctx)
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.EnableDesktop)
|
||||
})
|
||||
|
||||
t.Run("AdminCanSetTrue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, adminClient)
|
||||
|
||||
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
|
||||
EnableDesktop: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := adminClient.GetChatDesktopEnabled(ctx)
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.EnableDesktop)
|
||||
})
|
||||
|
||||
t.Run("AdminCanSetFalse", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, adminClient)
|
||||
|
||||
// Set true first, then set false.
|
||||
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
|
||||
EnableDesktop: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
|
||||
EnableDesktop: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := adminClient.GetChatDesktopEnabled(ctx)
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.EnableDesktop)
|
||||
})
|
||||
|
||||
t.Run("NonAdminCanRead", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient)
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
|
||||
|
||||
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
|
||||
EnableDesktop: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := memberClient.GetChatDesktopEnabled(ctx)
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.EnableDesktop)
|
||||
})
|
||||
|
||||
t.Run("NonAdminWriteFails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient)
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
|
||||
|
||||
err := memberClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
|
||||
EnableDesktop: true,
|
||||
})
|
||||
requireSDKError(t, err, http.StatusForbidden)
|
||||
})
|
||||
|
||||
t.Run("UnauthenticatedFails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, adminClient)
|
||||
|
||||
anonClient := codersdk.New(adminClient.URL)
|
||||
_, err := anonClient.GetChatDesktopEnabled(ctx)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error {
|
||||
t.Helper()
|
||||
|
||||
|
||||
+45
-12
@@ -10,6 +10,7 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
httppprof "net/http/pprof"
|
||||
"net/url"
|
||||
@@ -44,6 +45,7 @@ import (
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
|
||||
"github.com/coder/coder/v2/coderd/aiseats"
|
||||
_ "github.com/coder/coder/v2/coderd/apidoc" // Used for swagger docs.
|
||||
"github.com/coder/coder/v2/coderd/appearance"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
@@ -629,7 +631,9 @@ func New(options *Options) *API {
|
||||
),
|
||||
dbRolluper: options.DatabaseRolluper,
|
||||
ProfileCollector: defaultProfileCollector{},
|
||||
AISeatTracker: aiseats.Noop{},
|
||||
}
|
||||
|
||||
api.WorkspaceAppsProvider = workspaceapps.NewDBTokenProvider(
|
||||
ctx,
|
||||
options.Logger.Named("workspaceapps"),
|
||||
@@ -763,17 +767,26 @@ func New(options *Options) *API {
|
||||
}
|
||||
api.agentProvider = stn
|
||||
|
||||
maxChatsPerAcquire := options.DeploymentValues.AI.Chat.AcquireBatchSize.Value()
|
||||
if maxChatsPerAcquire > math.MaxInt32 {
|
||||
maxChatsPerAcquire = math.MaxInt32
|
||||
}
|
||||
if maxChatsPerAcquire < math.MinInt32 {
|
||||
maxChatsPerAcquire = math.MinInt32
|
||||
}
|
||||
|
||||
api.chatDaemon = chatd.New(chatd.Config{
|
||||
Logger: options.Logger.Named("chats"),
|
||||
Database: options.Database,
|
||||
ReplicaID: api.ID,
|
||||
SubscribeFn: options.ChatSubscribeFn,
|
||||
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
StartWorkspace: api.chatStartWorkspace,
|
||||
Pubsub: options.Pubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
Logger: options.Logger.Named("chats"),
|
||||
Database: options.Database,
|
||||
ReplicaID: api.ID,
|
||||
SubscribeFn: options.ChatSubscribeFn,
|
||||
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
|
||||
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
StartWorkspace: api.chatStartWorkspace,
|
||||
Pubsub: options.Pubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
})
|
||||
gitSyncLogger := options.Logger.Named("gitsync")
|
||||
refresher := gitsync.NewRefresher(
|
||||
@@ -1146,6 +1159,9 @@ func New(options *Options) *API {
|
||||
r.Get("/summary", api.chatCostSummary)
|
||||
})
|
||||
})
|
||||
r.Route("/insights", func(r chi.Router) {
|
||||
r.Get("/pull-requests", api.prInsights)
|
||||
})
|
||||
r.Route("/files", func(r chi.Router) {
|
||||
r.Use(httpmw.RateLimit(options.FilesRateLimit, time.Minute))
|
||||
r.Post("/", api.postChatFile)
|
||||
@@ -1154,6 +1170,8 @@ func New(options *Options) *API {
|
||||
r.Route("/config", func(r chi.Router) {
|
||||
r.Get("/system-prompt", api.getChatSystemPrompt)
|
||||
r.Put("/system-prompt", api.putChatSystemPrompt)
|
||||
r.Get("/desktop-enabled", api.getChatDesktopEnabled)
|
||||
r.Put("/desktop-enabled", api.putChatDesktopEnabled)
|
||||
r.Get("/user-prompt", api.getUserChatCustomPrompt)
|
||||
r.Put("/user-prompt", api.putUserChatCustomPrompt)
|
||||
})
|
||||
@@ -1175,13 +1193,25 @@ func New(options *Options) *API {
|
||||
r.Delete("/", api.deleteChatModelConfig)
|
||||
})
|
||||
})
|
||||
r.Route("/usage-limits", func(r chi.Router) {
|
||||
r.Get("/", api.getChatUsageLimitConfig)
|
||||
r.Put("/", api.updateChatUsageLimitConfig)
|
||||
r.Get("/status", api.getMyChatUsageLimitStatus)
|
||||
r.Route("/overrides/{user}", func(r chi.Router) {
|
||||
r.Put("/", api.upsertChatUsageLimitOverride)
|
||||
r.Delete("/", api.deleteChatUsageLimitOverride)
|
||||
})
|
||||
r.Route("/group-overrides/{group}", func(r chi.Router) {
|
||||
r.Put("/", api.upsertChatUsageLimitGroupOverride)
|
||||
r.Delete("/", api.deleteChatUsageLimitGroupOverride)
|
||||
})
|
||||
})
|
||||
r.Route("/{chat}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractChatParam(options.Database))
|
||||
r.Get("/", api.getChat)
|
||||
r.Get("/git/watch", api.watchChatGit)
|
||||
r.Get("/desktop", api.watchChatDesktop)
|
||||
r.Post("/archive", api.archiveChat)
|
||||
r.Post("/unarchive", api.unarchiveChat)
|
||||
r.Patch("/", api.patchChat)
|
||||
r.Get("/messages", api.getChatMessages)
|
||||
r.Post("/messages", api.postChatMessages)
|
||||
r.Patch("/messages/{message}", api.patchChatMessage)
|
||||
@@ -2033,6 +2063,8 @@ type API struct {
|
||||
dbRolluper *dbrollup.Rolluper
|
||||
// chatDaemon handles background processing of pending chats.
|
||||
chatDaemon *chatd.Server
|
||||
// AISeatTracker records AI seat usage.
|
||||
AISeatTracker aiseats.SeatTracker
|
||||
// gitSyncWorker refreshes stale chat diff statuses in the
|
||||
// background.
|
||||
gitSyncWorker *gitsync.Worker
|
||||
@@ -2245,6 +2277,7 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n
|
||||
provisionerdserver.Options{
|
||||
OIDCConfig: api.OIDCConfig,
|
||||
ExternalAuthConfigs: api.ExternalAuthConfigs,
|
||||
AISeatTracker: api.AISeatTracker,
|
||||
Clock: api.Clock,
|
||||
HeartbeatFn: options.heartbeatFn,
|
||||
},
|
||||
|
||||
@@ -879,6 +879,15 @@ func createAnotherUserRetry(t testing.TB, client *codersdk.Client, organizationI
|
||||
m(&req)
|
||||
}
|
||||
|
||||
// Service accounts cannot have a password or email and must
|
||||
// use login_type=none. Enforce this after mutators so callers
|
||||
// only need to set ServiceAccount=true.
|
||||
if req.ServiceAccount {
|
||||
req.Password = ""
|
||||
req.Email = ""
|
||||
req.UserLoginType = codersdk.LoginTypeNone
|
||||
}
|
||||
|
||||
user, err := client.CreateUserWithOrgs(context.Background(), req)
|
||||
var apiError *codersdk.Error
|
||||
// If the user already exists by username or email conflict, try again up to "retries" times.
|
||||
|
||||
@@ -13,32 +13,64 @@ var _ usage.Inserter = (*UsageInserter)(nil)
|
||||
|
||||
type UsageInserter struct {
|
||||
sync.Mutex
|
||||
events []usagetypes.DiscreteEvent
|
||||
discreteEvents []usagetypes.DiscreteEvent
|
||||
heartbeatEvents []usagetypes.HeartbeatEvent
|
||||
seenHeartbeats map[string]struct{}
|
||||
}
|
||||
|
||||
func NewUsageInserter() *UsageInserter {
|
||||
return &UsageInserter{
|
||||
events: []usagetypes.DiscreteEvent{},
|
||||
discreteEvents: []usagetypes.DiscreteEvent{},
|
||||
seenHeartbeats: map[string]struct{}{},
|
||||
heartbeatEvents: []usagetypes.HeartbeatEvent{},
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UsageInserter) InsertDiscreteUsageEvent(_ context.Context, _ database.Store, event usagetypes.DiscreteEvent) error {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
u.events = append(u.events, event)
|
||||
u.discreteEvents = append(u.discreteEvents, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UsageInserter) GetEvents() []usagetypes.DiscreteEvent {
|
||||
func (u *UsageInserter) InsertHeartbeatUsageEvent(_ context.Context, _ database.Store, id string, event usagetypes.HeartbeatEvent) error {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
eventsCopy := make([]usagetypes.DiscreteEvent, len(u.events))
|
||||
copy(eventsCopy, u.events)
|
||||
if _, seen := u.seenHeartbeats[id]; seen {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.seenHeartbeats[id] = struct{}{}
|
||||
u.heartbeatEvents = append(u.heartbeatEvents, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UsageInserter) GetHeartbeatEvents() []usagetypes.HeartbeatEvent {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
eventsCopy := make([]usagetypes.HeartbeatEvent, len(u.heartbeatEvents))
|
||||
copy(eventsCopy, u.heartbeatEvents)
|
||||
return eventsCopy
|
||||
}
|
||||
|
||||
func (u *UsageInserter) GetDiscreteEvents() []usagetypes.DiscreteEvent {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
eventsCopy := make([]usagetypes.DiscreteEvent, len(u.discreteEvents))
|
||||
copy(eventsCopy, u.discreteEvents)
|
||||
return eventsCopy
|
||||
}
|
||||
|
||||
func (u *UsageInserter) TotalEventCount() int {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
return len(u.discreteEvents) + len(u.heartbeatEvents)
|
||||
}
|
||||
|
||||
func (u *UsageInserter) Reset() {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
u.events = []usagetypes.DiscreteEvent{}
|
||||
u.seenHeartbeats = map[string]struct{}{}
|
||||
u.discreteEvents = []usagetypes.DiscreteEvent{}
|
||||
u.heartbeatEvents = []usagetypes.HeartbeatEvent{}
|
||||
}
|
||||
|
||||
@@ -6,22 +6,27 @@ type CheckConstraint string
|
||||
|
||||
// CheckConstraint enums.
|
||||
const (
|
||||
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
|
||||
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
|
||||
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
|
||||
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
|
||||
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
|
||||
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
|
||||
CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users
|
||||
CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users
|
||||
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
|
||||
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
|
||||
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
|
||||
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
|
||||
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
|
||||
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
|
||||
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
|
||||
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
|
||||
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
|
||||
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
|
||||
CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config
|
||||
CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config
|
||||
CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config
|
||||
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
|
||||
CheckGroupsChatSpendLimitMicrosCheck CheckConstraint = "groups_chat_spend_limit_micros_check" // groups
|
||||
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
|
||||
CheckUsersChatSpendLimitMicrosCheck CheckConstraint = "users_chat_spend_limit_micros_check" // users
|
||||
CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users
|
||||
CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users
|
||||
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
|
||||
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
|
||||
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
|
||||
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
|
||||
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
|
||||
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
|
||||
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/render"
|
||||
@@ -194,13 +195,14 @@ func MinimalUserFromVisibleUser(user database.VisibleUser) codersdk.MinimalUser
|
||||
|
||||
func ReducedUser(user database.User) codersdk.ReducedUser {
|
||||
return codersdk.ReducedUser{
|
||||
MinimalUser: MinimalUser(user),
|
||||
Email: user.Email,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
LastSeenAt: user.LastSeenAt,
|
||||
Status: codersdk.UserStatus(user.Status),
|
||||
LoginType: codersdk.LoginType(user.LoginType),
|
||||
MinimalUser: MinimalUser(user),
|
||||
Email: user.Email,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
LastSeenAt: user.LastSeenAt,
|
||||
Status: codersdk.UserStatus(user.Status),
|
||||
LoginType: codersdk.LoginType(user.LoginType),
|
||||
IsServiceAccount: user.IsServiceAccount,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1164,3 +1166,86 @@ func nullInt64Ptr(v sql.NullInt64) *int64 {
|
||||
value := v.Int64
|
||||
return &value
|
||||
}
|
||||
|
||||
// ChatDiffStatus converts a database.ChatDiffStatus to a
|
||||
// codersdk.ChatDiffStatus. When status is nil an empty value
|
||||
// containing only the chatID is returned.
|
||||
func ChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) codersdk.ChatDiffStatus {
|
||||
result := codersdk.ChatDiffStatus{
|
||||
ChatID: chatID,
|
||||
}
|
||||
if status == nil {
|
||||
return result
|
||||
}
|
||||
|
||||
result.ChatID = status.ChatID
|
||||
if status.Url.Valid {
|
||||
u := strings.TrimSpace(status.Url.String)
|
||||
if u != "" {
|
||||
result.URL = &u
|
||||
}
|
||||
}
|
||||
if result.URL == nil {
|
||||
// Try to build a branch URL from the stored origin.
|
||||
// Since this function does not have access to the API
|
||||
// instance, we construct a GitHub provider directly as
|
||||
// a best-effort fallback.
|
||||
// TODO: This uses the default github.com API base URL,
|
||||
// so branch URLs for GitHub Enterprise instances will
|
||||
// be incorrect. To fix this, this function would need
|
||||
// access to the external auth configs.
|
||||
gp := gitprovider.New("github", "", nil)
|
||||
if gp != nil {
|
||||
if owner, repo, _, ok := gp.ParseRepositoryOrigin(status.GitRemoteOrigin); ok {
|
||||
branchURL := gp.BuildBranchURL(owner, repo, status.GitBranch)
|
||||
if branchURL != "" {
|
||||
result.URL = &branchURL
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if status.PullRequestState.Valid {
|
||||
pullRequestState := strings.TrimSpace(status.PullRequestState.String)
|
||||
if pullRequestState != "" {
|
||||
result.PullRequestState = &pullRequestState
|
||||
}
|
||||
}
|
||||
result.PullRequestTitle = status.PullRequestTitle
|
||||
result.PullRequestDraft = status.PullRequestDraft
|
||||
result.ChangesRequested = status.ChangesRequested
|
||||
result.Additions = status.Additions
|
||||
result.Deletions = status.Deletions
|
||||
result.ChangedFiles = status.ChangedFiles
|
||||
if status.AuthorLogin.Valid {
|
||||
result.AuthorLogin = &status.AuthorLogin.String
|
||||
}
|
||||
if status.AuthorAvatarUrl.Valid {
|
||||
result.AuthorAvatarURL = &status.AuthorAvatarUrl.String
|
||||
}
|
||||
if status.BaseBranch.Valid {
|
||||
result.BaseBranch = &status.BaseBranch.String
|
||||
}
|
||||
if status.HeadBranch.Valid {
|
||||
result.HeadBranch = &status.HeadBranch.String
|
||||
}
|
||||
if status.PrNumber.Valid {
|
||||
result.PRNumber = &status.PrNumber.Int32
|
||||
}
|
||||
if status.Commits.Valid {
|
||||
result.Commits = &status.Commits.Int32
|
||||
}
|
||||
if status.Approved.Valid {
|
||||
result.Approved = &status.Approved.Bool
|
||||
}
|
||||
if status.ReviewerCount.Valid {
|
||||
result.ReviewerCount = &status.ReviewerCount.Int32
|
||||
}
|
||||
if status.RefreshedAt.Valid {
|
||||
refreshedAt := status.RefreshedAt.Time
|
||||
result.RefreshedAt = &refreshedAt
|
||||
}
|
||||
staleAt := status.StaleAt
|
||||
result.StaleAt = &staleAt
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -1264,7 +1264,7 @@ func (q *querier) canAssignRoles(ctx context.Context, orgID uuid.UUID, added, re
|
||||
// System roles are stored in the database but have a fixed, code-defined
|
||||
// meaning. Do not rewrite the name for them so the static "who can assign
|
||||
// what" mapping applies.
|
||||
if !rbac.SystemRoleName(roleName.Name) {
|
||||
if !rolestore.IsSystemRoleName(roleName.Name) {
|
||||
// To support a dynamic mapping of what roles can assign what, we need
|
||||
// to store this in the database. For now, just use a static role so
|
||||
// owners and org admins can assign roles.
|
||||
@@ -1726,6 +1726,13 @@ func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountCon
|
||||
return q.db.CountAuthorizedConnectionLogs(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.CountEnabledModelsWithoutPricing(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil {
|
||||
return nil, err
|
||||
@@ -1854,6 +1861,20 @@ func (q *querier) DeleteChatQueuedMessage(ctx context.Context, arg database.Dele
|
||||
return q.db.DeleteChatQueuedMessage(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatUsageLimitGroupOverride(ctx, groupID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatUsageLimitUserOverride(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceCryptoKey); err != nil {
|
||||
return database.CryptoKey{}, err
|
||||
@@ -2124,12 +2145,12 @@ func (q *querier) DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) erro
|
||||
return fetchAndExec(q.log, q.auth, policy.ActionShare, fetch, q.db.DeleteWorkspaceACLByID)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, params database.DeleteWorkspaceACLsByOrganizationParams) error {
|
||||
// This is a system-only function.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteWorkspaceACLsByOrganization(ctx, organizationID)
|
||||
return q.db.DeleteWorkspaceACLsByOrganization(ctx, params)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWorkspaceAgentPortShare(ctx context.Context, arg database.DeleteWorkspaceAgentPortShareParams) error {
|
||||
@@ -2327,6 +2348,13 @@ func (q *querier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Tim
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysLastUsedAfter)(ctx, lastUsed)
|
||||
}
|
||||
|
||||
func (q *querier) GetActiveAISeatCount(ctx context.Context) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceLicense); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.GetActiveAISeatCount(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTemplate.All()); err != nil {
|
||||
return nil, err
|
||||
@@ -2454,6 +2482,17 @@ func (q *querier) GetChatCostSummary(ctx context.Context, arg database.GetChatCo
|
||||
return q.db.GetChatCostSummary(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
// The desktop-enabled flag is a deployment-wide setting read by any
|
||||
// authenticated chat user and by chatd when deciding whether to expose
|
||||
// computer-use tooling. We only require that an explicit actor is present
|
||||
// in the context so unauthenticated calls fail closed.
|
||||
if _, ok := ActorFromContext(ctx); !ok {
|
||||
return false, ErrNoActor
|
||||
}
|
||||
return q.db.GetChatDesktopEnabled(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
@@ -2532,6 +2571,14 @@ func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetC
|
||||
return q.db.GetChatMessagesByChatID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
|
||||
_, err := q.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatMessagesByChatIDDescPaginated(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
@@ -2596,8 +2643,33 @@ func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
return q.db.GetChatSystemPrompt(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByOwnerID)(ctx, ownerID)
|
||||
func (q *querier) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatUsageLimitConfig{}, err
|
||||
}
|
||||
return q.db.GetChatUsageLimitConfig(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.GetChatUsageLimitGroupOverrideRow{}, err
|
||||
}
|
||||
return q.db.GetChatUsageLimitGroupOverride(ctx, groupID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.GetChatUsageLimitUserOverrideRow{}, err
|
||||
}
|
||||
return q.db.GetChatUsageLimitUserOverride(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceChat.Type)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
}
|
||||
return q.db.GetAuthorizedChats(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
@@ -3087,6 +3159,34 @@ func (q *querier) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg da
|
||||
return q.db.GetOrganizationsWithPrebuildStatus(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRInsightsPerModelParams) ([]database.GetPRInsightsPerModelRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetPRInsightsPerModel(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetPRInsightsRecentPRs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.GetPRInsightsSummaryRow{}, err
|
||||
}
|
||||
return q.db.GetPRInsightsSummary(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsTimeSeries(ctx context.Context, arg database.GetPRInsightsTimeSeriesParams) ([]database.GetPRInsightsTimeSeriesRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetPRInsightsTimeSeries(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
|
||||
version, err := q.db.GetTemplateVersionByJobID(ctx, jobID)
|
||||
if err != nil {
|
||||
@@ -3750,6 +3850,13 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID)
|
||||
return q.db.GetUserChatCustomPrompt(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.GetUserChatSpendInPeriod(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return 0, err
|
||||
@@ -3757,6 +3864,13 @@ func (q *querier) GetUserCount(ctx context.Context, includeSystem bool) (int64,
|
||||
return q.db.GetUserCount(ctx, includeSystem)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(userID.String())); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.GetUserGroupSpendLimit(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
|
||||
// Used by insights endpoints. Need to check both for auditors and for regular users with template acl perms.
|
||||
if err := q.authorizeContext(ctx, policy.ActionViewInsights, rbac.ResourceTemplate); err != nil {
|
||||
@@ -4426,6 +4540,13 @@ func (q *querier) InsertAIBridgeInterception(ctx context.Context, arg database.I
|
||||
return insert(q.log, q.auth, rbac.ResourceAibridgeInterception.WithOwner(arg.InitiatorID.String()), q.db.InsertAIBridgeInterception)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
|
||||
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
|
||||
return database.AIBridgeModelThought{}, err
|
||||
}
|
||||
return q.db.InsertAIBridgeModelThought(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
|
||||
// All aibridge_token_usages records belong to the initiator of their associated interception.
|
||||
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
|
||||
@@ -5115,6 +5236,20 @@ func (q *querier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context,
|
||||
return q.db.ListAIBridgeUserPromptsByInterceptionIDs(ctx, interceptionIDs)
|
||||
}
|
||||
|
||||
func (q *querier) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ListChatUsageLimitGroupOverrides(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ListChatUsageLimitOverrides(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListProvisionerKeysByOrganization)(ctx, organizationID)
|
||||
}
|
||||
@@ -5234,6 +5369,13 @@ func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveU
|
||||
return q.db.RemoveUserFromGroups(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(userID.String())); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.ResolveUserChatSpendLimit(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
@@ -6417,6 +6559,13 @@ func (q *querier) UpdateWorkspacesTTLByTemplateID(ctx context.Context, arg datab
|
||||
return q.db.UpdateWorkspacesTTLByTemplateID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertAISeatState(ctx context.Context, arg database.UpsertAISeatStateParams) (bool, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceSystem); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return q.db.UpsertAISeatState(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertAnnouncementBanners(ctx context.Context, value string) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
@@ -6438,6 +6587,13 @@ func (q *querier) UpsertBoundaryUsageStats(ctx context.Context, arg database.Ups
|
||||
return q.db.UpsertBoundaryUsageStats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertChatDesktopEnabled(ctx, enableDesktop)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
// Authorize update on the parent chat.
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
@@ -6469,6 +6625,27 @@ func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) erro
|
||||
return q.db.UpsertChatSystemPrompt(ctx, value)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatUsageLimitConfig{}, err
|
||||
}
|
||||
return q.db.UpsertChatUsageLimitConfig(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.UpsertChatUsageLimitGroupOverrideRow{}, err
|
||||
}
|
||||
return q.db.UpsertChatUsageLimitGroupOverride(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.UpsertChatUsageLimitUserOverrideRow{}, err
|
||||
}
|
||||
return q.db.UpsertChatUsageLimitUserOverride(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
|
||||
return database.ConnectionLog{}, err
|
||||
@@ -6656,6 +6833,13 @@ func (q *querier) UpsertWorkspaceAppAuditSession(ctx context.Context, arg databa
|
||||
return q.db.UpsertWorkspaceAppAuditSession(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UsageEventExistsByID(ctx context.Context, id string) (bool, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceUsageEvent); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return q.db.UsageEventExistsByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) ValidateGroupIDs(ctx context.Context, groupIDs []uuid.UUID) (database.ValidateGroupIDsRow, error) {
|
||||
// This check is probably overly restrictive, but the "correct" check isn't
|
||||
// necessarily obvious. It's only used as a verification check for ACLs right
|
||||
@@ -6751,3 +6935,7 @@ func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database
|
||||
// database.Store interface, so dbauthz needs to implement it.
|
||||
return q.ListAIBridgeModels(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
return q.GetChats(ctx, arg)
|
||||
}
|
||||
|
||||
@@ -513,6 +513,10 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatCostSummary(gomock.Any(), arg).Return(row, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(row)
|
||||
}))
|
||||
s.Run("CountEnabledModelsWithoutPricing", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().CountEnabledModelsWithoutPricing(gomock.Any()).Return(int64(3), nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(int64(3))
|
||||
}))
|
||||
s.Run("GetChatDiffStatusByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
|
||||
@@ -558,6 +562,14 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
|
||||
}))
|
||||
s.Run("GetChatMessagesByChatIDDescPaginated", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
|
||||
arg := database.GetChatMessagesByChatIDDescPaginatedParams{ChatID: chat.ID, BeforeID: 0, LimitVal: 50}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatMessagesByChatIDDescPaginated(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
|
||||
}))
|
||||
s.Run("GetLastChatMessageByRole", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
@@ -606,12 +618,17 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
|
||||
}))
|
||||
s.Run("GetChatsByOwnerID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
c1 := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
c2 := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
params := database.GetChatsByOwnerIDParams{OwnerID: c1.OwnerID}
|
||||
dbm.EXPECT().GetChatsByOwnerID(gomock.Any(), params).Return([]database.Chat{c1, c2}, nil).AnyTimes()
|
||||
check.Args(params).Asserts(c1, policy.ActionRead, c2, policy.ActionRead).Returns([]database.Chat{c1, c2})
|
||||
s.Run("GetChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
params := database.GetChatsParams{}
|
||||
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params).Asserts()
|
||||
}))
|
||||
s.Run("GetAuthorizedChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
params := database.GetChatsParams{}
|
||||
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
|
||||
// No asserts here because it re-routes through GetChats which uses SQLFilter.
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
s.Run("GetChatQueuedMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
@@ -624,6 +641,10 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatSystemPrompt(gomock.Any()).Return("prompt", nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("GetChatDesktopEnabled", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatDesktopEnabled(gomock.Any()).Return(false, nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("GetEnabledChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
@@ -833,6 +854,146 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpsertChatSystemPrompt(gomock.Any(), "").Return(nil).AnyTimes()
|
||||
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("UpsertChatDesktopEnabled", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatDesktopEnabled(gomock.Any(), false).Return(nil).AnyTimes()
|
||||
check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetUserChatSpendInPeriod", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetUserChatSpendInPeriodParams{
|
||||
UserID: uuid.New(),
|
||||
StartTime: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
EndTime: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
spend := int64(123)
|
||||
dbm.EXPECT().GetUserChatSpendInPeriod(gomock.Any(), arg).Return(spend, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.UserID.String()), policy.ActionRead).Returns(spend)
|
||||
}))
|
||||
s.Run("GetUserGroupSpendLimit", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
userID := uuid.New()
|
||||
limit := int64(456)
|
||||
dbm.EXPECT().GetUserGroupSpendLimit(gomock.Any(), userID).Return(limit, nil).AnyTimes()
|
||||
check.Args(userID).Asserts(rbac.ResourceChat.WithOwner(userID.String()), policy.ActionRead).Returns(limit)
|
||||
}))
|
||||
s.Run("ResolveUserChatSpendLimit", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
userID := uuid.New()
|
||||
limit := int64(789)
|
||||
dbm.EXPECT().ResolveUserChatSpendLimit(gomock.Any(), userID).Return(limit, nil).AnyTimes()
|
||||
check.Args(userID).Asserts(rbac.ResourceChat.WithOwner(userID.String()), policy.ActionRead).Returns(limit)
|
||||
}))
|
||||
s.Run("GetChatUsageLimitConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
now := dbtime.Now()
|
||||
config := database.ChatUsageLimitConfig{
|
||||
ID: 1,
|
||||
Singleton: true,
|
||||
Enabled: true,
|
||||
DefaultLimitMicros: 1_000_000,
|
||||
Period: "monthly",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
dbm.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(config, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
|
||||
}))
|
||||
s.Run("GetChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
groupID := uuid.New()
|
||||
override := database.GetChatUsageLimitGroupOverrideRow{
|
||||
GroupID: groupID,
|
||||
SpendLimitMicros: sql.NullInt64{Int64: 2_000_000, Valid: true},
|
||||
}
|
||||
dbm.EXPECT().GetChatUsageLimitGroupOverride(gomock.Any(), groupID).Return(override, nil).AnyTimes()
|
||||
check.Args(groupID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(override)
|
||||
}))
|
||||
s.Run("GetChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
userID := uuid.New()
|
||||
override := database.GetChatUsageLimitUserOverrideRow{
|
||||
UserID: userID,
|
||||
SpendLimitMicros: sql.NullInt64{Int64: 3_000_000, Valid: true},
|
||||
}
|
||||
dbm.EXPECT().GetChatUsageLimitUserOverride(gomock.Any(), userID).Return(override, nil).AnyTimes()
|
||||
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(override)
|
||||
}))
|
||||
s.Run("ListChatUsageLimitGroupOverrides", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
overrides := []database.ListChatUsageLimitGroupOverridesRow{{
|
||||
GroupID: uuid.New(),
|
||||
GroupName: "group-name",
|
||||
GroupDisplayName: "Group Name",
|
||||
GroupAvatarUrl: "https://example.com/group.png",
|
||||
SpendLimitMicros: sql.NullInt64{Int64: 4_000_000, Valid: true},
|
||||
MemberCount: 5,
|
||||
}}
|
||||
dbm.EXPECT().ListChatUsageLimitGroupOverrides(gomock.Any()).Return(overrides, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(overrides)
|
||||
}))
|
||||
s.Run("ListChatUsageLimitOverrides", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
overrides := []database.ListChatUsageLimitOverridesRow{{
|
||||
UserID: uuid.New(),
|
||||
Username: "usage-limit-user",
|
||||
Name: "Usage Limit User",
|
||||
AvatarURL: "https://example.com/avatar.png",
|
||||
SpendLimitMicros: sql.NullInt64{Int64: 5_000_000, Valid: true},
|
||||
}}
|
||||
dbm.EXPECT().ListChatUsageLimitOverrides(gomock.Any()).Return(overrides, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(overrides)
|
||||
}))
|
||||
s.Run("UpsertChatUsageLimitConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
now := dbtime.Now()
|
||||
arg := database.UpsertChatUsageLimitConfigParams{
|
||||
Enabled: true,
|
||||
DefaultLimitMicros: 6_000_000,
|
||||
Period: "monthly",
|
||||
}
|
||||
config := database.ChatUsageLimitConfig{
|
||||
ID: 1,
|
||||
Singleton: true,
|
||||
Enabled: arg.Enabled,
|
||||
DefaultLimitMicros: arg.DefaultLimitMicros,
|
||||
Period: arg.Period,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
dbm.EXPECT().UpsertChatUsageLimitConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
|
||||
}))
|
||||
s.Run("UpsertChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.UpsertChatUsageLimitGroupOverrideParams{
|
||||
SpendLimitMicros: 7_000_000,
|
||||
GroupID: uuid.New(),
|
||||
}
|
||||
override := database.UpsertChatUsageLimitGroupOverrideRow{
|
||||
GroupID: arg.GroupID,
|
||||
Name: "group",
|
||||
DisplayName: "Group",
|
||||
AvatarURL: "",
|
||||
SpendLimitMicros: sql.NullInt64{Int64: arg.SpendLimitMicros, Valid: true},
|
||||
}
|
||||
dbm.EXPECT().UpsertChatUsageLimitGroupOverride(gomock.Any(), arg).Return(override, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(override)
|
||||
}))
|
||||
s.Run("UpsertChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.UpsertChatUsageLimitUserOverrideParams{
|
||||
SpendLimitMicros: 8_000_000,
|
||||
UserID: uuid.New(),
|
||||
}
|
||||
override := database.UpsertChatUsageLimitUserOverrideRow{
|
||||
UserID: arg.UserID,
|
||||
Username: "user",
|
||||
Name: "User",
|
||||
AvatarURL: "",
|
||||
SpendLimitMicros: sql.NullInt64{Int64: arg.SpendLimitMicros, Valid: true},
|
||||
}
|
||||
dbm.EXPECT().UpsertChatUsageLimitUserOverride(gomock.Any(), arg).Return(override, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(override)
|
||||
}))
|
||||
s.Run("DeleteChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
groupID := uuid.New()
|
||||
dbm.EXPECT().DeleteChatUsageLimitGroupOverride(gomock.Any(), groupID).Return(nil).AnyTimes()
|
||||
check.Args(groupID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("DeleteChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
userID := uuid.New()
|
||||
dbm.EXPECT().DeleteChatUsageLimitUserOverride(gomock.Any(), userID).Return(nil).AnyTimes()
|
||||
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestFile() {
|
||||
@@ -1155,6 +1316,14 @@ func (s *MethodTestSuite) TestProvisionerJob() {
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestLicense() {
|
||||
s.Run("GetActiveAISeatCount", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetActiveAISeatCount(gomock.Any()).Return(int64(100), nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceLicense, policy.ActionRead).Returns(int64(100))
|
||||
}))
|
||||
s.Run("UpsertAISeatState", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertAISeatState(gomock.Any(), gomock.Any()).Return(true, nil).AnyTimes()
|
||||
check.Args(database.UpsertAISeatStateParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate)
|
||||
}))
|
||||
s.Run("GetLicenses", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
a := database.License{ID: 1}
|
||||
b := database.License{ID: 2}
|
||||
@@ -1324,7 +1493,7 @@ func (s *MethodTestSuite) TestOrganization() {
|
||||
org := testutil.Fake(s.T(), faker, database.Organization{})
|
||||
arg := database.UpdateOrganizationWorkspaceSharingSettingsParams{
|
||||
ID: org.ID,
|
||||
WorkspaceSharingDisabled: true,
|
||||
ShareableWorkspaceOwners: database.ShareableWorkspaceOwnersNone,
|
||||
}
|
||||
dbm.EXPECT().GetOrganizationByID(gomock.Any(), org.ID).Return(org, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateOrganizationWorkspaceSharingSettings(gomock.Any(), arg).Return(org, nil).AnyTimes()
|
||||
@@ -1755,6 +1924,26 @@ func (s *MethodTestSuite) TestTemplate() {
|
||||
dbm.EXPECT().GetTemplateInsightsByTemplate(gomock.Any(), arg).Return([]database.GetTemplateInsightsByTemplateRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceTemplate, policy.ActionViewInsights)
|
||||
}))
|
||||
s.Run("GetPRInsightsSummary", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsSummaryParams{}
|
||||
dbm.EXPECT().GetPRInsightsSummary(gomock.Any(), arg).Return(database.GetPRInsightsSummaryRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetPRInsightsTimeSeries", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsTimeSeriesParams{}
|
||||
dbm.EXPECT().GetPRInsightsTimeSeries(gomock.Any(), arg).Return([]database.GetPRInsightsTimeSeriesRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetPRInsightsPerModel", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsPerModelParams{}
|
||||
dbm.EXPECT().GetPRInsightsPerModel(gomock.Any(), arg).Return([]database.GetPRInsightsPerModelRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetPRInsightsRecentPRs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsRecentPRsParams{}
|
||||
dbm.EXPECT().GetPRInsightsRecentPRs(gomock.Any(), arg).Return([]database.GetPRInsightsRecentPRsRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetTelemetryTaskEventsParams{}
|
||||
dbm.EXPECT().GetTelemetryTaskEvents(gomock.Any(), arg).Return([]database.GetTelemetryTaskEventsRow{}, nil).AnyTimes()
|
||||
@@ -2243,9 +2432,12 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
check.Args(w.ID).Asserts(w, policy.ActionShare)
|
||||
}))
|
||||
s.Run("DeleteWorkspaceACLsByOrganization", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
orgID := uuid.New()
|
||||
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), orgID).Return(nil).AnyTimes()
|
||||
check.Args(orgID).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
|
||||
arg := database.DeleteWorkspaceACLsByOrganizationParams{
|
||||
OrganizationID: uuid.New(),
|
||||
ExcludeServiceAccounts: false,
|
||||
}
|
||||
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetLatestWorkspaceBuildByWorkspaceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
@@ -4951,6 +5143,12 @@ func (s *MethodTestSuite) TestUsageEvents() {
|
||||
check.Args(params).Asserts(rbac.ResourceUsageEvent, policy.ActionCreate)
|
||||
}))
|
||||
|
||||
s.Run("UsageEventExistsByID", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
id := uuid.NewString()
|
||||
db.EXPECT().UsageEventExistsByID(gomock.Any(), id).Return(true, nil)
|
||||
check.Args(id).Asserts(rbac.ResourceUsageEvent, policy.ActionRead)
|
||||
}))
|
||||
|
||||
s.Run("SelectUsageEventsForPublishing", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
now := dbtime.Now()
|
||||
db.EXPECT().SelectUsageEventsForPublishing(gomock.Any(), now).Return([]database.UsageEvent{}, nil)
|
||||
@@ -5011,6 +5209,17 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
check.Args(params).Asserts(intc, policy.ActionCreate)
|
||||
}))
|
||||
|
||||
s.Run("InsertAIBridgeModelThought", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
intID := uuid.UUID{2}
|
||||
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
|
||||
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intID).Return(intc, nil).AnyTimes() // Validation.
|
||||
|
||||
params := database.InsertAIBridgeModelThoughtParams{InterceptionID: intc.ID}
|
||||
expected := testutil.Fake(s.T(), faker, database.AIBridgeModelThought{InterceptionID: intc.ID})
|
||||
db.EXPECT().InsertAIBridgeModelThought(gomock.Any(), params).Return(expected, nil).AnyTimes()
|
||||
check.Args(params).Asserts(intc, policy.ActionUpdate)
|
||||
}))
|
||||
|
||||
s.Run("InsertAIBridgeTokenUsage", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
intID := uuid.UUID{2}
|
||||
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/rbac/regosql"
|
||||
"github.com/coder/coder/v2/coderd/rbac/rolestore"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
)
|
||||
|
||||
@@ -143,7 +144,7 @@ func (s *MethodTestSuite) Mocked(testCaseF func(dmb *dbmock.MockStore, faker *go
|
||||
UUID: pair.OrganizationID,
|
||||
Valid: pair.OrganizationID != uuid.Nil,
|
||||
},
|
||||
IsSystem: rbac.SystemRoleName(pair.Name),
|
||||
IsSystem: rolestore.IsSystemRoleName(pair.Name),
|
||||
ID: uuid.New(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -650,34 +650,26 @@ func Organization(t testing.TB, db database.Store, orig database.Organization) d
|
||||
})
|
||||
require.NoError(t, err, "insert organization")
|
||||
|
||||
// Populate the placeholder organization-member system role (created by
|
||||
// DB trigger/migration) so org members have expected permissions.
|
||||
//nolint:gocritic // ReconcileOrgMemberRole needs the system:update
|
||||
// Populate the placeholder system roles (created by DB
|
||||
// trigger/migration) so org members have expected permissions.
|
||||
//nolint:gocritic // ReconcileSystemRole needs the system:update
|
||||
// permission that `genCtx` does not have.
|
||||
sysCtx := dbauthz.AsSystemRestricted(genCtx)
|
||||
_, _, err = rolestore.ReconcileOrgMemberRole(sysCtx, db, database.CustomRole{
|
||||
Name: rbac.RoleOrgMember(),
|
||||
OrganizationID: uuid.NullUUID{
|
||||
UUID: org.ID,
|
||||
Valid: true,
|
||||
},
|
||||
}, org.WorkspaceSharingDisabled)
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// The trigger that creates the placeholder role didn't run (e.g.,
|
||||
// triggers were disabled in the test). Create the role manually.
|
||||
err = rolestore.CreateOrgMemberRole(sysCtx, db, org)
|
||||
require.NoError(t, err, "create organization-member role")
|
||||
|
||||
_, _, err = rolestore.ReconcileOrgMemberRole(sysCtx, db, database.CustomRole{
|
||||
Name: rbac.RoleOrgMember(),
|
||||
OrganizationID: uuid.NullUUID{
|
||||
UUID: org.ID,
|
||||
Valid: true,
|
||||
},
|
||||
}, org.WorkspaceSharingDisabled)
|
||||
for roleName := range rolestore.SystemRoleNames {
|
||||
role := database.CustomRole{
|
||||
Name: roleName,
|
||||
OrganizationID: uuid.NullUUID{UUID: org.ID, Valid: true},
|
||||
}
|
||||
_, _, err = rolestore.ReconcileSystemRole(sysCtx, db, role, org)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// The trigger that creates the placeholder role didn't run (e.g.,
|
||||
// triggers were disabled in the test). Create the role manually.
|
||||
err = rolestore.CreateSystemRole(sysCtx, db, org, roleName)
|
||||
require.NoError(t, err, "create role "+roleName)
|
||||
_, _, err = rolestore.ReconcileSystemRole(sysCtx, db, role, org)
|
||||
}
|
||||
require.NoError(t, err, "reconcile role "+roleName)
|
||||
}
|
||||
require.NoError(t, err, "reconcile organization-member role")
|
||||
|
||||
return org
|
||||
}
|
||||
|
||||
@@ -288,6 +288,14 @@ func (m queryMetricsStore) CountConnectionLogs(ctx context.Context, arg database
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountEnabledModelsWithoutPricing(ctx)
|
||||
m.queryLatencies.WithLabelValues("CountEnabledModelsWithoutPricing").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountEnabledModelsWithoutPricing").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountInProgressPrebuilds(ctx)
|
||||
@@ -408,6 +416,22 @@ func (m queryMetricsStore) DeleteChatQueuedMessage(ctx context.Context, arg data
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatUsageLimitGroupOverride(ctx, groupID)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatUsageLimitGroupOverride").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatUsageLimitUserOverride(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatUsageLimitUserOverride").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteCryptoKey(ctx, arg)
|
||||
@@ -672,10 +696,11 @@ func (m queryMetricsStore) DeleteWorkspaceACLByID(ctx context.Context, id uuid.U
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
func (m queryMetricsStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, arg database.DeleteWorkspaceACLsByOrganizationParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteWorkspaceACLsByOrganization(ctx, organizationID)
|
||||
r0 := m.s.DeleteWorkspaceACLsByOrganization(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteWorkspaceACLsByOrganization").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteWorkspaceACLsByOrganization").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -871,6 +896,14 @@ func (m queryMetricsStore) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetActiveAISeatCount(ctx context.Context) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetActiveAISeatCount(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetActiveAISeatCount").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetActiveAISeatCount").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetActivePresetPrebuildSchedules(ctx)
|
||||
@@ -1015,6 +1048,14 @@ func (m queryMetricsStore) GetChatCostSummary(ctx context.Context, arg database.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDesktopEnabled(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatDesktopEnabled").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDesktopEnabled").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDiffStatusByChatID(ctx, chatID)
|
||||
@@ -1063,6 +1104,14 @@ func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesByChatIDDescPaginated(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessagesByChatIDDescPaginated").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesByChatIDDescPaginated").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesForPromptByChatID(ctx, chatID)
|
||||
@@ -1127,11 +1176,35 @@ func (m queryMetricsStore) GetChatSystemPrompt(ctx context.Context) (string, err
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
|
||||
func (m queryMetricsStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatsByOwnerID(ctx, ownerID)
|
||||
m.queryLatencies.WithLabelValues("GetChatsByOwnerID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsByOwnerID").Inc()
|
||||
r0, r1 := m.s.GetChatUsageLimitConfig(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatUsageLimitConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatUsageLimitGroupOverride(ctx, groupID)
|
||||
m.queryLatencies.WithLabelValues("GetChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitGroupOverride").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatUsageLimitUserOverride(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("GetChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitUserOverride").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChats(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChats").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -1671,6 +1744,38 @@ func (m queryMetricsStore) GetOrganizationsWithPrebuildStatus(ctx context.Contex
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRInsightsPerModelParams) ([]database.GetPRInsightsPerModelRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetPRInsightsPerModel(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsPerModel").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsPerModel").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetPRInsightsRecentPRs(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsRecentPRs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsRecentPRs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetPRInsightsSummary(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsSummary").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsSummary").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetPRInsightsTimeSeries(ctx context.Context, arg database.GetPRInsightsTimeSeriesParams) ([]database.GetPRInsightsTimeSeriesRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetPRInsightsTimeSeries(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsTimeSeries").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsTimeSeries").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetParameterSchemasByJobID(ctx, jobID)
|
||||
@@ -2255,6 +2360,14 @@ func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID u
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatSpendInPeriod(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetUserChatSpendInPeriod").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatSpendInPeriod").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserCount(ctx, includeSystem)
|
||||
@@ -2263,6 +2376,14 @@ func (m queryMetricsStore) GetUserCount(ctx context.Context, includeSystem bool)
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserGroupSpendLimit(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("GetUserGroupSpendLimit").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserGroupSpendLimit").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserLatencyInsights(ctx, arg)
|
||||
@@ -2871,6 +2992,14 @@ func (m queryMetricsStore) InsertAIBridgeInterception(ctx context.Context, arg d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertAIBridgeModelThought(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertAIBridgeModelThought").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertAIBridgeModelThought").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertAIBridgeTokenUsage(ctx, arg)
|
||||
@@ -3495,6 +3624,22 @@ func (m queryMetricsStore) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListChatUsageLimitGroupOverrides(ctx)
|
||||
m.queryLatencies.WithLabelValues("ListChatUsageLimitGroupOverrides").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatUsageLimitGroupOverrides").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListChatUsageLimitOverrides(ctx)
|
||||
m.queryLatencies.WithLabelValues("ListChatUsageLimitOverrides").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatUsageLimitOverrides").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListProvisionerKeysByOrganization(ctx, organizationID)
|
||||
@@ -3607,6 +3752,14 @@ func (m queryMetricsStore) RemoveUserFromGroups(ctx context.Context, arg databas
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ResolveUserChatSpendLimit(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("ResolveUserChatSpendLimit").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ResolveUserChatSpendLimit").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest)
|
||||
@@ -3859,6 +4012,7 @@ func (m queryMetricsStore) UpdateOrganizationWorkspaceSharingSettings(ctx contex
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateOrganizationWorkspaceSharingSettings(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateOrganizationWorkspaceSharingSettings").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateOrganizationWorkspaceSharingSettings").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -4406,6 +4560,14 @@ func (m queryMetricsStore) UpdateWorkspacesTTLByTemplateID(ctx context.Context,
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertAISeatState(ctx context.Context, arg database.UpsertAISeatStateParams) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertAISeatState(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertAISeatState").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertAISeatState").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertAnnouncementBanners(ctx context.Context, value string) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertAnnouncementBanners(ctx, value)
|
||||
@@ -4430,6 +4592,14 @@ func (m queryMetricsStore) UpsertBoundaryUsageStats(ctx context.Context, arg dat
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatDesktopEnabled(ctx, enableDesktop)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatDesktopEnabled").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDesktopEnabled").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatDiffStatus(ctx, arg)
|
||||
@@ -4454,6 +4624,30 @@ func (m queryMetricsStore) UpsertChatSystemPrompt(ctx context.Context, value str
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatUsageLimitConfig(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatUsageLimitGroupOverride(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitGroupOverride").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatUsageLimitUserOverride(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitUserOverride").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertConnectionLog(ctx, arg)
|
||||
@@ -4630,6 +4824,14 @@ func (m queryMetricsStore) UpsertWorkspaceAppAuditSession(ctx context.Context, a
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UsageEventExistsByID(ctx context.Context, id string) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UsageEventExistsByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("UsageEventExistsByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UsageEventExistsByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ValidateGroupIDs(ctx context.Context, groupIds []uuid.UUID) (database.ValidateGroupIDsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ValidateGroupIDs(ctx, groupIds)
|
||||
@@ -4749,3 +4951,11 @@ func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeModels").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAuthorizedChats(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("GetAuthorizedChats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAuthorizedChats").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -424,6 +424,21 @@ func (mr *MockStoreMockRecorder) CountConnectionLogs(ctx, arg any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountConnectionLogs", reflect.TypeOf((*MockStore)(nil).CountConnectionLogs), ctx, arg)
|
||||
}
|
||||
|
||||
// CountEnabledModelsWithoutPricing mocks base method.
|
||||
func (m *MockStore) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountEnabledModelsWithoutPricing", ctx)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountEnabledModelsWithoutPricing indicates an expected call of CountEnabledModelsWithoutPricing.
|
||||
func (mr *MockStoreMockRecorder) CountEnabledModelsWithoutPricing(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEnabledModelsWithoutPricing", reflect.TypeOf((*MockStore)(nil).CountEnabledModelsWithoutPricing), ctx)
|
||||
}
|
||||
|
||||
// CountInProgressPrebuilds mocks base method.
|
||||
func (m *MockStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -639,6 +654,34 @@ func (mr *MockStoreMockRecorder) DeleteChatQueuedMessage(ctx, arg any) *gomock.C
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatQueuedMessage", reflect.TypeOf((*MockStore)(nil).DeleteChatQueuedMessage), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteChatUsageLimitGroupOverride mocks base method.
|
||||
func (m *MockStore) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatUsageLimitGroupOverride", ctx, groupID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatUsageLimitGroupOverride indicates an expected call of DeleteChatUsageLimitGroupOverride.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatUsageLimitGroupOverride(ctx, groupID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).DeleteChatUsageLimitGroupOverride), ctx, groupID)
|
||||
}
|
||||
|
||||
// DeleteChatUsageLimitUserOverride mocks base method.
|
||||
func (m *MockStore) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatUsageLimitUserOverride", ctx, userID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatUsageLimitUserOverride indicates an expected call of DeleteChatUsageLimitUserOverride.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatUsageLimitUserOverride(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).DeleteChatUsageLimitUserOverride), ctx, userID)
|
||||
}
|
||||
|
||||
// DeleteCryptoKey mocks base method.
|
||||
func (m *MockStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1112,17 +1155,17 @@ func (mr *MockStoreMockRecorder) DeleteWorkspaceACLByID(ctx, id any) *gomock.Cal
|
||||
}
|
||||
|
||||
// DeleteWorkspaceACLsByOrganization mocks base method.
|
||||
func (m *MockStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
func (m *MockStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, arg database.DeleteWorkspaceACLsByOrganizationParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteWorkspaceACLsByOrganization", ctx, organizationID)
|
||||
ret := m.ctrl.Call(m, "DeleteWorkspaceACLsByOrganization", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteWorkspaceACLsByOrganization indicates an expected call of DeleteWorkspaceACLsByOrganization.
|
||||
func (mr *MockStoreMockRecorder) DeleteWorkspaceACLsByOrganization(ctx, organizationID any) *gomock.Call {
|
||||
func (mr *MockStoreMockRecorder) DeleteWorkspaceACLsByOrganization(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLsByOrganization", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLsByOrganization), ctx, organizationID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLsByOrganization", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLsByOrganization), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteWorkspaceAgentPortShare mocks base method.
|
||||
@@ -1478,6 +1521,21 @@ func (mr *MockStoreMockRecorder) GetAPIKeysLastUsedAfter(ctx, lastUsed any) *gom
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeysLastUsedAfter", reflect.TypeOf((*MockStore)(nil).GetAPIKeysLastUsedAfter), ctx, lastUsed)
|
||||
}
|
||||
|
||||
// GetActiveAISeatCount mocks base method.
|
||||
func (m *MockStore) GetActiveAISeatCount(ctx context.Context) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveAISeatCount", ctx)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveAISeatCount indicates an expected call of GetActiveAISeatCount.
|
||||
func (mr *MockStoreMockRecorder) GetActiveAISeatCount(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveAISeatCount", reflect.TypeOf((*MockStore)(nil).GetActiveAISeatCount), ctx)
|
||||
}
|
||||
|
||||
// GetActivePresetPrebuildSchedules mocks base method.
|
||||
func (m *MockStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1673,6 +1731,21 @@ func (mr *MockStoreMockRecorder) GetAuthorizedAuditLogsOffset(ctx, arg, prepared
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedAuditLogsOffset", reflect.TypeOf((*MockStore)(nil).GetAuthorizedAuditLogsOffset), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// GetAuthorizedChats mocks base method.
|
||||
func (m *MockStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAuthorizedChats", ctx, arg, prepared)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAuthorizedChats indicates an expected call of GetAuthorizedChats.
|
||||
func (mr *MockStoreMockRecorder) GetAuthorizedChats(ctx, arg, prepared any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedChats", reflect.TypeOf((*MockStore)(nil).GetAuthorizedChats), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// GetAuthorizedConnectionLogsOffset mocks base method.
|
||||
func (m *MockStore) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1838,6 +1911,21 @@ func (mr *MockStoreMockRecorder) GetChatCostSummary(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostSummary", reflect.TypeOf((*MockStore)(nil).GetChatCostSummary), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatDesktopEnabled mocks base method.
|
||||
func (m *MockStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatDesktopEnabled", ctx)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatDesktopEnabled indicates an expected call of GetChatDesktopEnabled.
|
||||
func (mr *MockStoreMockRecorder) GetChatDesktopEnabled(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDesktopEnabled", reflect.TypeOf((*MockStore)(nil).GetChatDesktopEnabled), ctx)
|
||||
}
|
||||
|
||||
// GetChatDiffStatusByChatID mocks base method.
|
||||
func (m *MockStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1928,6 +2016,21 @@ func (mr *MockStoreMockRecorder) GetChatMessagesByChatID(ctx, arg any) *gomock.C
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatID), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatIDDescPaginated mocks base method.
|
||||
func (m *MockStore) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatMessagesByChatIDDescPaginated", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.ChatMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatIDDescPaginated indicates an expected call of GetChatMessagesByChatIDDescPaginated.
|
||||
func (mr *MockStoreMockRecorder) GetChatMessagesByChatIDDescPaginated(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatIDDescPaginated", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatIDDescPaginated), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatMessagesForPromptByChatID mocks base method.
|
||||
func (m *MockStore) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2048,19 +2151,64 @@ func (mr *MockStoreMockRecorder) GetChatSystemPrompt(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).GetChatSystemPrompt), ctx)
|
||||
}
|
||||
|
||||
// GetChatsByOwnerID mocks base method.
|
||||
func (m *MockStore) GetChatsByOwnerID(ctx context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
|
||||
// GetChatUsageLimitConfig mocks base method.
|
||||
func (m *MockStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatsByOwnerID", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "GetChatUsageLimitConfig", ctx)
|
||||
ret0, _ := ret[0].(database.ChatUsageLimitConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatUsageLimitConfig indicates an expected call of GetChatUsageLimitConfig.
|
||||
func (mr *MockStoreMockRecorder) GetChatUsageLimitConfig(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitConfig", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitConfig), ctx)
|
||||
}
|
||||
|
||||
// GetChatUsageLimitGroupOverride mocks base method.
|
||||
func (m *MockStore) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatUsageLimitGroupOverride", ctx, groupID)
|
||||
ret0, _ := ret[0].(database.GetChatUsageLimitGroupOverrideRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatUsageLimitGroupOverride indicates an expected call of GetChatUsageLimitGroupOverride.
|
||||
func (mr *MockStoreMockRecorder) GetChatUsageLimitGroupOverride(ctx, groupID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitGroupOverride), ctx, groupID)
|
||||
}
|
||||
|
||||
// GetChatUsageLimitUserOverride mocks base method.
|
||||
func (m *MockStore) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatUsageLimitUserOverride", ctx, userID)
|
||||
ret0, _ := ret[0].(database.GetChatUsageLimitUserOverrideRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatUsageLimitUserOverride indicates an expected call of GetChatUsageLimitUserOverride.
|
||||
func (mr *MockStoreMockRecorder) GetChatUsageLimitUserOverride(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitUserOverride), ctx, userID)
|
||||
}
|
||||
|
||||
// GetChats mocks base method.
|
||||
func (m *MockStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChats", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatsByOwnerID indicates an expected call of GetChatsByOwnerID.
|
||||
func (mr *MockStoreMockRecorder) GetChatsByOwnerID(ctx, arg any) *gomock.Call {
|
||||
// GetChats indicates an expected call of GetChats.
|
||||
func (mr *MockStoreMockRecorder) GetChats(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByOwnerID", reflect.TypeOf((*MockStore)(nil).GetChatsByOwnerID), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChats", reflect.TypeOf((*MockStore)(nil).GetChats), ctx, arg)
|
||||
}
|
||||
|
||||
// GetConnectionLogsOffset mocks base method.
|
||||
@@ -3068,6 +3216,66 @@ func (mr *MockStoreMockRecorder) GetOrganizationsWithPrebuildStatus(ctx, arg any
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationsWithPrebuildStatus", reflect.TypeOf((*MockStore)(nil).GetOrganizationsWithPrebuildStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsPerModel mocks base method.
|
||||
func (m *MockStore) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRInsightsPerModelParams) ([]database.GetPRInsightsPerModelRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsPerModel", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetPRInsightsPerModelRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPRInsightsPerModel indicates an expected call of GetPRInsightsPerModel.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsPerModel(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsPerModel", reflect.TypeOf((*MockStore)(nil).GetPRInsightsPerModel), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsRecentPRs mocks base method.
|
||||
func (m *MockStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsRecentPRs", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetPRInsightsRecentPRsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPRInsightsRecentPRs indicates an expected call of GetPRInsightsRecentPRs.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsRecentPRs(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsRecentPRs", reflect.TypeOf((*MockStore)(nil).GetPRInsightsRecentPRs), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsSummary mocks base method.
|
||||
func (m *MockStore) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsSummary", ctx, arg)
|
||||
ret0, _ := ret[0].(database.GetPRInsightsSummaryRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPRInsightsSummary indicates an expected call of GetPRInsightsSummary.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsSummary(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsSummary", reflect.TypeOf((*MockStore)(nil).GetPRInsightsSummary), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsTimeSeries mocks base method.
|
||||
func (m *MockStore) GetPRInsightsTimeSeries(ctx context.Context, arg database.GetPRInsightsTimeSeriesParams) ([]database.GetPRInsightsTimeSeriesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsTimeSeries", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetPRInsightsTimeSeriesRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPRInsightsTimeSeries indicates an expected call of GetPRInsightsTimeSeries.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsTimeSeries(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsTimeSeries", reflect.TypeOf((*MockStore)(nil).GetPRInsightsTimeSeries), ctx, arg)
|
||||
}
|
||||
|
||||
// GetParameterSchemasByJobID mocks base method.
|
||||
func (m *MockStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4193,6 +4401,21 @@ func (mr *MockStoreMockRecorder) GetUserChatCustomPrompt(ctx, userID any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).GetUserChatCustomPrompt), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserChatSpendInPeriod mocks base method.
|
||||
func (m *MockStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserChatSpendInPeriod", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserChatSpendInPeriod indicates an expected call of GetUserChatSpendInPeriod.
|
||||
func (mr *MockStoreMockRecorder) GetUserChatSpendInPeriod(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatSpendInPeriod", reflect.TypeOf((*MockStore)(nil).GetUserChatSpendInPeriod), ctx, arg)
|
||||
}
|
||||
|
||||
// GetUserCount mocks base method.
|
||||
func (m *MockStore) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4208,6 +4431,21 @@ func (mr *MockStoreMockRecorder) GetUserCount(ctx, includeSystem any) *gomock.Ca
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserCount", reflect.TypeOf((*MockStore)(nil).GetUserCount), ctx, includeSystem)
|
||||
}
|
||||
|
||||
// GetUserGroupSpendLimit mocks base method.
|
||||
func (m *MockStore) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserGroupSpendLimit", ctx, userID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserGroupSpendLimit indicates an expected call of GetUserGroupSpendLimit.
|
||||
func (mr *MockStoreMockRecorder) GetUserGroupSpendLimit(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserGroupSpendLimit", reflect.TypeOf((*MockStore)(nil).GetUserGroupSpendLimit), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserLatencyInsights mocks base method.
|
||||
func (m *MockStore) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -5362,6 +5600,21 @@ func (mr *MockStoreMockRecorder) InsertAIBridgeInterception(ctx, arg any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeInterception", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeInterception), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertAIBridgeModelThought mocks base method.
|
||||
func (m *MockStore) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertAIBridgeModelThought", ctx, arg)
|
||||
ret0, _ := ret[0].(database.AIBridgeModelThought)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertAIBridgeModelThought indicates an expected call of InsertAIBridgeModelThought.
|
||||
func (mr *MockStoreMockRecorder) InsertAIBridgeModelThought(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeModelThought", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeModelThought), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertAIBridgeTokenUsage mocks base method.
|
||||
func (m *MockStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6547,6 +6800,36 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// ListChatUsageLimitGroupOverrides mocks base method.
|
||||
func (m *MockStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListChatUsageLimitGroupOverrides", ctx)
|
||||
ret0, _ := ret[0].([]database.ListChatUsageLimitGroupOverridesRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListChatUsageLimitGroupOverrides indicates an expected call of ListChatUsageLimitGroupOverrides.
|
||||
func (mr *MockStoreMockRecorder) ListChatUsageLimitGroupOverrides(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChatUsageLimitGroupOverrides", reflect.TypeOf((*MockStore)(nil).ListChatUsageLimitGroupOverrides), ctx)
|
||||
}
|
||||
|
||||
// ListChatUsageLimitOverrides mocks base method.
|
||||
func (m *MockStore) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListChatUsageLimitOverrides", ctx)
|
||||
ret0, _ := ret[0].([]database.ListChatUsageLimitOverridesRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListChatUsageLimitOverrides indicates an expected call of ListChatUsageLimitOverrides.
|
||||
func (mr *MockStoreMockRecorder) ListChatUsageLimitOverrides(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChatUsageLimitOverrides", reflect.TypeOf((*MockStore)(nil).ListChatUsageLimitOverrides), ctx)
|
||||
}
|
||||
|
||||
// ListProvisionerKeysByOrganization mocks base method.
|
||||
func (m *MockStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6785,6 +7068,21 @@ func (mr *MockStoreMockRecorder) RemoveUserFromGroups(ctx, arg any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromGroups), ctx, arg)
|
||||
}
|
||||
|
||||
// ResolveUserChatSpendLimit mocks base method.
|
||||
func (m *MockStore) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ResolveUserChatSpendLimit", ctx, userID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ResolveUserChatSpendLimit indicates an expected call of ResolveUserChatSpendLimit.
|
||||
func (mr *MockStoreMockRecorder) ResolveUserChatSpendLimit(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolveUserChatSpendLimit", reflect.TypeOf((*MockStore)(nil).ResolveUserChatSpendLimit), ctx, userID)
|
||||
}
|
||||
|
||||
// RevokeDBCryptKey mocks base method.
|
||||
func (m *MockStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8229,6 +8527,21 @@ func (mr *MockStoreMockRecorder) UpdateWorkspacesTTLByTemplateID(ctx, arg any) *
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkspacesTTLByTemplateID", reflect.TypeOf((*MockStore)(nil).UpdateWorkspacesTTLByTemplateID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertAISeatState mocks base method.
|
||||
func (m *MockStore) UpsertAISeatState(ctx context.Context, arg database.UpsertAISeatStateParams) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertAISeatState", ctx, arg)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertAISeatState indicates an expected call of UpsertAISeatState.
|
||||
func (mr *MockStoreMockRecorder) UpsertAISeatState(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertAISeatState", reflect.TypeOf((*MockStore)(nil).UpsertAISeatState), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertAnnouncementBanners mocks base method.
|
||||
func (m *MockStore) UpsertAnnouncementBanners(ctx context.Context, value string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8272,6 +8585,20 @@ func (mr *MockStoreMockRecorder) UpsertBoundaryUsageStats(ctx, arg any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertBoundaryUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertBoundaryUsageStats), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatDesktopEnabled mocks base method.
|
||||
func (m *MockStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatDesktopEnabled", ctx, enableDesktop)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertChatDesktopEnabled indicates an expected call of UpsertChatDesktopEnabled.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatDesktopEnabled(ctx, enableDesktop any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDesktopEnabled", reflect.TypeOf((*MockStore)(nil).UpsertChatDesktopEnabled), ctx, enableDesktop)
|
||||
}
|
||||
|
||||
// UpsertChatDiffStatus mocks base method.
|
||||
func (m *MockStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8316,6 +8643,51 @@ func (mr *MockStoreMockRecorder) UpsertChatSystemPrompt(ctx, value any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatSystemPrompt), ctx, value)
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitConfig mocks base method.
|
||||
func (m *MockStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatUsageLimitConfig", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatUsageLimitConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitConfig indicates an expected call of UpsertChatUsageLimitConfig.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitConfig(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitConfig", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitConfig), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitGroupOverride mocks base method.
|
||||
func (m *MockStore) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatUsageLimitGroupOverride", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UpsertChatUsageLimitGroupOverrideRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitGroupOverride indicates an expected call of UpsertChatUsageLimitGroupOverride.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitGroupOverride(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitGroupOverride), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitUserOverride mocks base method.
|
||||
func (m *MockStore) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatUsageLimitUserOverride", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UpsertChatUsageLimitUserOverrideRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitUserOverride indicates an expected call of UpsertChatUsageLimitUserOverride.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitUserOverride(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitUserOverride), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertConnectionLog mocks base method.
|
||||
func (m *MockStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8633,6 +9005,21 @@ func (mr *MockStoreMockRecorder) UpsertWorkspaceAppAuditSession(ctx, arg any) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertWorkspaceAppAuditSession", reflect.TypeOf((*MockStore)(nil).UpsertWorkspaceAppAuditSession), ctx, arg)
|
||||
}
|
||||
|
||||
// UsageEventExistsByID mocks base method.
|
||||
func (m *MockStore) UsageEventExistsByID(ctx context.Context, id string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UsageEventExistsByID", ctx, id)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UsageEventExistsByID indicates an expected call of UsageEventExistsByID.
|
||||
func (mr *MockStoreMockRecorder) UsageEventExistsByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UsageEventExistsByID", reflect.TypeOf((*MockStore)(nil).UsageEventExistsByID), ctx, id)
|
||||
}
|
||||
|
||||
// ValidateGroupIDs mocks base method.
|
||||
func (m *MockStore) ValidateGroupIDs(ctx context.Context, groupIds []uuid.UUID) (database.ValidateGroupIDsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+112
-13
@@ -10,6 +10,11 @@ CREATE TYPE agent_key_scope_enum AS ENUM (
|
||||
'no_user_data'
|
||||
);
|
||||
|
||||
CREATE TYPE ai_seat_usage_reason AS ENUM (
|
||||
'aibridge',
|
||||
'task'
|
||||
);
|
||||
|
||||
CREATE TYPE api_key_scope AS ENUM (
|
||||
'coder:all',
|
||||
'coder:application_connect',
|
||||
@@ -503,7 +508,14 @@ CREATE TYPE resource_type AS ENUM (
|
||||
'workspace_agent',
|
||||
'workspace_app',
|
||||
'prebuilds_settings',
|
||||
'task'
|
||||
'task',
|
||||
'ai_seat'
|
||||
);
|
||||
|
||||
CREATE TYPE shareable_workspace_owners AS ENUM (
|
||||
'none',
|
||||
'everyone',
|
||||
'service_accounts'
|
||||
);
|
||||
|
||||
CREATE TYPE startup_script_behavior AS ENUM (
|
||||
@@ -608,28 +620,35 @@ CREATE FUNCTION aggregate_usage_event() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
-- Check for supported event types and throw error for unknown types
|
||||
IF NEW.event_type NOT IN ('dc_managed_agents_v1') THEN
|
||||
-- Check for supported event types and throw error for unknown types.
|
||||
IF NEW.event_type NOT IN ('dc_managed_agents_v1', 'hb_ai_seats_v1') THEN
|
||||
RAISE EXCEPTION 'Unhandled usage event type in aggregate_usage_event: %', NEW.event_type;
|
||||
END IF;
|
||||
|
||||
INSERT INTO usage_events_daily (day, event_type, usage_data)
|
||||
VALUES (
|
||||
-- Extract the date from the created_at timestamp, always using UTC for
|
||||
-- consistency
|
||||
date_trunc('day', NEW.created_at AT TIME ZONE 'UTC')::date,
|
||||
NEW.event_type,
|
||||
NEW.event_data
|
||||
)
|
||||
ON CONFLICT (day, event_type) DO UPDATE SET
|
||||
usage_data = CASE
|
||||
-- Handle simple counter events by summing the count
|
||||
-- Handle simple counter events by summing the count.
|
||||
WHEN NEW.event_type IN ('dc_managed_agents_v1') THEN
|
||||
jsonb_build_object(
|
||||
'count',
|
||||
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0) +
|
||||
COALESCE((NEW.event_data->>'count')::bigint, 0)
|
||||
)
|
||||
-- Heartbeat events: keep the max value seen that day
|
||||
WHEN NEW.event_type IN ('hb_ai_seats_v1') THEN
|
||||
jsonb_build_object(
|
||||
'count',
|
||||
GREATEST(
|
||||
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0),
|
||||
COALESCE((NEW.event_data->>'count')::bigint, 0)
|
||||
)
|
||||
)
|
||||
END;
|
||||
|
||||
RETURN NEW;
|
||||
@@ -786,7 +805,7 @@ BEGIN
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION insert_org_member_system_role() RETURNS trigger
|
||||
CREATE FUNCTION insert_organization_system_roles() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
@@ -801,7 +820,8 @@ BEGIN
|
||||
is_system,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
) VALUES
|
||||
(
|
||||
'organization-member',
|
||||
'',
|
||||
NEW.id,
|
||||
@@ -812,6 +832,18 @@ BEGIN
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
),
|
||||
(
|
||||
'organization-service-account',
|
||||
'',
|
||||
NEW.id,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
);
|
||||
RETURN NEW;
|
||||
END;
|
||||
@@ -1046,6 +1078,15 @@ BEGIN
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE TABLE ai_seat_state (
|
||||
user_id uuid NOT NULL,
|
||||
first_used_at timestamp with time zone NOT NULL,
|
||||
last_used_at timestamp with time zone NOT NULL,
|
||||
last_event_type ai_seat_usage_reason NOT NULL,
|
||||
last_event_description text NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE aibridge_interceptions (
|
||||
id uuid NOT NULL,
|
||||
initiator_id uuid NOT NULL,
|
||||
@@ -1071,6 +1112,15 @@ COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interceptio
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID supplied by the client (optional and not universally supported).';
|
||||
|
||||
CREATE TABLE aibridge_model_thoughts (
|
||||
interception_id uuid NOT NULL,
|
||||
content text NOT NULL,
|
||||
metadata jsonb,
|
||||
created_at timestamp with time zone NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_model_thoughts IS 'Audit log of model thinking in intercepted requests in AI Bridge';
|
||||
|
||||
CREATE TABLE aibridge_token_usages (
|
||||
id uuid NOT NULL,
|
||||
interception_id uuid NOT NULL,
|
||||
@@ -1239,7 +1289,8 @@ CREATE TABLE chat_messages (
|
||||
compressed boolean DEFAULT false NOT NULL,
|
||||
created_by uuid,
|
||||
content_version smallint NOT NULL,
|
||||
total_cost_micros bigint
|
||||
total_cost_micros bigint,
|
||||
runtime_ms bigint
|
||||
);
|
||||
|
||||
CREATE SEQUENCE chat_messages_id_seq
|
||||
@@ -1303,6 +1354,28 @@ CREATE SEQUENCE chat_queued_messages_id_seq
|
||||
|
||||
ALTER SEQUENCE chat_queued_messages_id_seq OWNED BY chat_queued_messages.id;
|
||||
|
||||
CREATE TABLE chat_usage_limit_config (
|
||||
id bigint NOT NULL,
|
||||
singleton boolean DEFAULT true NOT NULL,
|
||||
enabled boolean DEFAULT false NOT NULL,
|
||||
default_limit_micros bigint DEFAULT 0 NOT NULL,
|
||||
period text DEFAULT 'month'::text NOT NULL,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
CONSTRAINT chat_usage_limit_config_default_limit_micros_check CHECK ((default_limit_micros >= 0)),
|
||||
CONSTRAINT chat_usage_limit_config_period_check CHECK ((period = ANY (ARRAY['day'::text, 'week'::text, 'month'::text]))),
|
||||
CONSTRAINT chat_usage_limit_config_singleton_check CHECK (singleton)
|
||||
);
|
||||
|
||||
CREATE SEQUENCE chat_usage_limit_config_id_seq
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1;
|
||||
|
||||
ALTER SEQUENCE chat_usage_limit_config_id_seq OWNED BY chat_usage_limit_config.id;
|
||||
|
||||
CREATE TABLE chats (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
owner_id uuid NOT NULL,
|
||||
@@ -1459,7 +1532,9 @@ CREATE TABLE groups (
|
||||
avatar_url text DEFAULT ''::text NOT NULL,
|
||||
quota_allowance integer DEFAULT 0 NOT NULL,
|
||||
display_name text DEFAULT ''::text NOT NULL,
|
||||
source group_source DEFAULT 'user'::group_source NOT NULL
|
||||
source group_source DEFAULT 'user'::group_source NOT NULL,
|
||||
chat_spend_limit_micros bigint,
|
||||
CONSTRAINT groups_chat_spend_limit_micros_check CHECK (((chat_spend_limit_micros IS NULL) OR (chat_spend_limit_micros > 0)))
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN groups.display_name IS 'Display name is a custom, human-friendly group name that user can set. This is not required to be unique and can be the empty string.';
|
||||
@@ -1494,7 +1569,9 @@ CREATE TABLE users (
|
||||
one_time_passcode_expires_at timestamp with time zone,
|
||||
is_system boolean DEFAULT false NOT NULL,
|
||||
is_service_account boolean DEFAULT false NOT NULL,
|
||||
chat_spend_limit_micros bigint,
|
||||
CONSTRAINT one_time_passcode_set CHECK ((((hashed_one_time_passcode IS NULL) AND (one_time_passcode_expires_at IS NULL)) OR ((hashed_one_time_passcode IS NOT NULL) AND (one_time_passcode_expires_at IS NOT NULL)))),
|
||||
CONSTRAINT users_chat_spend_limit_micros_check CHECK (((chat_spend_limit_micros IS NULL) OR (chat_spend_limit_micros > 0))),
|
||||
CONSTRAINT users_email_not_empty CHECK (((is_service_account = true) = (email = ''::text))),
|
||||
CONSTRAINT users_service_account_login_type CHECK (((is_service_account = false) OR (login_type = 'none'::login_type))),
|
||||
CONSTRAINT users_username_min_length CHECK ((length(username) >= 1))
|
||||
@@ -1782,9 +1859,11 @@ CREATE TABLE organizations (
|
||||
display_name text NOT NULL,
|
||||
icon text DEFAULT ''::text NOT NULL,
|
||||
deleted boolean DEFAULT false NOT NULL,
|
||||
workspace_sharing_disabled boolean DEFAULT false NOT NULL
|
||||
shareable_workspace_owners shareable_workspace_owners DEFAULT 'everyone'::shareable_workspace_owners NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN organizations.shareable_workspace_owners IS 'Controls whose workspaces can be shared: none, everyone, or service_accounts.';
|
||||
|
||||
CREATE TABLE parameter_schemas (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
@@ -2584,7 +2663,7 @@ CREATE TABLE usage_events (
|
||||
publish_started_at timestamp with time zone,
|
||||
published_at timestamp with time zone,
|
||||
failure_message text,
|
||||
CONSTRAINT usage_event_type_check CHECK ((event_type = 'dc_managed_agents_v1'::text))
|
||||
CONSTRAINT usage_event_type_check CHECK ((event_type = ANY (ARRAY['dc_managed_agents_v1'::text, 'hb_ai_seats_v1'::text])))
|
||||
);
|
||||
|
||||
COMMENT ON TABLE usage_events IS 'usage_events contains usage data that is collected from the product and potentially shipped to the usage collector service.';
|
||||
@@ -3141,6 +3220,8 @@ ALTER TABLE ONLY chat_messages ALTER COLUMN id SET DEFAULT nextval('chat_message
|
||||
|
||||
ALTER TABLE ONLY chat_queued_messages ALTER COLUMN id SET DEFAULT nextval('chat_queued_messages_id_seq'::regclass);
|
||||
|
||||
ALTER TABLE ONLY chat_usage_limit_config ALTER COLUMN id SET DEFAULT nextval('chat_usage_limit_config_id_seq'::regclass);
|
||||
|
||||
ALTER TABLE ONLY licenses ALTER COLUMN id SET DEFAULT nextval('licenses_id_seq'::regclass);
|
||||
|
||||
ALTER TABLE ONLY provisioner_job_logs ALTER COLUMN id SET DEFAULT nextval('provisioner_job_logs_id_seq'::regclass);
|
||||
@@ -3156,6 +3237,9 @@ ALTER TABLE ONLY workspace_resource_metadata ALTER COLUMN id SET DEFAULT nextval
|
||||
ALTER TABLE ONLY workspace_agent_stats
|
||||
ADD CONSTRAINT agent_stats_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY ai_seat_state
|
||||
ADD CONSTRAINT ai_seat_state_pkey PRIMARY KEY (user_id);
|
||||
|
||||
ALTER TABLE ONLY aibridge_interceptions
|
||||
ADD CONSTRAINT aibridge_interceptions_pkey PRIMARY KEY (id);
|
||||
|
||||
@@ -3198,6 +3282,12 @@ ALTER TABLE ONLY chat_providers
|
||||
ALTER TABLE ONLY chat_queued_messages
|
||||
ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_usage_limit_config
|
||||
ADD CONSTRAINT chat_usage_limit_config_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_usage_limit_config
|
||||
ADD CONSTRAINT chat_usage_limit_config_singleton_key UNIQUE (singleton);
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_pkey PRIMARY KEY (id);
|
||||
|
||||
@@ -3510,6 +3600,8 @@ CREATE INDEX idx_aibridge_interceptions_thread_parent_id ON aibridge_interceptio
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_thread_root_id ON aibridge_interceptions USING btree (thread_root_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_model_thoughts_interception_id ON aibridge_model_thoughts USING btree (interception_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_token_usages_interception_id ON aibridge_token_usages USING btree (interception_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_token_usages_provider_response_id ON aibridge_token_usages USING btree (provider_response_id);
|
||||
@@ -3550,6 +3642,8 @@ CREATE INDEX idx_chat_messages_compressed_summary_boundary ON chat_messages USIN
|
||||
|
||||
CREATE INDEX idx_chat_messages_created_at ON chat_messages USING btree (created_at);
|
||||
|
||||
CREATE INDEX idx_chat_messages_owner_spend ON chat_messages USING btree (chat_id, created_at) WHERE (total_cost_micros IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_chat_model_configs_enabled ON chat_model_configs USING btree (enabled);
|
||||
|
||||
CREATE INDEX idx_chat_model_configs_provider ON chat_model_configs USING btree (provider);
|
||||
@@ -3624,6 +3718,8 @@ CREATE INDEX idx_template_versions_has_ai_task ON template_versions USING btree
|
||||
|
||||
CREATE UNIQUE INDEX idx_unique_preset_name ON template_version_presets USING btree (name, template_version_id);
|
||||
|
||||
CREATE INDEX idx_usage_events_ai_seats ON usage_events USING btree (event_type, created_at) WHERE (event_type = 'hb_ai_seats_v1'::text);
|
||||
|
||||
CREATE INDEX idx_usage_events_select_for_publishing ON usage_events USING btree (published_at, publish_started_at, created_at);
|
||||
|
||||
CREATE INDEX idx_user_deleted_deleted_at ON user_deleted USING btree (deleted_at);
|
||||
@@ -3798,7 +3894,7 @@ CREATE TRIGGER trigger_delete_oauth2_provider_app_token AFTER DELETE ON oauth2_p
|
||||
|
||||
CREATE TRIGGER trigger_insert_apikeys BEFORE INSERT ON api_keys FOR EACH ROW EXECUTE FUNCTION insert_apikey_fail_if_user_deleted();
|
||||
|
||||
CREATE TRIGGER trigger_insert_org_member_system_role AFTER INSERT ON organizations FOR EACH ROW EXECUTE FUNCTION insert_org_member_system_role();
|
||||
CREATE TRIGGER trigger_insert_organization_system_roles AFTER INSERT ON organizations FOR EACH ROW EXECUTE FUNCTION insert_organization_system_roles();
|
||||
|
||||
CREATE TRIGGER trigger_nullify_next_start_at_on_workspace_autostart_modificati AFTER UPDATE ON workspaces FOR EACH ROW EXECUTE FUNCTION nullify_next_start_at_on_workspace_autostart_modification();
|
||||
|
||||
@@ -3816,6 +3912,9 @@ COMMENT ON TRIGGER workspace_agent_name_unique_trigger ON workspace_agents IS 'U
|
||||
the uniqueness requirement. A trigger allows us to enforce uniqueness going
|
||||
forward without requiring a migration to clean up historical data.';
|
||||
|
||||
ALTER TABLE ONLY ai_seat_state
|
||||
ADD CONSTRAINT ai_seat_state_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY aibridge_interceptions
|
||||
ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ type ForeignKeyConstraint string
|
||||
|
||||
// ForeignKeyConstraint enums.
|
||||
const (
|
||||
ForeignKeyAiSeatStateUserID ForeignKeyConstraint = "ai_seat_state_user_id_fkey" // ALTER TABLE ONLY ai_seat_state ADD CONSTRAINT ai_seat_state_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
|
||||
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
@@ -26,6 +26,7 @@ func TestCustomQueriesSyncedRowScan(t *testing.T) {
|
||||
"GetTemplatesWithFilter": "GetAuthorizedTemplates",
|
||||
"GetWorkspaces": "GetAuthorizedWorkspaces",
|
||||
"GetUsers": "GetAuthorizedUsers",
|
||||
"GetChats": "GetAuthorizedChats",
|
||||
}
|
||||
|
||||
// Scan custom
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
DROP TABLE ai_seat_state;
|
||||
|
||||
DROP TYPE ai_seat_usage_reason;
|
||||
@@ -0,0 +1,13 @@
|
||||
CREATE TYPE ai_seat_usage_reason AS ENUM (
|
||||
'aibridge',
|
||||
'task'
|
||||
);
|
||||
|
||||
CREATE TABLE ai_seat_state (
|
||||
user_id uuid NOT NULL PRIMARY KEY REFERENCES users (id) ON DELETE CASCADE,
|
||||
first_used_at timestamptz NOT NULL,
|
||||
last_used_at timestamptz NOT NULL,
|
||||
last_event_type ai_seat_usage_reason NOT NULL,
|
||||
last_event_description text NOT NULL,
|
||||
updated_at timestamptz NOT NULL
|
||||
);
|
||||
@@ -0,0 +1 @@
|
||||
-- resource_type enum values cannot be removed safely; no-op.
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TYPE resource_type ADD VALUE IF NOT EXISTS 'ai_seat';
|
||||
@@ -0,0 +1,4 @@
|
||||
DROP INDEX IF EXISTS idx_chat_messages_owner_spend;
|
||||
ALTER TABLE groups DROP COLUMN IF EXISTS chat_spend_limit_micros;
|
||||
ALTER TABLE users DROP COLUMN IF EXISTS chat_spend_limit_micros;
|
||||
DROP TABLE IF EXISTS chat_usage_limit_config;
|
||||
@@ -0,0 +1,32 @@
|
||||
-- 1. Singleton config table
|
||||
CREATE TABLE chat_usage_limit_config (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
-- Only one row allowed (enforced by CHECK).
|
||||
singleton BOOLEAN NOT NULL DEFAULT TRUE CHECK (singleton),
|
||||
UNIQUE (singleton),
|
||||
enabled BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
-- Limit per user per period, in micro-dollars (1 USD = 1,000,000).
|
||||
default_limit_micros BIGINT NOT NULL DEFAULT 0
|
||||
CHECK (default_limit_micros >= 0),
|
||||
-- Period length: 'day', 'week', or 'month'.
|
||||
period TEXT NOT NULL DEFAULT 'month'
|
||||
CHECK (period IN ('day', 'week', 'month')),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- Seed a single disabled row so reads never return empty.
|
||||
INSERT INTO chat_usage_limit_config (singleton) VALUES (TRUE);
|
||||
|
||||
-- 2. Per-user overrides (inline on users table).
|
||||
ALTER TABLE users ADD COLUMN chat_spend_limit_micros BIGINT DEFAULT NULL
|
||||
CHECK (chat_spend_limit_micros IS NULL OR chat_spend_limit_micros > 0);
|
||||
|
||||
-- 3. Per-group overrides (inline on groups table).
|
||||
ALTER TABLE groups ADD COLUMN chat_spend_limit_micros BIGINT DEFAULT NULL
|
||||
CHECK (chat_spend_limit_micros IS NULL OR chat_spend_limit_micros > 0);
|
||||
|
||||
-- Speed up per-user spend aggregation in the usage-limit hot path.
|
||||
CREATE INDEX idx_chat_messages_owner_spend
|
||||
ON chat_messages (chat_id, created_at)
|
||||
WHERE total_cost_micros IS NOT NULL;
|
||||
@@ -0,0 +1,3 @@
|
||||
DROP INDEX idx_aibridge_model_thoughts_interception_id;
|
||||
|
||||
DROP TABLE aibridge_model_thoughts;
|
||||
@@ -0,0 +1,10 @@
|
||||
CREATE TABLE aibridge_model_thoughts (
|
||||
interception_id UUID NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
metadata jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_model_thoughts IS 'Audit log of model thinking in intercepted requests in AI Bridge';
|
||||
|
||||
CREATE INDEX idx_aibridge_model_thoughts_interception_id ON aibridge_model_thoughts(interception_id);
|
||||
+52
@@ -0,0 +1,52 @@
|
||||
DELETE FROM custom_roles
|
||||
WHERE name = 'organization-service-account' AND is_system = true;
|
||||
|
||||
ALTER TABLE organizations
|
||||
ADD COLUMN workspace_sharing_disabled boolean NOT NULL DEFAULT false;
|
||||
|
||||
-- Migrate back: 'none' -> disabled, everything else -> enabled.
|
||||
UPDATE organizations
|
||||
SET workspace_sharing_disabled = true
|
||||
WHERE shareable_workspace_owners = 'none';
|
||||
|
||||
ALTER TABLE organizations DROP COLUMN shareable_workspace_owners;
|
||||
|
||||
DROP TYPE shareable_workspace_owners;
|
||||
|
||||
-- Restore the original single-role trigger from migration 408.
|
||||
DROP TRIGGER IF EXISTS trigger_insert_organization_system_roles ON organizations;
|
||||
DROP FUNCTION IF EXISTS insert_organization_system_roles;
|
||||
|
||||
CREATE OR REPLACE FUNCTION insert_org_member_system_role() RETURNS trigger AS $$
|
||||
BEGIN
|
||||
INSERT INTO custom_roles (
|
||||
name,
|
||||
display_name,
|
||||
organization_id,
|
||||
site_permissions,
|
||||
org_permissions,
|
||||
user_permissions,
|
||||
member_permissions,
|
||||
is_system,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
'organization-member',
|
||||
'',
|
||||
NEW.id,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TRIGGER trigger_insert_org_member_system_role
|
||||
AFTER INSERT ON organizations
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION insert_org_member_system_role();
|
||||
@@ -0,0 +1,101 @@
|
||||
CREATE TYPE shareable_workspace_owners AS ENUM ('none', 'everyone', 'service_accounts');
|
||||
|
||||
ALTER TABLE organizations
|
||||
ADD COLUMN shareable_workspace_owners shareable_workspace_owners NOT NULL DEFAULT 'everyone';
|
||||
|
||||
COMMENT ON COLUMN organizations.shareable_workspace_owners IS 'Controls whose workspaces can be shared: none, everyone, or service_accounts.';
|
||||
|
||||
-- Migrate existing data from the boolean column.
|
||||
UPDATE organizations
|
||||
SET shareable_workspace_owners = 'none'
|
||||
WHERE workspace_sharing_disabled = true;
|
||||
|
||||
ALTER TABLE organizations DROP COLUMN workspace_sharing_disabled;
|
||||
|
||||
-- Defensively rename any existing 'organization-service-account' roles
|
||||
-- so they don't collide with the new system role.
|
||||
UPDATE custom_roles
|
||||
SET name = name || '-' || id::text
|
||||
-- lower(name) is part of the existing unique index
|
||||
WHERE lower(name) = 'organization-service-account';
|
||||
|
||||
-- Create skeleton organization-service-account system roles for all
|
||||
-- existing organizations, mirroring what migration 408 did for
|
||||
-- organization-member.
|
||||
INSERT INTO custom_roles (
|
||||
name,
|
||||
display_name,
|
||||
organization_id,
|
||||
site_permissions,
|
||||
org_permissions,
|
||||
user_permissions,
|
||||
member_permissions,
|
||||
is_system,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
SELECT
|
||||
'organization-service-account',
|
||||
'',
|
||||
id,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
FROM
|
||||
organizations;
|
||||
|
||||
-- Replace the single-role trigger with one that creates both system
|
||||
-- roles when a new organization is inserted.
|
||||
DROP TRIGGER IF EXISTS trigger_insert_org_member_system_role ON organizations;
|
||||
DROP FUNCTION IF EXISTS insert_org_member_system_role;
|
||||
|
||||
CREATE OR REPLACE FUNCTION insert_organization_system_roles() RETURNS trigger AS $$
|
||||
BEGIN
|
||||
INSERT INTO custom_roles (
|
||||
name,
|
||||
display_name,
|
||||
organization_id,
|
||||
site_permissions,
|
||||
org_permissions,
|
||||
user_permissions,
|
||||
member_permissions,
|
||||
is_system,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES
|
||||
(
|
||||
'organization-member',
|
||||
'',
|
||||
NEW.id,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
),
|
||||
(
|
||||
'organization-service-account',
|
||||
'',
|
||||
NEW.id,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TRIGGER trigger_insert_organization_system_roles
|
||||
AFTER INSERT ON organizations
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION insert_organization_system_roles();
|
||||
@@ -0,0 +1,38 @@
|
||||
DROP INDEX IF EXISTS idx_usage_events_ai_seats;
|
||||
|
||||
-- Remove hb_ai_seats_v1 rows so the original constraint can be restored.
|
||||
DELETE FROM usage_events WHERE event_type = 'hb_ai_seats_v1';
|
||||
DELETE FROM usage_events_daily WHERE event_type = 'hb_ai_seats_v1';
|
||||
|
||||
-- Restore original constraint.
|
||||
ALTER TABLE usage_events
|
||||
DROP CONSTRAINT usage_event_type_check,
|
||||
ADD CONSTRAINT usage_event_type_check CHECK (event_type IN ('dc_managed_agents_v1'));
|
||||
|
||||
-- Restore the original aggregate function without hb_ai_seats_v1 support.
|
||||
CREATE OR REPLACE FUNCTION aggregate_usage_event()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
IF NEW.event_type NOT IN ('dc_managed_agents_v1') THEN
|
||||
RAISE EXCEPTION 'Unhandled usage event type in aggregate_usage_event: %', NEW.event_type;
|
||||
END IF;
|
||||
|
||||
INSERT INTO usage_events_daily (day, event_type, usage_data)
|
||||
VALUES (
|
||||
date_trunc('day', NEW.created_at AT TIME ZONE 'UTC')::date,
|
||||
NEW.event_type,
|
||||
NEW.event_data
|
||||
)
|
||||
ON CONFLICT (day, event_type) DO UPDATE SET
|
||||
usage_data = CASE
|
||||
WHEN NEW.event_type IN ('dc_managed_agents_v1') THEN
|
||||
jsonb_build_object(
|
||||
'count',
|
||||
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0) +
|
||||
COALESCE((NEW.event_data->>'count')::bigint, 0)
|
||||
)
|
||||
END;
|
||||
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
@@ -0,0 +1,50 @@
|
||||
-- Expand the CHECK constraint to allow hb_ai_seats_v1.
|
||||
ALTER TABLE usage_events
|
||||
DROP CONSTRAINT usage_event_type_check,
|
||||
ADD CONSTRAINT usage_event_type_check CHECK (event_type IN ('dc_managed_agents_v1', 'hb_ai_seats_v1'));
|
||||
|
||||
-- Partial index for efficient lookups of AI seat heartbeat events by time.
|
||||
-- This will be used for the admin dashboard to see seat count over time.
|
||||
CREATE INDEX idx_usage_events_ai_seats
|
||||
ON usage_events (event_type, created_at)
|
||||
WHERE event_type = 'hb_ai_seats_v1';
|
||||
|
||||
-- Update the aggregate function to handle hb_ai_seats_v1 events.
|
||||
-- Heartbeat events replace the previous value for the same time period.
|
||||
CREATE OR REPLACE FUNCTION aggregate_usage_event()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
-- Check for supported event types and throw error for unknown types.
|
||||
IF NEW.event_type NOT IN ('dc_managed_agents_v1', 'hb_ai_seats_v1') THEN
|
||||
RAISE EXCEPTION 'Unhandled usage event type in aggregate_usage_event: %', NEW.event_type;
|
||||
END IF;
|
||||
|
||||
INSERT INTO usage_events_daily (day, event_type, usage_data)
|
||||
VALUES (
|
||||
date_trunc('day', NEW.created_at AT TIME ZONE 'UTC')::date,
|
||||
NEW.event_type,
|
||||
NEW.event_data
|
||||
)
|
||||
ON CONFLICT (day, event_type) DO UPDATE SET
|
||||
usage_data = CASE
|
||||
-- Handle simple counter events by summing the count.
|
||||
WHEN NEW.event_type IN ('dc_managed_agents_v1') THEN
|
||||
jsonb_build_object(
|
||||
'count',
|
||||
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0) +
|
||||
COALESCE((NEW.event_data->>'count')::bigint, 0)
|
||||
)
|
||||
-- Heartbeat events: keep the max value seen that day
|
||||
WHEN NEW.event_type IN ('hb_ai_seats_v1') THEN
|
||||
jsonb_build_object(
|
||||
'count',
|
||||
GREATEST(
|
||||
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0),
|
||||
COALESCE((NEW.event_data->>'count')::bigint, 0)
|
||||
)
|
||||
)
|
||||
END;
|
||||
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE chat_messages DROP COLUMN runtime_ms;
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE chat_messages ADD COLUMN runtime_ms bigint;
|
||||
Vendored
+28
@@ -0,0 +1,28 @@
|
||||
-- Fixture for migration 000443_three_options_for_allowed_workspace_sharing.
|
||||
-- Inserts a custom role named 'Organization-Service-Account' (mixed case)
|
||||
-- to ensure the migration's case-insensitive rename catches it.
|
||||
INSERT INTO custom_roles (
|
||||
name,
|
||||
display_name,
|
||||
organization_id,
|
||||
site_permissions,
|
||||
org_permissions,
|
||||
user_permissions,
|
||||
member_permissions,
|
||||
is_system,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
VALUES (
|
||||
'Organization-Service-Account',
|
||||
'User-created role',
|
||||
'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1',
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
false,
|
||||
NOW(),
|
||||
NOW()
|
||||
)
|
||||
ON CONFLICT DO NOTHING;
|
||||
@@ -0,0 +1,11 @@
|
||||
INSERT INTO
|
||||
ai_seat_state (
|
||||
user_id,
|
||||
first_used_at,
|
||||
last_used_at,
|
||||
last_event_type,
|
||||
last_event_description,
|
||||
updated_at
|
||||
)
|
||||
VALUES
|
||||
('30095c71-380b-457a-8995-97b8ee6e5307', NOW(), NOW(), 'task'::ai_seat_usage_reason, 'Used for AI task', NOW());
|
||||
@@ -0,0 +1,5 @@
|
||||
UPDATE users SET chat_spend_limit_micros = 5000000
|
||||
WHERE id = 'fc1511ef-4fcf-4a3b-98a1-8df64160e35a';
|
||||
|
||||
UPDATE groups SET chat_spend_limit_micros = 10000000
|
||||
WHERE id = 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1';
|
||||
+13
@@ -0,0 +1,13 @@
|
||||
INSERT INTO
|
||||
aibridge_model_thoughts (
|
||||
interception_id,
|
||||
content,
|
||||
metadata,
|
||||
created_at
|
||||
)
|
||||
VALUES (
|
||||
'be003e1e-b38f-43bf-847d-928074dd0aa8', -- from 000370_aibridge.up.sql
|
||||
'The user is asking about their workspaces. I should use the coder_list_workspaces tool to retrieve this information.',
|
||||
'{"source": "commentary"}',
|
||||
'2025-09-15 12:45:19.123456+00'
|
||||
);
|
||||
+20
@@ -0,0 +1,20 @@
|
||||
INSERT INTO usage_events (
|
||||
id,
|
||||
event_type,
|
||||
event_data,
|
||||
created_at,
|
||||
publish_started_at,
|
||||
published_at,
|
||||
failure_message
|
||||
)
|
||||
VALUES
|
||||
-- Unpublished hb_ai_seats_v1 event.
|
||||
(
|
||||
'ai-seats-event1',
|
||||
'hb_ai_seats_v1',
|
||||
'{"count":3}',
|
||||
'2023-06-01 00:00:00+00',
|
||||
NULL,
|
||||
NULL,
|
||||
NULL
|
||||
);
|
||||
@@ -52,6 +52,7 @@ type customQuerier interface {
|
||||
auditLogQuerier
|
||||
connectionLogQuerier
|
||||
aibridgeQuerier
|
||||
chatQuerier
|
||||
}
|
||||
|
||||
type templateQuerier interface {
|
||||
@@ -451,6 +452,7 @@ func (q *sqlQuerier) GetAuthorizedUsers(ctx context.Context, arg GetUsersParams,
|
||||
&i.OneTimePasscodeExpiresAt,
|
||||
&i.IsSystem,
|
||||
&i.IsServiceAccount,
|
||||
&i.ChatSpendLimitMicros,
|
||||
&i.Count,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
@@ -737,6 +739,68 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
|
||||
return count, nil
|
||||
}
|
||||
|
||||
type chatQuerier interface {
|
||||
GetAuthorizedChats(ctx context.Context, arg GetChatsParams, prepared rbac.PreparedAuthorized) ([]Chat, error)
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams, prepared rbac.PreparedAuthorized) ([]Chat, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(ctx, rbac.ConfigChats())
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
|
||||
filtered, err := insertAuthorizedFilter(getChats, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
// The name comment is for metric tracking
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedChats :many\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.OwnerID,
|
||||
arg.Archived,
|
||||
arg.AfterID,
|
||||
arg.OffsetOpt,
|
||||
arg.LimitOpt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Chat
|
||||
for rows.Next() {
|
||||
var i Chat
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ParentChatID,
|
||||
&i.RootChatID,
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
type aibridgeQuerier interface {
|
||||
ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error)
|
||||
CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error)
|
||||
|
||||
+166
-13
@@ -741,6 +741,64 @@ func AllAgentKeyScopeEnumValues() []AgentKeyScopeEnum {
|
||||
}
|
||||
}
|
||||
|
||||
type AiSeatUsageReason string
|
||||
|
||||
const (
|
||||
AiSeatUsageReasonAibridge AiSeatUsageReason = "aibridge"
|
||||
AiSeatUsageReasonTask AiSeatUsageReason = "task"
|
||||
)
|
||||
|
||||
func (e *AiSeatUsageReason) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = AiSeatUsageReason(s)
|
||||
case string:
|
||||
*e = AiSeatUsageReason(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for AiSeatUsageReason: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type NullAiSeatUsageReason struct {
|
||||
AiSeatUsageReason AiSeatUsageReason `json:"ai_seat_usage_reason"`
|
||||
Valid bool `json:"valid"` // Valid is true if AiSeatUsageReason is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (ns *NullAiSeatUsageReason) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
ns.AiSeatUsageReason, ns.Valid = "", false
|
||||
return nil
|
||||
}
|
||||
ns.Valid = true
|
||||
return ns.AiSeatUsageReason.Scan(value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (ns NullAiSeatUsageReason) Value() (driver.Value, error) {
|
||||
if !ns.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return string(ns.AiSeatUsageReason), nil
|
||||
}
|
||||
|
||||
func (e AiSeatUsageReason) Valid() bool {
|
||||
switch e {
|
||||
case AiSeatUsageReasonAibridge,
|
||||
AiSeatUsageReasonTask:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AllAiSeatUsageReasonValues() []AiSeatUsageReason {
|
||||
return []AiSeatUsageReason{
|
||||
AiSeatUsageReasonAibridge,
|
||||
AiSeatUsageReasonTask,
|
||||
}
|
||||
}
|
||||
|
||||
type AppSharingLevel string
|
||||
|
||||
const (
|
||||
@@ -2969,6 +3027,7 @@ const (
|
||||
ResourceTypeWorkspaceApp ResourceType = "workspace_app"
|
||||
ResourceTypePrebuildsSettings ResourceType = "prebuilds_settings"
|
||||
ResourceTypeTask ResourceType = "task"
|
||||
ResourceTypeAiSeat ResourceType = "ai_seat"
|
||||
)
|
||||
|
||||
func (e *ResourceType) Scan(src interface{}) error {
|
||||
@@ -3033,7 +3092,8 @@ func (e ResourceType) Valid() bool {
|
||||
ResourceTypeWorkspaceAgent,
|
||||
ResourceTypeWorkspaceApp,
|
||||
ResourceTypePrebuildsSettings,
|
||||
ResourceTypeTask:
|
||||
ResourceTypeTask,
|
||||
ResourceTypeAiSeat:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -3067,6 +3127,68 @@ func AllResourceTypeValues() []ResourceType {
|
||||
ResourceTypeWorkspaceApp,
|
||||
ResourceTypePrebuildsSettings,
|
||||
ResourceTypeTask,
|
||||
ResourceTypeAiSeat,
|
||||
}
|
||||
}
|
||||
|
||||
type ShareableWorkspaceOwners string
|
||||
|
||||
const (
|
||||
ShareableWorkspaceOwnersNone ShareableWorkspaceOwners = "none"
|
||||
ShareableWorkspaceOwnersEveryone ShareableWorkspaceOwners = "everyone"
|
||||
ShareableWorkspaceOwnersServiceAccounts ShareableWorkspaceOwners = "service_accounts"
|
||||
)
|
||||
|
||||
func (e *ShareableWorkspaceOwners) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = ShareableWorkspaceOwners(s)
|
||||
case string:
|
||||
*e = ShareableWorkspaceOwners(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for ShareableWorkspaceOwners: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type NullShareableWorkspaceOwners struct {
|
||||
ShareableWorkspaceOwners ShareableWorkspaceOwners `json:"shareable_workspace_owners"`
|
||||
Valid bool `json:"valid"` // Valid is true if ShareableWorkspaceOwners is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (ns *NullShareableWorkspaceOwners) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
ns.ShareableWorkspaceOwners, ns.Valid = "", false
|
||||
return nil
|
||||
}
|
||||
ns.Valid = true
|
||||
return ns.ShareableWorkspaceOwners.Scan(value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (ns NullShareableWorkspaceOwners) Value() (driver.Value, error) {
|
||||
if !ns.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return string(ns.ShareableWorkspaceOwners), nil
|
||||
}
|
||||
|
||||
func (e ShareableWorkspaceOwners) Valid() bool {
|
||||
switch e {
|
||||
case ShareableWorkspaceOwnersNone,
|
||||
ShareableWorkspaceOwnersEveryone,
|
||||
ShareableWorkspaceOwnersServiceAccounts:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AllShareableWorkspaceOwnersValues() []ShareableWorkspaceOwners {
|
||||
return []ShareableWorkspaceOwners{
|
||||
ShareableWorkspaceOwnersNone,
|
||||
ShareableWorkspaceOwnersEveryone,
|
||||
ShareableWorkspaceOwnersServiceAccounts,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3916,6 +4038,14 @@ type AIBridgeInterception struct {
|
||||
ClientSessionID sql.NullString `db:"client_session_id" json:"client_session_id"`
|
||||
}
|
||||
|
||||
// Audit log of model thinking in intercepted requests in AI Bridge
|
||||
type AIBridgeModelThought struct {
|
||||
InterceptionID uuid.UUID `db:"interception_id" json:"interception_id"`
|
||||
Content string `db:"content" json:"content"`
|
||||
Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
// Audit log of tokens used by intercepted requests in AI Bridge
|
||||
type AIBridgeTokenUsage struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
@@ -3975,6 +4105,15 @@ type APIKey struct {
|
||||
AllowList AllowList `db:"allow_list" json:"allow_list"`
|
||||
}
|
||||
|
||||
type AiSeatState struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
FirstUsedAt time.Time `db:"first_used_at" json:"first_used_at"`
|
||||
LastUsedAt time.Time `db:"last_used_at" json:"last_used_at"`
|
||||
LastEventType AiSeatUsageReason `db:"last_event_type" json:"last_event_type"`
|
||||
LastEventDescription string `db:"last_event_description" json:"last_event_description"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
type AuditLog struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
Time time.Time `db:"time" json:"time"`
|
||||
@@ -4085,6 +4224,7 @@ type ChatMessage struct {
|
||||
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
|
||||
ContentVersion int16 `db:"content_version" json:"content_version"`
|
||||
TotalCostMicros sql.NullInt64 `db:"total_cost_micros" json:"total_cost_micros"`
|
||||
RuntimeMs sql.NullInt64 `db:"runtime_ms" json:"runtime_ms"`
|
||||
}
|
||||
|
||||
type ChatModelConfig struct {
|
||||
@@ -4126,6 +4266,16 @@ type ChatQueuedMessage struct {
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
type ChatUsageLimitConfig struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
Singleton bool `db:"singleton" json:"singleton"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
DefaultLimitMicros int64 `db:"default_limit_micros" json:"default_limit_micros"`
|
||||
Period string `db:"period" json:"period"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
type ConnectionLog struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
ConnectTime time.Time `db:"connect_time" json:"connect_time"`
|
||||
@@ -4238,7 +4388,8 @@ type Group struct {
|
||||
// Display name is a custom, human-friendly group name that user can set. This is not required to be unique and can be the empty string.
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
// Source indicates how the group was created. It can be created by a user manually, or through some system process like OIDC group sync.
|
||||
Source GroupSource `db:"source" json:"source"`
|
||||
Source GroupSource `db:"source" json:"source"`
|
||||
ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"`
|
||||
}
|
||||
|
||||
// Joins group members with user information, organization ID, group name. Includes both regular group members and organization members (as part of the "Everyone" group).
|
||||
@@ -4446,16 +4597,17 @@ type OAuth2ProviderAppToken struct {
|
||||
}
|
||||
|
||||
type Organization struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
IsDefault bool `db:"is_default" json:"is_default"`
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
Icon string `db:"icon" json:"icon"`
|
||||
Deleted bool `db:"deleted" json:"deleted"`
|
||||
WorkspaceSharingDisabled bool `db:"workspace_sharing_disabled" json:"workspace_sharing_disabled"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
IsDefault bool `db:"is_default" json:"is_default"`
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
Icon string `db:"icon" json:"icon"`
|
||||
Deleted bool `db:"deleted" json:"deleted"`
|
||||
// Controls whose workspaces can be shared: none, everyone, or service_accounts.
|
||||
ShareableWorkspaceOwners ShareableWorkspaceOwners `db:"shareable_workspace_owners" json:"shareable_workspace_owners"`
|
||||
}
|
||||
|
||||
type OrganizationMember struct {
|
||||
@@ -5008,7 +5160,8 @@ type User struct {
|
||||
// Determines if a user is a system user, and therefore cannot login or perform normal actions
|
||||
IsSystem bool `db:"is_system" json:"is_system"`
|
||||
// Determines if a user is an admin-managed account that cannot login
|
||||
IsServiceAccount bool `db:"is_service_account" json:"is_service_account"`
|
||||
IsServiceAccount bool `db:"is_service_account" json:"is_service_account"`
|
||||
ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"`
|
||||
}
|
||||
|
||||
type UserConfig struct {
|
||||
|
||||
@@ -77,6 +77,9 @@ type sqlcQuerier interface {
|
||||
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
|
||||
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
|
||||
CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error)
|
||||
// Counts enabled, non-deleted model configs that lack both input and
|
||||
// output pricing in their JSONB options.cost configuration.
|
||||
CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error)
|
||||
// CountInProgressPrebuilds returns the number of in-progress prebuilds, grouped by preset ID and transition.
|
||||
// Prebuild considered in-progress if it's in the "pending", "starting", "stopping", or "deleting" state.
|
||||
CountInProgressPrebuilds(ctx context.Context) ([]CountInProgressPrebuildsRow, error)
|
||||
@@ -99,6 +102,8 @@ type sqlcQuerier interface {
|
||||
DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteChatQueuedMessage(ctx context.Context, arg DeleteChatQueuedMessageParams) error
|
||||
DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error
|
||||
DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error
|
||||
DeleteCryptoKey(ctx context.Context, arg DeleteCryptoKeyParams) (CryptoKey, error)
|
||||
DeleteCustomRole(ctx context.Context, arg DeleteCustomRoleParams) error
|
||||
DeleteExpiredAPIKeys(ctx context.Context, arg DeleteExpiredAPIKeysParams) (int64, error)
|
||||
@@ -145,7 +150,7 @@ type sqlcQuerier interface {
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error
|
||||
DeleteWorkspaceACLsByOrganization(ctx context.Context, arg DeleteWorkspaceACLsByOrganizationParams) error
|
||||
DeleteWorkspaceAgentPortShare(ctx context.Context, arg DeleteWorkspaceAgentPortShareParams) error
|
||||
DeleteWorkspaceAgentPortSharesByTemplate(ctx context.Context, templateID uuid.UUID) error
|
||||
DeleteWorkspaceSubAgentByID(ctx context.Context, id uuid.UUID) error
|
||||
@@ -187,6 +192,7 @@ type sqlcQuerier interface {
|
||||
GetAPIKeysByLoginType(ctx context.Context, arg GetAPIKeysByLoginTypeParams) ([]APIKey, error)
|
||||
GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error)
|
||||
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
|
||||
GetActiveAISeatCount(ctx context.Context) (int64, error)
|
||||
GetActivePresetPrebuildSchedules(ctx context.Context) ([]TemplateVersionPresetPrebuildSchedule, error)
|
||||
GetActiveUserCount(ctx context.Context, includeSystem bool) (int64, error)
|
||||
GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceBuild, error)
|
||||
@@ -228,12 +234,14 @@ type sqlcQuerier interface {
|
||||
// Aggregate cost summary for a single user within a date range.
|
||||
// Only counts assistant-role messages.
|
||||
GetChatCostSummary(ctx context.Context, arg GetChatCostSummaryParams) (GetChatCostSummaryRow, error)
|
||||
GetChatDesktopEnabled(ctx context.Context) (bool, error)
|
||||
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
|
||||
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
|
||||
GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error)
|
||||
GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error)
|
||||
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
|
||||
GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error)
|
||||
GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg GetChatMessagesByChatIDDescPaginatedParams) ([]ChatMessage, error)
|
||||
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
|
||||
GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
|
||||
GetChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
|
||||
@@ -242,7 +250,10 @@ type sqlcQuerier interface {
|
||||
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
|
||||
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
|
||||
GetChatSystemPrompt(ctx context.Context) (string, error)
|
||||
GetChatsByOwnerID(ctx context.Context, arg GetChatsByOwnerIDParams) ([]Chat, error)
|
||||
GetChatUsageLimitConfig(ctx context.Context) (ChatUsageLimitConfig, error)
|
||||
GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (GetChatUsageLimitGroupOverrideRow, error)
|
||||
GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (GetChatUsageLimitUserOverrideRow, error)
|
||||
GetChats(ctx context.Context, arg GetChatsParams) ([]Chat, error)
|
||||
GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error)
|
||||
GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error)
|
||||
GetCryptoKeys(ctx context.Context) ([]CryptoKey, error)
|
||||
@@ -330,6 +341,18 @@ type sqlcQuerier interface {
|
||||
// GetOrganizationsWithPrebuildStatus returns organizations with prebuilds configured and their
|
||||
// membership status for the prebuilds system user (org membership, group existence, group membership).
|
||||
GetOrganizationsWithPrebuildStatus(ctx context.Context, arg GetOrganizationsWithPrebuildStatusParams) ([]GetOrganizationsWithPrebuildStatusRow, error)
|
||||
// Returns PR metrics grouped by the model used for each chat.
|
||||
GetPRInsightsPerModel(ctx context.Context, arg GetPRInsightsPerModelParams) ([]GetPRInsightsPerModelRow, error)
|
||||
// Returns individual PR rows with cost for the recent PRs table.
|
||||
GetPRInsightsRecentPRs(ctx context.Context, arg GetPRInsightsRecentPRsParams) ([]GetPRInsightsRecentPRsRow, error)
|
||||
// PR Insights queries for the /agents analytics dashboard.
|
||||
// These aggregate data from chat_diff_statuses (PR metadata) joined
|
||||
// with chats and chat_messages (cost) to power the PR Insights view.
|
||||
// Returns aggregate PR metrics for the given date range.
|
||||
// The handler calls this twice (current + previous period) for trends.
|
||||
GetPRInsightsSummary(ctx context.Context, arg GetPRInsightsSummaryParams) (GetPRInsightsSummaryRow, error)
|
||||
// Returns daily PR counts grouped by state for the chart.
|
||||
GetPRInsightsTimeSeries(ctx context.Context, arg GetPRInsightsTimeSeriesParams) ([]GetPRInsightsTimeSeriesRow, error)
|
||||
GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error)
|
||||
GetPrebuildMetrics(ctx context.Context) ([]GetPrebuildMetricsRow, error)
|
||||
GetPrebuildsSettings(ctx context.Context) (string, error)
|
||||
@@ -494,7 +517,11 @@ type sqlcQuerier interface {
|
||||
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
|
||||
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
|
||||
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
|
||||
GetUserChatSpendInPeriod(ctx context.Context, arg GetUserChatSpendInPeriodParams) (int64, error)
|
||||
GetUserCount(ctx context.Context, includeSystem bool) (int64, error)
|
||||
// Returns the minimum (most restrictive) group limit for a user.
|
||||
// Returns -1 if the user has no group limits applied.
|
||||
GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error)
|
||||
// GetUserLatencyInsights returns the median and 95th percentile connection
|
||||
// latency that users have experienced. The result can be filtered on
|
||||
// template_ids, meaning only user data from workspaces based on those templates
|
||||
@@ -598,6 +625,7 @@ type sqlcQuerier interface {
|
||||
GetWorkspacesEligibleForTransition(ctx context.Context, now time.Time) ([]GetWorkspacesEligibleForTransitionRow, error)
|
||||
GetWorkspacesForWorkspaceMetrics(ctx context.Context) ([]GetWorkspacesForWorkspaceMetricsRow, error)
|
||||
InsertAIBridgeInterception(ctx context.Context, arg InsertAIBridgeInterceptionParams) (AIBridgeInterception, error)
|
||||
InsertAIBridgeModelThought(ctx context.Context, arg InsertAIBridgeModelThoughtParams) (AIBridgeModelThought, error)
|
||||
InsertAIBridgeTokenUsage(ctx context.Context, arg InsertAIBridgeTokenUsageParams) (AIBridgeTokenUsage, error)
|
||||
InsertAIBridgeToolUsage(ctx context.Context, arg InsertAIBridgeToolUsageParams) (AIBridgeToolUsage, error)
|
||||
InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIBridgeUserPromptParams) (AIBridgeUserPrompt, error)
|
||||
@@ -694,6 +722,8 @@ type sqlcQuerier interface {
|
||||
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
|
||||
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
|
||||
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
|
||||
ListChatUsageLimitGroupOverrides(ctx context.Context) ([]ListChatUsageLimitGroupOverridesRow, error)
|
||||
ListChatUsageLimitOverrides(ctx context.Context) ([]ListChatUsageLimitOverridesRow, error)
|
||||
ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
|
||||
@@ -714,6 +744,12 @@ type sqlcQuerier interface {
|
||||
ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error
|
||||
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
|
||||
RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error)
|
||||
// Resolves the effective spend limit for a user using the hierarchy:
|
||||
// 1. Individual user override (highest priority)
|
||||
// 2. Minimum group limit across all user's groups
|
||||
// 3. Global default from config
|
||||
// Returns -1 if limits are not enabled.
|
||||
ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error)
|
||||
RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error
|
||||
// Note that this selects from the CTE, not the original table. The CTE is named
|
||||
// the same as the original table to trick sqlc into reusing the existing struct
|
||||
@@ -833,6 +869,8 @@ type sqlcQuerier interface {
|
||||
UpdateWorkspaceTTL(ctx context.Context, arg UpdateWorkspaceTTLParams) error
|
||||
UpdateWorkspacesDormantDeletingAtByTemplateID(ctx context.Context, arg UpdateWorkspacesDormantDeletingAtByTemplateIDParams) ([]WorkspaceTable, error)
|
||||
UpdateWorkspacesTTLByTemplateID(ctx context.Context, arg UpdateWorkspacesTTLByTemplateIDParams) error
|
||||
// Returns true if a new rows was inserted, false otherwise.
|
||||
UpsertAISeatState(ctx context.Context, arg UpsertAISeatStateParams) (bool, error)
|
||||
UpsertAnnouncementBanners(ctx context.Context, value string) error
|
||||
UpsertApplicationName(ctx context.Context, value string) error
|
||||
// Upserts boundary usage statistics for a replica. On INSERT (new period), uses
|
||||
@@ -840,9 +878,13 @@ type sqlcQuerier interface {
|
||||
// cumulative values for unique counts (accurate period totals). Request counts
|
||||
// are always deltas, accumulated in DB. Returns true if insert, false if update.
|
||||
UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBoundaryUsageStatsParams) (bool, error)
|
||||
UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error
|
||||
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
|
||||
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
|
||||
UpsertChatSystemPrompt(ctx context.Context, value string) error
|
||||
UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error)
|
||||
UpsertChatUsageLimitGroupOverride(ctx context.Context, arg UpsertChatUsageLimitGroupOverrideParams) (UpsertChatUsageLimitGroupOverrideRow, error)
|
||||
UpsertChatUsageLimitUserOverride(ctx context.Context, arg UpsertChatUsageLimitUserOverrideParams) (UpsertChatUsageLimitUserOverrideRow, error)
|
||||
UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error)
|
||||
// The default proxy is implied and not actually stored in the database.
|
||||
// So we need to store it's configuration here for display purposes.
|
||||
@@ -877,6 +919,7 @@ type sqlcQuerier interface {
|
||||
// was started. This means that a new row was inserted (no previous session) or
|
||||
// the updated_at is older than stale interval.
|
||||
UpsertWorkspaceAppAuditSession(ctx context.Context, arg UpsertWorkspaceAppAuditSessionParams) (bool, error)
|
||||
UsageEventExistsByID(ctx context.Context, id string) (bool, error)
|
||||
ValidateGroupIDs(ctx context.Context, groupIds []uuid.UUID) (ValidateGroupIDsRow, error)
|
||||
ValidateUserIDs(ctx context.Context, userIds []uuid.UUID) (ValidateUserIDsRow, error)
|
||||
}
|
||||
|
||||
+504
-58
@@ -1235,6 +1235,230 @@ func TestGetAuthorizedWorkspacesAndAgentsByOwnerID(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAuthorizedChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
if testing.Short() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
sqlDB := testSQLDB(t)
|
||||
err := migrations.Up(sqlDB)
|
||||
require.NoError(t, err)
|
||||
db := database.New(sqlDB)
|
||||
authorizer := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||||
|
||||
// Create users with different roles.
|
||||
owner := dbgen.User(t, db, database.User{
|
||||
RBACRoles: []string{rbac.RoleOwner().String()},
|
||||
})
|
||||
member := dbgen.User(t, db, database.User{})
|
||||
secondMember := dbgen.User(t, db, database.User{})
|
||||
|
||||
// Create FK dependencies: a chat provider and model config.
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "test-model",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 80,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create 3 chats owned by owner.
|
||||
for i := range 3 {
|
||||
_, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: fmt.Sprintf("owner chat %d", i+1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create 2 chats owned by member.
|
||||
for i := range 2 {
|
||||
_, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: member.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: fmt.Sprintf("member chat %d", i+1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("sqlQuerier", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
// Member should only see their own 2 chats.
|
||||
memberSubject, _, err := httpmw.UserRBACSubject(ctx, db, member.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
||||
require.NoError(t, err)
|
||||
preparedMember, err := authorizer.Prepare(ctx, memberSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
||||
require.NoError(t, err)
|
||||
memberRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedMember)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, memberRows, 2)
|
||||
for _, row := range memberRows {
|
||||
require.Equal(t, member.ID, row.OwnerID, "member should only see own chats")
|
||||
}
|
||||
|
||||
// Owner should see at least the 5 pre-created chats (site-wide
|
||||
// access). Parallel subtests may add more, so use GreaterOrEqual.
|
||||
ownerSubject, _, err := httpmw.UserRBACSubject(ctx, db, owner.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
||||
require.NoError(t, err)
|
||||
preparedOwner, err := authorizer.Prepare(ctx, ownerSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
||||
require.NoError(t, err)
|
||||
ownerRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedOwner)
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, len(ownerRows), 5)
|
||||
|
||||
// secondMember has no chats and should see 0.
|
||||
secondSubject, _, err := httpmw.UserRBACSubject(ctx, db, secondMember.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
||||
require.NoError(t, err)
|
||||
preparedSecond, err := authorizer.Prepare(ctx, secondSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
||||
require.NoError(t, err)
|
||||
secondRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedSecond)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secondRows, 0)
|
||||
|
||||
// Org admin should NOT see other users' chats — chats are
|
||||
// not org-scoped resources.
|
||||
orgs, err := db.GetOrganizations(ctx, database.GetOrganizationsParams{})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, orgs)
|
||||
orgAdmin := dbgen.User(t, db, database.User{})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: orgAdmin.ID,
|
||||
OrganizationID: orgs[0].ID,
|
||||
Roles: []string{rbac.RoleOrgAdmin()},
|
||||
})
|
||||
orgAdminSubject, _, err := httpmw.UserRBACSubject(ctx, db, orgAdmin.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
||||
require.NoError(t, err)
|
||||
preparedOrgAdmin, err := authorizer.Prepare(ctx, orgAdminSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
||||
require.NoError(t, err)
|
||||
orgAdminRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedOrgAdmin)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, orgAdminRows, 0, "org admin with no chats should see 0 chats")
|
||||
|
||||
// OwnerID filter: member queries their own chats.
|
||||
memberFilterSelf, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
|
||||
OwnerID: member.ID,
|
||||
}, preparedMember)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, memberFilterSelf, 2)
|
||||
|
||||
// OwnerID filter: member queries owner's chats → sees 0.
|
||||
memberFilterOwner, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
|
||||
OwnerID: owner.ID,
|
||||
}, preparedMember)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, memberFilterOwner, 0)
|
||||
})
|
||||
|
||||
t.Run("dbauthz", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
authzdb := dbauthz.New(db, authorizer, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
||||
|
||||
// As member: should see only own 2 chats.
|
||||
memberSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, member.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
||||
require.NoError(t, err)
|
||||
memberCtx := dbauthz.As(ctx, memberSubject)
|
||||
memberRows, err := authzdb.GetChats(memberCtx, database.GetChatsParams{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, memberRows, 2)
|
||||
for _, row := range memberRows {
|
||||
require.Equal(t, member.ID, row.OwnerID, "member should only see own chats")
|
||||
}
|
||||
|
||||
// As owner: should see at least the 5 pre-created chats.
|
||||
ownerSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, owner.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
||||
require.NoError(t, err)
|
||||
ownerCtx := dbauthz.As(ctx, ownerSubject)
|
||||
ownerRows, err := authzdb.GetChats(ownerCtx, database.GetChatsParams{})
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, len(ownerRows), 5)
|
||||
|
||||
// As secondMember: should see 0 chats.
|
||||
secondSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, secondMember.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
||||
require.NoError(t, err)
|
||||
secondCtx := dbauthz.As(ctx, secondSubject)
|
||||
secondRows, err := authzdb.GetChats(secondCtx, database.GetChatsParams{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secondRows, 0)
|
||||
})
|
||||
|
||||
t.Run("pagination", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
// Use a dedicated user for pagination to avoid interference
|
||||
// with the other parallel subtests.
|
||||
paginationUser := dbgen.User(t, db, database.User{})
|
||||
for i := range 7 {
|
||||
_, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: paginationUser.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: fmt.Sprintf("pagination chat %d", i+1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
pagUserSubject, _, err := httpmw.UserRBACSubject(ctx, db, paginationUser.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
||||
require.NoError(t, err)
|
||||
preparedMember, err := authorizer.Prepare(ctx, pagUserSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fetch first page with limit=2.
|
||||
page1, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
|
||||
LimitOpt: 2,
|
||||
}, preparedMember)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, page1, 2)
|
||||
for _, row := range page1 {
|
||||
require.Equal(t, paginationUser.ID, row.OwnerID, "paginated results must belong to pagination user")
|
||||
}
|
||||
|
||||
// Fetch remaining pages and collect all chat IDs.
|
||||
allIDs := make(map[uuid.UUID]struct{})
|
||||
for _, row := range page1 {
|
||||
allIDs[row.ID] = struct{}{}
|
||||
}
|
||||
offset := int32(2)
|
||||
for {
|
||||
page, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
|
||||
LimitOpt: 2,
|
||||
OffsetOpt: offset,
|
||||
}, preparedMember)
|
||||
require.NoError(t, err)
|
||||
for _, row := range page {
|
||||
require.Equal(t, paginationUser.ID, row.OwnerID, "paginated results must belong to pagination user")
|
||||
allIDs[row.ID] = struct{}{}
|
||||
}
|
||||
if len(page) < 2 {
|
||||
break
|
||||
}
|
||||
offset += int32(len(page)) //nolint:gosec // Test code, pagination values are small.
|
||||
}
|
||||
|
||||
// All 7 member chats should be accounted for with no leakage.
|
||||
require.Len(t, allIDs, 7, "pagination should return all member chats exactly once")
|
||||
})
|
||||
}
|
||||
|
||||
func TestInsertWorkspaceAgentLogs(t *testing.T) {
|
||||
t.Parallel()
|
||||
if testing.Short() {
|
||||
@@ -2431,6 +2655,42 @@ func TestDeleteCustomRoleDoesNotDeleteSystemRole(t *testing.T) {
|
||||
require.True(t, roles[0].IsSystem)
|
||||
}
|
||||
|
||||
func TestGetAuthorizationUserRolesImpliedOrgRole(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
regularUser := dbgen.User(t, db, database.User{})
|
||||
saUser := dbgen.User(t, db, database.User{IsServiceAccount: true})
|
||||
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org.ID,
|
||||
UserID: regularUser.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org.ID,
|
||||
UserID: saUser.ID,
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
wantMember := rbac.RoleOrgMember() + ":" + org.ID.String()
|
||||
wantSA := rbac.RoleOrgServiceAccount() + ":" + org.ID.String()
|
||||
|
||||
// Regular users get the implied organization-member role.
|
||||
regularRoles, err := db.GetAuthorizationUserRoles(ctx, regularUser.ID)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, regularRoles.Roles, wantMember)
|
||||
require.NotContains(t, regularRoles.Roles, wantSA)
|
||||
|
||||
// Service accounts get the implied organization-service-account role.
|
||||
saRoles, err := db.GetAuthorizationUserRoles(ctx, saUser.ID)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, saRoles.Roles, wantSA)
|
||||
require.NotContains(t, saRoles.Roles, wantMember)
|
||||
}
|
||||
|
||||
func TestUpdateOrganizationWorkspaceSharingSettings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -2441,82 +2701,155 @@ func TestUpdateOrganizationWorkspaceSharingSettings(t *testing.T) {
|
||||
|
||||
updated, err := db.UpdateOrganizationWorkspaceSharingSettings(ctx, database.UpdateOrganizationWorkspaceSharingSettingsParams{
|
||||
ID: org.ID,
|
||||
WorkspaceSharingDisabled: true,
|
||||
ShareableWorkspaceOwners: database.ShareableWorkspaceOwnersNone,
|
||||
UpdatedAt: dbtime.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, updated.WorkspaceSharingDisabled)
|
||||
require.Equal(t, database.ShareableWorkspaceOwnersNone, updated.ShareableWorkspaceOwners)
|
||||
|
||||
got, err := db.GetOrganizationByID(ctx, org.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, got.WorkspaceSharingDisabled)
|
||||
require.Equal(t, database.ShareableWorkspaceOwnersNone, got.ShareableWorkspaceOwners)
|
||||
}
|
||||
|
||||
func TestDeleteWorkspaceACLsByOrganization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
org1 := dbgen.Organization(t, db, database.Organization{})
|
||||
org2 := dbgen.Organization(t, db, database.Organization{})
|
||||
t.Run("DeletesAll", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
owner1 := dbgen.User(t, db, database.User{})
|
||||
owner2 := dbgen.User(t, db, database.User{})
|
||||
sharedUser := dbgen.User(t, db, database.User{})
|
||||
sharedGroup := dbgen.Group(t, db, database.Group{
|
||||
OrganizationID: org1.ID,
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
org1 := dbgen.Organization(t, db, database.Organization{})
|
||||
org2 := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
owner1 := dbgen.User(t, db, database.User{})
|
||||
owner2 := dbgen.User(t, db, database.User{})
|
||||
sharedUser := dbgen.User(t, db, database.User{})
|
||||
sharedGroup := dbgen.Group(t, db, database.Group{
|
||||
OrganizationID: org1.ID,
|
||||
})
|
||||
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org1.ID,
|
||||
UserID: owner1.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org2.ID,
|
||||
UserID: owner2.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org1.ID,
|
||||
UserID: sharedUser.ID,
|
||||
})
|
||||
|
||||
ws1 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: owner1.ID,
|
||||
OrganizationID: org1.ID,
|
||||
UserACL: database.WorkspaceACL{
|
||||
sharedUser.ID.String(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
GroupACL: database.WorkspaceACL{
|
||||
sharedGroup.ID.String(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
}).Do().Workspace
|
||||
|
||||
ws2 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: owner2.ID,
|
||||
OrganizationID: org2.ID,
|
||||
UserACL: database.WorkspaceACL{
|
||||
uuid.NewString(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
}).Do().Workspace
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
err := db.DeleteWorkspaceACLsByOrganization(ctx, database.DeleteWorkspaceACLsByOrganizationParams{
|
||||
OrganizationID: org1.ID,
|
||||
ExcludeServiceAccounts: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got1, err := db.GetWorkspaceByID(ctx, ws1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, got1.UserACL)
|
||||
require.Empty(t, got1.GroupACL)
|
||||
|
||||
got2, err := db.GetWorkspaceByID(ctx, ws2.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, got2.UserACL)
|
||||
})
|
||||
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org1.ID,
|
||||
UserID: owner1.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org2.ID,
|
||||
UserID: owner2.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org1.ID,
|
||||
UserID: sharedUser.ID,
|
||||
})
|
||||
t.Run("ExcludesServiceAccounts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ws1 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: owner1.ID,
|
||||
OrganizationID: org1.ID,
|
||||
UserACL: database.WorkspaceACL{
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
regularUser := dbgen.User(t, db, database.User{})
|
||||
saUser := dbgen.User(t, db, database.User{IsServiceAccount: true})
|
||||
sharedUser := dbgen.User(t, db, database.User{})
|
||||
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org.ID,
|
||||
UserID: regularUser.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org.ID,
|
||||
UserID: saUser.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org.ID,
|
||||
UserID: sharedUser.ID,
|
||||
})
|
||||
|
||||
regularWS := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: regularUser.ID,
|
||||
OrganizationID: org.ID,
|
||||
UserACL: database.WorkspaceACL{
|
||||
sharedUser.ID.String(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
}).Do().Workspace
|
||||
|
||||
saWS := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: saUser.ID,
|
||||
OrganizationID: org.ID,
|
||||
UserACL: database.WorkspaceACL{
|
||||
sharedUser.ID.String(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
}).Do().Workspace
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
err := db.DeleteWorkspaceACLsByOrganization(ctx, database.DeleteWorkspaceACLsByOrganizationParams{
|
||||
OrganizationID: org.ID,
|
||||
ExcludeServiceAccounts: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Regular user workspace ACLs should be cleared.
|
||||
gotRegular, err := db.GetWorkspaceByID(ctx, regularWS.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, gotRegular.UserACL)
|
||||
|
||||
// Service account workspace ACLs should be preserved.
|
||||
gotSA, err := db.GetWorkspaceByID(ctx, saWS.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.WorkspaceACL{
|
||||
sharedUser.ID.String(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
GroupACL: database.WorkspaceACL{
|
||||
sharedGroup.ID.String(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
}).Do().Workspace
|
||||
|
||||
ws2 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: owner2.ID,
|
||||
OrganizationID: org2.ID,
|
||||
UserACL: database.WorkspaceACL{
|
||||
uuid.NewString(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
}).Do().Workspace
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
err := db.DeleteWorkspaceACLsByOrganization(ctx, org1.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
got1, err := db.GetWorkspaceByID(ctx, ws1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, got1.UserACL)
|
||||
require.Empty(t, got1.GroupACL)
|
||||
|
||||
got2, err := db.GetWorkspaceByID(ctx, ws2.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, got2.UserACL)
|
||||
}, gotSA.UserACL)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthorizedAuditLogs(t *testing.T) {
|
||||
@@ -7982,6 +8315,80 @@ func TestUsageEventsTrigger(t *testing.T) {
|
||||
require.WithinDuration(t, time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC), rows[1].Day, time.Second)
|
||||
})
|
||||
|
||||
t.Run("HeartbeatAISeats", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
||||
|
||||
// Insert a heartbeat event.
|
||||
err := db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
||||
ID: "hb-1",
|
||||
EventType: "hb_ai_seats_v1",
|
||||
EventData: []byte(`{"count": 10}`),
|
||||
CreatedAt: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rows := getDailyRows(ctx, sqlDB)
|
||||
require.Len(t, rows, 1)
|
||||
require.Equal(t, "hb_ai_seats_v1", rows[0].EventType)
|
||||
require.JSONEq(t, `{"count": 10}`, string(rows[0].UsageData))
|
||||
|
||||
// Insert a higher count on the same day — should take the max.
|
||||
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
||||
ID: "hb-2",
|
||||
EventType: "hb_ai_seats_v1",
|
||||
EventData: []byte(`{"count": 50}`),
|
||||
CreatedAt: time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rows = getDailyRows(ctx, sqlDB)
|
||||
require.Len(t, rows, 1)
|
||||
require.JSONEq(t, `{"count": 50}`, string(rows[0].UsageData))
|
||||
|
||||
// Insert a lower count on the same day — should keep the max (50).
|
||||
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
||||
ID: "hb-3",
|
||||
EventType: "hb_ai_seats_v1",
|
||||
EventData: []byte(`{"count": 25}`),
|
||||
CreatedAt: time.Date(2025, 1, 1, 18, 0, 0, 0, time.UTC),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rows = getDailyRows(ctx, sqlDB)
|
||||
require.Len(t, rows, 1)
|
||||
require.JSONEq(t, `{"count": 50}`, string(rows[0].UsageData))
|
||||
|
||||
// Insert on a different day.
|
||||
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
||||
ID: "hb-4",
|
||||
EventType: "hb_ai_seats_v1",
|
||||
EventData: []byte(`{"count": 5}`),
|
||||
CreatedAt: time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rows = getDailyRows(ctx, sqlDB)
|
||||
require.Len(t, rows, 2)
|
||||
require.JSONEq(t, `{"count": 50}`, string(rows[0].UsageData))
|
||||
require.JSONEq(t, `{"count": 5}`, string(rows[1].UsageData))
|
||||
|
||||
// Also insert a dc_managed_agents_v1 on the same first day to
|
||||
// verify different event types get separate daily rows.
|
||||
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
||||
ID: "dc-1",
|
||||
EventType: "dc_managed_agents_v1",
|
||||
EventData: []byte(`{"count": 7}`),
|
||||
CreatedAt: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rows = getDailyRows(ctx, sqlDB)
|
||||
require.Len(t, rows, 3)
|
||||
})
|
||||
|
||||
t.Run("UnknownEventType", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -9441,3 +9848,42 @@ func TestGetWorkspaceBuildMetricsByResourceID(t *testing.T) {
|
||||
require.Equal(t, "success", row.WorstStatus)
|
||||
})
|
||||
}
|
||||
|
||||
// TestUpsertAISeats verifies 'UpsertAISeatState' only returns true when a new
|
||||
// row is inserted.
|
||||
func TestUpsertAISeats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqlDB := testSQLDB(t)
|
||||
err := migrations.Up(sqlDB)
|
||||
require.NoError(t, err)
|
||||
db := database.New(sqlDB)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
now := dbtime.Now()
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
newRow, err := db.UpsertAISeatState(ctx, database.UpsertAISeatStateParams{
|
||||
UserID: user.ID,
|
||||
FirstUsedAt: now.Add(time.Hour * -24),
|
||||
LastEventType: database.AiSeatUsageReasonTask,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, newRow)
|
||||
|
||||
alreadyExists, err := db.UpsertAISeatState(ctx, database.UpsertAISeatStateParams{
|
||||
UserID: user.ID,
|
||||
FirstUsedAt: now.Add(time.Hour * -23),
|
||||
LastEventType: database.AiSeatUsageReasonTask,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, alreadyExists)
|
||||
|
||||
alreadyExists, err = db.UpsertAISeatState(ctx, database.UpsertAISeatStateParams{
|
||||
UserID: user.ID,
|
||||
FirstUsedAt: now,
|
||||
LastEventType: database.AiSeatUsageReasonTask,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, alreadyExists)
|
||||
}
|
||||
|
||||
+1024
-56
File diff suppressed because it is too large
Load Diff
@@ -53,6 +53,14 @@ INSERT INTO aibridge_tool_usages (
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: InsertAIBridgeModelThought :one
|
||||
INSERT INTO aibridge_model_thoughts (
|
||||
interception_id, content, metadata, created_at
|
||||
) VALUES (
|
||||
@interception_id, @content, COALESCE(@metadata::jsonb, '{}'::jsonb), @created_at
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetAIBridgeInterceptionByID :one
|
||||
SELECT
|
||||
*
|
||||
@@ -362,6 +370,11 @@ WITH
|
||||
WHERE started_at < @before_time::timestamp with time zone
|
||||
),
|
||||
-- CTEs are executed in order.
|
||||
model_thoughts AS (
|
||||
DELETE FROM aibridge_model_thoughts
|
||||
WHERE interception_id IN (SELECT id FROM to_delete)
|
||||
RETURNING 1
|
||||
),
|
||||
tool_usages AS (
|
||||
DELETE FROM aibridge_tool_usages
|
||||
WHERE interception_id IN (SELECT id FROM to_delete)
|
||||
@@ -384,6 +397,7 @@ WITH
|
||||
)
|
||||
-- Cumulative count.
|
||||
SELECT (
|
||||
(SELECT COUNT(*) FROM model_thoughts) +
|
||||
(SELECT COUNT(*) FROM tool_usages) +
|
||||
(SELECT COUNT(*) FROM token_usages) +
|
||||
(SELECT COUNT(*) FROM user_prompts) +
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
-- name: UpsertAISeatState :one
|
||||
-- Returns true if a new rows was inserted, false otherwise.
|
||||
INSERT INTO ai_seat_state (
|
||||
user_id,
|
||||
first_used_at,
|
||||
last_used_at,
|
||||
last_event_type,
|
||||
last_event_description,
|
||||
updated_at
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $2, $3, $4, $2)
|
||||
ON CONFLICT (user_id) DO UPDATE
|
||||
SET
|
||||
last_used_at = EXCLUDED.last_used_at,
|
||||
last_event_type = EXCLUDED.last_event_type,
|
||||
last_event_description = EXCLUDED.last_event_description,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
RETURNING
|
||||
-- Postgres vodoo to know if a row was inserted.
|
||||
(xmax = 0)::boolean AS is_new;
|
||||
|
||||
-- name: GetActiveAISeatCount :one
|
||||
SELECT
|
||||
COUNT(*)
|
||||
FROM
|
||||
ai_seat_state ais
|
||||
JOIN
|
||||
users u
|
||||
ON
|
||||
ais.user_id = u.id
|
||||
WHERE
|
||||
u.status = 'active'::user_status
|
||||
AND u.deleted = false
|
||||
AND u.is_system = false;
|
||||
@@ -0,0 +1,118 @@
|
||||
-- PR Insights queries for the /agents analytics dashboard.
|
||||
-- These aggregate data from chat_diff_statuses (PR metadata) joined
|
||||
-- with chats and chat_messages (cost) to power the PR Insights view.
|
||||
|
||||
-- name: GetPRInsightsSummary :one
|
||||
-- Returns aggregate PR metrics for the given date range.
|
||||
-- The handler calls this twice (current + previous period) for trends.
|
||||
SELECT
|
||||
COUNT(*)::bigint AS total_prs_created,
|
||||
COUNT(*) FILTER (WHERE cds.pull_request_state = 'merged')::bigint AS total_prs_merged,
|
||||
COUNT(*) FILTER (WHERE cds.pull_request_state = 'closed')::bigint AS total_prs_closed,
|
||||
COALESCE(SUM(cds.additions), 0)::bigint AS total_additions,
|
||||
COALESCE(SUM(cds.deletions), 0)::bigint AS total_deletions,
|
||||
COALESCE(SUM(cc.cost_micros), 0)::bigint AS total_cost_micros,
|
||||
COALESCE(SUM(cc.cost_micros) FILTER (WHERE cds.pull_request_state = 'merged'), 0)::bigint AS merged_cost_micros
|
||||
FROM chat_diff_statuses cds
|
||||
JOIN chats c ON c.id = cds.chat_id
|
||||
LEFT JOIN (
|
||||
SELECT
|
||||
COALESCE(ch.root_chat_id, ch.id) AS root_id,
|
||||
COALESCE(SUM(cm.total_cost_micros), 0) AS cost_micros
|
||||
FROM chat_messages cm
|
||||
JOIN chats ch ON ch.id = cm.chat_id
|
||||
WHERE cm.total_cost_micros IS NOT NULL
|
||||
GROUP BY COALESCE(ch.root_chat_id, ch.id)
|
||||
) cc ON cc.root_id = COALESCE(c.root_chat_id, c.id)
|
||||
WHERE cds.pull_request_state IS NOT NULL
|
||||
AND c.created_at >= @start_date::timestamptz
|
||||
AND c.created_at < @end_date::timestamptz
|
||||
AND (sqlc.narg('owner_id')::uuid IS NULL OR c.owner_id = sqlc.narg('owner_id')::uuid);
|
||||
|
||||
-- name: GetPRInsightsTimeSeries :many
|
||||
-- Returns daily PR counts grouped by state for the chart.
|
||||
SELECT
|
||||
date_trunc('day', c.created_at)::timestamptz AS date,
|
||||
COUNT(*)::bigint AS prs_created,
|
||||
COUNT(*) FILTER (WHERE cds.pull_request_state = 'merged')::bigint AS prs_merged,
|
||||
COUNT(*) FILTER (WHERE cds.pull_request_state = 'closed')::bigint AS prs_closed
|
||||
FROM chat_diff_statuses cds
|
||||
JOIN chats c ON c.id = cds.chat_id
|
||||
WHERE cds.pull_request_state IS NOT NULL
|
||||
AND c.created_at >= @start_date::timestamptz
|
||||
AND c.created_at < @end_date::timestamptz
|
||||
AND (sqlc.narg('owner_id')::uuid IS NULL OR c.owner_id = sqlc.narg('owner_id')::uuid)
|
||||
GROUP BY date_trunc('day', c.created_at)
|
||||
ORDER BY date_trunc('day', c.created_at);
|
||||
|
||||
-- name: GetPRInsightsPerModel :many
|
||||
-- Returns PR metrics grouped by the model used for each chat.
|
||||
SELECT
|
||||
cmc.id AS model_config_id,
|
||||
cmc.display_name,
|
||||
cmc.provider,
|
||||
COUNT(*)::bigint AS total_prs,
|
||||
COUNT(*) FILTER (WHERE cds.pull_request_state = 'merged')::bigint AS merged_prs,
|
||||
COALESCE(SUM(cds.additions), 0)::bigint AS total_additions,
|
||||
COALESCE(SUM(cds.deletions), 0)::bigint AS total_deletions,
|
||||
COALESCE(SUM(cc.cost_micros), 0)::bigint AS total_cost_micros,
|
||||
COALESCE(SUM(cc.cost_micros) FILTER (WHERE cds.pull_request_state = 'merged'), 0)::bigint AS merged_cost_micros
|
||||
FROM chat_diff_statuses cds
|
||||
JOIN chats c ON c.id = cds.chat_id
|
||||
JOIN chat_model_configs cmc ON cmc.id = c.last_model_config_id
|
||||
LEFT JOIN (
|
||||
SELECT
|
||||
COALESCE(ch.root_chat_id, ch.id) AS root_id,
|
||||
COALESCE(SUM(cm.total_cost_micros), 0) AS cost_micros
|
||||
FROM chat_messages cm
|
||||
JOIN chats ch ON ch.id = cm.chat_id
|
||||
WHERE cm.total_cost_micros IS NOT NULL
|
||||
GROUP BY COALESCE(ch.root_chat_id, ch.id)
|
||||
) cc ON cc.root_id = COALESCE(c.root_chat_id, c.id)
|
||||
WHERE cds.pull_request_state IS NOT NULL
|
||||
AND c.created_at >= @start_date::timestamptz
|
||||
AND c.created_at < @end_date::timestamptz
|
||||
AND (sqlc.narg('owner_id')::uuid IS NULL OR c.owner_id = sqlc.narg('owner_id')::uuid)
|
||||
GROUP BY cmc.id, cmc.display_name, cmc.provider
|
||||
ORDER BY total_prs DESC;
|
||||
|
||||
-- name: GetPRInsightsRecentPRs :many
|
||||
-- Returns individual PR rows with cost for the recent PRs table.
|
||||
SELECT
|
||||
c.id AS chat_id,
|
||||
cds.pull_request_title AS pr_title,
|
||||
cds.url AS pr_url,
|
||||
cds.pr_number,
|
||||
cds.pull_request_state AS state,
|
||||
cds.pull_request_draft AS draft,
|
||||
cds.additions,
|
||||
cds.deletions,
|
||||
cds.changed_files,
|
||||
cds.commits,
|
||||
cds.approved,
|
||||
cds.changes_requested,
|
||||
cds.reviewer_count,
|
||||
cds.author_login,
|
||||
cds.author_avatar_url,
|
||||
COALESCE(cds.base_branch, '')::text AS base_branch,
|
||||
COALESCE(cmc.display_name, cmc.model)::text AS model_display_name,
|
||||
COALESCE(cc.cost_micros, 0)::bigint AS cost_micros,
|
||||
c.created_at
|
||||
FROM chat_diff_statuses cds
|
||||
JOIN chats c ON c.id = cds.chat_id
|
||||
JOIN chat_model_configs cmc ON cmc.id = c.last_model_config_id
|
||||
LEFT JOIN (
|
||||
SELECT
|
||||
COALESCE(ch.root_chat_id, ch.id) AS root_id,
|
||||
COALESCE(SUM(cm.total_cost_micros), 0) AS cost_micros
|
||||
FROM chat_messages cm
|
||||
JOIN chats ch ON ch.id = cm.chat_id
|
||||
WHERE cm.total_cost_micros IS NOT NULL
|
||||
GROUP BY COALESCE(ch.root_chat_id, ch.id)
|
||||
) cc ON cc.root_id = COALESCE(c.root_chat_id, c.id)
|
||||
WHERE cds.pull_request_state IS NOT NULL
|
||||
AND c.created_at >= @start_date::timestamptz
|
||||
AND c.created_at < @end_date::timestamptz
|
||||
AND (sqlc.narg('owner_id')::uuid IS NULL OR c.owner_id = sqlc.narg('owner_id')::uuid)
|
||||
ORDER BY c.created_at DESC
|
||||
LIMIT @limit_val::int;
|
||||
@@ -40,6 +40,23 @@ WHERE
|
||||
ORDER BY
|
||||
created_at ASC;
|
||||
|
||||
-- name: GetChatMessagesByChatIDDescPaginated :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid
|
||||
AND CASE
|
||||
WHEN @before_id::bigint > 0 THEN id < @before_id::bigint
|
||||
ELSE true
|
||||
END
|
||||
AND visibility IN ('user', 'both')
|
||||
ORDER BY
|
||||
id DESC
|
||||
LIMIT
|
||||
COALESCE(NULLIF(@limit_val::int, 0), 50);
|
||||
|
||||
-- name: GetChatMessagesForPromptByChatID :many
|
||||
WITH latest_compressed_summary AS (
|
||||
SELECT
|
||||
@@ -96,13 +113,16 @@ ORDER BY
|
||||
created_at ASC,
|
||||
id ASC;
|
||||
|
||||
-- name: GetChatsByOwnerID :many
|
||||
-- name: GetChats :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
owner_id = @owner_id::uuid
|
||||
CASE
|
||||
WHEN @owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN chats.owner_id = @owner_id
|
||||
ELSE true
|
||||
END
|
||||
AND CASE
|
||||
WHEN sqlc.narg('archived') :: boolean IS NULL THEN true
|
||||
ELSE chats.archived = sqlc.narg('archived') :: boolean
|
||||
@@ -126,6 +146,8 @@ WHERE
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedChats
|
||||
-- @authorize_filter
|
||||
ORDER BY
|
||||
-- Deterministic and consistent ordering of all rows, even if they share
|
||||
-- a timestamp. This is to ensure consistent pagination.
|
||||
@@ -183,7 +205,8 @@ INSERT INTO chat_messages (
|
||||
cache_read_tokens,
|
||||
context_limit,
|
||||
compressed,
|
||||
total_cost_micros
|
||||
total_cost_micros,
|
||||
runtime_ms
|
||||
) VALUES (
|
||||
@chat_id::uuid,
|
||||
sqlc.narg('created_by')::uuid,
|
||||
@@ -200,7 +223,8 @@ INSERT INTO chat_messages (
|
||||
sqlc.narg('cache_read_tokens')::bigint,
|
||||
sqlc.narg('context_limit')::bigint,
|
||||
COALESCE(sqlc.narg('compressed')::boolean, FALSE),
|
||||
sqlc.narg('total_cost_micros')::bigint
|
||||
sqlc.narg('total_cost_micros')::bigint,
|
||||
sqlc.narg('runtime_ms')::bigint
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
@@ -683,3 +707,128 @@ LIMIT
|
||||
sqlc.arg('page_limit')::int
|
||||
OFFSET
|
||||
sqlc.arg('page_offset')::int;
|
||||
|
||||
-- name: GetChatUsageLimitConfig :one
|
||||
SELECT * FROM chat_usage_limit_config WHERE singleton = TRUE LIMIT 1;
|
||||
|
||||
-- name: UpsertChatUsageLimitConfig :one
|
||||
INSERT INTO chat_usage_limit_config (singleton, enabled, default_limit_micros, period, updated_at)
|
||||
VALUES (TRUE, @enabled::boolean, @default_limit_micros::bigint, @period::text, NOW())
|
||||
ON CONFLICT (singleton) DO UPDATE SET
|
||||
enabled = EXCLUDED.enabled,
|
||||
default_limit_micros = EXCLUDED.default_limit_micros,
|
||||
period = EXCLUDED.period,
|
||||
updated_at = NOW()
|
||||
RETURNING *;
|
||||
|
||||
-- name: ListChatUsageLimitOverrides :many
|
||||
SELECT u.id AS user_id, u.username, u.name, u.avatar_url,
|
||||
u.chat_spend_limit_micros AS spend_limit_micros
|
||||
FROM users u
|
||||
WHERE u.chat_spend_limit_micros IS NOT NULL
|
||||
ORDER BY u.username ASC;
|
||||
|
||||
-- name: UpsertChatUsageLimitUserOverride :one
|
||||
UPDATE users
|
||||
SET chat_spend_limit_micros = @spend_limit_micros::bigint
|
||||
WHERE id = @user_id::uuid
|
||||
RETURNING id AS user_id, username, name, avatar_url, chat_spend_limit_micros AS spend_limit_micros;
|
||||
|
||||
-- name: DeleteChatUsageLimitUserOverride :exec
|
||||
UPDATE users SET chat_spend_limit_micros = NULL WHERE id = @user_id::uuid;
|
||||
|
||||
-- name: GetChatUsageLimitUserOverride :one
|
||||
SELECT id AS user_id, chat_spend_limit_micros AS spend_limit_micros
|
||||
FROM users
|
||||
WHERE id = @user_id::uuid AND chat_spend_limit_micros IS NOT NULL;
|
||||
|
||||
-- name: GetUserChatSpendInPeriod :one
|
||||
SELECT COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_spend_micros
|
||||
FROM chat_messages cm
|
||||
JOIN chats c ON c.id = cm.chat_id
|
||||
WHERE c.owner_id = @user_id::uuid
|
||||
AND cm.created_at >= @start_time::timestamptz
|
||||
AND cm.created_at < @end_time::timestamptz
|
||||
AND cm.total_cost_micros IS NOT NULL;
|
||||
|
||||
-- name: CountEnabledModelsWithoutPricing :one
|
||||
-- Counts enabled, non-deleted model configs that lack both input and
|
||||
-- output pricing in their JSONB options.cost configuration.
|
||||
SELECT COUNT(*)::bigint AS count
|
||||
FROM chat_model_configs
|
||||
WHERE enabled = TRUE
|
||||
AND deleted = FALSE
|
||||
AND (
|
||||
options->'cost' IS NULL
|
||||
OR options->'cost' = 'null'::jsonb
|
||||
OR (
|
||||
(options->'cost'->>'input_price_per_million_tokens' IS NULL)
|
||||
AND (options->'cost'->>'output_price_per_million_tokens' IS NULL)
|
||||
)
|
||||
);
|
||||
|
||||
-- name: ListChatUsageLimitGroupOverrides :many
|
||||
SELECT
|
||||
g.id AS group_id,
|
||||
g.name AS group_name,
|
||||
g.display_name AS group_display_name,
|
||||
g.avatar_url AS group_avatar_url,
|
||||
g.chat_spend_limit_micros AS spend_limit_micros,
|
||||
(SELECT COUNT(*)
|
||||
FROM group_members_expanded gme
|
||||
WHERE gme.group_id = g.id
|
||||
AND gme.user_is_system = FALSE) AS member_count
|
||||
FROM groups g
|
||||
WHERE g.chat_spend_limit_micros IS NOT NULL
|
||||
ORDER BY g.name ASC;
|
||||
|
||||
-- name: UpsertChatUsageLimitGroupOverride :one
|
||||
UPDATE groups
|
||||
SET chat_spend_limit_micros = @spend_limit_micros::bigint
|
||||
WHERE id = @group_id::uuid
|
||||
RETURNING id AS group_id, name, display_name, avatar_url, chat_spend_limit_micros AS spend_limit_micros;
|
||||
|
||||
-- name: DeleteChatUsageLimitGroupOverride :exec
|
||||
UPDATE groups SET chat_spend_limit_micros = NULL WHERE id = @group_id::uuid;
|
||||
|
||||
-- name: GetChatUsageLimitGroupOverride :one
|
||||
SELECT id AS group_id, chat_spend_limit_micros AS spend_limit_micros
|
||||
FROM groups
|
||||
WHERE id = @group_id::uuid AND chat_spend_limit_micros IS NOT NULL;
|
||||
|
||||
-- name: GetUserGroupSpendLimit :one
|
||||
-- Returns the minimum (most restrictive) group limit for a user.
|
||||
-- Returns -1 if the user has no group limits applied.
|
||||
SELECT COALESCE(MIN(g.chat_spend_limit_micros), -1)::bigint AS limit_micros
|
||||
FROM groups g
|
||||
JOIN group_members_expanded gme ON gme.group_id = g.id
|
||||
WHERE gme.user_id = @user_id::uuid
|
||||
AND g.chat_spend_limit_micros IS NOT NULL;
|
||||
|
||||
-- name: ResolveUserChatSpendLimit :one
|
||||
-- Resolves the effective spend limit for a user using the hierarchy:
|
||||
-- 1. Individual user override (highest priority)
|
||||
-- 2. Minimum group limit across all user's groups
|
||||
-- 3. Global default from config
|
||||
-- Returns -1 if limits are not enabled.
|
||||
SELECT CASE
|
||||
-- If limits are disabled, return -1.
|
||||
WHEN NOT cfg.enabled THEN -1
|
||||
-- Individual override takes priority.
|
||||
WHEN u.chat_spend_limit_micros IS NOT NULL THEN u.chat_spend_limit_micros
|
||||
-- Group limit (minimum across all user's groups) is next.
|
||||
WHEN gl.limit_micros IS NOT NULL THEN gl.limit_micros
|
||||
-- Fall back to global default.
|
||||
ELSE cfg.default_limit_micros
|
||||
END::bigint AS effective_limit_micros
|
||||
FROM chat_usage_limit_config cfg
|
||||
CROSS JOIN users u
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT MIN(g.chat_spend_limit_micros) AS limit_micros
|
||||
FROM groups g
|
||||
JOIN group_members_expanded gme ON gme.group_id = g.id
|
||||
WHERE gme.user_id = @user_id::uuid
|
||||
AND g.chat_spend_limit_micros IS NOT NULL
|
||||
) gl ON TRUE
|
||||
WHERE u.id = @user_id::uuid
|
||||
LIMIT 1;
|
||||
|
||||
@@ -147,7 +147,7 @@ WHERE
|
||||
UPDATE
|
||||
organizations
|
||||
SET
|
||||
workspace_sharing_disabled = @workspace_sharing_disabled,
|
||||
shareable_workspace_owners = @shareable_workspace_owners,
|
||||
updated_at = @updated_at
|
||||
WHERE
|
||||
id = @id
|
||||
|
||||
@@ -140,3 +140,23 @@ SELECT
|
||||
-- name: UpsertChatSystemPrompt :exec
|
||||
INSERT INTO site_configs (key, value) VALUES ('agents_chat_system_prompt', $1)
|
||||
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_system_prompt';
|
||||
|
||||
-- name: GetChatDesktopEnabled :one
|
||||
SELECT
|
||||
COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_desktop_enabled'), false) :: boolean AS enable_desktop;
|
||||
|
||||
-- name: UpsertChatDesktopEnabled :exec
|
||||
INSERT INTO site_configs (key, value)
|
||||
VALUES (
|
||||
'agents_desktop_enabled',
|
||||
CASE
|
||||
WHEN sqlc.arg(enable_desktop)::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET value = CASE
|
||||
WHEN sqlc.arg(enable_desktop)::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
WHERE site_configs.key = 'agents_desktop_enabled';
|
||||
|
||||
@@ -15,6 +15,11 @@ VALUES
|
||||
(@id, @event_type, @event_data, @created_at, NULL, NULL, NULL)
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
|
||||
-- name: UsageEventExistsByID :one
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM usage_events WHERE id = @id
|
||||
)::bool;
|
||||
|
||||
-- name: SelectUsageEventsForPublishing :many
|
||||
WITH usage_events AS (
|
||||
UPDATE
|
||||
|
||||
@@ -391,9 +391,21 @@ SELECT
|
||||
array_agg(org_roles || ':' || organization_members.organization_id::text)
|
||||
FROM
|
||||
organization_members,
|
||||
-- All org_members get the organization-member role for their orgs
|
||||
-- All org members get an implied role for their orgs. Most members
|
||||
-- get organization-member, but service accounts will get
|
||||
-- organization-service-account instead. They're largely the same,
|
||||
-- but having them be distinct means we can allow configuring
|
||||
-- service-accounts to have slightly broader permissions–such as
|
||||
-- for workspace sharing.
|
||||
unnest(
|
||||
array_append(roles, 'organization-member')
|
||||
array_append(
|
||||
roles,
|
||||
CASE WHEN users.is_service_account THEN
|
||||
'organization-service-account'
|
||||
ELSE
|
||||
'organization-member'
|
||||
END
|
||||
)
|
||||
) AS org_roles
|
||||
WHERE
|
||||
user_id = users.id
|
||||
|
||||
@@ -955,7 +955,13 @@ SET
|
||||
group_acl = '{}'::jsonb,
|
||||
user_acl = '{}'::jsonb
|
||||
WHERE
|
||||
organization_id = @organization_id;
|
||||
organization_id = @organization_id
|
||||
AND (
|
||||
NOT @exclude_service_accounts::boolean
|
||||
OR owner_id NOT IN (
|
||||
SELECT id FROM users WHERE is_service_account = true
|
||||
)
|
||||
);
|
||||
|
||||
-- name: GetRegularWorkspaceCreateMetrics :many
|
||||
-- Count regular workspaces: only those whose first successful 'start' build
|
||||
|
||||
@@ -235,6 +235,7 @@ sql:
|
||||
aibridge_tool_usage: AIBridgeToolUsage
|
||||
aibridge_token_usage: AIBridgeTokenUsage
|
||||
aibridge_user_prompt: AIBridgeUserPrompt
|
||||
aibridge_model_thought: AIBridgeModelThought
|
||||
rules:
|
||||
- name: do-not-use-public-schema-in-queries
|
||||
message: "do not use public schema in queries"
|
||||
|
||||
@@ -7,6 +7,7 @@ type UniqueConstraint string
|
||||
// UniqueConstraint enums.
|
||||
const (
|
||||
UniqueAgentStatsPkey UniqueConstraint = "agent_stats_pkey" // ALTER TABLE ONLY workspace_agent_stats ADD CONSTRAINT agent_stats_pkey PRIMARY KEY (id);
|
||||
UniqueAiSeatStatePkey UniqueConstraint = "ai_seat_state_pkey" // ALTER TABLE ONLY ai_seat_state ADD CONSTRAINT ai_seat_state_pkey PRIMARY KEY (user_id);
|
||||
UniqueAibridgeInterceptionsPkey UniqueConstraint = "aibridge_interceptions_pkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_pkey PRIMARY KEY (id);
|
||||
UniqueAibridgeTokenUsagesPkey UniqueConstraint = "aibridge_token_usages_pkey" // ALTER TABLE ONLY aibridge_token_usages ADD CONSTRAINT aibridge_token_usages_pkey PRIMARY KEY (id);
|
||||
UniqueAibridgeToolUsagesPkey UniqueConstraint = "aibridge_tool_usages_pkey" // ALTER TABLE ONLY aibridge_tool_usages ADD CONSTRAINT aibridge_tool_usages_pkey PRIMARY KEY (id);
|
||||
@@ -21,6 +22,8 @@ const (
|
||||
UniqueChatProvidersPkey UniqueConstraint = "chat_providers_pkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id);
|
||||
UniqueChatProvidersProviderKey UniqueConstraint = "chat_providers_provider_key" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider);
|
||||
UniqueChatQueuedMessagesPkey UniqueConstraint = "chat_queued_messages_pkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
|
||||
UniqueChatUsageLimitConfigPkey UniqueConstraint = "chat_usage_limit_config_pkey" // ALTER TABLE ONLY chat_usage_limit_config ADD CONSTRAINT chat_usage_limit_config_pkey PRIMARY KEY (id);
|
||||
UniqueChatUsageLimitConfigSingletonKey UniqueConstraint = "chat_usage_limit_config_singleton_key" // ALTER TABLE ONLY chat_usage_limit_config ADD CONSTRAINT chat_usage_limit_config_singleton_key UNIQUE (singleton);
|
||||
UniqueChatsPkey UniqueConstraint = "chats_pkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_pkey PRIMARY KEY (id);
|
||||
UniqueConnectionLogsPkey UniqueConstraint = "connection_logs_pkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_pkey PRIMARY KEY (id);
|
||||
UniqueCryptoKeysPkey UniqueConstraint = "crypto_keys_pkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_pkey PRIMARY KEY (feature, sequence);
|
||||
|
||||
@@ -48,8 +48,8 @@ type Store interface {
|
||||
UpsertChatDiffStatusReference(
|
||||
ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams,
|
||||
) (database.ChatDiffStatus, error)
|
||||
GetChatsByOwnerID(
|
||||
ctx context.Context, arg database.GetChatsByOwnerIDParams,
|
||||
GetChats(
|
||||
ctx context.Context, arg database.GetChatsParams,
|
||||
) ([]database.Chat, error)
|
||||
}
|
||||
|
||||
@@ -250,7 +250,7 @@ func (w *Worker) MarkStale(
|
||||
return
|
||||
}
|
||||
|
||||
chats, err := w.store.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
|
||||
chats, err := w.store.GetChats(ctx, database.GetChatsParams{
|
||||
OwnerID: ownerID,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -469,8 +469,8 @@ func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
|
||||
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
|
||||
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
||||
require.Equal(t, ownerID, arg.OwnerID)
|
||||
return []database.Chat{
|
||||
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
|
||||
@@ -478,13 +478,12 @@ func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) {
|
||||
{ID: chatOther, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
|
||||
}, nil
|
||||
})
|
||||
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
||||
mu.Lock()
|
||||
upsertRefCalls = append(upsertRefCalls, arg)
|
||||
mu.Unlock()
|
||||
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
||||
}).Times(2)
|
||||
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
||||
mu.Lock()
|
||||
upsertRefCalls = append(upsertRefCalls, arg)
|
||||
mu.Unlock()
|
||||
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
||||
}).Times(2)
|
||||
|
||||
pub := func(_ context.Context, chatID uuid.UUID) error {
|
||||
mu.Lock()
|
||||
@@ -527,7 +526,7 @@ func TestWorker_MarkStale_NoMatchingChats(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
|
||||
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
|
||||
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
|
||||
Return([]database.Chat{
|
||||
{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
|
||||
{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
|
||||
@@ -555,7 +554,7 @@ func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
|
||||
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
|
||||
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
|
||||
Return([]database.Chat{
|
||||
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
|
||||
{ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
|
||||
@@ -590,7 +589,7 @@ func TestWorker_MarkStale_GetChatsFails(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
|
||||
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
|
||||
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
|
||||
Return(nil, fmt.Errorf("db error"))
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
|
||||
@@ -136,6 +136,12 @@ func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool
|
||||
Properties: sdkTool.Schema.Properties,
|
||||
Required: sdkTool.Schema.Required,
|
||||
},
|
||||
Annotations: mcp.ToolAnnotation{
|
||||
ReadOnlyHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.ReadOnlyHint),
|
||||
DestructiveHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.DestructiveHint),
|
||||
IdempotentHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.IdempotentHint),
|
||||
OpenWorldHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.OpenWorldHint),
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
@@ -91,21 +91,41 @@ func TestMCPHTTP_E2E_ClientIntegration(t *testing.T) {
|
||||
|
||||
// Verify we have some expected Coder tools
|
||||
var foundTools []string
|
||||
for _, tool := range tools.Tools {
|
||||
var userTool *mcp.Tool
|
||||
var writeFileTool *mcp.Tool
|
||||
for i := range tools.Tools {
|
||||
tool := tools.Tools[i]
|
||||
foundTools = append(foundTools, tool.Name)
|
||||
switch tool.Name {
|
||||
case toolsdk.ToolNameGetAuthenticatedUser:
|
||||
userTool = &tools.Tools[i]
|
||||
case toolsdk.ToolNameWorkspaceWriteFile:
|
||||
writeFileTool = &tools.Tools[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Check for some basic tools that should be available
|
||||
assert.Contains(t, foundTools, toolsdk.ToolNameGetAuthenticatedUser, "Should have authenticated user tool")
|
||||
require.NotNil(t, userTool)
|
||||
require.NotNil(t, writeFileTool)
|
||||
require.NotNil(t, userTool.Annotations.ReadOnlyHint)
|
||||
require.NotNil(t, userTool.Annotations.DestructiveHint)
|
||||
require.NotNil(t, userTool.Annotations.IdempotentHint)
|
||||
require.NotNil(t, userTool.Annotations.OpenWorldHint)
|
||||
assert.True(t, *userTool.Annotations.ReadOnlyHint)
|
||||
assert.False(t, *userTool.Annotations.DestructiveHint)
|
||||
assert.True(t, *userTool.Annotations.IdempotentHint)
|
||||
assert.False(t, *userTool.Annotations.OpenWorldHint)
|
||||
require.NotNil(t, writeFileTool.Annotations.ReadOnlyHint)
|
||||
require.NotNil(t, writeFileTool.Annotations.DestructiveHint)
|
||||
require.NotNil(t, writeFileTool.Annotations.IdempotentHint)
|
||||
require.NotNil(t, writeFileTool.Annotations.OpenWorldHint)
|
||||
assert.False(t, *writeFileTool.Annotations.ReadOnlyHint)
|
||||
assert.True(t, *writeFileTool.Annotations.DestructiveHint)
|
||||
assert.False(t, *writeFileTool.Annotations.IdempotentHint)
|
||||
assert.False(t, *writeFileTool.Annotations.OpenWorldHint)
|
||||
|
||||
// Find and execute the authenticated user tool
|
||||
var userTool *mcp.Tool
|
||||
for _, tool := range tools.Tools {
|
||||
if tool.Name == toolsdk.ToolNameGetAuthenticatedUser {
|
||||
userTool = &tool
|
||||
break
|
||||
}
|
||||
}
|
||||
// Execute the authenticated user tool.
|
||||
require.NotNil(t, userTool, "Expected to find "+toolsdk.ToolNameGetAuthenticatedUser+" tool")
|
||||
|
||||
// Execute the tool
|
||||
|
||||
@@ -34,6 +34,7 @@ const (
|
||||
ServiceAgentMetricAggregator = "agent-metrics-aggregator"
|
||||
// ServiceTallymanPublisher publishes usage events to coder/tallyman.
|
||||
ServiceTallymanPublisher = "tallyman-publisher"
|
||||
ServiceUsageEventCron = "usage-event-cron"
|
||||
|
||||
RequestTypeTag = "coder_request_type"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
package provisionerdserver_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/provisionerdserver"
|
||||
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
|
||||
)
|
||||
|
||||
func TestMergeExtraEnvs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initial map[string]string
|
||||
envs []*sdkproto.Env
|
||||
expected map[string]string
|
||||
expectErr string
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
initial: map[string]string{},
|
||||
envs: nil,
|
||||
expected: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "default_replace",
|
||||
initial: map[string]string{},
|
||||
envs: []*sdkproto.Env{
|
||||
{Name: "FOO", Value: "bar"},
|
||||
},
|
||||
expected: map[string]string{"FOO": "bar"},
|
||||
},
|
||||
{
|
||||
name: "explicit_replace",
|
||||
initial: map[string]string{"FOO": "old"},
|
||||
envs: []*sdkproto.Env{
|
||||
{Name: "FOO", Value: "new", MergeStrategy: "replace"},
|
||||
},
|
||||
expected: map[string]string{"FOO": "new"},
|
||||
},
|
||||
{
|
||||
name: "empty_strategy_defaults_to_replace",
|
||||
initial: map[string]string{"FOO": "old"},
|
||||
envs: []*sdkproto.Env{
|
||||
{Name: "FOO", Value: "new", MergeStrategy: ""},
|
||||
},
|
||||
expected: map[string]string{"FOO": "new"},
|
||||
},
|
||||
{
|
||||
name: "append_to_existing",
|
||||
initial: map[string]string{"PATH": "/usr/bin"},
|
||||
envs: []*sdkproto.Env{
|
||||
{Name: "PATH", Value: "/custom/bin", MergeStrategy: "append"},
|
||||
},
|
||||
expected: map[string]string{"PATH": "/usr/bin:/custom/bin"},
|
||||
},
|
||||
{
|
||||
name: "append_no_existing",
|
||||
initial: map[string]string{},
|
||||
envs: []*sdkproto.Env{
|
||||
{Name: "PATH", Value: "/custom/bin", MergeStrategy: "append"},
|
||||
},
|
||||
expected: map[string]string{"PATH": "/custom/bin"},
|
||||
},
|
||||
{
|
||||
name: "append_to_empty_value",
|
||||
initial: map[string]string{"PATH": ""},
|
||||
envs: []*sdkproto.Env{
|
||||
{Name: "PATH", Value: "/custom/bin", MergeStrategy: "append"},
|
||||
},
|
||||
expected: map[string]string{"PATH": "/custom/bin"},
|
||||
},
|
||||
{
|
||||
name: "prepend_to_existing",
|
||||
initial: map[string]string{"PATH": "/usr/bin"},
|
||||
envs: []*sdkproto.Env{
|
||||
{Name: "PATH", Value: "/custom/bin", MergeStrategy: "prepend"},
|
||||
},
|
||||
expected: map[string]string{"PATH": "/custom/bin:/usr/bin"},
|
||||
},
|
||||
{
|
||||
name: "prepend_no_existing",
|
||||
initial: map[string]string{},
|
||||
envs: []*sdkproto.Env{
|
||||
{Name: "PATH", Value: "/custom/bin", MergeStrategy: "prepend"},
|
||||
},
|
||||
expected: map[string]string{"PATH": "/custom/bin"},
|
||||
},
|
||||
{
|
||||
name: "error_no_duplicate",
|
||||
initial: map[string]string{},
|
||||
envs: []*sdkproto.Env{
|
||||
{Name: "FOO", Value: "bar", MergeStrategy: "error"},
|
||||
},
|
||||
expected: map[string]string{"FOO": "bar"},
|
||||
},
|
||||
{
|
||||
name: "error_with_duplicate",
|
||||
initial: map[string]string{"FOO": "existing"},
|
||||
envs: []*sdkproto.Env{
|
||||
{Name: "FOO", Value: "new", MergeStrategy: "error"},
|
||||
},
|
||||
expectErr: "duplicate env var",
|
||||
},
|
||||
{
|
||||
name: "multiple_appends_same_key",
|
||||
initial: map[string]string{},
|
||||
envs: []*sdkproto.Env{
|
||||
{Name: "PATH", Value: "/a/bin", MergeStrategy: "append"},
|
||||
{Name: "PATH", Value: "/b/bin", MergeStrategy: "append"},
|
||||
},
|
||||
expected: map[string]string{"PATH": "/a/bin:/b/bin"},
|
||||
},
|
||||
{
|
||||
name: "multiple_prepends_same_key",
|
||||
initial: map[string]string{},
|
||||
envs: []*sdkproto.Env{
|
||||
{Name: "PATH", Value: "/a/bin", MergeStrategy: "prepend"},
|
||||
{Name: "PATH", Value: "/b/bin", MergeStrategy: "prepend"},
|
||||
},
|
||||
expected: map[string]string{"PATH": "/b/bin:/a/bin"},
|
||||
},
|
||||
{
|
||||
name: "mixed_strategies",
|
||||
initial: map[string]string{},
|
||||
envs: []*sdkproto.Env{
|
||||
{Name: "PATH", Value: "/first", MergeStrategy: "append"},
|
||||
{Name: "PATH", Value: "/override", MergeStrategy: "replace"},
|
||||
},
|
||||
expected: map[string]string{"PATH": "/override"},
|
||||
},
|
||||
{
|
||||
name: "mixed_keys",
|
||||
initial: map[string]string{},
|
||||
envs: []*sdkproto.Env{
|
||||
{Name: "PATH", Value: "/a", MergeStrategy: "append"},
|
||||
{Name: "HOME", Value: "/home/user"},
|
||||
{Name: "PATH", Value: "/b", MergeStrategy: "append"},
|
||||
},
|
||||
expected: map[string]string{
|
||||
"PATH": "/a:/b",
|
||||
"HOME": "/home/user",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
env := make(map[string]string)
|
||||
for k, v := range tc.initial {
|
||||
env[k] = v
|
||||
}
|
||||
err := provisionerdserver.MergeExtraEnvs(env, tc.envs)
|
||||
if tc.expectErr != "" {
|
||||
require.ErrorContains(t, err, tc.expectErr)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.expected, env)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
protobuf "google.golang.org/protobuf/proto"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/aiseats"
|
||||
"github.com/coder/coder/v2/coderd/apikey"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -76,6 +77,7 @@ const (
|
||||
type Options struct {
|
||||
OIDCConfig promoauth.OAuth2Config
|
||||
ExternalAuthConfigs []*externalauth.Config
|
||||
AISeatTracker aiseats.SeatTracker
|
||||
|
||||
// Clock for testing
|
||||
Clock quartz.Clock
|
||||
@@ -120,6 +122,7 @@ type server struct {
|
||||
NotificationsEnqueuer notifications.Enqueuer
|
||||
PrebuildsOrchestrator *atomic.Pointer[prebuilds.ReconciliationOrchestrator]
|
||||
UsageInserter *atomic.Pointer[usage.Inserter]
|
||||
AISeatTracker aiseats.SeatTracker
|
||||
Experiments codersdk.Experiments
|
||||
|
||||
OIDCConfig promoauth.OAuth2Config
|
||||
@@ -215,6 +218,9 @@ func NewServer(
|
||||
if err := tags.Valid(); err != nil {
|
||||
return nil, xerrors.Errorf("invalid tags: %w", err)
|
||||
}
|
||||
if options.AISeatTracker == nil {
|
||||
options.AISeatTracker = aiseats.Noop{}
|
||||
}
|
||||
if options.AcquireJobLongPollDur == 0 {
|
||||
options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur
|
||||
}
|
||||
@@ -253,6 +259,7 @@ func NewServer(
|
||||
heartbeatFn: options.HeartbeatFn,
|
||||
PrebuildsOrchestrator: prebuildsOrchestrator,
|
||||
UsageInserter: usageInserter,
|
||||
AISeatTracker: options.AISeatTracker,
|
||||
metrics: metrics,
|
||||
Experiments: experiments,
|
||||
}
|
||||
@@ -2437,6 +2444,12 @@ func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.Pro
|
||||
})
|
||||
}
|
||||
|
||||
// Record AI seat usage for successful task workspace builds.
|
||||
if workspaceBuild.Transition == database.WorkspaceTransitionStart && workspace.TaskID.Valid {
|
||||
s.AISeatTracker.RecordUsage(ctx, workspace.OwnerID,
|
||||
aiseats.ReasonTask("task workspace build succeeded"))
|
||||
}
|
||||
|
||||
if s.PrebuildsOrchestrator != nil && input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
|
||||
// Track resource replacements, if there are any.
|
||||
orchestrator := s.PrebuildsOrchestrator.Load()
|
||||
@@ -2821,12 +2834,11 @@ func InsertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
|
||||
}
|
||||
|
||||
env := make(map[string]string)
|
||||
// For now, we only support adding extra envs, not overriding
|
||||
// existing ones or performing other manipulations. In future
|
||||
// we may write these to a separate table so we can perform
|
||||
// conditional logic on the agent.
|
||||
for _, e := range prAgent.ExtraEnvs {
|
||||
env[e.Name] = e.Value
|
||||
// Apply extra envs with merge strategy support.
|
||||
// When multiple coder_env resources define the same name,
|
||||
// the merge_strategy controls how values are combined.
|
||||
if err := MergeExtraEnvs(env, prAgent.ExtraEnvs); err != nil {
|
||||
return err
|
||||
}
|
||||
// Allow the agent defined envs to override extra envs.
|
||||
for k, v := range prAgent.Env {
|
||||
@@ -3422,14 +3434,54 @@ func insertDevcontainerSubagent(
|
||||
return subAgentID, nil
|
||||
}
|
||||
|
||||
// MergeExtraEnvs applies extra environment variables to the given map,
|
||||
// respecting the merge_strategy field on each env. When merge_strategy
|
||||
// is empty or "replace", the value overwrites any existing entry.
|
||||
// "append" and "prepend" join values with a ":" separator (PATH-style).
|
||||
// "error" causes a failure if the key already exists.
|
||||
func MergeExtraEnvs(env map[string]string, extraEnvs []*sdkproto.Env) error {
|
||||
for _, e := range extraEnvs {
|
||||
strategy := e.GetMergeStrategy()
|
||||
if strategy == "" {
|
||||
strategy = "replace"
|
||||
}
|
||||
existing, exists := env[e.GetName()]
|
||||
switch strategy {
|
||||
case "error":
|
||||
if exists {
|
||||
return xerrors.Errorf(
|
||||
"duplicate env var %q: merge_strategy is %q but variable is already defined",
|
||||
e.GetName(), strategy,
|
||||
)
|
||||
}
|
||||
env[e.GetName()] = e.GetValue()
|
||||
case "append":
|
||||
if exists && existing != "" {
|
||||
env[e.GetName()] = existing + ":" + e.GetValue()
|
||||
} else {
|
||||
env[e.GetName()] = e.GetValue()
|
||||
}
|
||||
case "prepend":
|
||||
if exists && existing != "" {
|
||||
env[e.GetName()] = e.GetValue() + ":" + existing
|
||||
} else {
|
||||
env[e.GetName()] = e.GetValue()
|
||||
}
|
||||
default: // "replace"
|
||||
env[e.GetName()] = e.GetValue()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodeSubagentEnvs(envs []*sdkproto.Env) (pqtype.NullRawMessage, error) {
|
||||
if len(envs) == 0 {
|
||||
return pqtype.NullRawMessage{}, nil
|
||||
}
|
||||
|
||||
subAgentEnvs := make(map[string]string, len(envs))
|
||||
for _, env := range envs {
|
||||
subAgentEnvs[env.GetName()] = env.GetValue()
|
||||
if err := MergeExtraEnvs(subAgentEnvs, envs); err != nil {
|
||||
return pqtype.NullRawMessage{}, err
|
||||
}
|
||||
|
||||
data, err := json.Marshal(subAgentEnvs)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user