Compare commits

..

94 Commits

Author SHA1 Message Date
Kyle Carberry 986d6a856d perf(site): bypass React re-renders during panel drag-resize
The main bottleneck was setWidth() firing on every pointermove,
which re-rendered RightPanel → SidebarTabView → GitPanel →
DiffViewer → all FileDiff components on every pixel of movement.

Now during drag, the --panel-width CSS custom property is set
directly on the DOM via panelRef, skipping React reconciliation
entirely. React state is only committed on pointerup for
localStorage persistence.

With this change the only React re-renders during a normal drag
are: one on the first pointermove (dragSnap null → normal) and
one on pointerup. Everything in between is pure DOM mutation.
2026-03-18 17:00:44 +00:00
Kyle Carberry 351ab6c7c7 chore: apply biome formatting fixes 2026-03-18 16:52:19 +00:00
Kyle Carberry d3d0ea3622 feat(site): add StressLargeDiff story for manual perf testing
Adds a story with 100 chat messages and a 10,000-line diff (50 files
x 200 lines each) with the sidebar panel open. Useful for manually
verifying that drag-resize and scroll feel smooth.
2026-03-18 16:36:51 +00:00
Kyle Carberry c0b71a1161 fix(site): remove spacer div from DiffViewer 2026-03-18 16:34:40 +00:00
Kyle Carberry 828c9e23f5 fix(site): reduce diff panel lag during drag-resize with CSS containment
Add three CSS-level optimizations to improve responsiveness when
dragging the right panel resize handle with large diffs open:

- content-visibility: auto on each file diff wrapper in DiffViewer so
  the browser skips layout/paint for off-screen diffs entirely during
  resize and scroll.
- contain: layout style paint on the RightPanel children container and
  SidebarTabView tab panel to isolate the diff subtree from external
  layout recalculation.
- pointer-events: none on panel children during active drag to
  eliminate hit-testing against the expensive Shadow DOM elements from
  @pierre/diffs.
2026-03-18 16:32:42 +00:00
Kyle Carberry a130a7dc97 fix: renumber duplicate migration 000444 to 000445 (#23229)
Two migrations were merged with the same number 000444:
- `000444_usage_events_ai_seats` (#22689, merged first at 09:30) — keeps
000444
- `000444_chat_message_runtime_ms` (#23219, merged second at 10:57) —
renumbered to **000445**

This collision causes `golang-migrate` to fail at runtime since it reads
both files as the same version.

**Fix:** Rename `000444_chat_message_runtime_ms.{up,down}.sql` →
`000445_chat_message_runtime_ms.{up,down}.sql`.

Closes https://github.com/coder/internal/issues/1411
2026-03-18 11:30:33 -04:00
Kyle Carberry d6fef96d72 feat: add PR insights analytics dashboard (#23215)
## What

Adds a new admin-only **PR Insights** page for the `/agents` analytics
view — a dashboard for engineering leaders to understand code shipped by
AI agents.

### Backend
- `GET /api/v2/chats/insights/pull-requests` — admin-only endpoint
- 4 SQL queries in `chatinsights.sql` aggregating `chat_diff_statuses`
joined with chat cost data (via root chat tree rollup)
- Runs 5 parallel DB queries: current summary, previous summary (for
trends), time series, per-model breakdown, recent PRs
- SDK types auto-generate to TypeScript

### Frontend (`PRInsightsView`)
- **Stat cards**: PRs created, Merged, Merge rate, Lines shipped,
Cost/merged PR — with trend badges comparing to previous period
- **Activity chart**: Stacked area chart (created/merged/closed) using
git color tokens (`git-added-bright`, `git-merged-bright`,
`git-deleted-bright`)
- **Model performance table**: Per-model PR counts, inline merge rate
bars, diff stats, cost breakdown
- **Recent PRs table**: Status badges, review state icons, author info,
external links
- **Time range filter**: 7d/14d/30d/90d button group
- **4 Storybook stories**: Default, HighPerformance, LowVolume, NoPRs

### Data source
All PR data comes from the existing `chat_diff_statuses` table
(populated by the `gitsync.Worker` background job that polls GitHub
every 120s). No new data collection required.

### Screenshot
View in Storybook: `pages/AgentsPage/PRInsightsView`
2026-03-18 15:29:29 +00:00
Kyle Carberry 4dd8531f37 feat: track step runtime_ms on chat messages (#23219)
## Summary

Adds a `runtime_ms` column to `chat_messages` that records the
wall-clock duration (in milliseconds) of each LLM step. This covers LLM
streaming, tool execution, and retries — the full time the agent is
"alive" for a step.

This is the foundation for billing by agent alive time. The column
follows the same pattern as `total_cost_micros`: stored per assistant
message, aggregatable with `SUM()` over time periods by user.

## Changes

- **Migration**: adds nullable `runtime_ms bigint` to `chat_messages`.
- **chatloop**: adds `Runtime time.Duration` field to `PersistedStep`,
measures `time.Since(stepStart)` at the beginning of each step (covering
stream + tool execution + retries).
- **chatd**: passes `step.Runtime.Milliseconds()` to the assistant
message `InsertChatMessage` call; all other message types (system, user,
tool) get `NULL`.
- **Tests**: adds `runtime > 0` assertion in chatloop tests.

## Billing query pattern

Once ready, aggregation mirrors the existing cost queries:

```sql
SELECT COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms
FROM chat_messages cm
JOIN chats c ON c.id = cm.chat_id
WHERE c.owner_id = @user_id
  AND cm.created_at >= @start_time
  AND cm.created_at < @end_time
  AND cm.runtime_ms IS NOT NULL;
```
2026-03-18 10:57:35 -04:00
Danielle Maywood 3bcb7de7c0 fix(site): normalize chat message spacing for visual consistency (#23222) 2026-03-18 14:49:50 +00:00
Kacper Sawicki 1e07ec49a6 feat: add merge_strategy support for coder_env resources (#23107)
## Description

Implements the server-side merge logic for the `merge_strategy`
attribute added to `coder_env` in [terraform-provider-coder
v2.15.0](https://github.com/coder/terraform-provider-coder/pull/489).
This allows template authors to control how duplicate environment
variable names are combined across multiple `coder_env` resources.

Relates to https://github.com/coder/coder/issues/21885

## Supported strategies

| Strategy | Behavior |
|----------|----------|
| `replace` (default) | Last value wins — backward compatible |
| `append` | Joins values with `:` separator (e.g. PATH additions) |
| `prepend` | Prepends value with `:` separator |
| `error` | Fails the build if the variable is already defined |

## Example

```hcl
resource "coder_env" "path_tools" {
  agent_id       = coder_agent.dev.id
  name           = "PATH"
  value          = "/home/coder/tools/bin"
  merge_strategy = "append"
}
```

## Changes

- **Proto**: Added `merge_strategy` field to `Env` message in
`provisioner.proto`
- **State reader**: Updated `agentEnvAttributes` struct and proto
construction in `resources.go`
- **Merge logic**: Added `mergeExtraEnvs()` function in
`provisionerdserver.go` with strategy-aware merging for both agent envs
and devcontainer subagent envs
- **Tests**: 15 unit tests covering all strategies, edge cases (empty
values, mixed strategies, multiple appends)
- **Dependency**: Bumped `terraform-provider-coder` v2.14.0 → v2.15.0
- **Fixtures**: Updated `duplicate-env-keys` test fixtures and golden
files

## Ordering

When multiple resources `append` or `prepend` to the same key, they are
processed in alphabetical order by Terraform resource address (per the
determinism fix in #22706).
2026-03-18 15:43:28 +01:00
Steven Masley 84de391f26 chore: add tallyman events for ai seat tracking (#22689)
AI seat tracking inserted as heartbeat into usage table.
2026-03-18 09:30:22 -05:00
Kyle Carberry b83b93ea5c feat: add workspace awareness system message on chat creation (#23213)
When a chat is created via `chatd`, a system message is now inserted
informing the model whether the chat was created with or without a
workspace.

**With workspace:**
> This chat is attached to a workspace. You can use workspace tools like
execute, read_file, write_file, etc.

**Without workspace:**
> There is no workspace associated with this chat yet. Create one using
the create_workspace tool before using workspace tools like execute,
read_file, write_file, etc.

This is a model-only visibility system message (not shown to users) that
helps the model understand its available capabilities upfront —
particularly important for subagents spawned without a workspace, which
previously would attempt to use workspace tools and fail.

**Changes:**
- `coderd/chatd/chatd.go`: Added workspace awareness constants and
inserted the system message in `CreateChat` after the system prompt,
before the initial user message.
- `coderd/chatd/chatd_test.go`: Added
`TestCreateChatInsertsWorkspaceAwarenessMessage` with sub-tests for both
with-workspace and without-workspace cases.
2026-03-18 14:01:46 +00:00
Hugo Dutka 014e5b4f57 chore(site): remove experiment label from agents virtual desktop (#23217)
The "experiment" label is not needed since Coder Agents as a whole is an
experimental feature.
2026-03-18 13:55:30 +00:00
Ethan fc3508dc60 feat: configure acquire chat batch size (#23196)
## Summary
- add a hidden deployment config option for chat acquire batch size
(`CODER_CHAT_ACQUIRE_BATCH_SIZE` / `chat.acquireBatchSize`)
- thread the configured value into chatd startup while preserving the
existing default of `10`
- clamp the deployment value to the `int32` range before passing it into
chatd
- regenerate the API/docs/types/testdata artifacts for the new config
field

## Why
`chatd` currently acquires pending chats in batches of `10` via a
compile-time default. This change makes that batch size
operator-configurable from deployment config, so we can tune acquisition
behavior without another code change.
2026-03-19 00:54:32 +11:00
Mathias Fredriksson 8b4d35798a refactor: type both chat message parsers (#23176)
Both message parsers accepted untyped input and relied on scattered
asRecord/asString calls to extract fields at runtime. With the
discriminated ChatMessagePart union, both accept typed input directly
and narrow via switch (part.type).

parseMessageContent narrows from (content: unknown) to
(content: readonly ChatMessagePart[] | undefined), removing legacy
input shape handling the Go backend normalizes away.
applyMessagePartToStreamState narrows from Record<string, unknown>
to ChatMessagePart.

The SSE type guards had a & Record<string, unknown> intersection
that widened everything untyped downstream. Since the data comes
from our own API, the intersection was removed and all handlers in
ChatContext now use generated types directly.

Fixes tool_call_id and tool_name variant tags in codersdk/chats.go:
marked optional to match reality (Go guards against empty values,
omitempty omits them at the wire level).

Refs #23168, #23175
2026-03-18 15:50:57 +02:00
Danielle Maywood d69dcf18de fix: balance visual padding on agent chat sidebar items (#23211) 2026-03-18 13:44:37 +00:00
Cian Johnston fe82d0aeb9 fix: allow member users to generate support bundles (#23040)
Fixes AIGOV-141

The `coder support bundle` command previously required admin permissions
(`Read DeploymentConfig`) and would abort entirely for non-admin
`member` users with:

```
failed authorization check: cannot Read DeploymentValues
```

This change makes the command **degrade gracefully** instead of failing
outright.

<details>
<summary>
Changes
</summary>

### `support/support.go`
- **`Run()`**: The authorization check for `Read DeploymentValues` is
now a soft warning instead of a hard gate. Unauthenticated users (401)
still fail, but authenticated users with insufficient permissions
proceed with reduced data.
- **`DeploymentInfo()`**: `DeploymentConfig` and `DebugHealth` fetches
now handle 403/401 responses gracefully, matching the existing pattern
used by `DeploymentStats`, `Entitlements`, and `HealthSettings`.
- **`NetworkInfo()`**: Coordinator debug and tailnet debug fetches now
check response status codes for 403/401 before reading the body.

### `cli/support.go`
- **`summarizeBundle()`**: No longer returns early when `Config` or
`HealthReport` is nil. Instead prints warnings and continues summarizing
available data (e.g., netcheck).

### Tests
- `MissingPrivilege` → `MemberNoWorkspace`: Asserts member users can
generate a bundle successfully with degraded admin-only data.
- `NoPrivilege` → `MemberCanGenerateBundle`: Asserts the CLI produces a
valid zip bundle for member users.
- All existing tests continue to pass (`NoAuth`, `OK`, `OK_NoWorkspace`,
`DontPanic`, etc.).

## Behavior matrix

| User type | Before | After |
|---|---|---|
| **Admin** | Full bundle | Full bundle (no change) |
| **Member** | Hard error | Bundle with degraded admin-only data |
| **Unauthenticated** | Hard error | Hard error (no change) |

Related to PRODUCT-182
2026-03-18 13:43:10 +00:00
Ethan 81dba9da14 test: stabilize AgentsPageView analytics story date (#23216)
## Summary
The `AgentsPageView: Opens Analytics For Admins` story was flaky because
the analytics header renders a rolling 30-day date range in the
top-right corner. Since that range was based on the current date, the
story output changed every day.

This change makes the story deterministic by:
- adding an optional `analyticsNow` prop to `AgentsPageView`
- passing that value through to `AnalyticsPageContent` when the
analytics panel is shown
- setting a fixed local-noon timestamp in the story so the rendered
range label stays stable across timezones
2026-03-19 00:34:16 +11:00
Thomas Kosiewski 20ac96e68d feat(site): include chatId in editor deep links (#23214)
## Summary

- include the current agent chat ID in VS Code and Cursor deep links
opened from the agent detail page
- extend `getVSCodeHref` so `chatId` is added only when provided
- add focused tests for deep-link generation with and without `chatId`

## Testing

- `pnpm -C site run format -- src/modules/apps/apps.ts
src/modules/apps/apps.test.ts src/pages/AgentsPage/AgentDetail.tsx`
- `pnpm -C site run check -- src/modules/apps/apps.ts
src/modules/apps/apps.test.ts src/pages/AgentsPage/AgentDetail.tsx`
- `pnpm -C site exec vitest run src/modules/apps/apps.test.ts`
- `pnpm -C site run lint:types`

---
_Generated with [`mux`](https://github.com/coder/mux) • Model:
`openai:gpt-5.4` • Thinking: `high`_
2026-03-18 14:25:38 +01:00
Atif Ali 677f90b78a chore: label community PRs on open (#23157) 2026-03-18 18:15:37 +05:00
35C4n0r d697213373 feat(docs/ai-coder/ai-bridge): update aibridge docs for codex to use model_provider (#23199) 2026-03-18 18:09:55 +05:00
Michael Suchacz 62144d230f feat(site): show PR link in TopBar header (#23178)
When a PR is detected for a chat, display a compact PR badge in the
AgentDetail TopBar. On mobile it is always visible; on desktop it is
hidden when the sidebar panel is open (which already surfaces PR info)
and shown when the panel is closed.

The badge shows a state-colored icon (open, draft, merged, closed) and
the PR title or number, linking to the PR URL. Only URLs confirmed as
real PRs (via explicit `pull_request_state` or a `/pull/<number>`
pathname) trigger the badge.

## Changes

- **`TopBar.tsx`** — Added `diffStatusData` prop, `PrStateIcon` helper,
and a PR link badge between the title and actions area. Hidden on
desktop when the sidebar panel is open.
- **`AgentDetailView.tsx`** — Pass `diffStatusData` through to
`AgentDetailTopBar`.
- **`TopBar.stories.tsx`** — Added stories for open, draft, merged, and
closed PR states.
2026-03-18 13:40:33 +01:00
Hugo Dutka 0d0c6c956d fix(dogfood): chrome desktop icons with compatibility flags (#23209)
Our dogfood image already included chrome. Since we run dogfood
workspaces in Docker, chrome requires some compatibility flags to run
properly. If you launch chrome without them, some webpages crash and
fail to load.

The newest release of https://github.com/coder/portabledesktop added an
icon dock. This PR edits the chrome `.desktop` files so when you open
chrome from the dock it runs with the correct flags.


https://github.com/user-attachments/assets/7bf880e1-22a4-4faa-8f7f-394863c6b127
2026-03-18 13:36:16 +01:00
Mathias Fredriksson 488ceb6e58 refactor(site/src/pages/AgentsPage): clean up RenderBlock types and dead fields (#23175)
RenderBlock's file-reference variant diverged from the API (camelCase
vs snake_case), and both file variants were defined inline duplicating
the generated ChatFilePart and ChatFileReferencePart types. The
thinking and file-reference variants carried dead fields (title, text)
that were never populated by the backend.

Replace inline definitions with references to generated types, remove
dead fields, and simplify ReasoningDisclosure (disclosure button path
was dead without title).

Refs #23168
2026-03-18 12:25:05 +00:00
Matt Vollmer 481c132135 docs: clarify agent permission inheritance and default security posture (#23194)
Addresses five documentation gaps identified from an internal agents
briefing Q&A, specifically around what permissions an agent inherits
from the user:

1. **No privilege escalation** — Added explicit statement that the agent
has the exact same permissions as the user. No escalation, no shared
service account.
2. **Cross-user workspace isolation** — Added statement that agents
cannot access workspaces belonging to other users.
3. **Default-state warning** — Added WARNING callouts that agent
workspaces inherit the user's full network access unless templates
explicitly restrict it.
4. **Tool boundary statement** — Added explicit statement that the agent
cannot act outside its defined tool set and has no direct access to the
Coder API.
5. **Template visibility scoped to user RBAC** — Clarified that template
selection respects the user's role and permissions.

Changes across 3 files:
- `docs/ai-coder/agents/index.md`
- `docs/ai-coder/agents/architecture.md`
- `docs/ai-coder/agents/platform-controls/template-optimization.md`

---
PR generated with Coder Agents
2026-03-18 12:15:50 +00:00
Kyle Carberry d42008e93d fix: persist partial assistant response when chat is interrupted mid-stream (#23193)
## Problem

When a user cancels a streaming chat response mid-stream, the partial
content disappears entirely — both from the UI and the database. The
streamed text vanishes as if the response never happened.

## Root Causes

Three issues combine to prevent partial message persistence on
interrupt:

### 1. StreamPartTypeError only matched `context.Canceled`
(`chatloop.go`)

The interrupt detection in `processStepStream` checked:
```go
errors.Is(part.Error, context.Canceled) && errors.Is(context.Cause(ctx), ErrInterrupted)
```
But some providers propagate `ErrInterrupted` directly as the stream
error rather than wrapping it in `context.Canceled`. This caused the
condition to fail, so `flushActiveState` was never called and partial
text accumulated in `activeTextContent` was lost.

### 2. No post-loop interrupt check (`chatloop.go`)

If the stream iterator stops yielding parts without producing a
`StreamPartTypeError` (e.g., a provider that silently closes the
response body on cancel), there was no check after the `for part :=
range stream` loop to detect the interrupt and flush active state.

### 3. Worker ownership check blocked interrupted persists (`chatd.go`)

`InterruptChat` → `setChatWaiting` clears `worker_id` in the DB
**before** the chatloop detects the interrupt. When
`persistInterruptedStep` (using `context.WithoutCancel`) tried to write
the partial message, the ownership check:
```go
if !lockedChat.WorkerID.Valid || lockedChat.WorkerID.UUID != p.workerID {
    return chatloop.ErrInterrupted  // always blocks!
}
```
unconditionally rejected the write. The error was silently logged as a
warning.

## Fix

- **Broaden the `StreamPartTypeError` interrupt detection** to match
both `context.Canceled` and `ErrInterrupted` as the stream error.
- **Add a post-loop interrupt check** in `processStepStream` that
flushes active state when the context was canceled with
`ErrInterrupted`.
- **Allow `persistStep` to write when the chat is in `waiting` status**
(interrupt) even if `worker_id` was cleared. The `pending` status (from
`EditMessage`, where history is truncated) still correctly blocks stale
writes.

## Testing

Added `TestInterruptChatPersistsPartialResponse` — an end-to-end
integration test that:
1. Streams partial text chunks from a mock LLM
2. Waits for the chatloop to publish `message_part` events (confirming
chunks were processed)
3. Interrupts the chat mid-stream
4. Verifies the partial assistant message is persisted in the database
with the expected text content
2026-03-18 11:48:28 +00:00
Danielle Maywood aa3cee6410 fix: polish agents UI (sidebar width, combobox, limits padding, back button) (#23204) 2026-03-18 11:46:56 +00:00
Danielle Maywood 4f566f92b5 fix(site): use ExternalImage for preset icons in task prompt (#23206) 2026-03-18 11:16:30 +00:00
Atif Ali bd5b62c976 feat: expose MCP tool annotations for tool grouping (#23195)
## Summary
- add shared MCP annotation metadata to toolsdk tools
- emit MCP tool annotations from both coderd and CLI MCP servers
- cover annotation serialization in toolsdk, coderd MCP e2e, and CLI MCP
tests

## Why
- Coder already exposed MCP tools, but it did not populate MCP tool
annotation hints (`readOnlyHint`, `destructiveHint`, `idempotentHint`,
`openWorldHint`).
- Hosts such as Claude Desktop use those hints to classify and group
tools, so without them Coder tools can get lumped together.
- This change adds a shared annotation source in `toolsdk` and has both
MCP servers emit those hints through `mcp.Tool.Annotations`, avoiding
drift between local and remote MCP implementations.

## Testing
- Tested locally on Cladue Desktop and the tools are categorized
correctly.

<table>
<tr>
 <td> Before
 <td> After
<tr>
<td> <img width="613" height="183" alt="image"
src="https://github.com/user-attachments/assets/29d2e3fb-53bc-4ea7-bdb3-f10df4ef996b"
/>
<td> <img width="600" height="457" alt="image"
src="https://github.com/user-attachments/assets/cc384036-c9a7-4db9-9400-43ad51920ff5"
/>
</table>

Note: Done using Coder Agents, reviewed and tested by human locally
2026-03-18 10:21:45 +00:00
Mathias Fredriksson 66f809388e refactor: make ChatMessagePart a discriminated union in TypeScript (#23168)
The flat ChatMessagePart interface had 20+ optional fields, preventing
TypeScript from narrowing types on switch(part.type). Each consumer
needed runtime validation, type assertions, or defensive ?. chains.

Add `variants` struct tags to ChatMessagePart fields declaring which
union variants include each field. A codegen mutation in apitypings
reads these tags via reflect and generates per-variant sub-interfaces
(ChatTextPart, ChatReasoningPart, etc.) plus a union type alias.
A test validates every field has a variants tag or is explicitly
excluded, and every part type is covered.

Remove dead frontend code: normalizeBlockType, alias case branches
("thinking", "toolcall", "toolresult"), legacy field fallbacks
(line_number, typedBlock.name/id/input/output), and result_delta
handling. Add test coverage for args_delta streaming, provider_executed
skip logic, and source part parsing.
2026-03-18 09:27:51 +00:00
Mathias Fredriksson 563c00fb2c fix(dogfood/coder): suppress du stderr in docker usage metadata (#23200)
Transient 'No such file or directory' errors from disappearing
overlay2 layers during container operations pollute the displayed
metadata value. Redirect stderr to /dev/null.
2026-03-18 10:54:13 +02:00
Hugo Dutka 817fb4e67a feat: virtual desktop settings toggle frontend (#23173)
Add a toggle in agents settings to enable/disable virtual desktop. The
Desktop tab (next to the Git tab) will only be visible if the feature is
enabled.

<img width="879" height="648" alt="Screenshot 2026-03-17 at 18 01 26"
src="https://github.com/user-attachments/assets/09fc3850-c88d-4c5c-b6e4-760590e53b95"
/>
2026-03-18 09:50:14 +01:00
Hugo Dutka 2cf47ec384 feat: virtual desktop settings toggle backend (#23171)
Adds a new `site_config` entry that controls whether the virtual desktop
feature for Coder Agents is enabled. It can be set via a new
`/api/experimental/chats/config/desktop-enabled` endpoint, which will be
used by the frontend.
2026-03-18 09:35:13 +01:00
Ethan 11481d7bed perf(coderd/chatd): reduce lock contention in instruction cache and persistStep (#23144)
## Summary

Two targeted performance improvements to the chatd server, identified
through benchmarking.

### 1. RWMutex for instruction cache

The instruction cache is read on every chat turn to fetch the home
instruction file for a workspace agent. Writes only occur on cache
misses (once per agent per 5-minute TTL window), making the access
pattern ~90%+ reads.

Switching from `sync.Mutex` to `sync.RWMutex` and using
`RLock`/`RUnlock` on the read path allows concurrent readers instead of
serializing them.

**Benchmark (200 concurrent chats):**
| | ns/op |
|---|---|
| Mutex | 108 |
| RWMutex | 32 |
| **Speedup** | **3.4x** |

### 2. Hoist JSON marshaling out of persistStep transaction

`MarshalParts`, `PartFromContent`, `CalculateTotalCostMicros`, and the
`usageForCost` struct population are pure CPU work that ran inside the
`FOR UPDATE` transaction in `persistStep`. They have zero dependency on
the database transaction.

Moving all marshal and cost-calculation calls above `p.db.InTx()` means
the row lock is held only for `GetChatByIDForUpdate` +
`InsertChatMessage` calls.

**Benchmark (16 goroutines contending on same lock):**
| Tool calls | Inside lock | Outside lock | Speedup |
|---|---|---|---|
| 1 | 13,977 ns/op | 1,055 ns/op | 13x |
| 5 | 38,203 ns/op | 3,769 ns/op | 10x |
| 10 | 67,353 ns/op | 7,284 ns/op | 9x |
| 20 | 145,864 ns/op | 14,045 ns/op | 10x |

No behavioral changes in either commit.
2026-03-18 16:12:14 +11:00
Ben Potter f3bf5baba0 chore: update coder/tailscale fork to 33e050fd4bd9 (#23191)
Updates the tailscale replace directive to pick up two new commits from
[coder/tailscale](https://github.com/coder/tailscale):

- [feat(magicsock): add DERPTLSConfig for custom TLS configuration
(#105)](https://github.com/coder/tailscale/commit/8ffb3e998ba9c11d770eacac9a2f3932ce36590d)
- [chore: improve logging for derp server mesh clients
(#107)](https://github.com/coder/tailscale/commit/33e050fd4bd97d9e805afb4df7fac7a1c6e4abf8)

Relates to: PRODUCT-204
2026-03-18 15:14:02 +11:00
Matt Vollmer 9df7fda5f6 docs: rename "Template Routing" to "Template Optimization" (#23192)
Renames the page title from "Template Routing" to "Template
Optimization" in both the markdown H1 header and the docs manifest
entry.

---

PR generated with Coder Agents
2026-03-17 20:37:39 -04:00
Matt Vollmer 665db7bdeb docs: add agent workspaces best practices guide (#23142)
Add a new docs page under /docs/ai-coder/agents/ covering best practices
for creating templates that are discoverable and useful to Coder Agents.

Covers template descriptions, dedicated agent templates, network
boundaries, credential scoping, parameter design, pre-installed tooling,
and prebuilt workspaces for reducing provisioning latency.

<!--

If you have used AI to produce some or all of this PR, please ensure you
have read our [AI Contribution
guidelines](https://coder.com/docs/about/contributing/AI_CONTRIBUTING)
before submitting.

-->
2026-03-17 19:28:46 -04:00
Asher 903cfb183f feat: add --service-account to cli user creation (#23186) 2026-03-17 14:07:20 -08:00
Kayla はな 49e5547c22 feat: add support for creating service accounts (#23140) 2026-03-17 15:36:20 -06:00
Michael Suchacz f9c265ca6e feat: expose PromptCacheKey in OpenAI model config form (#23185)
## Summary

Remove the `hidden` tag from the `PromptCacheKey` field on
`ChatModelOpenAIProviderOptions` so the auto-generated JSON schema
no longer marks it as hidden. This allows the admin model
configuration UI to render a "Prompt Cache Key" text input for
OpenAI models alongside other visible options like Reasoning Effort,
Service Tier, and Web Search.

## Changes

- **`codersdk/chats.go`**: Remove `hidden:"true"` from `PromptCacheKey`
struct tag.
- **`site/src/api/chatModelOptionsGenerated.json`**: Regenerated via
`make gen` — `hidden: true` removed from the `prompt_cache_key` entry.
- **`modelConfigFormLogic.test.ts`**: Extend existing "all fields set"
tests to cover extract and build roundtrip for `promptCacheKey`.

## How it works

The `hidden` Go struct tag propagates through the code generation
pipeline:

1. Go struct tag → `scripts/modeloptionsgen` →
`chatModelOptionsGenerated.json`
2. The frontend `getVisibleProviderFields()` filters out fields with
`hidden: true`
3. Removing the tag makes the field visible in the schema-driven form
renderer

No new UI components are needed — the existing `ModelConfigFields`
component
automatically renders the field as a text input based on the schema
(`type: "string"`, `input_type: "input"`).

The field appears as **"Prompt Cache Key"** with description
"Key for enabling cross-request prompt caching" in the OpenAI provider
section of the admin model configuration form.
2026-03-17 21:58:36 +01:00
Danielle Maywood a65a31a5a3 fix(site): symmetric horizontal padding on agents sidebar chat rows (#23187) 2026-03-17 20:50:11 +00:00
Danielle Maywood 22a4a33886 fix(site): restore gap between agent chat messages (#23188) 2026-03-17 20:49:14 +00:00
Charlie Voiselle d3c9469e13 fix: open coder_app links in new tab when open_in is tab (#23000)
Fixes #18573

## Changes

When a `coder_app` resource sets `open_in = "tab"`, clicking the app
link now opens in a new browser tab instead of navigating in the same
tab.

`target="_blank"` and `rel="noreferrer"` are set inline on the
`<a>` elements in `AppLink.tsx`, gated on `app.open_in === "tab"`. This
follows the codebase convention of co-locating `target` and `rel` at the
render site.

`noreferrer` suppresses the Referer header to avoid leaking workspace IDs
to destination servers and implies `noopener`.
`noopener` prevents tabnabbing — without it, the opened page can
redirect the Coder dashboard tab via `window.opener`. This is especially
relevant for same-origin path-based apps, which would otherwise have
full DOM access to the dashboard. 

> **Future enhancement**: template admins could opt into sending the
referrer via a `coder_app` setting, enabling feedback pages built around
workspace context.

## Tests

A vitest case is added in `AppLink.test.tsx` (rather than a Storybook
story, since the assertions are purely behavioral with no visual
component):

- **`sets target=_blank and rel=noopener noreferrer when open_in is
tab`** — renders the app link with `open_in: "tab"` and asserts
`target="_blank"` and `rel="noreferrer"` are present on the
anchor.

## Slim-window behavior

The `slim-window` test case and the `openAppInNewWindow()` comment in
`apps.ts` have been split out into a follow-up PR for separate review,
since the `window.open()` / `noopener` tradeoffs there deserve dedicated
discussion.

---------

Co-authored-by: Kayla はな <kayla@tree.camp>
2026-03-17 15:32:45 -04:00
George K 91ec0f1484 feat: add service_accounts workspace sharing mode (#23093)
Introduce a three-way workspace sharing setting (none, everyone,
service_accounts) replacing the boolean workspace_sharing_disabled.
In service_accounts mode, only service account-owned workspaces can be
shared while regular members' share permissions are removed. Adds a
new organization-service-account system role with per-org permissions
reconciled alongside the existing organization-member system role.

Related to:
https://linear.app/codercom/issue/PLAT-28/feat-service-accounts-sharing-mode-and-rbac-role

---------

Co-authored-by: Steven Masley <Emyrk@users.noreply.github.com>
Co-authored-by: Kayla はな <mckayla@hey.com>
2026-03-17 12:16:43 -07:00
Danielle Maywood 6b76e30321 fix(site): align workspace combobox styling with model selector (#23181) 2026-03-17 18:46:35 +00:00
Kyle Carberry 6fc9f195f1 fix: resolve chat message pagination scroll issues (#23169)
## Summary

Fixes four interrelated issues that caused scroll position jumps and
phantom scroll growth when paginating older chat messages.

## Changes

### 1. Removed client-side message windowing (`useMessageWindow`)

There were two competing sentinel systems: server-side pagination and
client-side windowing. The client windowing sentinel was nested deep
inside the timeline with no explicit IntersectionObserver `root`,
causing scroll position jumps when messages were prepended. Blink
(coder/blink) has no client-side windowing. Removed it entirely; server
pagination + `contentVisibility` handled performance.

### 2. Removed `contentVisibility: "auto"` from message sections

Each section had `contentVisibility: "auto"` with `containIntrinsicSize:
"1px 600px"`, causing the scroll region to grow/shrink as the browser
swapped 600px placeholders for actual heights while scrolling. This
created phantom scroll growth with no fetch involved.

### 3. Gated WebSocket on initial REST data

The WebSocket `Subscribe` snapshot calls `GetChatMessagesByChatID` (no
LIMIT) which returns every message when `afterMessageID` is 0. The
WebSocket effect opened before the REST page resolved, so
`lastMessageIdRef` was undefined, causing the server to replay the
entire history and defeating pagination. Added `initialDataLoaded` guard
so the socket waits for the first REST page.

### 4. Manual scroll position restoration

Replaced unreliable CSS scroll anchoring in `flex-col-reverse` with a
`ScrollAnchoredContainer` that snapshots `scrollHeight` before fetch and
restores `scrollTop` via `useLayoutEffect` after render. Disabled
browser scroll anchoring (`overflow-anchor: none`) to prevent conflicts.
2026-03-17 14:26:53 -04:00
Mathias Fredriksson c2243addce fix(scripts/develop): allow empty access-url for devtunnel (#23166) 2026-03-17 18:06:55 +00:00
Danielle Maywood cd163d404b fix(site): strip SVN-style Index headers from diffs before parsing (#23179) 2026-03-17 17:57:00 +00:00
Danielle Maywood 41d12b8aa3 feat(site): improve edit-message UX with dedicated button and confirmation (#23172) 2026-03-17 17:39:28 +00:00
Kyle Carberry 497e1e6589 feat: render file references inline in user messages (#23174)
File references in user messages now render as inline chips (matching
the chat input style) instead of in a separate bordered section at the
bottom of the message bubble.

This reimplements #23131 which was accidentally reverted during the
merge of #23072 (the spend-limit UI PR resolved a merge conflict by
dropping the inline chip logic).

## Changes
- **FileReferenceNode.tsx**: Export `FileReferenceChip` so it can be
imported for read-only use (no remove button when `onRemove` is
omitted).
- **ConversationTimeline.tsx**: Iterate through `parsed.blocks` in
document order, rendering `response` blocks as text and `file-reference`
blocks as inline `FileReferenceChip` components. Removes the old
separated file-reference section with `border-t` divider.
- **ConversationTimeline.stories.tsx**: Added
`UserMessageWithInlineFileRef` and
`UserMessageWithMultipleInlineFileRefs` stories.
2026-03-17 16:52:00 +00:00
Kyle Carberry b779c9ee33 fix: use SQL-level auth filtering for chat listing (#23159)
## Problem

The chat listing endpoint (`GetChatsByOwnerID`) was using
`fetchWithPostFilter`, which fetches N rows from the database and then
filters them in Go memory using RBAC checks. This causes a pagination
bug: if the user requests `limit=25` but some rows fail the auth check,
fewer than 25 rows are returned even though more authorized rows exist
in the database. The client may incorrectly assume it has reached the
end of the list.

## Solution

Switch to the same pattern used by `GetWorkspaces`, `GetTemplates`, and
`GetUsers`: `prepareSQLFilter` + `GetAuthorized*` variant. The RBAC
filter is compiled to a SQL WHERE clause and injected into the query
before `ORDER BY`/`LIMIT`, so the database returns exactly the requested
number of authorized rows.

Additionally, `GetChatsByOwnerID` is renamed to `GetChats` with
`OwnerID` as an optional (nullable) filter parameter, matching the
`GetWorkspaces` naming convention.

## Changes

| File | Change |
|------|--------|
| `queries/chats.sql` | Renamed to `GetChats`, `owner_id` now optional
via CASE/NULL, added `-- @authorize_filter` |
| `queries.sql.go` | Renamed constant, params struct (`GetChatsParams`),
and method |
| `querier.go` | Interface method renamed |
| `modelqueries.go` | Added `chatQuerier` interface +
`GetAuthorizedChats` impl |
| `dbauthz/dbauthz.go` | `GetChats` now uses `prepareSQLFilter` instead
of `fetchWithPostFilter` |
| `dbauthz/dbauthz_test.go` | Updated tests for SQL filter pattern |
| `dbmock/dbmock.go` | Renamed + added mock for `GetAuthorizedChats` |
| `dbmetrics/querymetrics.go` | Renamed + added metrics wrapper |
| `rbac/regosql/configs.go` | Added `ChatConverter` (maps `org_owner` to
empty string literal since `chats` has no `organization_id` column) |
| `rbac/authz.go` | Added `ConfigChats()` |
| `chats.go` | Handler uses renamed method with `uuid.NullUUID` |
| `searchquery/search.go` | Updated return type |
| `gitsync/worker.go` | Updated interface and call site |
| Various test files | Updated for renamed types |
2026-03-17 12:46:24 -04:00
Mathias Fredriksson 144b32a4b6 fix(scripts/develop): skip build on Windows via build tag (#23118)
Previously main.go used syscall.SysProcAttr{Setpgid: true} and
syscall.Kill, both undefined on Windows. This broke GOOS=windows
cross-compilation.

Add a //go:build !windows constraint to the package since it is
a dev-only tool that requires Unix utilities (bash, make, etc.)
and is not intended to run on Windows.

Refs #23054
Fixes coder/internal#1407
2026-03-18 02:02:06 +11:00
Kyle Carberry a40716b6fe fix(site): stop spamming chats list endpoint on diff_status_change events (#23167)
## Problem

The WebSocket handler for `diff_status_change` events in
`AgentsPage.tsx` was triggering a burst of redundant HTTP requests on
every event:

1. **`invalidateChatListQueries(queryClient)`** — Full refetch of the
chats list endpoint. Unnecessary because `updateInfiniteChatsCache`
already writes `diff_status` into the sidebar cache optimistically on
every event.

2. **`invalidateQueries({ queryKey: chatKey(id) })`** — Refetch of the
individual chat. Also unnecessary — the SSE event carries `diff_status`
in its payload and the optimistic updater writes it into the `chatKey`
cache directly. Worse, this call was missing `exact: true`, so TanStack
Query's prefix matching cascaded the invalidation to `chatMessagesKey`,
`chatDiffContentsKey`, and every other query under `["chats", id]`.

Since diff status changes fire frequently during active agent work, this
spammed the chats list endpoint and caused redundant refetches of
messages and diff contents on every single event.

## Fix

Strip the handler down to the one invalidation that's actually needed —
`chatDiffContentsKey` (the file-level diff contents aren't in the SSE
payload):

```typescript
if (chatEvent.kind === "diff_status_change") {
    void queryClient.invalidateQueries({
        queryKey: chatDiffContentsKey(updatedChat.id),
        exact: true,
    });
}
```

## Why tests didn't catch this

The existing tests in `chats.test.ts` cover query utilities in isolation
(e.g. `invalidateChatListQueries` scoping, mutation invalidation). The
WebSocket event handler lives in the `AgentsPage` component — there was
no test covering what the `diff_status_change` code path actually
invalidates.

Added regression tests verifying that `exact: true` prevents
prefix-match cascade vs the old behavior.
2026-03-17 14:51:01 +00:00
Danielle Maywood 635c5d52a8 feat(site): move Settings and Analytics from dialogs to sidebar sub-navigation (#23126) 2026-03-17 14:48:09 +00:00
Kyle Carberry 075dfecd12 refactor: consolidate experimental chats API types (#23143)
## Summary

Consolidates three areas of type duplication in the experimental chats
API:

### 1. Merge archive/unarchive into `PATCH /{chat}`
- **Before:** `POST /{chat}/archive` + `POST /{chat}/unarchive` (two
endpoints, two handlers with mirrored logic)
- **After:** `PATCH /{chat}` accepting `{ "archived": true/false }` via
`UpdateChatRequest`
- Removes one endpoint and ~30 lines of duplicated handler code

### 2. Collapse identical request/response prompt types
- `ChatSystemPromptResponse` + `UpdateChatSystemPromptRequest` →
`ChatSystemPrompt`
- `UserChatCustomPromptResponse` + `UpdateUserChatCustomPromptRequest` →
`UserChatCustomPrompt`
- These pairs were field-for-field identical (single string field)

### 3. Merge duplicate reasoning options types
- `ChatModelOpenRouterReasoningOptions` +
`ChatModelVercelReasoningOptions` → `ChatModelReasoningOptions`
- Same 4 fields, same types — only field ordering and enum value sets
differed
- Unified type uses the superset of enum values

### Files changed
- `codersdk/chats.go` — SDK types and client methods
- `coderd/chats.go` — Handler consolidation
- `coderd/coderd.go` — Route change
- `coderd/chats_test.go` — Test updates
- `site/src/api/api.ts` — Frontend API client
- `site/src/api/queries/chats.ts` — Query mutations
- `site/src/api/queries/chats.test.ts` — Test mocks
- `site/src/pages/AgentsPage/AgentsPage.tsx` — Call site
- Generated files (`typesGenerated.ts`,
`chatModelOptionsGenerated.json`)

### Testing
- All Go tests pass (`TestArchiveChat`, `TestUnarchiveChat`,
`TestChatSystemPrompt`)
- All frontend tests pass (31/31 in `chats.test.ts`)
2026-03-17 14:31:11 +00:00
Hugo Dutka fdb1205bdf chore(agent): remove portabledesktop download logic (#23128)
The new way to install portabledesktop in a workspace will be via a
module: https://github.com/coder/registry/pull/805
2026-03-17 15:24:11 +01:00
Danielle Maywood 33a47fced3 fix(site): use theme-aware git tokens for PR status badges (#23148) 2026-03-17 14:12:36 +00:00
Kyle Carberry ca5158f94a fix: unify sidebar Git/Desktop tab styles with GitPanel tabs (#23164)
The active Git tab looked different than the Desktop tab, and they
didn't match the actual tabs in the Git section.
2026-03-17 10:08:18 -04:00
Hugo Dutka b7e0f42591 feat(dogfood): add the portabledesktop module (#23165) 2026-03-17 13:56:27 +00:00
Ethan 41bd7acf66 perf(chatd): remove redundant chat rereads (#23161)
## Summary
This PR removes two redundant chat rereads in `chatd`.

### Archive / unarchive
- `archiveChat` and `unarchiveChat` already come through
`httpmw.ChatParam`, so the handlers already have the `database.Chat`
row.
- Pass that row into `chatd.ArchiveChat` / `chatd.UnarchiveChat` instead
of rereading by ID before publishing the sidebar events.

### End-of-turn cleanup
- `processChat` no longer calls `GetChatByID` after the cleanup
transaction just to refresh the chat snapshot.
- Title generation already persists the generated title and emits its
own `title_change` event.
- To preserve best-effort title freshness for the cleanup path, the
async title-generation goroutine stores the generated title in per-turn
shared state and cleanup overlays it if available before publishing the
`status_change` event and dispatching push notifications.

## Why
- removes one DB read from archive / unarchive requests
- removes one DB read from completed turns, which is the larger hot-path
win
- keeps the existing pubsub/event contract intact instead of broadening
this into a larger event-model redesign

## Notes
- `title_change` remains the authoritative title update for clients
- cleanup does not wait for title generation; it uses the generated
title only when it is already available
2026-03-18 00:52:06 +11:00
Dean Sheather 87d4a29371 fix(site): add left offset to agents sidebar user dropdown (#23162) 2026-03-17 13:34:10 +00:00
Mathias Fredriksson a797a494ef feat: add starter template option and Coder Desktop URLs to scripts/develop (#23149)
- Add `--starter-template` option and properly create starter template
  with name and icon
- Add Coder Desktop URLs to listening banner
- Makefile tweak to avoid rebuilding `scripts/develop` every time Go
  code changes
2026-03-17 15:34:03 +02:00
Ethan a33605df58 perf(coderd/chatd): reuse workspace context within a turn (#23145)
## Summary
- reuse workspace agent context within a single `runChat()` turn
- remove duplicate latest-build agent lookups between
`resolveInstructions()` and `getWorkspaceConn()`
- avoid the extra `GetWorkspaceAgentByID` fetch when the selected
`WorkspaceAgent` already has the needed metadata
- add focused internal tests for reuse and refresh-on-dial-failure

## Why
This came out of a 5000-chat / 10-turn scaletest on bravo against a
single workspace.

The run completed successfully, but coderd stayed DB-pool bound, and one
workspace-backed hot path stood out:
- `GetWorkspaceAgentsInLatestBuildByWorkspaceID ≈ 46.7k`
- `GetWorkspaceByID ≈ 48.0k`
- `GetWorkspaceAgentByID ≈ 2.2k`

Within one `runChat()` turn, chatd was rediscovering the same workspace
agent multiple times just to resolve instructions and open the workspace
connection.

## What this changes
This PR introduces a **turn-local** workspace context helper so a single
acquired turn can:
- resolve the selected workspace agent once
- reuse that agent for instruction resolution
- reuse the same `AgentConn` for workspace tools and reload/compaction

This stays turn-local only, so a later turn on another replica still
rebuilds fresh context from the DB.

## Expected impact
This is an incremental improvement, not a full fix.

It should reduce duplicated workspace-agent lookups and shave some DB
pressure from a hot path for workspace-backed chats, while preserving
multi-replica correctness.

## Testing
- `go test ./coderd/chatd/...`
- `golangci-lint run ./coderd/chatd/...`
2026-03-18 00:33:44 +11:00
Dean Sheather 3c430a67fa fix(site): balance sidebar header spacing in agents page (#23163) 2026-03-18 00:33:14 +11:00
Dean Sheather abee77ac2f fix(site): move analytics date range above cards (#23158) 2026-03-17 12:58:43 +00:00
Kacper Sawicki 7946dc6645 fix(provisioner): skip duplicate-env-keys in generate.sh (#23155)
## Problem

When `generate.sh` is run (e.g. to regenerate fixtures after adding a
new field like `subagent_id`), the `duplicate-env-keys` fixture gets
UUID scrambling.

The `minimize_diff()` function uses a bash associative array keyed by
JSON field name (`deleted["id"]`). The `duplicate-env-keys` fixture has
multiple `coder_env` resources, each with the same key names (`id`,
`agent_id`). Since an associative array can only hold one value per key,
UUIDs get cross-contaminated or left as random terraform-generated
values.

Discovered while working on #23122.

## Fix

Add `duplicate-env-keys` to the `toskip` array in `generate.sh`,
alongside `kubernetes-metadata`. This fixture uses hand-crafted
placeholder UUIDs and should not be regenerated.

Relates to #21885.
2026-03-17 13:41:47 +01:00
Kyle Carberry eb828a6a86 fix: skip input refocus after send on mobile viewports (#23141) 2026-03-17 12:40:26 +00:00
Mathias Fredriksson 4e2d7ffaa7 refactor(site/src/pages/AgentsPage): use ChatMessagePart for editingFileBlocks (#23151)
Replace the ad-hoc camelCase file block shape ({ mediaType, fileId, data })
with snake_case fields matching ChatMessagePart from the API types.

The RenderBlock file variant now uses media_type/file_id instead of
mediaType/fileId. The parsers in messageParsing.ts and streamState.ts
pass validated ChatMessagePart objects through directly instead of
destructuring and reassembling with renamed fields. This eliminates
the needless API → camelCase → snake_case roundtrip that the edit
flow previously required.

Refs #22735
2026-03-17 04:10:08 -08:00
Mathias Fredriksson 524bca4c87 fix(site/src/pages/AgentsPage): fix chat image paste bugs and refactor queued message display (#22735)
handleSubmit (triggered via Enter key) didn't check isUploading, so
messages could be sent while an image upload was still in progress.
The send button was correctly disabled via canSend, but the keyboard
shortcut bypassed that guard.

QueuedMessagesList used untyped extraction helpers that fell through to
JSON.stringify for attachment-only messages. Replace them with a single
getQueuedMessageInfo function using typed ChatMessagePart access.
Show an attachment badge (ImageIcon + count) for file parts, and use a
consistent "[Queued message]" placeholder for all no-text situations.

Editing a queued message with file attachments silently dropped all
attachments because handleStartQueueEdit only accepted text. Thread
file blocks from QueuedMessagesList through the edit callback into
handleStartQueueEdit, which now calls setEditingFileBlocks. The
existing useEffect in AgentDetailInput picks these up and populates
the attachment UI. Also clear editingFileBlocks in handleCancelQueueEdit
and handleSendFromInput.
2026-03-17 14:00:31 +02:00
Danny Kopping 365de3e367 feat: record model thoughts (#22676)
Depends on https://github.com/coder/aibridge/pull/203
Closes https://github.com/coder/internal/issues/1337

---------

Signed-off-by: Danny Kopping <danny@coder.com>
2026-03-17 11:41:10 +00:00
Michael Suchacz 5d0eb772da fix(cored): fix flaky TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease (#23147) 2026-03-17 19:08:22 +11:00
Ethan 04fca84872 perf(coderd): reduce duplicated reads in push and webpush paths (#23115)
## Background

A 5000-chat scaletest (~50k turns, ~2m45s wall time) completed
successfully,
but the main bottleneck was **DB pool starvation from repeated reads**,
not
individually expensive SQL. The push/webpush path showed a few
especially noisy
reads:

- `GetLastChatMessageByRole` for push body generation
- `GetEnabledChatProviders` + `GetChatModelConfigByID` for push summary
model
  resolution
- `GetWebpushSubscriptionsByUserID` for every webpush dispatch

This PR keeps the optimizations that remove those duplicate reads while
leaving
stream behavior unchanged.

## What changes in this PR

### 1. Reuse resolved chat state for push notifications

`maybeSendPushNotification` used to re-read the last assistant message
and
re-resolve the chat model/provider after `runChat` had already done that
work.

Now `runChat` returns the final assistant text plus the already-resolved
model
and provider keys, and the push goroutine uses that state directly.

That removes the extra push-path reads for:

- `GetLastChatMessageByRole`
- the second `resolveChatModel` path
- the provider/model lookups that came with that second resolution

### 2. Cache webpush subscriptions during dispatch

`Dispatch()` previously hit `GetWebpushSubscriptionsByUserID` on every
push. A
small per-user in-memory cache now avoids those repeated reads.

The follow-up fix keeps that optimization correct: `InvalidateUser()`
bumps a
per-user generation so an older in-flight fetch cannot repopulate the
cache with
pre-mutation data after subscribe/unsubscribe.

That preserves the cache win without letting local subscription changes
be
silently overwritten by stale fetch results.

## Why this is safe

- The push change only reuses data already produced during the same chat
run. It
does not change notification semantics; if there is no assistant text to
  summarize, the existing fallback body still applies.
- The webpush change keeps the existing TTL and `410 Gone` cleanup
behavior. The
generation guard only prevents stale in-flight fetches from poisoning
the
  shared cache after invalidation.
- The final PR does **not** change stream setup, pubsub/relay behavior,
or chat
  status snapshot timing.

## Deliberately not included

- No stream-path optimization in `Subscribe`.
- No inline pubsub message payloads.
- No distributed cross-replica webpush cache invalidation.
2026-03-17 13:50:47 +11:00
Michael Suchacz 7cca2b6176 feat(site): add chat spend limit UI (#23072)
Frontend for agent chat spend limiting on `/agents`.

## Changes
- add the limits management UI, API hooks, and validation for
deployment, group, and user overrides
- show spend limit status in Agents analytics and usage summaries
- surface limit-related chat errors consistently in the agent detail
experience
- add shared currency and usage-limit messaging helpers plus related
stories/tests
2026-03-17 02:01:51 +01:00
Michael Suchacz 1031da9738 feat: add agent chat spend limiting (backend) (#23071)
Introduces deployment-scoped spend limiting for Coder Agents, enabling
administrators to control LLM costs at global, group, and individual
user levels.

## Changes

- **Database migration (000437)**: `chat_usage_limit_config`
(singleton), `chat_usage_limit_overrides` (per-user),
`chat_usage_limit_group_overrides` (per-group)
- **Single-query limit resolution**: individual override > min(group) >
global default via `ResolveUserChatSpendLimit`
- **Fail-open enforcement** in chatd with documented TOCTOU trade-off
- **Experimental API** under `/api/experimental/chats/usage-limits` for
CRUD on limits
- **`AsChatd` RBAC subject** for narrowly-scoped daemon access (replaces
`AsSystemRestricted`)
- **Generated TypeScript types** for the frontend SDK

## Hierarchy

1. Individual user override (highest)
2. Minimum of group limits
3. Global default
4. Disabled / unlimited

Currency stored as micro-dollars (`1,000,000` = $1.00).

Frontend PR: #23072
2026-03-17 01:24:03 +01:00
Kyle Carberry b69631cb35 chore(site): improve mobile layout for agent chat (#23139) 2026-03-16 18:36:37 -04:00
Kyle Carberry 7b0aa31b55 feat: render file references inline in user messages (#23131) 2026-03-16 17:17:23 -04:00
Steven Masley 93b9d70a9b chore: add audit log entry when ai seat is consumed (#22683)
When an ai seat is consumed, an audit log entry is made. This only happens the first time a seat is used.
2026-03-16 15:30:25 -05:00
Kyle Carberry 6972d073a2 fix: improve background process handling for agent tools (#23132)
## Problem

Models frequently use shell `&` instead of `run_in_background=true` when
starting long-running processes through `/agents`, causing them to die
shortly after starting. This happens because:

1. **No guidance in tool schema** — The `ExecuteArgs` struct had zero
`description` tags. The model saw `run_in_background: boolean
(optional)` with no explanation of when/why to use it.
2. **Shell `&` is silently broken** — `sh -c "command &"` forks the
process, the shell exits immediately, and the forked child becomes an
orphan not tracked by the process manager.
3. **No process group isolation** — The SSH subsystem sets `Setsid:
true` on spawned processes, but the agent process manager set no
`SysProcAttr` at all. Signals only hit the top-level `sh`, not child
processes.

## Investigation

Compared our implementation against **openai/codex** and **coder/mux**:

| Aspect | codex | mux | coder/coder (before) |
|--------|-------|-----|---------------------|
| Background flag | Yield/resume with `session_id` | `run_in_background`
with rich description | `run_in_background` with **no description** |
| `&` handling | `setsid()` + `killpg()` | `detached: true` +
`killProcessTree()` | **Nothing** — orphaned children escape |
| Process isolation | `setsid()` on every spawn | `set -m; nohup ...
setsid` for background | **No `SysProcAttr` at all** |
| Signal delivery | `killpg(pgid, sig)` — entire group | `kill -15
-\$pid` — negative PID | `proc.cmd.Process.Signal()` — **PID only** |

## Changes

### Fix 1: Add descriptions to `ExecuteArgs` (highest impact)
The model now sees explicit guidance: *"Use for long-running processes
like dev servers, file watchers, or builds. Do NOT use shell & — it will
not work correctly."*

### Fix 2: Update tool description
The top-level execute tool description now reinforces: *"Use
run_in_background=true for long-running processes. Never use shell '&'
for backgrounding."*

### Fix 3: Detect trailing `&` and auto-promote to background
Defense-in-depth: if the model still uses `command &`, we strip the `&`
and promote to `run_in_background=true` automatically. Correctly
distinguishes `&` from `&&`.

### Fix 4: Process group isolation (`Setpgid`)
New platform-specific files (`proc_other.go` / `proc_windows.go`)
following the same pattern as `agentssh/exec_other.go`. Every spawned
process gets its own process group.

### Fix 5: Process group signaling
`signal()` now uses `syscall.Kill(-pid, sig)` on Unix to signal the
entire process group, ensuring child processes from shell pipelines are
also cleaned up.

## Testing
All existing `agent/agentproc` tests pass. Both packages compile
cleanly.
2026-03-16 16:22:10 -04:00
Kyle Carberry 89bb5bb945 ci: fix build job disk exhaustion on Depot runners (#23136)
## Problem

The `build` job on `main` has been failing intermittently (and now
consistently) with `no space left on device` on the
`depot-ubuntu-22.04-8` runner. The runner's disk fills up during Docker
image builds or SBOM generation, depending on how close to the limit a
given run lands.

The build was already at the boundary — the Go build cache alone is ~1.3
GB, build artifacts are ~2 GB, and Docker image builds + SBOM scans need
several hundred MB of headroom in `/tmp`. No single commit caused this;
cumulative growth in dependencies and the scheduled `coder-base:latest`
rebuild on Monday morning nudged it past the limit.

## Fix

Three changes to reclaim ~2 GB of disk before Docker runs:

1. **Build all platform archives and packages in the Build step** —
moves arm64/armv7 `.tar.gz` and `.deb` from the Docker step to the Build
step so we can clean caches in between.

2. **Clean up Go caches between Build and Docker** — once binaries are
compiled, the Go build cache and module cache aren't needed. Also
removes `.apk`/`.rpm` packages that are never uploaded.

3. **Set `DOCKER_IMAGE_NO_PREREQUISITES`** — tells make to skip
redundantly building `.deb`/`.rpm`/`.apk`/`.tar.gz` as prerequisites of
Docker image targets. The Makefile already supports this flag for
exactly this purpose.
2026-03-16 15:38:58 -04:00
Kyle Carberry b7eab35734 fix(site): scope chat cache helpers to chat-list queries only (#23134)
## Problem

`updateInfiniteChatsCache`, `prependToInfiniteChatsCache`, and
`readInfiniteChatsCache` use `setQueriesData({ queryKey: ["chats"] })`
which prefix-matches **all** queries starting with `"chats"`, including
`["chats", chatId, "messages"]`.

After #23083 converted chat messages to `useInfiniteQuery`, the cached
messages data gained a `.pages` property containing
`ChatMessagesResponse` objects (not `Chat[]` arrays). The `if
(!prev.pages)` guard no longer bailed out, and the updater called
`.map()` on these objects — `TypeError: Z.map is not a function`.

## Fix

Extract the `isChatListQuery` predicate that already existed inline in
`invalidateChatListQueries` and apply it to all four cache helpers. This
scopes them to sidebar queries (`["chats"]` or `["chats",
<filterOpts>]`) and skips per-chat queries (`["chats", <id>, ...]`).
2026-03-16 14:23:56 -04:00
Zach 3f76f312e4 feat(cli): add --no-wait flag to coder create (#22867)
Adds a `--no-wait` flag (CODER_CREATE_NO_WAIT) to the create command,
matching the existing pattern in `coder start`. When set, the `coder
create` command returns immediately after the workspace creation API
call succeeds instead of streaming build logs until completion.

This enables fire-and-forget workspace creation in CI/automation
contexts (e.g., GitHub Actions), where waiting for the build to finish
is unnecessary. Combined with other existing flags, users can create a
workspace with no interactivity, assuming the user is already
authenticated.
2026-03-16 11:54:30 -06:00
Steven Masley abf59ee7a6 feat: track ai seat usage (#22682)
When a user uses an AI feature, we record them in the `ai_seat_state` as consuming a seat. 

Added in debouching to prevent excessive writes to the db for this feature. There is no need for frequent updates.
2026-03-16 12:36:26 -05:00
Steven Masley cabb611fd9 chore: implement database crud for AI seat usage (#22681)
Creates a new table `ai_seat_state` to keep track of when users consume an ai_seat. Once a user consumes an AI seat, they will forever in this table (as it stands today).
2026-03-16 11:53:20 -05:00
Matt Vollmer b2d8b67ff7 feat(site): add Early Access notice below agents chat input (#23130)
Adds a centered "Coder Agents is available via Early Access" line
directly beneath the chat input on the `/agents` index page. The "Early
Access" text links to
https://coder.com/docs/ai-coder/agents/early-access.

<img width="1192" height="683" alt="image"
src="https://github.com/user-attachments/assets/1823a5d2-6f02-48c2-ac70-a62b8f52be55"
/>

---

PR generated with Coder Agents
2026-03-16 12:45:17 -04:00
Thomas Kosiewski c1884148f0 feat: add VS Code iframe embed auth bootstrap (#23060)
## VS Code iframe embed auth via postMessage + setSessionToken

Adds embed auth for VS Code iframe integration, allowing the Coder agent
chat UI to be embedded in VS Code webviews without manual login — using
direct header auth instead of cookies.

### How it works

1. **Parent frame** (VS Code webview) loads an iframe pointing to
`/agents/:agentId/embed`
2. **Embed page** detects the user is signed out and posts
`coder:vscode-ready` to the parent
3. **Parent** responds with `coder:vscode-auth-bootstrap` containing the
user's Coder API token
4. **Embed page** calls `API.setSessionToken(token)` to set the
`Coder-Session-Token` header on all subsequent axios requests
5. **Embed page** fetches user + permissions, sets them in the React
Query cache atomically, and renders the authenticated agent chat UI

No cookies, no CSRF, no backend endpoint needed. The token is passed via
postMessage and used as a header on every API request.

### What changed

**Frontend** (`site/src/pages/AgentsPage/`):
- `AgentEmbedPage.tsx` — added postMessage bootstrap directly in the
embed page: listens for `coder:vscode-auth-bootstrap`, calls
`API.setSessionToken(token)`, fetches user/permissions atomically to
avoid race conditions
- `EmbedContext.tsx` — React context signaling embed mode (from previous
commit, unchanged)
- `AgentDetail/TopBar.tsx` — conditionally hides navigation elements in
embed mode (from previous commit, unchanged)
- Both `/agents/:agentId/embed` and `/agents/:agentId/embed/session`
routes live outside `RequireAuth`

**Auth bootstrap** (`site/src/api/queries/users.ts`):
- `bootstrapChatEmbedSessionFn` now calls `API.setSessionToken(token)`
instead of posting to a backend endpoint
- Fetches user and permissions directly via `API.getAuthenticatedUser()`
and `API.checkAuthorization()`, then sets both in the query cache
atomically — this avoids a race where `isSignedIn` flips before
permissions are loaded

**Removed** (no longer needed):
- `coderd/embedauth.go` — the `POST
/api/experimental/chats/embed-session` handler
- `coderd/embedauth_test.go` — backend tests for the endpoint
- `codersdk/embedauth.go` — `EmbedSessionTokenRequest` SDK type
- `site/src/api/api.ts` — `postChatEmbedSession` method
- `docs/user-guides/workspace-access/vscode-embed-auth.md` — doc page
for the old cookie flow
- Swagger/API doc entries for the endpoint

### Why not cookies?

The initial implementation used a backend endpoint to set an HttpOnly
session cookie. This required `SameSite=None; Secure` for cross-origin
iframes, which doesn't work over HTTP in development (Chrome requires
HTTPS for `Secure` cookies). The `setSessionToken` approach bypasses
cookies entirely — the token is set as an axios default header, and
header-based auth also naturally bypasses CSRF protection.

### Dogfooding

Tested end-to-end with a VS Code extension that:
1. Registers a `/openChat` deep link handler
(`vscode://coder.coder-remote/openChat?url=...&token=...&agentId=...`)
2. Starts a local HTTP reverse proxy (to work around VS Code webview
iframe sandboxing)
3. Loads `/agents/:agentId/embed` in an iframe through the proxy
4. Relays the postMessage handshake between the iframe and the extension
host
5. The embed page receives the token, calls `setSessionToken`, and
renders the chat

Verified: chat title, messages, and input field all display correctly in
VS Code's secondary sidebar panel.
2026-03-16 17:45:01 +01:00
Kyle Carberry 741af057dc feat: paginate chat messages endpoint with cursor-based infinite scroll (#23083)
Adds cursor-based pagination to the chat messages endpoint.

## Backend

- New `GetChatMessagesByChatIDPaginated` SQL query: returns messages in
`id DESC` order with a `before_id` keyset cursor and configurable
`limit`
- Handler parses `?before_id=N&limit=N` query params, uses the `LIMIT
N+1` trick to set `has_more` without a separate COUNT query
- Queued messages only returned on the first page (no cursor) since
they're always the most recent
- SDK client updated with `ChatMessagesPaginationOptions`
- Fully backward compatible: omitting params returns the 50 newest
messages

## Frontend

- Switches `getChatMessages` from `useQuery` to `useInfiniteQuery` with
cursor chaining via `getNextPageParam`
- Pages flattened and sorted by `id` ascending for chronological display
- `MessagesPaginationSentinel` component uses `IntersectionObserver`
(200px rootMargin prefetch) inside the existing `flex-col-reverse`
scroll container
- `flex-col-reverse` handles scroll anchoring natively when older
messages are prepended — no manual `scrollTop` adjustment needed (same
pattern as coder/blink)

## Why cursor-based instead of offset/limit

Offset-based pagination breaks when new messages arrive while paginating
backward (offsets shift, causing duplicates or missed messages). The
`before_id` cursor is stable regardless of inserts — each page is
deterministic.
2026-03-16 16:40:59 +00:00
Kyle Carberry 32a894d4a7 fix: error on ambiguous matches in edit_files tool (#23125)
## Problem

The `edit_files` tool used `strings.ReplaceAll` for exact substring
matches, silently replacing **every** occurrence. When an LLM's search
string wasn't unique in the file, this caused unintended edits. Fuzzy
matches (passes 2 and 3) only replaced the first occurrence, creating
inconsistent behavior. Zero matches were also silently ignored.

## Investigation

Investigated how **coder/mux** and **openai/codex** handle this:

| Tool | Multiple matches | No match | Flag |
|---|---|---|---|
| **coder/mux** `file_edit_replace_string` | Error (default
`replace_count=1`) | Error | `replace_count` (int, default 1, -1=all) |
| **openai/codex** `apply_patch` | Uses first match after cursor
(structural disambiguation via context lines + `@@` markers) | Error |
None (different paradigm) |
| **coder/coder** `edit_files` (before) | Exact: replaces all. Fuzzy:
replaces first. | Silent success | None |

## Solution

Adopted the mux approach (error on ambiguity) with a simpler
`replace_all: bool` instead of `replace_count: int`:

- **Default (`replace_all: false`)**: search string must match exactly
once. Multiple matches → error with guidance: *"search string matches N
occurrences. Include more surrounding context to make the match unique,
or set replace_all to true"*
- **`replace_all: true`**: replaces all occurrences (opt-in for
intentional bulk operations like variable renames)
- **Zero matches**: now returns an error instead of silently succeeding

Chose `bool` over `int` count because:
1. LLMs are bad at counting occurrences
2. The real intent is binary (one specific spot vs. all occurrences)
3. Simpler error recovery loop for the LLM

## Changes

| File | Change |
|---|---|
| `codersdk/workspacesdk/agentconn.go` | Add `ReplaceAll bool` to
`FileEdit` struct |
| `agent/agentfiles/files.go` | Count matches before replacing; error if
>1 and not opted in; error on zero matches; add `countLineMatches`
helper |
| `codersdk/toolsdk/toolsdk.go` | Expose `replace_all` in tool schema
with description |
| `agent/agentfiles/files_test.go` | Update existing tests, add
`EditEditAmbiguous`, `EditEditReplaceAll`, `NoMatchErrors`,
`AmbiguousExactMatch`, `ReplaceAllExact` |
2026-03-16 16:17:33 +00:00
Spike Curtis 4fdd48b3f5 chore: randomize task status update times in load generator (#23058)
fixes https://github.com/coder/scaletest/issues/92

Randomizes the time between task status updates so that we don't send them all at the same time for load testing.
2026-03-16 12:06:29 -04:00
Charlie Voiselle e94de0bdab fix(coderd): render HTML error page for OIDC email validation failures (#23059)
## Summary

When the email address returned from an OIDC provider doesn't match the
configured allowed domain list (or isn't verified), users previously saw
raw JSON dumped directly in the browser — an ugly and confusing
experience during a browser-redirect flow.

This PR replaces those JSON responses with the same styled static HTML
error page already used for group allow-list errors, signups-disabled,
and wrong-login-type errors.

## Changes

### `coderd/userauth.go`
Replaced 3 `httpapi.Write` calls in `userOIDC` with
`site.RenderStaticErrorPage`:

| Error case | Title shown |
|---|---|
| Email domain not in allowed list | "Unauthorized email" |
| Malformed email (no `@`) with domain restrictions | "Unauthorized
email" |
| `email_verified` is `false` | "Email not verified" |

All render HTTP 403 with `HideStatus: true` and a "Back to login" action
button.

### `coderd/userauth_test.go`
- Updated `AssertResponse` callbacks on existing table-driven tests
(`EmailNotVerified`, `NotInRequiredEmailDomain`,
`EmailDomainForbiddenWithLeadingAt`) to verify HTML Content-Type and
page content.
- Extended `TestOIDCDomainErrorMessage` to additionally assert HTML
rendering.
- Added new `TestOIDCErrorPageRendering` with 3 subtests covering all
error scenarios, verifying: HTML doctype, expected title/description,
"Back to login" link, and absence of JSON markers.

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-03-16 11:56:59 -04:00
Mathias Fredriksson fa8693605f test(provisioner/terraform/testdata): add subagent_id to UUID preservation and regenerate (#23122)
The minimize_diff function in generate.sh preserves autogenerated
values (UUIDs, tokens, etc.) across regeneration to keep diffs
minimal. The subagent_id field was missing from the preservation
list, causing unnecessary churn on devcontainer test data.

Also regenerates all testdata with the current terraform (1.14.5)
and coder provider (2.14.0).
2026-03-16 15:51:06 +00:00
Kyle Carberry af1be592cf fix: disable agent notification chime by default (#23124)
The completion chime on `/agents` was enabled by default for new users
(or when no localStorage preference existed). This changes the default
to disabled, so users must explicitly opt in via the sound toggle
button.

## Changes

- `getChimeEnabled()` now returns `false` when no preference is stored
(was `true`)
- `catch` fallback also returns `false` (was `true`)
- Updated tests to reflect the new default and explicitly enable the
chime in `maybePlayChime` tests
2026-03-16 11:49:04 -04:00
Kyle Carberry 6f97539122 fix: update sidebar diff status on WebSocket events (#23116)
## Problem

The sidebar diff status (PR icon, +additions/-deletions, file count) was
not updating in real-time. Users had to reload the page to see changes.

Two root causes:

1. **Frontend**: The `diff_status_change` WebSocket handler in
`AgentsPage.tsx` had an early `return` (line 398) that skipped
`updateInfiniteChatsCache`, so the sidebar's cache was never updated.
Even for other event types, the cache merge only spread `status` and
`title` — never `diff_status`.

2. **Server**: `publishChatPubsubEvent` in `chatd.go` constructed a
minimal `Chat` payload without `DiffStatus`, so even if the frontend
consumed the event, `updatedChat.diff_status` would be `undefined`.

## Fix

### Server (`coderd/chatd/chatd.go`)
- `publishChatPubsubEvent` now accepts an optional
`*codersdk.ChatDiffStatus` parameter; when non-nil it's set on the
outgoing `Chat` payload.
- `PublishDiffStatusChange` fetches the diff status from the DB,
converts it, and passes it through.
- Added `convertDBChatDiffStatus` (mirrors `coderd/chats.go`'s converter
to avoid circular import).
- All other callers pass `nil`.

### Frontend (`site/src/pages/AgentsPage/AgentsPage.tsx`)
- Removed the early `return` so `diff_status_change` events fall through
to the cache update logic.
- Added `isDiffStatusEvent` flag and spread `diff_status` into both the
infinite chats cache (sidebar) and the individual chat cache.
2026-03-16 15:41:32 +00:00
Kyle Carberry 530872873e chore: remove swagger annotations from experimental chat endpoints (#23120)
The `/archive` and `/desktop` chat endpoints had swagger route comments
(`@Summary`, `@ID`, `@Router`, etc.) that would cause them to appear in
generated API docs. Since these live under `/experimental/chats`, they
should not be documented.

This removes the swagger annotations and adds the standard `//
EXPERIMENTAL: this endpoint is experimental and is subject to change.`
comment to `archiveChat` (the `watchChatDesktop` handler already had it,
just needed the swagger block removed).
2026-03-16 08:41:13 -07:00
Matt Vollmer 115011bd70 docs: rename Chat API to Chats API (#23121)
Renames the page title and manifest label from "Chat API" to "Chats API"
to match the plural endpoint path (`/api/experimental/chats`).
2026-03-16 11:31:43 -04:00
315 changed files with 22704 additions and 6820 deletions
+18 -1
View File
@@ -1198,7 +1198,7 @@ jobs:
make -j \
build/coder_linux_{amd64,arm64,armv7} \
build/coder_"$version"_windows_amd64.zip \
build/coder_"$version"_linux_amd64.{tar.gz,deb}
build/coder_"$version"_linux_{amd64,arm64,armv7}.{tar.gz,deb}
env:
# The Windows and Darwin slim binaries must be signed for Coder
# Desktop to accept them.
@@ -1216,11 +1216,28 @@ jobs:
GCLOUD_ACCESS_TOKEN: ${{ steps.gcloud_auth.outputs.access_token }}
JSIGN_PATH: /tmp/jsign-6.0.jar
# Free up disk space before building Docker images. The preceding
# Build step produces ~2 GB of binaries and packages, the Go build
# cache is ~1.3 GB, and node_modules is ~500 MB. Docker image
# builds, pushes, and SBOM generation need headroom that isn't
# available without reclaiming some of that space.
- name: Clean up build cache
run: |
set -euxo pipefail
# Go caches are no longer needed — binaries are already compiled.
go clean -cache -modcache
# Remove .apk and .rpm packages that are not uploaded as
# artifacts and were only built as make prerequisites.
rm -f ./build/*.apk ./build/*.rpm
- name: Build Linux Docker images
id: build-docker
env:
CODER_IMAGE_BASE: ghcr.io/coder/coder-preview
DOCKER_CLI_EXPERIMENTAL: "enabled"
# Skip building .deb/.rpm/.apk/.tar.gz as prerequisites for
# the Docker image targets — they were already built above.
DOCKER_IMAGE_NO_PREREQUISITES: "true"
run: |
set -euxo pipefail
+38
View File
@@ -23,6 +23,44 @@ permissions:
concurrency: pr-${{ github.ref }}
jobs:
community-label:
runs-on: ubuntu-latest
permissions:
pull-requests: write
if: >-
${{
github.event_name == 'pull_request_target' &&
github.event.action == 'opened' &&
github.event.pull_request.author_association != 'MEMBER' &&
github.event.pull_request.author_association != 'COLLABORATOR' &&
github.event.pull_request.author_association != 'OWNER'
}}
steps:
- name: Add community label
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
script: |
const params = {
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
}
const labels = context.payload.pull_request.labels.map((label) => label.name)
if (labels.includes("community")) {
console.log('PR already has "community" label.')
return
}
console.log(
'Adding "community" label for author association "%s".',
context.payload.pull_request.author_association,
)
await github.rest.issues.addLabels({
...params,
labels: ["community"],
})
cla:
runs-on: ubuntu-latest
permissions:
+6 -11
View File
@@ -136,18 +136,10 @@ endif
# the search path so that these exclusions match.
FIND_EXCLUSIONS= \
-not \( \( -path '*/.git/*' -o -path './build/*' -o -path './vendor/*' -o -path './.coderv2/*' -o -path '*/node_modules/*' -o -path '*/out/*' -o -path './coderd/apidoc/*' -o -path '*/.next/*' -o -path '*/.terraform/*' -o -path './_gen/*' \) -prune \)
# Source files used for make targets, evaluated on use.
GO_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.go' -not -name '*_test.go')
# Same as GO_SRC_FILES but excluding certain files that have problematic
# Makefile dependencies (e.g. pnpm).
MOST_GO_SRC_FILES := $(shell \
find . \
$(FIND_EXCLUSIONS) \
-type f \
-name '*.go' \
-not -name '*_test.go' \
-not -wholename './agent/agentcontainers/dcspec/dcspec_gen.go' \
)
# All the shell files in the repo, excluding ignored files.
SHELL_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.sh')
@@ -514,7 +506,10 @@ install: build/coder_$(VERSION)_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
cp "$<" "$$output_file"
.PHONY: install
build/.bin/develop: go.mod go.sum $(GO_SRC_FILES)
# Only wildcard the go files in the develop directory to avoid rebuilds
# when project files are changd. Technically changes to some imports may
# not be detected, but it's unlikely to cause any issues.
build/.bin/develop: go.mod go.sum $(wildcard scripts/develop/*.go)
CGO_ENABLED=0 go build -o $@ ./scripts/develop
BOLD := $(shell tput bold 2>/dev/null)
+1 -1
View File
@@ -389,7 +389,7 @@ func (a *agent) init() {
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
desktop := agentdesktop.NewPortableDesktop(
a.logger.Named("desktop"), a.execer, a.scriptDataDir,
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(),
)
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
a.reconnectingPTYServer = reconnectingpty.NewServer(
+24 -169
View File
@@ -2,13 +2,9 @@ package agentdesktop
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
@@ -24,28 +20,6 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
portableDesktopVersion = "v0.0.4"
downloadRetries = 3
downloadRetryDelay = time.Second
)
// platformBinaries maps GOARCH to download URL and expected SHA-256
// digest for each supported platform.
var platformBinaries = map[string]struct {
URL string
SHA256 string
}{
"amd64": {
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-x64",
SHA256: "a04e05e6c7d6f2e6b3acbf1729a7b21271276300b4fee321f4ffee6136538317",
},
"arm64": {
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-arm64",
SHA256: "b8cb9142dc32d46a608f25229cbe8168ff2a3aadc54253c74ff54cd347e16ca6",
},
}
// portableDesktopOutput is the JSON output from
// `portabledesktop up --json`.
type portableDesktopOutput struct {
@@ -78,43 +52,31 @@ type screenshotOutput struct {
// portableDesktop implements Desktop by shelling out to the
// portabledesktop CLI via agentexec.Execer.
type portableDesktop struct {
logger slog.Logger
execer agentexec.Execer
dataDir string // agent's ScriptDataDir, used for binary caching
logger slog.Logger
execer agentexec.Execer
scriptBinDir string // coder script bin directory
mu sync.Mutex
session *desktopSession // nil until started
binPath string // resolved path to binary, cached
closed bool
// httpClient is used for downloading the binary. If nil,
// http.DefaultClient is used.
httpClient *http.Client
}
// NewPortableDesktop creates a Desktop backed by the portabledesktop
// CLI binary, using execer to spawn child processes. dataDir is used
// to cache the downloaded binary.
// CLI binary, using execer to spawn child processes. scriptBinDir is
// the coder script bin directory checked for the binary.
func NewPortableDesktop(
logger slog.Logger,
execer agentexec.Execer,
dataDir string,
scriptBinDir string,
) Desktop {
return &portableDesktop{
logger: logger,
execer: execer,
dataDir: dataDir,
logger: logger,
execer: execer,
scriptBinDir: scriptBinDir,
}
}
// httpDo returns the HTTP client to use for downloads.
func (p *portableDesktop) httpDo() *http.Client {
if p.httpClient != nil {
return p.httpClient
}
return http.DefaultClient
}
// Start launches the desktop session (idempotent).
func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
p.mu.Lock()
@@ -399,8 +361,8 @@ func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, e
return string(out), nil
}
// ensureBinary resolves or downloads the portabledesktop binary. It
// must be called while p.mu is held.
// ensureBinary resolves the portabledesktop binary from PATH or the
// coder script bin directory. It must be called while p.mu is held.
func (p *portableDesktop) ensureBinary(ctx context.Context) error {
if p.binPath != "" {
return nil
@@ -415,130 +377,23 @@ func (p *portableDesktop) ensureBinary(ctx context.Context) error {
return nil
}
// 2. Platform checks.
if runtime.GOOS != "linux" {
return xerrors.New("portabledesktop is only supported on Linux")
}
bin, ok := platformBinaries[runtime.GOARCH]
if !ok {
return xerrors.Errorf("unsupported architecture for portabledesktop: %s", runtime.GOARCH)
}
// 3. Check cache.
cacheDir := filepath.Join(p.dataDir, "portabledesktop", bin.SHA256)
cachedPath := filepath.Join(cacheDir, "portabledesktop")
if info, err := os.Stat(cachedPath); err == nil && !info.IsDir() {
// Verify it is executable.
if info.Mode()&0o100 != 0 {
p.logger.Info(ctx, "using cached portabledesktop binary",
slog.F("path", cachedPath),
// 2. Check the coder script bin directory.
scriptBinPath := filepath.Join(p.scriptBinDir, "portabledesktop")
if info, err := os.Stat(scriptBinPath); err == nil && !info.IsDir() {
// On Windows, permission bits don't indicate executability,
// so accept any regular file.
if runtime.GOOS == "windows" || info.Mode()&0o111 != 0 {
p.logger.Info(ctx, "found portabledesktop in script bin directory",
slog.F("path", scriptBinPath),
)
p.binPath = cachedPath
p.binPath = scriptBinPath
return nil
}
}
// 4. Download with retry.
p.logger.Info(ctx, "downloading portabledesktop binary",
slog.F("url", bin.URL),
slog.F("version", portableDesktopVersion),
slog.F("arch", runtime.GOARCH),
)
var lastErr error
for attempt := range downloadRetries {
if err := downloadBinary(ctx, p.httpDo(), bin.URL, bin.SHA256, cachedPath); err != nil {
lastErr = err
p.logger.Warn(ctx, "download attempt failed",
slog.F("attempt", attempt+1),
slog.F("max_attempts", downloadRetries),
slog.Error(err),
)
if attempt < downloadRetries-1 {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(downloadRetryDelay):
}
}
continue
}
p.binPath = cachedPath
p.logger.Info(ctx, "downloaded portabledesktop binary",
slog.F("path", cachedPath),
)
return nil
}
return xerrors.Errorf("download portabledesktop after %d attempts: %w", downloadRetries, lastErr)
}
// downloadBinary fetches a binary from url, verifies its SHA-256
// digest matches expectedSHA256, and atomically writes it to destPath.
func downloadBinary(ctx context.Context, client *http.Client, url, expectedSHA256, destPath string) error {
if err := os.MkdirAll(filepath.Dir(destPath), 0o700); err != nil {
return xerrors.Errorf("create cache directory: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return xerrors.Errorf("create HTTP request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return xerrors.Errorf("HTTP GET %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return xerrors.Errorf("HTTP GET %s: status %d", url, resp.StatusCode)
}
// Write to a temp file in the same directory so the final rename
// is atomic on the same filesystem.
tmpFile, err := os.CreateTemp(filepath.Dir(destPath), "portabledesktop-download-*")
if err != nil {
return xerrors.Errorf("create temp file: %w", err)
}
tmpPath := tmpFile.Name()
// Clean up the temp file on any error path.
success := false
defer func() {
if !success {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
}
}()
// Stream the response body while computing SHA-256.
hasher := sha256.New()
if _, err := io.Copy(tmpFile, io.TeeReader(resp.Body, hasher)); err != nil {
return xerrors.Errorf("download body: %w", err)
}
if err := tmpFile.Close(); err != nil {
return xerrors.Errorf("close temp file: %w", err)
}
// Verify digest.
actualSHA256 := hex.EncodeToString(hasher.Sum(nil))
if actualSHA256 != expectedSHA256 {
return xerrors.Errorf(
"SHA-256 mismatch: expected %s, got %s",
expectedSHA256, actualSHA256,
p.logger.Warn(ctx, "portabledesktop found in script bin directory but not executable",
slog.F("path", scriptBinPath),
slog.F("mode", info.Mode().String()),
)
}
if err := os.Chmod(tmpPath, 0o700); err != nil {
return xerrors.Errorf("chmod: %w", err)
}
if err := os.Rename(tmpPath, destPath); err != nil {
return xerrors.Errorf("rename to final path: %w", err)
}
success = true
return nil
return xerrors.New("portabledesktop binary not found in PATH or script bin directory")
}
@@ -2,11 +2,6 @@ package agentdesktop
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"path/filepath"
@@ -77,7 +72,6 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
// The "up" script prints the JSON line then sleeps until
// the context is canceled (simulating a long-running process).
@@ -88,13 +82,13 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
}
ctx := context.Background()
ctx := t.Context()
cfg, err := pd.Start(ctx)
require.NoError(t, err)
@@ -111,7 +105,6 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
@@ -120,13 +113,13 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
}
ctx := context.Background()
ctx := t.Context()
cfg1, err := pd.Start(ctx)
require.NoError(t, err)
@@ -154,7 +147,6 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
@@ -163,13 +155,13 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
}
ctx := context.Background()
ctx := t.Context()
result, err := pd.Screenshot(ctx, ScreenshotOptions{})
require.NoError(t, err)
@@ -180,7 +172,6 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
@@ -189,13 +180,13 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
}
ctx := context.Background()
ctx := t.Context()
_, err := pd.Screenshot(ctx, ScreenshotOptions{
TargetWidth: 800,
TargetHeight: 600,
@@ -287,13 +278,13 @@ func TestPortableDesktop_MouseMethods(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
}
err := tt.invoke(context.Background(), pd)
err := tt.invoke(t.Context(), pd)
require.NoError(t, err)
cmds := rec.allCommands()
@@ -372,13 +363,13 @@ func TestPortableDesktop_KeyboardMethods(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
}
err := tt.invoke(context.Background(), pd)
err := tt.invoke(t.Context(), pd)
require.NoError(t, err)
cmds := rec.allCommands()
@@ -404,13 +395,13 @@ func TestPortableDesktop_CursorPosition(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
}
x, y, err := pd.CursorPosition(context.Background())
x, y, err := pd.CursorPosition(t.Context())
require.NoError(t, err)
assert.Equal(t, 100, x)
assert.Equal(t, 200, y)
@@ -428,13 +419,13 @@ func TestPortableDesktop_Close(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
}
ctx := context.Background()
ctx := t.Context()
_, err := pd.Start(ctx)
require.NoError(t, err)
@@ -457,81 +448,6 @@ func TestPortableDesktop_Close(t *testing.T) {
assert.Contains(t, err.Error(), "desktop is closed")
}
// --- downloadBinary tests ---
func TestDownloadBinary_Success(t *testing.T) {
t.Parallel()
binaryContent := []byte("#!/bin/sh\necho portable\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
err := downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
require.NoError(t, err)
// Verify the file exists and has correct content.
got, err := os.ReadFile(destPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
// Verify executable permissions.
info, err := os.Stat(destPath)
require.NoError(t, err)
assert.NotZero(t, info.Mode()&0o700, "binary should be executable")
}
func TestDownloadBinary_ChecksumMismatch(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("real binary content"))
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
wrongSHA := "0000000000000000000000000000000000000000000000000000000000000000"
err := downloadBinary(context.Background(), srv.Client(), srv.URL, wrongSHA, destPath)
require.Error(t, err)
assert.Contains(t, err.Error(), "SHA-256 mismatch")
// The destination file should not exist (temp file cleaned up).
_, statErr := os.Stat(destPath)
assert.True(t, os.IsNotExist(statErr), "dest file should not exist after checksum failure")
// No leftover temp files in the directory.
entries, err := os.ReadDir(destDir)
require.NoError(t, err)
assert.Empty(t, entries, "no leftover temp files should remain")
}
func TestDownloadBinary_HTTPError(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
err := downloadBinary(context.Background(), srv.Client(), srv.URL, "irrelevant", destPath)
require.Error(t, err)
assert.Contains(t, err.Error(), "status 404")
}
// --- ensureBinary tests ---
func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
@@ -541,173 +457,89 @@ func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
// immediately without doing any work.
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: t.TempDir(),
binPath: "/already/set",
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: t.TempDir(),
binPath: "/already/set",
}
err := pd.ensureBinary(context.Background())
err := pd.ensureBinary(t.Context())
require.NoError(t, err)
assert.Equal(t, "/already/set", pd.binPath)
}
func TestEnsureBinary_UsesCachedBinary(t *testing.T) {
func TestEnsureBinary_UsesScriptBinDir(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
}
bin, ok := platformBinaries[runtime.GOARCH]
if !ok {
t.Skipf("no platformBinary entry for %s", runtime.GOARCH)
}
dataDir := t.TempDir()
cacheDir := filepath.Join(dataDir, "portabledesktop", bin.SHA256)
require.NoError(t, os.MkdirAll(cacheDir, 0o700))
cachedPath := filepath.Join(cacheDir, "portabledesktop")
require.NoError(t, os.WriteFile(cachedPath, []byte("#!/bin/sh\n"), 0o600))
scriptBinDir := t.TempDir()
binPath := filepath.Join(scriptBinDir, "portabledesktop")
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
require.NoError(t, os.Chmod(binPath, 0o755))
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: dataDir,
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: scriptBinDir,
}
// Clear PATH so LookPath won't find a real binary.
t.Setenv("PATH", "")
err := pd.ensureBinary(context.Background())
err := pd.ensureBinary(t.Context())
require.NoError(t, err)
assert.Equal(t, cachedPath, pd.binPath)
assert.Equal(t, binPath, pd.binPath)
}
func TestEnsureBinary_Downloads(t *testing.T) {
func TestEnsureBinary_ScriptBinDirNotExecutable(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Windows does not support Unix permission bits")
}
// Cannot use t.Parallel because t.Setenv modifies the process
// environment and we override the package-level platformBinaries.
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
}
// environment.
binaryContent := []byte("#!/bin/sh\necho downloaded\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
scriptBinDir := t.TempDir()
binPath := filepath.Join(scriptBinDir, "portabledesktop")
// Write without execute permission.
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
_ = binPath
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
// Save and restore platformBinaries for this test.
origBinaries := platformBinaries
platformBinaries = map[string]struct {
URL string
SHA256 string
}{
runtime.GOARCH: {
URL: srv.URL + "/portabledesktop",
SHA256: expectedSHA,
},
}
t.Cleanup(func() { platformBinaries = origBinaries })
dataDir := t.TempDir()
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: dataDir,
httpClient: srv.Client(),
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: scriptBinDir,
}
// Ensure PATH doesn't contain a real portabledesktop binary.
// Clear PATH so LookPath won't find a real binary.
t.Setenv("PATH", "")
err := pd.ensureBinary(context.Background())
require.NoError(t, err)
expectedPath := filepath.Join(dataDir, "portabledesktop", expectedSHA, "portabledesktop")
assert.Equal(t, expectedPath, pd.binPath)
// Verify the downloaded file has correct content.
got, err := os.ReadFile(expectedPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
err := pd.ensureBinary(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
}
func TestEnsureBinary_RetriesOnFailure(t *testing.T) {
t.Parallel()
func TestEnsureBinary_NotFound(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: t.TempDir(), // empty directory
}
binaryContent := []byte("#!/bin/sh\necho retried\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
// Clear PATH so LookPath won't find a real binary.
t.Setenv("PATH", "")
var mu sync.Mutex
attempt := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
mu.Lock()
current := attempt
attempt++
mu.Unlock()
// Fail the first 2 attempts, succeed on the third.
if current < 2 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
// Test downloadBinary directly to avoid time.Sleep in
// ensureBinary's retry loop. We call it 3 times to simulate
// what ensureBinary would do.
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
var lastErr error
for i := range 3 {
lastErr = downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
if lastErr == nil {
break
}
if i < 2 {
// In the real code, ensureBinary sleeps here.
// We skip the sleep in tests.
continue
}
}
require.NoError(t, lastErr, "download should succeed on the third attempt")
got, err := os.ReadFile(destPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
mu.Lock()
assert.Equal(t, 3, attempt, "server should have been hit 3 times")
mu.Unlock()
err := pd.ensureBinary(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
}
// Ensure that portableDesktop satisfies the Desktop interface at
// compile time. This uses the unexported type so it lives in the
// internal test package.
var _ Desktop = (*portableDesktop)(nil)
// Silence the linter about unused imports — agentexec.DefaultExecer
// is used in TestEnsureBinary_UsesCachedBinPath and others, and
// fmt.Sscanf is used indirectly via the implementation.
var (
_ = agentexec.DefaultExecer
_ = fmt.Sprintf
)
+89 -38
View File
@@ -447,13 +447,10 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
content := string(data)
for _, edit := range edits {
var ok bool
content, ok = fuzzyReplace(content, edit.Search, edit.Replace)
if !ok {
api.logger.Warn(ctx, "edit search string not found, skipping",
slog.F("path", path),
slog.F("search_preview", truncate(edit.Search, 64)),
)
var err error
content, err = fuzzyReplace(content, edit)
if err != nil {
return http.StatusBadRequest, xerrors.Errorf("edit %s: %w", path, err)
}
}
@@ -480,51 +477,92 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
return 0, nil
}
// fuzzyReplace attempts to find `search` inside `content` and replace its first
// occurrence with `replace`. It uses a cascading match strategy inspired by
// fuzzyReplace attempts to find `search` inside `content` and replace it
// with `replace`. It uses a cascading match strategy inspired by
// openai/codex's apply_patch:
//
// 1. Exact substring match (byte-for-byte).
// 2. Line-by-line match ignoring trailing whitespace on each line.
// 3. Line-by-line match ignoring all leading/trailing whitespace (indentation-tolerant).
// 3. Line-by-line match ignoring all leading/trailing whitespace
// (indentation-tolerant).
//
// When a fuzzy match is found (passes 2 or 3), the replacement is still applied
// at the byte offsets of the original content so that surrounding text (including
// indentation of untouched lines) is preserved.
// When edit.ReplaceAll is false (the default), the search string must
// match exactly one location. If multiple matches are found, an error
// is returned asking the caller to include more context or set
// replace_all.
//
// Returns the (possibly modified) content and a bool indicating whether a match
// was found.
func fuzzyReplace(content, search, replace string) (string, bool) {
// Pass 1 exact substring (replace all occurrences).
// When a fuzzy match is found (passes 2 or 3), the replacement is still
// applied at the byte offsets of the original content so that surrounding
// text (including indentation of untouched lines) is preserved.
func fuzzyReplace(content string, edit workspacesdk.FileEdit) (string, error) {
search := edit.Search
replace := edit.Replace
// Pass 1 exact substring match.
if strings.Contains(content, search) {
return strings.ReplaceAll(content, search, replace), true
if edit.ReplaceAll {
return strings.ReplaceAll(content, search, replace), nil
}
count := strings.Count(content, search)
if count > 1 {
return "", xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
// Exactly one match.
return strings.Replace(content, search, replace, 1), nil
}
// For line-level fuzzy matching we split both content and search into lines.
// For line-level fuzzy matching we split both content and search
// into lines.
contentLines := strings.SplitAfter(content, "\n")
searchLines := strings.SplitAfter(search, "\n")
// A trailing newline in the search produces an empty final element from
// SplitAfter. Drop it so it doesn't interfere with line matching.
// A trailing newline in the search produces an empty final element
// from SplitAfter. Drop it so it doesn't interfere with line
// matching.
if len(searchLines) > 0 && searchLines[len(searchLines)-1] == "" {
searchLines = searchLines[:len(searchLines)-1]
}
// Pass 2 trim trailing whitespace on each line.
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
trimRight := func(a, b string) bool {
return strings.TrimRight(a, " \t\r\n") == strings.TrimRight(b, " \t\r\n")
}); ok {
return spliceLines(contentLines, start, end, replace), true
}
// Pass 3 trim all leading and trailing whitespace (indentation-tolerant).
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
trimAll := func(a, b string) bool {
return strings.TrimSpace(a) == strings.TrimSpace(b)
}); ok {
return spliceLines(contentLines, start, end, replace), true
}
return content, false
// Pass 2 trim trailing whitespace on each line.
if start, end, ok := seekLines(contentLines, searchLines, trimRight); ok {
if !edit.ReplaceAll {
if count := countLineMatches(contentLines, searchLines, trimRight); count > 1 {
return "", xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
}
return spliceLines(contentLines, start, end, replace), nil
}
// Pass 3 trim all leading and trailing whitespace
// (indentation-tolerant).
if start, end, ok := seekLines(contentLines, searchLines, trimAll); ok {
if !edit.ReplaceAll {
if count := countLineMatches(contentLines, searchLines, trimAll); count > 1 {
return "", xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
}
return spliceLines(contentLines, start, end, replace), nil
}
return "", xerrors.New("search string not found in file. Verify the search " +
"string matches the file content exactly, including whitespace " +
"and indentation")
}
// seekLines scans contentLines looking for a contiguous subsequence that matches
@@ -549,6 +587,26 @@ outer:
return 0, 0, false
}
// countLineMatches counts how many non-overlapping contiguous
// subsequences of contentLines match searchLines according to eq.
func countLineMatches(contentLines, searchLines []string, eq func(a, b string) bool) int {
count := 0
if len(searchLines) == 0 || len(searchLines) > len(contentLines) {
return count
}
outer:
for i := 0; i <= len(contentLines)-len(searchLines); i++ {
for j, sLine := range searchLines {
if !eq(contentLines[i+j], sLine) {
continue outer
}
}
count++
i += len(searchLines) - 1 // skip past this match
}
return count
}
// spliceLines replaces contentLines[start:end] with replacement text, returning
// the full content as a single string.
func spliceLines(contentLines []string, start, end int, replacement string) string {
@@ -562,10 +620,3 @@ func spliceLines(contentLines []string, start, end int, replacement string) stri
}
return b.String()
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}
+68 -3
View File
@@ -576,7 +576,9 @@ func TestEditFiles(t *testing.T) {
expected: map[string]string{filepath.Join(tmpdir, "edit1"): "bar bar"},
},
{
name: "EditEdit", // Edits affect previous edits.
// When the second edit creates ambiguity (two "bar"
// occurrences), it should fail.
name: "EditEditAmbiguous",
contents: map[string]string{filepath.Join(tmpdir, "edit-edit"): "foo bar"},
edits: []workspacesdk.FileEdits{
{
@@ -593,7 +595,33 @@ func TestEditFiles(t *testing.T) {
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "edit-edit"): "qux qux"},
errCode: http.StatusBadRequest,
errors: []string{"matches 2 occurrences"},
// File should not be modified on error.
expected: map[string]string{filepath.Join(tmpdir, "edit-edit"): "foo bar"},
},
{
// With replace_all the cascading edit replaces
// both occurrences.
name: "EditEditReplaceAll",
contents: map[string]string{filepath.Join(tmpdir, "edit-edit-ra"): "foo bar"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "edit-edit-ra"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "bar",
},
{
Search: "bar",
Replace: "qux",
ReplaceAll: true,
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "edit-edit-ra"): "qux qux"},
},
{
name: "Multiline",
@@ -720,7 +748,7 @@ func TestEditFiles(t *testing.T) {
expected: map[string]string{filepath.Join(tmpdir, "exact-preferred"): "goodbye world"},
},
{
name: "NoMatchStillSucceeds",
name: "NoMatchErrors",
contents: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
edits: []workspacesdk.FileEdits{
{
@@ -733,9 +761,46 @@ func TestEditFiles(t *testing.T) {
},
},
},
errCode: http.StatusBadRequest,
errors: []string{"search string not found in file"},
// File should remain unchanged.
expected: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
},
{
name: "AmbiguousExactMatch",
contents: map[string]string{filepath.Join(tmpdir, "ambig-exact"): "foo bar foo baz foo"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "ambig-exact"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "qux",
},
},
},
},
errCode: http.StatusBadRequest,
errors: []string{"matches 3 occurrences"},
expected: map[string]string{filepath.Join(tmpdir, "ambig-exact"): "foo bar foo baz foo"},
},
{
name: "ReplaceAllExact",
contents: map[string]string{filepath.Join(tmpdir, "ra-exact"): "foo bar foo baz foo"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "ra-exact"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "qux",
ReplaceAll: true,
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "ra-exact"): "qux bar qux baz qux"},
},
{
name: "MixedWhitespaceMultiline",
contents: map[string]string{filepath.Join(tmpdir, "mixed-ws"): "func main() {\n\tresult := compute()\n\tfmt.Println(result)\n}"},
+26
View File
@@ -0,0 +1,26 @@
//go:build !windows
package agentproc
import (
"os"
"syscall"
)
// procSysProcAttr returns the SysProcAttr to use when spawning
// processes. On Unix, Setpgid creates a new process group so
// that signals can be delivered to the entire group (the shell
// and all its children).
func procSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
Setpgid: true,
}
}
// signalProcess sends a signal to the process group rooted at p.
// Using the negative PID sends the signal to every process in the
// group, ensuring child processes (e.g. from shell pipelines) are
// also signaled.
func signalProcess(p *os.Process, sig syscall.Signal) error {
return syscall.Kill(-p.Pid, sig)
}
+20
View File
@@ -0,0 +1,20 @@
package agentproc
import (
"os"
"syscall"
)
// procSysProcAttr returns the SysProcAttr to use when spawning
// processes. On Windows, process groups are not supported in the
// same way as Unix, so this returns an empty struct.
func procSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{}
}
// signalProcess sends a signal directly to the process. Windows
// does not support process group signaling, so we fall back to
// sending the signal to the process itself.
func signalProcess(p *os.Process, _ syscall.Signal) error {
return p.Kill()
}
+7 -4
View File
@@ -113,6 +113,7 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
cmd.Dir = req.WorkDir
}
cmd.Stdin = nil
cmd.SysProcAttr = procSysProcAttr()
// WaitDelay ensures cmd.Wait returns promptly after
// the process is killed, even if child processes are
@@ -272,13 +273,15 @@ func (m *manager) signal(id string, sig string) error {
switch sig {
case "kill":
if err := proc.cmd.Process.Kill(); err != nil {
// Use process group kill to ensure child processes
// (e.g. from shell pipelines) are also killed.
if err := signalProcess(proc.cmd.Process, syscall.SIGKILL); err != nil {
return xerrors.Errorf("kill process: %w", err)
}
case "terminate":
//nolint:revive // syscall.SIGTERM is portable enough
// for our supported platforms.
if err := proc.cmd.Process.Signal(syscall.SIGTERM); err != nil {
// Use process group signal to ensure child processes
// are also terminated.
if err := signalProcess(proc.cmd.Process, syscall.SIGTERM); err != nil {
return xerrors.Errorf("terminate process: %w", err)
}
default:
+15
View File
@@ -46,6 +46,7 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
autoUpdates string
copyParametersFrom string
useParameterDefaults bool
noWait bool
// Organization context is only required if more than 1 template
// shares the same name across multiple organizations.
orgContext = NewOrganizationContext()
@@ -372,6 +373,14 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
cliutil.WarnMatchedProvisioners(inv.Stderr, workspace.LatestBuild.MatchedProvisioners, workspace.LatestBuild.Job)
if noWait {
_, _ = fmt.Fprintf(inv.Stdout,
"\nThe %s workspace has been created and is building in the background.\n",
cliui.Keyword(workspace.Name),
)
return nil
}
err = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, workspace.LatestBuild.ID)
if err != nil {
return xerrors.Errorf("watch build: %w", err)
@@ -445,6 +454,12 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
Description: "Automatically accept parameter defaults when no value is provided.",
Value: serpent.BoolOf(&useParameterDefaults),
},
serpent.Option{
Flag: "no-wait",
Env: "CODER_CREATE_NO_WAIT",
Description: "Return immediately after creating the workspace. The build will run in the background.",
Value: serpent.BoolOf(&noWait),
},
cliui.SkipPromptOption(),
)
cmd.Options = append(cmd.Options, parameterFlags.cliParameters()...)
+75
View File
@@ -603,6 +603,81 @@ func TestCreate(t *testing.T) {
assert.Nil(t, ws.AutostartSchedule, "expected workspace autostart schedule to be nil")
}
})
t.Run("NoWait", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
ctx := testutil.Context(t, testutil.WaitLong)
inv, root := clitest.New(t, "create", "my-workspace",
"--template", template.Name,
"-y",
"--no-wait",
)
clitest.SetupConfig(t, member, root)
doneChan := make(chan struct{})
pty := ptytest.New(t).Attach(inv)
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
pty.ExpectMatchContext(ctx, "building in the background")
_ = testutil.TryReceive(ctx, t, doneChan)
// Verify workspace was actually created.
ws, err := member.WorkspaceByOwnerAndName(ctx, codersdk.Me, "my-workspace", codersdk.WorkspaceOptions{})
require.NoError(t, err)
assert.Equal(t, ws.TemplateName, template.Name)
})
t.Run("NoWaitWithParameterDefaults", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses([]*proto.RichParameter{
{Name: "region", Type: "string", DefaultValue: "us-east-1"},
{Name: "instance_type", Type: "string", DefaultValue: "t3.micro"},
}))
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
ctx := testutil.Context(t, testutil.WaitLong)
inv, root := clitest.New(t, "create", "my-workspace",
"--template", template.Name,
"-y",
"--use-parameter-defaults",
"--no-wait",
)
clitest.SetupConfig(t, member, root)
doneChan := make(chan struct{})
pty := ptytest.New(t).Attach(inv)
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
pty.ExpectMatchContext(ctx, "building in the background")
_ = testutil.TryReceive(ctx, t, doneChan)
// Verify workspace was created and parameters were applied.
ws, err := member.WorkspaceByOwnerAndName(ctx, codersdk.Me, "my-workspace", codersdk.WorkspaceOptions{})
require.NoError(t, err)
assert.Equal(t, ws.TemplateName, template.Name)
buildParams, err := member.WorkspaceBuildParameters(ctx, ws.LatestBuild.ID)
require.NoError(t, err)
assert.Contains(t, buildParams, codersdk.WorkspaceBuildParameter{Name: "region", Value: "us-east-1"})
assert.Contains(t, buildParams, codersdk.WorkspaceBuildParameter{Name: "instance_type", Value: "t3.micro"})
})
}
func prepareEchoResponses(parameters []*proto.RichParameter, presets ...*proto.Preset) *echo.Responses {
+6
View File
@@ -1000,6 +1000,12 @@ func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool
Properties: sdkTool.Schema.Properties,
Required: sdkTool.Schema.Required,
},
Annotations: mcp.ToolAnnotation{
ReadOnlyHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.ReadOnlyHint),
DestructiveHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.DestructiveHint),
IdempotentHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.IdempotentHint),
OpenWorldHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.OpenWorldHint),
},
},
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
var buf bytes.Buffer
+16 -1
View File
@@ -81,7 +81,13 @@ func TestExpMcpServer(t *testing.T) {
var toolsResponse struct {
Result struct {
Tools []struct {
Name string `json:"name"`
Name string `json:"name"`
Annotations struct {
ReadOnlyHint *bool `json:"readOnlyHint"`
DestructiveHint *bool `json:"destructiveHint"`
IdempotentHint *bool `json:"idempotentHint"`
OpenWorldHint *bool `json:"openWorldHint"`
} `json:"annotations"`
} `json:"tools"`
} `json:"result"`
}
@@ -94,6 +100,15 @@ func TestExpMcpServer(t *testing.T) {
}
slices.Sort(foundTools)
require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools)
annotations := toolsResponse.Result.Tools[0].Annotations
require.NotNil(t, annotations.ReadOnlyHint)
require.NotNil(t, annotations.DestructiveHint)
require.NotNil(t, annotations.IdempotentHint)
require.NotNil(t, annotations.OpenWorldHint)
assert.True(t, *annotations.ReadOnlyHint)
assert.False(t, *annotations.DestructiveHint)
assert.True(t, *annotations.IdempotentHint)
assert.False(t, *annotations.OpenWorldHint)
// Call the tool and ensure it works.
toolPayload := `{"jsonrpc":"2.0","id":3,"method":"tools/call", "params": {"name": "coder_get_authenticated_user", "arguments": {}}}`
+1 -1
View File
@@ -214,7 +214,7 @@ func (r *RootCmd) createOrganizationRole(orgContext *OrganizationContext) *serpe
} else {
updated, err = client.CreateOrganizationRole(ctx, customRole)
if err != nil {
return xerrors.Errorf("patch role: %w", err)
return xerrors.Errorf("create role: %w", err)
}
}
+26 -17
View File
@@ -113,6 +113,20 @@ func (r *RootCmd) supportBundle() *serpent.Command {
)
cliLog.Debug(inv.Context(), "invocation", slog.F("args", strings.Join(os.Args, " ")))
// Bypass rate limiting for support bundle collection since it makes many API calls.
// Note: this can only be done by the owner user.
if ok, err := support.CanGenerateFull(inv.Context(), client); err == nil && ok {
cliLog.Debug(inv.Context(), "running as owner")
client.HTTPClient.Transport = &codersdk.HeaderTransport{
Transport: client.HTTPClient.Transport,
Header: http.Header{codersdk.BypassRatelimitHeader: {"true"}},
}
} else if !ok {
cliLog.Warn(inv.Context(), "not running as owner, not all information available")
} else {
cliLog.Error(inv.Context(), "failed to look up current user", slog.Error(err))
}
// Check if we're running inside a workspace
if val, found := os.LookupEnv("CODER"); found && val == "true" {
cliui.Warn(inv.Stderr, "Running inside Coder workspace; this can affect results!")
@@ -200,12 +214,6 @@ func (r *RootCmd) supportBundle() *serpent.Command {
_, _ = fmt.Fprintln(inv.Stderr, "pprof data collection will take approximately 30 seconds...")
}
// Bypass rate limiting for support bundle collection since it makes many API calls.
client.HTTPClient.Transport = &codersdk.HeaderTransport{
Transport: client.HTTPClient.Transport,
Header: http.Header{codersdk.BypassRatelimitHeader: {"true"}},
}
deps := support.Deps{
Client: client,
// Support adds a sink so we don't need to supply one ourselves.
@@ -354,19 +362,20 @@ func summarizeBundle(inv *serpent.Invocation, bun *support.Bundle) {
return
}
if bun.Deployment.Config == nil {
cliui.Error(inv.Stdout, "No deployment configuration available!")
return
var docsURL string
if bun.Deployment.Config != nil {
docsURL = bun.Deployment.Config.Values.DocsURL.String()
} else {
cliui.Warn(inv.Stdout, "No deployment configuration available. This may require the Owner role.")
}
docsURL := bun.Deployment.Config.Values.DocsURL.String()
if bun.Deployment.HealthReport == nil {
cliui.Error(inv.Stdout, "No deployment health report available!")
return
}
deployHealthSummary := bun.Deployment.HealthReport.Summarize(docsURL)
if len(deployHealthSummary) > 0 {
cliui.Warn(inv.Stdout, "Deployment health issues detected:", deployHealthSummary...)
if bun.Deployment.HealthReport != nil {
deployHealthSummary := bun.Deployment.HealthReport.Summarize(docsURL)
if len(deployHealthSummary) > 0 {
cliui.Warn(inv.Stdout, "Deployment health issues detected:", deployHealthSummary...)
}
} else {
cliui.Warn(inv.Stdout, "No deployment health report available.")
}
if bun.Network.Netcheck == nil {
+30 -3
View File
@@ -132,12 +132,35 @@ func TestSupportBundle(t *testing.T) {
assertBundleContents(t, path, true, false, []string{secretValue})
})
t.Run("NoPrivilege", func(t *testing.T) {
t.Run("MemberCanGenerateBundle", func(t *testing.T) {
t.Parallel()
inv, root := clitest.New(t, "support", "bundle", memberWorkspace.Workspace.Name, "--yes")
d := t.TempDir()
path := filepath.Join(d, "bundle.zip")
inv, root := clitest.New(t, "support", "bundle", memberWorkspace.Workspace.Name, "--output-file", path, "--yes")
clitest.SetupConfig(t, memberClient, root)
err := inv.Run()
require.ErrorContains(t, err, "failed authorization check")
require.NoError(t, err)
r, err := zip.OpenReader(path)
require.NoError(t, err, "open zip file")
defer r.Close()
fileNames := make(map[string]struct{}, len(r.File))
for _, f := range r.File {
fileNames[f.Name] = struct{}{}
}
// These should always be present in the zip structure, even if
// the content is null/empty for non-admin users.
for _, name := range []string{
"deployment/buildinfo.json",
"deployment/config.json",
"workspace/workspace.json",
"logs.txt",
"cli_logs.txt",
"network/netcheck.json",
"network/interfaces.json",
} {
require.Contains(t, fileNames, name)
}
})
// This ensures that the CLI does not panic when trying to generate a support bundle
@@ -159,6 +182,10 @@ func TestSupportBundle(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("received request: %s %s", r.Method, r.URL)
switch r.URL.Path {
case "/api/v2/users/me":
resp := codersdk.User{}
w.WriteHeader(http.StatusOK)
assert.NoError(t, json.NewEncoder(w).Encode(resp))
case "/api/v2/authcheck":
// Fake auth check
resp := codersdk.AuthorizationResponse{
+4
View File
@@ -20,6 +20,10 @@ OPTIONS:
--copy-parameters-from string, $CODER_WORKSPACE_COPY_PARAMETERS_FROM
Specify the source workspace name to copy parameters from.
--no-wait bool, $CODER_CREATE_NO_WAIT
Return immediately after creating the workspace. The build will run in
the background.
--parameter string-array, $CODER_RICH_PARAMETER
Rich parameter value in the format "name=value".
+1 -1
View File
@@ -7,7 +7,7 @@
"last_seen_at": "====[timestamp]=====",
"name": "test-daemon",
"version": "v0.0.0-devel",
"api_version": "1.15",
"api_version": "1.16",
"provisioners": [
"echo"
],
+4
View File
@@ -24,6 +24,10 @@ OPTIONS:
-p, --password string
Specifies a password for the new user.
--service-account bool
Create a user account intended to be used by a service or as an
intermediary rather than by a human.
-u, --username string
Specifies a username for the new user.
+5
View File
@@ -752,6 +752,11 @@ workspace_prebuilds:
# limit; disabled when set to zero.
# (default: 3, type: int)
failure_hard_limit: 3
# Configure the background chat processing daemon.
chat:
# How many pending chats a worker should acquire per polling cycle.
# (default: 10, type: int)
acquireBatchSize: 10
aibridge:
# Whether to start an in-memory aibridged instance.
# (default: false, type: bool)
+37 -12
View File
@@ -17,13 +17,14 @@ import (
func (r *RootCmd) userCreate() *serpent.Command {
var (
email string
username string
name string
password string
disableLogin bool
loginType string
orgContext = NewOrganizationContext()
email string
username string
name string
password string
disableLogin bool
loginType string
serviceAccount bool
orgContext = NewOrganizationContext()
)
cmd := &serpent.Command{
Use: "create",
@@ -32,6 +33,23 @@ func (r *RootCmd) userCreate() *serpent.Command {
serpent.RequireNArgs(0),
),
Handler: func(inv *serpent.Invocation) error {
if serviceAccount {
switch {
case loginType != "":
return xerrors.New("You cannot use --login-type with --service-account")
case password != "":
return xerrors.New("You cannot use --password with --service-account")
case email != "":
return xerrors.New("You cannot use --email with --service-account")
case disableLogin:
return xerrors.New("You cannot use --disable-login with --service-account")
}
}
if disableLogin && loginType != "" {
return xerrors.New("You cannot specify both --disable-login and --login-type")
}
client, err := r.InitClient(inv)
if err != nil {
return err
@@ -59,7 +77,7 @@ func (r *RootCmd) userCreate() *serpent.Command {
return err
}
}
if email == "" {
if email == "" && !serviceAccount {
email, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "Email:",
Validate: func(s string) error {
@@ -87,10 +105,7 @@ func (r *RootCmd) userCreate() *serpent.Command {
}
}
userLoginType := codersdk.LoginTypePassword
if disableLogin && loginType != "" {
return xerrors.New("You cannot specify both --disable-login and --login-type")
}
if disableLogin {
if disableLogin || serviceAccount {
userLoginType = codersdk.LoginTypeNone
} else if loginType != "" {
userLoginType = codersdk.LoginType(loginType)
@@ -111,6 +126,7 @@ func (r *RootCmd) userCreate() *serpent.Command {
Password: password,
OrganizationIDs: []uuid.UUID{organization.ID},
UserLoginType: userLoginType,
ServiceAccount: serviceAccount,
})
if err != nil {
return err
@@ -127,6 +143,10 @@ func (r *RootCmd) userCreate() *serpent.Command {
case codersdk.LoginTypeOIDC:
authenticationMethod = `Login is authenticated through the configured OIDC provider.`
}
if serviceAccount {
email = "n/a"
authenticationMethod = "Service accounts must authenticate with a token and cannot log in."
}
_, _ = fmt.Fprintln(inv.Stderr, `A new user has been created!
Share the instructions below to get them started.
@@ -194,6 +214,11 @@ Create a workspace `+pretty.Sprint(cliui.DefaultStyles.Code, "coder create")+`!
)),
Value: serpent.StringOf(&loginType),
},
{
Flag: "service-account",
Description: "Create a user account intended to be used by a service or as an intermediary rather than by a human.",
Value: serpent.BoolOf(&serviceAccount),
},
}
orgContext.AttachOptions(cmd)
+53
View File
@@ -8,6 +8,7 @@ import (
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
@@ -124,4 +125,56 @@ func TestUserCreate(t *testing.T) {
assert.Equal(t, args[5], created.Username)
assert.Empty(t, created.Name)
})
tests := []struct {
name string
args []string
err string
}{
{
name: "ServiceAccount",
args: []string{"--service-account", "-u", "dean"},
},
{
name: "ServiceAccountLoginType",
args: []string{"--service-account", "-u", "dean", "--login-type", "none"},
err: "You cannot use --login-type with --service-account",
},
{
name: "ServiceAccountDisableLogin",
args: []string{"--service-account", "-u", "dean", "--disable-login"},
err: "You cannot use --disable-login with --service-account",
},
{
name: "ServiceAccountEmail",
args: []string{"--service-account", "-u", "dean", "--email", "dean@coder.com"},
err: "You cannot use --email with --service-account",
},
{
name: "ServiceAccountPassword",
args: []string{"--service-account", "-u", "dean", "--password", "1n5ecureP4ssw0rd!"},
err: "You cannot use --password with --service-account",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
coderdtest.CreateFirstUser(t, client)
inv, root := clitest.New(t, append([]string{"users", "create"}, tt.args...)...)
clitest.SetupConfig(t, client, root)
err := inv.Run()
if tt.err == "" {
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitShort)
created, err := client.User(ctx, "dean")
require.NoError(t, err)
assert.Equal(t, codersdk.LoginTypeNone, created.LoginType)
} else {
require.Error(t, err)
require.ErrorContains(t, err, tt.err)
}
})
}
}
+38
View File
@@ -0,0 +1,38 @@
// Package aiseats is the AGPL version the package.
// The actual implementation is in `enterprise/aiseats`.
package aiseats
import (
"context"
"github.com/google/uuid"
"github.com/coder/coder/v2/coderd/database"
)
type Reason struct {
EventType database.AiSeatUsageReason
Description string
}
// ReasonAIBridge constructs a reason for usage originating from AI Bridge.
func ReasonAIBridge(description string) Reason {
return Reason{EventType: database.AiSeatUsageReasonAibridge, Description: description}
}
// ReasonTask constructs a reason for usage originating from tasks.
func ReasonTask(description string) Reason {
return Reason{EventType: database.AiSeatUsageReasonTask, Description: description}
}
// SeatTracker records AI seat consumption state.
type SeatTracker interface {
// RecordUsage does not return an error to prevent blocking the user from using
// AI features. This method is used to record usage, not enforce it.
RecordUsage(ctx context.Context, userID uuid.UUID, reason Reason)
}
// Noop is an AGPL seat tracker that does nothing.
type Noop struct{}
func (Noop) RecordUsage(context.Context, uuid.UUID, Reason) {}
+276 -27
View File
@@ -481,46 +481,47 @@ const docTemplate = `{
}
}
},
"/chats/{chat}/archive": {
"post": {
"tags": [
"Chats"
],
"summary": "Archive a chat",
"operationId": "archive-chat",
"responses": {
"204": {
"description": "No Content"
}
}
}
},
"/chats/{chat}/desktop": {
"/chats/insights/pull-requests": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": [
"application/json"
],
"tags": [
"Chats"
],
"summary": "Watch chat desktop",
"operationId": "watch-chat-desktop",
"summary": "Get PR insights",
"operationId": "get-pr-insights",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Chat ID",
"name": "chat",
"in": "path",
"description": "Start date (RFC3339)",
"name": "start_date",
"in": "query",
"required": true
},
{
"type": "string",
"description": "End date (RFC3339)",
"name": "end_date",
"in": "query",
"required": true
}
],
"responses": {
"101": {
"description": "Switching Protocols"
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.PRInsightsResponse"
}
}
},
"x-apidocgen": {
"skip": true
}
}
},
@@ -4862,7 +4863,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
"$ref": "#/definitions/codersdk.UpdateWorkspaceSharingSettingsRequest"
}
}
],
@@ -4870,7 +4871,7 @@ const docTemplate = `{
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.UpdateWorkspaceSharingSettingsRequest"
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
}
}
}
@@ -12757,6 +12758,9 @@ const docTemplate = `{
},
"bridge": {
"$ref": "#/definitions/codersdk.AIBridgeConfig"
},
"chat": {
"$ref": "#/definitions/codersdk.ChatConfig"
}
}
},
@@ -13814,6 +13818,14 @@ const docTemplate = `{
}
}
},
"codersdk.ChatConfig": {
"type": "object",
"properties": {
"acquire_batch_size": {
"type": "integer"
}
}
},
"codersdk.ConnectionLatency": {
"type": "object",
"properties": {
@@ -17140,6 +17152,191 @@ const docTemplate = `{
}
}
},
"codersdk.PRInsightsModelBreakdown": {
"type": "object",
"properties": {
"cost_per_merged_pr_micros": {
"type": "integer"
},
"display_name": {
"type": "string"
},
"merge_rate": {
"type": "number"
},
"merged_prs": {
"type": "integer"
},
"model_config_id": {
"type": "string",
"format": "uuid"
},
"provider": {
"type": "string"
},
"total_additions": {
"type": "integer"
},
"total_cost_micros": {
"type": "integer"
},
"total_deletions": {
"type": "integer"
},
"total_prs": {
"type": "integer"
}
}
},
"codersdk.PRInsightsPullRequest": {
"type": "object",
"properties": {
"additions": {
"type": "integer"
},
"approved": {
"type": "boolean"
},
"author_avatar_url": {
"type": "string"
},
"author_login": {
"type": "string"
},
"base_branch": {
"type": "string"
},
"changed_files": {
"type": "integer"
},
"changes_requested": {
"type": "boolean"
},
"chat_id": {
"type": "string",
"format": "uuid"
},
"commits": {
"type": "integer"
},
"cost_micros": {
"type": "integer"
},
"created_at": {
"type": "string",
"format": "date-time"
},
"deletions": {
"type": "integer"
},
"draft": {
"type": "boolean"
},
"model_display_name": {
"type": "string"
},
"pr_number": {
"type": "integer"
},
"pr_title": {
"type": "string"
},
"pr_url": {
"type": "string"
},
"reviewer_count": {
"type": "integer"
},
"state": {
"type": "string"
}
}
},
"codersdk.PRInsightsResponse": {
"type": "object",
"properties": {
"by_model": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.PRInsightsModelBreakdown"
}
},
"recent_prs": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.PRInsightsPullRequest"
}
},
"summary": {
"$ref": "#/definitions/codersdk.PRInsightsSummary"
},
"time_series": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.PRInsightsTimeSeriesEntry"
}
}
}
},
"codersdk.PRInsightsSummary": {
"type": "object",
"properties": {
"approval_rate": {
"type": "number"
},
"cost_per_merged_pr_micros": {
"type": "integer"
},
"merge_rate": {
"type": "number"
},
"prev_cost_per_merged_pr_micros": {
"type": "integer"
},
"prev_merge_rate": {
"type": "number"
},
"prev_total_prs_created": {
"type": "integer"
},
"prev_total_prs_merged": {
"type": "integer"
},
"total_additions": {
"type": "integer"
},
"total_cost_micros": {
"type": "integer"
},
"total_deletions": {
"type": "integer"
},
"total_prs_created": {
"type": "integer"
},
"total_prs_merged": {
"type": "integer"
}
}
},
"codersdk.PRInsightsTimeSeriesEntry": {
"type": "object",
"properties": {
"date": {
"type": "string",
"format": "date-time"
},
"prs_closed": {
"type": "integer"
},
"prs_created": {
"type": "integer"
},
"prs_merged": {
"type": "integer"
}
}
},
"codersdk.PaginatedMembersResponse": {
"type": "object",
"properties": {
@@ -18353,6 +18550,9 @@ const docTemplate = `{
"type": "string",
"format": "uuid"
},
"is_service_account": {
"type": "boolean"
},
"last_seen_at": {
"type": "string",
"format": "date-time"
@@ -18521,7 +18721,8 @@ const docTemplate = `{
"idp_sync_settings_role",
"workspace_agent",
"workspace_app",
"task"
"task",
"ai_seat"
],
"x-enum-varnames": [
"ResourceTypeTemplate",
@@ -18549,7 +18750,8 @@ const docTemplate = `{
"ResourceTypeIdpSyncSettingsRole",
"ResourceTypeWorkspaceAgent",
"ResourceTypeWorkspaceApp",
"ResourceTypeTask"
"ResourceTypeTask",
"ResourceTypeAISeat"
]
},
"codersdk.Response": {
@@ -18761,6 +18963,19 @@ const docTemplate = `{
}
}
},
"codersdk.ShareableWorkspaceOwners": {
"type": "string",
"enum": [
"none",
"everyone",
"service_accounts"
],
"x-enum-varnames": [
"ShareableWorkspaceOwnersNone",
"ShareableWorkspaceOwnersEveryone",
"ShareableWorkspaceOwnersServiceAccounts"
]
},
"codersdk.SharedWorkspaceActor": {
"type": "object",
"properties": {
@@ -19659,6 +19874,9 @@ const docTemplate = `{
"type": "string",
"format": "uuid"
},
"is_service_account": {
"type": "boolean"
},
"last_seen_at": {
"type": "string",
"format": "date-time"
@@ -20369,7 +20587,21 @@ const docTemplate = `{
"codersdk.UpdateWorkspaceSharingSettingsRequest": {
"type": "object",
"properties": {
"shareable_workspace_owners": {
"description": "ShareableWorkspaceOwners controls whose workspaces can be shared\nwithin the organization.",
"enum": [
"none",
"everyone",
"service_accounts"
],
"allOf": [
{
"$ref": "#/definitions/codersdk.ShareableWorkspaceOwners"
}
]
},
"sharing_disabled": {
"description": "SharingDisabled is deprecated and left for backward compatibility\npurposes.\nDeprecated: use ` + "`" + `ShareableWorkspaceOwners` + "`" + ` instead",
"type": "boolean"
}
}
@@ -20491,6 +20723,9 @@ const docTemplate = `{
"type": "string",
"format": "uuid"
},
"is_service_account": {
"type": "boolean"
},
"last_seen_at": {
"type": "string",
"format": "date-time"
@@ -22210,7 +22445,21 @@ const docTemplate = `{
"codersdk.WorkspaceSharingSettings": {
"type": "object",
"properties": {
"shareable_workspace_owners": {
"description": "ShareableWorkspaceOwners controls whose workspaces can be shared\nwithin the organization.",
"enum": [
"none",
"everyone",
"service_accounts"
],
"allOf": [
{
"$ref": "#/definitions/codersdk.ShareableWorkspaceOwners"
}
]
},
"sharing_disabled": {
"description": "SharingDisabled is deprecated and left for backward compatibility\npurposes.\nDeprecated: use ` + "`" + `ShareableWorkspaceOwners` + "`" + ` instead",
"type": "boolean"
},
"sharing_globally_disabled": {
+262 -25
View File
@@ -410,42 +410,43 @@
}
}
},
"/chats/{chat}/archive": {
"post": {
"tags": ["Chats"],
"summary": "Archive a chat",
"operationId": "archive-chat",
"responses": {
"204": {
"description": "No Content"
}
}
}
},
"/chats/{chat}/desktop": {
"/chats/insights/pull-requests": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": ["application/json"],
"tags": ["Chats"],
"summary": "Watch chat desktop",
"operationId": "watch-chat-desktop",
"summary": "Get PR insights",
"operationId": "get-pr-insights",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Chat ID",
"name": "chat",
"in": "path",
"description": "Start date (RFC3339)",
"name": "start_date",
"in": "query",
"required": true
},
{
"type": "string",
"description": "End date (RFC3339)",
"name": "end_date",
"in": "query",
"required": true
}
],
"responses": {
"101": {
"description": "Switching Protocols"
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.PRInsightsResponse"
}
}
},
"x-apidocgen": {
"skip": true
}
}
},
@@ -4301,7 +4302,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
"$ref": "#/definitions/codersdk.UpdateWorkspaceSharingSettingsRequest"
}
}
],
@@ -4309,7 +4310,7 @@
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.UpdateWorkspaceSharingSettingsRequest"
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
}
}
}
@@ -11359,6 +11360,9 @@
},
"bridge": {
"$ref": "#/definitions/codersdk.AIBridgeConfig"
},
"chat": {
"$ref": "#/definitions/codersdk.ChatConfig"
}
}
},
@@ -12381,6 +12385,14 @@
}
}
},
"codersdk.ChatConfig": {
"type": "object",
"properties": {
"acquire_batch_size": {
"type": "integer"
}
}
},
"codersdk.ConnectionLatency": {
"type": "object",
"properties": {
@@ -15581,6 +15593,191 @@
}
}
},
"codersdk.PRInsightsModelBreakdown": {
"type": "object",
"properties": {
"cost_per_merged_pr_micros": {
"type": "integer"
},
"display_name": {
"type": "string"
},
"merge_rate": {
"type": "number"
},
"merged_prs": {
"type": "integer"
},
"model_config_id": {
"type": "string",
"format": "uuid"
},
"provider": {
"type": "string"
},
"total_additions": {
"type": "integer"
},
"total_cost_micros": {
"type": "integer"
},
"total_deletions": {
"type": "integer"
},
"total_prs": {
"type": "integer"
}
}
},
"codersdk.PRInsightsPullRequest": {
"type": "object",
"properties": {
"additions": {
"type": "integer"
},
"approved": {
"type": "boolean"
},
"author_avatar_url": {
"type": "string"
},
"author_login": {
"type": "string"
},
"base_branch": {
"type": "string"
},
"changed_files": {
"type": "integer"
},
"changes_requested": {
"type": "boolean"
},
"chat_id": {
"type": "string",
"format": "uuid"
},
"commits": {
"type": "integer"
},
"cost_micros": {
"type": "integer"
},
"created_at": {
"type": "string",
"format": "date-time"
},
"deletions": {
"type": "integer"
},
"draft": {
"type": "boolean"
},
"model_display_name": {
"type": "string"
},
"pr_number": {
"type": "integer"
},
"pr_title": {
"type": "string"
},
"pr_url": {
"type": "string"
},
"reviewer_count": {
"type": "integer"
},
"state": {
"type": "string"
}
}
},
"codersdk.PRInsightsResponse": {
"type": "object",
"properties": {
"by_model": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.PRInsightsModelBreakdown"
}
},
"recent_prs": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.PRInsightsPullRequest"
}
},
"summary": {
"$ref": "#/definitions/codersdk.PRInsightsSummary"
},
"time_series": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.PRInsightsTimeSeriesEntry"
}
}
}
},
"codersdk.PRInsightsSummary": {
"type": "object",
"properties": {
"approval_rate": {
"type": "number"
},
"cost_per_merged_pr_micros": {
"type": "integer"
},
"merge_rate": {
"type": "number"
},
"prev_cost_per_merged_pr_micros": {
"type": "integer"
},
"prev_merge_rate": {
"type": "number"
},
"prev_total_prs_created": {
"type": "integer"
},
"prev_total_prs_merged": {
"type": "integer"
},
"total_additions": {
"type": "integer"
},
"total_cost_micros": {
"type": "integer"
},
"total_deletions": {
"type": "integer"
},
"total_prs_created": {
"type": "integer"
},
"total_prs_merged": {
"type": "integer"
}
}
},
"codersdk.PRInsightsTimeSeriesEntry": {
"type": "object",
"properties": {
"date": {
"type": "string",
"format": "date-time"
},
"prs_closed": {
"type": "integer"
},
"prs_created": {
"type": "integer"
},
"prs_merged": {
"type": "integer"
}
}
},
"codersdk.PaginatedMembersResponse": {
"type": "object",
"properties": {
@@ -16747,6 +16944,9 @@
"type": "string",
"format": "uuid"
},
"is_service_account": {
"type": "boolean"
},
"last_seen_at": {
"type": "string",
"format": "date-time"
@@ -16910,7 +17110,8 @@
"idp_sync_settings_role",
"workspace_agent",
"workspace_app",
"task"
"task",
"ai_seat"
],
"x-enum-varnames": [
"ResourceTypeTemplate",
@@ -16938,7 +17139,8 @@
"ResourceTypeIdpSyncSettingsRole",
"ResourceTypeWorkspaceAgent",
"ResourceTypeWorkspaceApp",
"ResourceTypeTask"
"ResourceTypeTask",
"ResourceTypeAISeat"
]
},
"codersdk.Response": {
@@ -17146,6 +17348,15 @@
}
}
},
"codersdk.ShareableWorkspaceOwners": {
"type": "string",
"enum": ["none", "everyone", "service_accounts"],
"x-enum-varnames": [
"ShareableWorkspaceOwnersNone",
"ShareableWorkspaceOwnersEveryone",
"ShareableWorkspaceOwnersServiceAccounts"
]
},
"codersdk.SharedWorkspaceActor": {
"type": "object",
"properties": {
@@ -18007,6 +18218,9 @@
"type": "string",
"format": "uuid"
},
"is_service_account": {
"type": "boolean"
},
"last_seen_at": {
"type": "string",
"format": "date-time"
@@ -18682,7 +18896,17 @@
"codersdk.UpdateWorkspaceSharingSettingsRequest": {
"type": "object",
"properties": {
"shareable_workspace_owners": {
"description": "ShareableWorkspaceOwners controls whose workspaces can be shared\nwithin the organization.",
"enum": ["none", "everyone", "service_accounts"],
"allOf": [
{
"$ref": "#/definitions/codersdk.ShareableWorkspaceOwners"
}
]
},
"sharing_disabled": {
"description": "SharingDisabled is deprecated and left for backward compatibility\npurposes.\nDeprecated: use `ShareableWorkspaceOwners` instead",
"type": "boolean"
}
}
@@ -18786,6 +19010,9 @@
"type": "string",
"format": "uuid"
},
"is_service_account": {
"type": "boolean"
},
"last_seen_at": {
"type": "string",
"format": "date-time"
@@ -20421,7 +20648,17 @@
"codersdk.WorkspaceSharingSettings": {
"type": "object",
"properties": {
"shareable_workspace_owners": {
"description": "ShareableWorkspaceOwners controls whose workspaces can be shared\nwithin the organization.",
"enum": ["none", "everyone", "service_accounts"],
"allOf": [
{
"$ref": "#/definitions/codersdk.ShareableWorkspaceOwners"
}
]
},
"sharing_disabled": {
"description": "SharingDisabled is deprecated and left for backward compatibility\npurposes.\nDeprecated: use `ShareableWorkspaceOwners` instead",
"type": "boolean"
},
"sharing_globally_disabled": {
+2 -1
View File
@@ -32,7 +32,8 @@ type Auditable interface {
idpsync.OrganizationSyncSettings |
idpsync.GroupSyncSettings |
idpsync.RoleSyncSettings |
database.TaskTable
database.TaskTable |
database.AiSeatState
}
// Map is a map of changed fields in an audited resource. It maps field names to
+8
View File
@@ -132,6 +132,8 @@ func ResourceTarget[T Auditable](tgt T) string {
return "Organization Role Sync"
case database.TaskTable:
return typed.Name
case database.AiSeatState:
return "AI Seat"
default:
panic(fmt.Sprintf("unknown resource %T for ResourceTarget", tgt))
}
@@ -196,6 +198,8 @@ func ResourceID[T Auditable](tgt T) uuid.UUID {
return noID // Org field on audit log has org id
case database.TaskTable:
return typed.ID
case database.AiSeatState:
return typed.UserID
default:
panic(fmt.Sprintf("unknown resource %T for ResourceID", tgt))
}
@@ -251,6 +255,8 @@ func ResourceType[T Auditable](tgt T) database.ResourceType {
return database.ResourceTypeIdpSyncSettingsGroup
case database.TaskTable:
return database.ResourceTypeTask
case database.AiSeatState:
return database.ResourceTypeAiSeat
default:
panic(fmt.Sprintf("unknown resource %T for ResourceType", typed))
}
@@ -309,6 +315,8 @@ func ResourceRequiresOrgID[T Auditable]() bool {
return true
case database.TaskTable:
return true
case database.AiSeatState:
return false
default:
panic(fmt.Sprintf("unknown resource %T for ResourceRequiresOrgID", tgt))
}
+592 -320
View File
File diff suppressed because it is too large Load Diff
+139
View File
@@ -2,13 +2,20 @@ package chatd
import (
"context"
"sync"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
)
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
@@ -84,3 +91,135 @@ func TestRefreshChatWorkspaceSnapshot_ReturnsReloadError(t *testing.T) {
require.ErrorContains(t, err, loadErr.Error())
require.Equal(t, chat, refreshed)
}
func TestResolveInstructionsReusesTurnLocalWorkspaceAgent(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
workspaceAgent := database.WorkspaceAgent{
ID: uuid.New(),
OperatingSystem: "linux",
Directory: "/home/coder/project",
ExpandedDirectory: "/home/coder/project",
}
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{workspaceAgent}, nil).Times(1)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).Return(
workspacesdk.LSResponse{},
codersdk.NewTestError(404, "POST", "/api/v0/list-directory"),
).Times(1)
conn.EXPECT().ReadFile(
gomock.Any(),
"/home/coder/project/AGENTS.md",
int64(0),
int64(maxInstructionFileBytes+1),
).Return(
nil,
"",
codersdk.NewTestError(404, "GET", "/api/v0/read-file"),
).Times(1)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := &Server{
db: db,
logger: logger,
instructionCache: make(map[uuid.UUID]cachedInstruction),
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
return conn, func() {}, nil
},
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
instruction := server.resolveInstructions(
ctx,
chat,
workspaceCtx.getWorkspaceAgent,
workspaceCtx.getWorkspaceConn,
)
require.Contains(t, instruction, "Operating System: linux")
require.Contains(t, instruction, "Working Directory: /home/coder/project")
}
func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
initialAgent := database.WorkspaceAgent{ID: uuid.New()}
refreshedAgent := database.WorkspaceAgent{ID: uuid.New()}
gomock.InOrder(
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{initialAgent}, nil),
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{refreshedAgent}, nil),
)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
var dialed []uuid.UUID
server := &Server{db: db}
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialed = append(dialed, agentID)
if agentID == initialAgent.ID {
return nil, nil, xerrors.New("dial failed")
}
return conn, func() {}, nil
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
require.NoError(t, err)
require.Same(t, conn, gotConn)
require.Equal(t, []uuid.UUID{initialAgent.ID, refreshedAgent.ID}, dialed)
}
+781 -612
View File
File diff suppressed because it is too large Load Diff
+30 -4
View File
@@ -42,6 +42,11 @@ type PersistedStep struct {
Content []fantasy.Content
Usage fantasy.Usage
ContextLimit sql.NullInt64
// Runtime is the wall-clock duration of this step,
// covering LLM streaming, tool execution, and retries.
// Zero indicates the duration was not measured (e.g.
// interrupted steps).
Runtime time.Duration
}
// RunOptions configures a single streaming chat loop run.
@@ -260,6 +265,7 @@ func Run(ctx context.Context, opts RunOptions) error {
for step := 0; totalSteps < opts.MaxSteps; step++ {
totalSteps++
stepStart := time.Now()
// Copy messages so that provider-specific caching
// mutations don't leak back to the caller's slice.
// copy copies Message structs by value, so field
@@ -365,6 +371,7 @@ func Run(ctx context.Context, opts RunOptions) error {
Content: result.content,
Usage: result.usage,
ContextLimit: contextLimit,
Runtime: time.Since(stepStart),
}); err != nil {
if errors.Is(err, ErrInterrupted) {
persistInterruptedStep(ctx, opts, &result)
@@ -610,10 +617,12 @@ func processStepStream(
result.providerMetadata = part.ProviderMetadata
case fantasy.StreamPartTypeError:
// Detect interruption: context canceled with
// ErrInterrupted as the cause.
if errors.Is(part.Error, context.Canceled) &&
errors.Is(context.Cause(ctx), ErrInterrupted) {
// Detect interruption: the stream may surface the
// cancel as context.Canceled or propagate the
// ErrInterrupted cause directly, depending on
// the provider implementation.
if errors.Is(context.Cause(ctx), ErrInterrupted) &&
(errors.Is(part.Error, context.Canceled) || errors.Is(part.Error, ErrInterrupted)) {
// Flush in-progress content so that
// persistInterruptedStep has access to partial
// text, reasoning, and tool calls that were
@@ -631,6 +640,23 @@ func processStepStream(
}
}
// The stream iterator may stop yielding parts without
// producing a StreamPartTypeError when the context is
// canceled (e.g. some providers close the response body
// silently). Detect this case and flush partial content
// so that persistInterruptedStep can save it.
if ctx.Err() != nil &&
errors.Is(context.Cause(ctx), ErrInterrupted) {
flushActiveState(
&result,
activeTextContent,
activeReasoningContent,
activeToolCalls,
toolNames,
)
return result, ErrInterrupted
}
hasLocalToolCalls := false
for _, tc := range result.toolCalls {
if !tc.ProviderExecuted {
+3
View File
@@ -7,6 +7,7 @@ import (
"strings"
"sync"
"testing"
"time"
"charm.land/fantasy"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
@@ -64,6 +65,8 @@ func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
require.Equal(t, 1, persistStepCalls)
require.True(t, persistedStep.ContextLimit.Valid)
require.Equal(t, int64(4096), persistedStep.ContextLimit.Int64)
require.Greater(t, persistedStep.Runtime, time.Duration(0),
"step runtime should be positive")
require.NotEmpty(t, capturedCall.Prompt)
require.False(t, containsPromptSentinel(capturedCall.Prompt))
@@ -82,7 +82,7 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
options := &codersdk.ChatModelProviderOptions{
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
Reasoning: &codersdk.ChatModelReasoningOptions{
Enabled: boolPtr(true),
},
Provider: &codersdk.ChatModelOpenRouterProvider{
@@ -92,7 +92,7 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
}
defaults := &codersdk.ChatModelProviderOptions{
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
Reasoning: &codersdk.ChatModelReasoningOptions{
Enabled: boolPtr(false),
Exclude: boolPtr(true),
MaxTokens: int64Ptr(123),
+15 -5
View File
@@ -78,10 +78,10 @@ type ProcessToolOptions struct {
// ExecuteArgs are the parameters accepted by the execute tool.
type ExecuteArgs struct {
Command string `json:"command"`
Timeout *string `json:"timeout,omitempty"`
WorkDir *string `json:"workdir,omitempty"`
RunInBackground *bool `json:"run_in_background,omitempty"`
Command string `json:"command" description:"The shell command to execute."`
Timeout *string `json:"timeout,omitempty" description:"Timeout duration (e.g. '30s', '5m'). Default is 10s. Only applies to foreground commands."`
WorkDir *string `json:"workdir,omitempty" description:"Working directory for the command."`
RunInBackground *bool `json:"run_in_background,omitempty" description:"Run this command in the background without blocking. Use for long-running processes like dev servers, file watchers, or builds that run longer than 5 seconds. Do NOT use shell & to background processes — it will not work correctly. Always use this parameter instead."`
}
// Execute returns an AgentTool that runs a shell command in the
@@ -89,7 +89,7 @@ type ExecuteArgs struct {
func Execute(options ExecuteOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"execute",
"Execute a shell command in the workspace.",
"Execute a shell command in the workspace. Use run_in_background=true for long-running processes (dev servers, file watchers, builds). Never use shell '&' for backgrounding.",
func(ctx context.Context, args ExecuteArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
@@ -122,6 +122,16 @@ func executeTool(
background := args.RunInBackground != nil && *args.RunInBackground
// Detect shell-style backgrounding (trailing &) and promote to
// background mode. Models sometimes use "cmd &" instead of the
// run_in_background parameter, which causes the shell to fork
// and exit immediately, leaving an untracked orphan process.
trimmed := strings.TrimSpace(args.Command)
if !background && strings.HasSuffix(trimmed, "&") && !strings.HasSuffix(trimmed, "&&") {
background = true
args.Command = strings.TrimSpace(strings.TrimSuffix(trimmed, "&"))
}
var workDir string
if args.WorkDir != nil {
workDir = *args.WorkDir
+2 -2
View File
@@ -92,7 +92,7 @@ func TestAnthropicWebSearchRoundTrip(t *testing.T) {
// Verify the chat completed and messages were persisted.
chatData, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
chatMsgs, err := client.GetChatMessages(ctx, chat.ID)
chatMsgs, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
t.Logf("Chat status after step 1: %s, messages: %d",
chatData.Status, len(chatMsgs.Messages))
@@ -154,7 +154,7 @@ func TestAnthropicWebSearchRoundTrip(t *testing.T) {
// Verify the follow-up completed and produced content.
chatData2, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
chatMsgs2, err := client.GetChatMessages(ctx, chat.ID)
chatMsgs2, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
t.Logf("Chat status after step 2: %s, messages: %d",
chatData2.Status, len(chatMsgs2.Messages))
+3 -1
View File
@@ -62,6 +62,7 @@ func (p *Server) maybeGenerateChatTitle(
messages []database.ChatMessage,
fallbackModel fantasy.LanguageModel,
keys chatprovider.ProviderAPIKeys,
generatedTitle *generatedChatTitle,
logger slog.Logger,
) {
input, ok := titleInput(chat, messages)
@@ -111,7 +112,8 @@ func (p *Server) maybeGenerateChatTitle(
return
}
chat.Title = title
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange)
generatedTitle.Store(title)
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange, nil)
return
}
+10 -3
View File
@@ -84,6 +84,14 @@ func (p *Server) isAnthropicConfigured(ctx context.Context) bool {
return false
}
func (p *Server) isDesktopEnabled(ctx context.Context) bool {
enabled, err := p.db.GetChatDesktopEnabled(ctx)
if err != nil {
return false
}
return enabled
}
func (p *Server) subagentTools(ctx context.Context, currentChat func() database.Chat) []fantasy.AgentTool {
tools := []fantasy.AgentTool{
fantasy.NewAgentTool(
@@ -253,9 +261,8 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
}
// Only include the computer use tool when an Anthropic
// provider is configured, since it requires an Anthropic
// model.
if p.isAnthropicConfigured(ctx) {
// provider is configured and desktop is enabled.
if p.isAnthropicConfigured(ctx) && p.isDesktopEnabled(ctx) {
tools = append(tools, fantasy.NewAgentTool(
"spawn_computer_use_agent",
"Spawn a dedicated computer use agent that can see the desktop "+
+37 -3
View File
@@ -15,6 +15,7 @@ import (
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/chatd/chattool"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/pubsub"
@@ -144,14 +145,20 @@ func findToolByName(tools []fantasy.AgentTool, name string) fantasy.AgentTool {
return nil
}
func chatdTestContext(t *testing.T) context.Context {
t.Helper()
return dbauthz.AsChatd(testutil.Context(t, testutil.WaitLong))
}
func TestSpawnComputerUseAgent_NoAnthropicProvider(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
// No Anthropic key in ProviderAPIKeys.
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
ctx := testutil.Context(t, testutil.WaitLong)
ctx := chatdTestContext(t)
user, model := seedInternalChatDeps(ctx, t, db)
// Create a root parent chat.
@@ -176,12 +183,13 @@ func TestSpawnComputerUseAgent_NotAvailableForChildChats(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
// Provide an Anthropic key so the provider check passes.
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
Anthropic: "test-anthropic-key",
})
ctx := testutil.Context(t, testutil.WaitLong)
ctx := chatdTestContext(t)
user, model := seedInternalChatDeps(ctx, t, db)
// Create a root parent chat.
@@ -232,16 +240,42 @@ func TestSpawnComputerUseAgent_NotAvailableForChildChats(t *testing.T) {
assert.Contains(t, resp.Content, "delegated chats cannot create child subagents")
}
func TestSpawnComputerUseAgent_DesktopDisabled(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
Anthropic: "test-anthropic-key",
})
ctx := chatdTestContext(t)
user, model := seedInternalChatDeps(ctx, t, db)
parent, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
Title: "parent-desktop-disabled",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
parentChat, err := db.GetChatByID(ctx, parent.ID)
require.NoError(t, err)
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
tool := findToolByName(tools, "spawn_computer_use_agent")
assert.Nil(t, tool, "spawn_computer_use_agent tool must be omitted when desktop is disabled")
}
func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
// Provide an Anthropic key so the tool can proceed.
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
Anthropic: "test-anthropic-key",
})
ctx := testutil.Context(t, testutil.WaitLong)
ctx := chatdTestContext(t)
user, model := seedInternalChatDeps(ctx, t, db)
// The parent uses an OpenAI model.
+128
View File
@@ -0,0 +1,128 @@
package chatd
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/codersdk"
)
// ComputeUsagePeriodBounds returns the UTC-aligned start and end bounds for the
// active usage-limit period containing now.
func ComputeUsagePeriodBounds(now time.Time, period codersdk.ChatUsageLimitPeriod) (start, end time.Time) {
utcNow := now.UTC()
switch period {
case codersdk.ChatUsageLimitPeriodDay:
start = time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
end = start.AddDate(0, 0, 1)
case codersdk.ChatUsageLimitPeriodWeek:
// Walk backward to Monday of the current ISO week.
// ISO 8601 weeks always start on Monday, so this never
// crosses an ISO-week boundary.
start = time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
for start.Weekday() != time.Monday {
start = start.AddDate(0, 0, -1)
}
end = start.AddDate(0, 0, 7)
case codersdk.ChatUsageLimitPeriodMonth:
start = time.Date(utcNow.Year(), utcNow.Month(), 1, 0, 0, 0, 0, time.UTC)
end = start.AddDate(0, 1, 0)
default:
panic(fmt.Sprintf("unknown chat usage limit period: %q", period))
}
return start, end
}
// ResolveUsageLimitStatus resolves the current usage-limit status for userID.
//
// Note: There is a potential race condition where two concurrent messages
// from the same user can both pass the limit check if processed in
// parallel, allowing brief overage. This is acceptable because:
// - Cost is only known after the LLM API returns.
// - Overage is bounded by message cost × concurrency.
// - Fail-open is the deliberate design choice for this feature.
//
// Architecture note: today this path enforces one period globally
// (day/week/month) from config.
// To support simultaneous periods, add nullable
// daily/weekly/monthly_limit_micros columns on override tables, where NULL
// means no limit for that period.
// Then scan spend once over the widest active window with conditional SUMs
// for each period and compare each spend/limit pair Go-side, blocking on
// whichever period is tightest.
func ResolveUsageLimitStatus(ctx context.Context, db database.Store, userID uuid.UUID, now time.Time) (*codersdk.ChatUsageLimitStatus, error) {
//nolint:gocritic // AsChatd provides narrowly-scoped daemon access for
// deployment config reads and cross-user chat spend aggregation.
authCtx := dbauthz.AsChatd(ctx)
config, err := db.GetChatUsageLimitConfig(authCtx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
}
return nil, err
}
if !config.Enabled {
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
}
period, ok := mapDBPeriodToSDK(config.Period)
if !ok {
return nil, xerrors.Errorf("invalid chat usage limit period %q", config.Period)
}
// Resolve effective limit in a single query:
// individual override > group limit > global default.
effectiveLimit, err := db.ResolveUserChatSpendLimit(authCtx, userID)
if err != nil {
return nil, err
}
// -1 means limits are disabled (shouldn't happen since we checked above,
// but handle gracefully).
if effectiveLimit < 0 {
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
}
start, end := ComputeUsagePeriodBounds(now, period)
spendTotal, err := db.GetUserChatSpendInPeriod(authCtx, database.GetUserChatSpendInPeriodParams{
UserID: userID,
StartTime: start,
EndTime: end,
})
if err != nil {
return nil, err
}
return &codersdk.ChatUsageLimitStatus{
IsLimited: true,
Period: period,
SpendLimitMicros: &effectiveLimit,
CurrentSpend: spendTotal,
PeriodStart: start,
PeriodEnd: end,
}, nil
}
func mapDBPeriodToSDK(dbPeriod string) (codersdk.ChatUsageLimitPeriod, bool) {
switch dbPeriod {
case string(codersdk.ChatUsageLimitPeriodDay):
return codersdk.ChatUsageLimitPeriodDay, true
case string(codersdk.ChatUsageLimitPeriodWeek):
return codersdk.ChatUsageLimitPeriodWeek, true
case string(codersdk.ChatUsageLimitPeriodMonth):
return codersdk.ChatUsageLimitPeriodMonth, true
default:
return "", false
}
}
+132
View File
@@ -0,0 +1,132 @@
package chatd //nolint:testpackage // Keeps chatd unit tests in the package.
import (
"testing"
"time"
"github.com/coder/coder/v2/codersdk"
)
func TestComputeUsagePeriodBounds(t *testing.T) {
t.Parallel()
newYork, err := time.LoadLocation("America/New_York")
if err != nil {
t.Fatalf("load America/New_York: %v", err)
}
tests := []struct {
name string
now time.Time
period codersdk.ChatUsageLimitPeriod
wantStart time.Time
wantEnd time.Time
}{
{
name: "day/mid_day",
now: time.Date(2025, time.June, 15, 14, 30, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodDay,
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "day/midnight_exactly",
now: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodDay,
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "day/end_of_day",
now: time.Date(2025, time.June, 15, 23, 59, 59, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodDay,
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "week/wednesday",
now: time.Date(2025, time.June, 11, 10, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodWeek,
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "week/monday",
now: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodWeek,
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "week/sunday",
now: time.Date(2025, time.June, 15, 23, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodWeek,
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "week/year_boundary",
now: time.Date(2024, time.December, 31, 12, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodWeek,
wantStart: time.Date(2024, time.December, 30, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.January, 6, 0, 0, 0, 0, time.UTC),
},
{
name: "month/mid_month",
now: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "month/first_day",
now: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "month/last_day",
now: time.Date(2025, time.June, 30, 23, 59, 59, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "month/february",
now: time.Date(2025, time.February, 15, 12, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2025, time.February, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.March, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "month/leap_year_february",
now: time.Date(2024, time.February, 29, 12, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2024, time.February, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2024, time.March, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "day/non_utc_timezone",
now: time.Date(2025, time.June, 15, 22, 0, 0, 0, newYork),
period: codersdk.ChatUsageLimitPeriodDay,
wantStart: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 17, 0, 0, 0, 0, time.UTC),
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
start, end := ComputeUsagePeriodBounds(tc.now, tc.period)
if !start.Equal(tc.wantStart) {
t.Errorf("start: got %v, want %v", start, tc.wantStart)
}
if !end.Equal(tc.wantEnd) {
t.Errorf("end: got %v, want %v", end, tc.wantEnd)
}
})
}
}
+887 -164
View File
File diff suppressed because it is too large Load Diff
+641 -29
View File
@@ -2,6 +2,7 @@ package coderd_test
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
@@ -16,14 +17,19 @@ import (
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/externalauth"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/websocket"
@@ -54,6 +60,93 @@ func newChatClientWithDatabase(t testing.TB) (*codersdk.Client, database.Store)
})
}
func requireChatUsageLimitExceededError(
t *testing.T,
err error,
wantSpentMicros int64,
wantLimitMicros int64,
wantResetsAt time.Time,
) *codersdk.ChatUsageLimitExceededResponse {
t.Helper()
sdkErr, ok := codersdk.AsError(err)
require.True(t, ok)
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
require.Equal(t, "Chat usage limit exceeded.", sdkErr.Message)
limitErr := codersdk.ChatUsageLimitExceededFrom(err)
require.NotNil(t, limitErr)
require.Equal(t, "Chat usage limit exceeded.", limitErr.Message)
require.Equal(t, wantSpentMicros, limitErr.SpentMicros)
require.Equal(t, wantLimitMicros, limitErr.LimitMicros)
require.True(
t,
limitErr.ResetsAt.Equal(wantResetsAt),
"expected resets_at %s, got %s",
wantResetsAt.UTC().Format(time.RFC3339),
limitErr.ResetsAt.UTC().Format(time.RFC3339),
)
return limitErr
}
func enableDailyChatUsageLimit(
ctx context.Context,
t *testing.T,
db database.Store,
limitMicros int64,
) time.Time {
t.Helper()
_, err := db.UpsertChatUsageLimitConfig(
dbauthz.AsSystemRestricted(ctx),
database.UpsertChatUsageLimitConfigParams{
Enabled: true,
DefaultLimitMicros: limitMicros,
Period: string(codersdk.ChatUsageLimitPeriodDay),
},
)
require.NoError(t, err)
_, periodEnd := chatd.ComputeUsagePeriodBounds(time.Now(), codersdk.ChatUsageLimitPeriodDay)
return periodEnd
}
func insertAssistantCostMessage(
ctx context.Context,
t *testing.T,
db database.Store,
chatID uuid.UUID,
modelConfigID uuid.UUID,
totalCostMicros int64,
) {
t.Helper()
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("assistant"),
})
require.NoError(t, err)
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
ChatID: chatID,
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
Role: database.ChatMessageRoleAssistant,
ContentVersion: chatprompt.CurrentContentVersion,
Content: assistantContent,
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{},
OutputTokens: sql.NullInt64{},
TotalTokens: sql.NullInt64{},
ReasoningTokens: sql.NullInt64{},
CacheCreationTokens: sql.NullInt64{},
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
Compressed: sql.NullBool{},
TotalCostMicros: sql.NullInt64{Int64: totalCostMicros, Valid: true},
})
require.NoError(t, err)
}
func TestPostChats(t *testing.T) {
t.Parallel()
@@ -88,7 +181,7 @@ func TestPostChats(t *testing.T) {
chatResult, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
messagesResult, err := client.GetChatMessages(ctx, chat.ID)
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
require.Equal(t, chat.ID, chatResult.ID)
@@ -126,7 +219,7 @@ func TestPostChats(t *testing.T) {
})
require.NoError(t, err)
messagesResult, err := client.GetChatMessages(ctx, chat.ID)
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
for _, message := range messagesResult.Messages {
require.NotEqual(t, codersdk.ChatMessageRoleSystem, message.Role)
@@ -324,6 +417,33 @@ func TestPostChats(t *testing.T) {
require.Equal(t, "Invalid input part.", sdkErr.Message)
require.Equal(t, `content[0].type "image" is not supported.`, sdkErr.Detail)
})
t.Run("UsageLimitExceeded", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client)
modelConfig := createChatModelConfig(t, client)
wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100)
existingChat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "existing-limit-chat",
})
require.NoError(t, err)
insertAssistantCostMessage(ctx, t, db, existingChat.ID, modelConfig.ID, 100)
_, err = client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "over limit",
}},
})
requireChatUsageLimitExceededError(t, err, 100, 100, wantResetsAt)
})
}
func TestListChats(t *testing.T) {
@@ -616,6 +736,127 @@ func TestWatchChats(t *testing.T) {
}
})
t.Run("DiffStatusChangeIncludesDiffStatus", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
DeploymentValues: chatDeploymentValues(t),
})
db := api.Database
user := coderdtest.CreateFirstUser(t, client)
modelConfig := createChatModelConfig(t, client)
// Insert a chat and a diff status row.
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "diff status watch test",
})
require.NoError(t, err)
refreshedAt := time.Now().UTC().Truncate(time.Second)
staleAt := refreshedAt.Add(time.Hour)
_, err = db.UpsertChatDiffStatusReference(
dbauthz.AsSystemRestricted(ctx),
database.UpsertChatDiffStatusReferenceParams{
ChatID: chat.ID,
Url: sql.NullString{String: "https://github.com/coder/coder/pull/99", Valid: true},
GitBranch: "feature/test",
GitRemoteOrigin: "git@github.com:coder/coder.git",
StaleAt: staleAt,
},
)
require.NoError(t, err)
_, err = db.UpsertChatDiffStatus(
dbauthz.AsSystemRestricted(ctx),
database.UpsertChatDiffStatusParams{
ChatID: chat.ID,
Url: sql.NullString{String: "https://github.com/coder/coder/pull/99", Valid: true},
PullRequestState: sql.NullString{String: "open", Valid: true},
Additions: 42,
Deletions: 7,
ChangedFiles: 5,
RefreshedAt: refreshedAt,
StaleAt: staleAt,
},
)
require.NoError(t, err)
// Open the watch WebSocket.
conn, err := client.Dial(ctx, "/api/experimental/chats/watch", nil)
require.NoError(t, err)
defer conn.Close(websocket.StatusNormalClosure, "done")
type watchEvent struct {
Type codersdk.ServerSentEventType `json:"type"`
Data json.RawMessage `json:"data,omitempty"`
}
// Read the initial ping.
var ping watchEvent
err = wsjson.Read(ctx, conn, &ping)
require.NoError(t, err)
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
// Publish a diff_status_change event via pubsub,
// mimicking what PublishDiffStatusChange does after
// it reads the diff status from the DB.
dbStatus, err := db.GetChatDiffStatusByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID)
require.NoError(t, err)
sdkDiffStatus := db2sdk.ChatDiffStatus(chat.ID, &dbStatus)
event := coderdpubsub.ChatEvent{
Kind: coderdpubsub.ChatEventKindDiffStatusChange,
Chat: codersdk.Chat{
ID: chat.ID,
OwnerID: chat.OwnerID,
Title: chat.Title,
Status: codersdk.ChatStatus(chat.Status),
CreatedAt: chat.CreatedAt,
UpdatedAt: chat.UpdatedAt,
DiffStatus: &sdkDiffStatus,
},
}
payload, err := json.Marshal(event)
require.NoError(t, err)
err = api.Pubsub.Publish(coderdpubsub.ChatEventChannel(user.UserID), payload)
require.NoError(t, err)
// Read events until we find the diff_status_change.
for {
var update watchEvent
err = wsjson.Read(ctx, conn, &update)
require.NoError(t, err)
if update.Type == codersdk.ServerSentEventTypePing {
continue
}
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
var received coderdpubsub.ChatEvent
err = json.Unmarshal(update.Data, &received)
require.NoError(t, err)
if received.Kind != coderdpubsub.ChatEventKindDiffStatusChange ||
received.Chat.ID != chat.ID {
continue
}
// Verify the event carries the full DiffStatus.
require.NotNil(t, received.Chat.DiffStatus, "diff_status_change event must include DiffStatus")
ds := received.Chat.DiffStatus
require.Equal(t, chat.ID, ds.ChatID)
require.NotNil(t, ds.URL)
require.Equal(t, "https://github.com/coder/coder/pull/99", *ds.URL)
require.NotNil(t, ds.PullRequestState)
require.Equal(t, "open", *ds.PullRequestState)
require.EqualValues(t, 42, ds.Additions)
require.EqualValues(t, 7, ds.Deletions)
require.EqualValues(t, 5, ds.ChangedFiles)
break
}
})
t.Run("Unauthenticated", func(t *testing.T) {
t.Parallel()
@@ -1362,7 +1603,7 @@ func TestGetChat(t *testing.T) {
chatResult, err := client.GetChat(ctx, createdChat.ID)
require.NoError(t, err)
messagesResult, err := client.GetChatMessages(ctx, createdChat.ID)
messagesResult, err := client.GetChatMessages(ctx, createdChat.ID, nil)
require.NoError(t, err)
require.Equal(t, createdChat.ID, chatResult.ID)
require.Equal(t, firstUser.UserID, chatResult.OwnerID)
@@ -1447,7 +1688,7 @@ func TestArchiveChat(t *testing.T) {
require.NoError(t, err)
require.Len(t, chatsBeforeArchive, 2)
err = client.ArchiveChat(ctx, chatToArchive.ID)
err = client.UpdateChat(ctx, chatToArchive.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
require.NoError(t, err)
// Default (no filter) returns only non-archived chats.
@@ -1481,7 +1722,7 @@ func TestArchiveChat(t *testing.T) {
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client)
err := client.ArchiveChat(ctx, uuid.New())
err := client.UpdateChat(ctx, uuid.New(), codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
requireSDKError(t, err, http.StatusNotFound)
})
@@ -1524,7 +1765,7 @@ func TestArchiveChat(t *testing.T) {
require.NoError(t, err)
// Archive the parent via the API.
err = client.ArchiveChat(ctx, parentChat.ID)
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
require.NoError(t, err)
// archived:false should exclude the entire archived family.
@@ -1571,7 +1812,7 @@ func TestUnarchiveChat(t *testing.T) {
require.NoError(t, err)
// Archive the chat first.
err = client.ArchiveChat(ctx, chat.ID)
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
require.NoError(t, err)
// Verify it's archived.
@@ -1582,7 +1823,7 @@ func TestUnarchiveChat(t *testing.T) {
require.Len(t, archivedChats, 1)
require.True(t, archivedChats[0].Archived)
// Unarchive the chat.
err = client.UnarchiveChat(ctx, chat.ID)
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
require.NoError(t, err)
// Verify it's no longer archived.
@@ -1621,10 +1862,9 @@ func TestUnarchiveChat(t *testing.T) {
require.NoError(t, err)
// Trying to unarchive a non-archived chat should fail.
err = client.UnarchiveChat(ctx, chat.ID)
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
requireSDKError(t, err, http.StatusBadRequest)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
@@ -1632,7 +1872,7 @@ func TestUnarchiveChat(t *testing.T) {
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client)
err := client.UnarchiveChat(ctx, uuid.New())
err := client.UpdateChat(ctx, uuid.New(), codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
requireSDKError(t, err, http.StatusNotFound)
})
}
@@ -1686,7 +1926,7 @@ func TestPostChatMessages(t *testing.T) {
require.True(t, hasTextPart(created.QueuedMessage.Content, messageText))
require.Eventually(t, func() bool {
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID)
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil)
if getErr != nil {
return false
}
@@ -1714,7 +1954,7 @@ func TestPostChatMessages(t *testing.T) {
require.True(t, hasTextPart(created.Message.Content, messageText))
require.Eventually(t, func() bool {
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID)
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil)
if getErr != nil {
return false
}
@@ -1761,6 +2001,34 @@ func TestPostChatMessages(t *testing.T) {
require.Equal(t, "content[0].text cannot be empty.", sdkErr.Detail)
})
t.Run("UsageLimitExceeded", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
_ = coderdtest.CreateFirstUser(t, client)
modelConfig := createChatModelConfig(t, client)
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "initial message for usage-limit test",
}},
})
require.NoError(t, err)
wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100)
insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100)
_, err = client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "over limit",
}},
})
requireChatUsageLimitExceededError(t, err, 100, 100, wantResetsAt)
})
t.Run("ChatNotFound", func(t *testing.T) {
t.Parallel()
@@ -1829,7 +2097,7 @@ func TestChatMessageWithFileReferences(t *testing.T) {
var found bool
require.Eventually(t, func() bool {
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID)
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil)
if getErr != nil {
return false
}
@@ -1889,7 +2157,7 @@ func TestChatMessageWithFileReferences(t *testing.T) {
}
require.Eventually(t, func() bool {
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID)
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil)
if getErr != nil {
return false
}
@@ -1942,7 +2210,7 @@ func TestChatMessageWithFileReferences(t *testing.T) {
}
require.Eventually(t, func() bool {
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID)
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil)
if getErr != nil {
return false
}
@@ -1995,7 +2263,7 @@ func TestChatMessageWithFileReferences(t *testing.T) {
}
require.Eventually(t, func() bool {
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID)
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil)
if getErr != nil {
return false
}
@@ -2085,7 +2353,7 @@ func TestChatMessageWithFileReferences(t *testing.T) {
}
require.Eventually(t, func() bool {
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID)
messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil)
if getErr != nil {
return false
}
@@ -2275,7 +2543,7 @@ func TestChatMessageWithFiles(t *testing.T) {
}
// Verify file parts omit inline data in the API response.
messagesResult, err := client.GetChatMessages(ctx, chat.ID)
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
for _, msg := range messagesResult.Messages {
for _, part := range msg.Content {
@@ -2371,7 +2639,7 @@ func TestPatchChatMessage(t *testing.T) {
})
require.NoError(t, err)
messagesResult, err := client.GetChatMessages(ctx, chat.ID)
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
var userMessageID int64
@@ -2403,7 +2671,7 @@ func TestPatchChatMessage(t *testing.T) {
}
require.True(t, foundEditedText)
messagesResult, err = client.GetChatMessages(ctx, chat.ID)
messagesResult, err = client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
foundEditedInChat := false
foundOriginalInChat := false
@@ -2456,7 +2724,7 @@ func TestPatchChatMessage(t *testing.T) {
require.NoError(t, err)
// Find the user message ID.
messagesResult, err := client.GetChatMessages(ctx, chat.ID)
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
var userMessageID int64
@@ -2499,7 +2767,7 @@ func TestPatchChatMessage(t *testing.T) {
require.True(t, foundFile, "edited message should preserve file_id")
// GET the chat messages and verify the file_id persists.
messagesResult, err = client.GetChatMessages(ctx, chat.ID)
messagesResult, err = client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
var foundTextInChat, foundFileInChat bool
@@ -2521,6 +2789,46 @@ func TestPatchChatMessage(t *testing.T) {
require.True(t, foundFileInChat, "chat should preserve file_id after edit")
})
t.Run("UsageLimitExceeded", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
_ = coderdtest.CreateFirstUser(t, client)
modelConfig := createChatModelConfig(t, client)
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "hello before edit",
}},
})
require.NoError(t, err)
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
var userMessageID int64
for _, message := range messagesResult.Messages {
if message.Role == codersdk.ChatMessageRoleUser {
userMessageID = message.ID
break
}
}
require.NotZero(t, userMessageID)
wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100)
insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100)
_, err = client.EditChatMessage(ctx, chat.ID, userMessageID, codersdk.EditChatMessageRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "edited over limit",
}},
})
requireChatUsageLimitExceededError(t, err, 100, 100, wantResetsAt)
})
t.Run("MessageNotFound", func(t *testing.T) {
t.Parallel()
@@ -3114,7 +3422,7 @@ func TestDeleteChatQueuedMessage(t *testing.T) {
res.Body.Close()
require.Equal(t, http.StatusNoContent, res.StatusCode)
messagesResult, err := client.GetChatMessages(ctx, chat.ID)
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
for _, queued := range messagesResult.QueuedMessages {
require.NotEqual(t, queuedMessage.ID, queued.ID)
@@ -3217,7 +3525,7 @@ func TestPromoteChatQueuedMessage(t *testing.T) {
}
require.True(t, foundPromotedText)
messagesResult, err := client.GetChatMessages(ctx, chat.ID)
messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
for _, queued := range messagesResult.QueuedMessages {
require.NotEqual(t, queuedMessage.ID, queued.ID)
@@ -3230,6 +3538,81 @@ func TestPromoteChatQueuedMessage(t *testing.T) {
}
})
t.Run("PromotesAlreadyQueuedMessageAfterLimitReached", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client)
modelConfig := createChatModelConfig(t, client)
enableDailyChatUsageLimit(ctx, t, db, 100)
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "promote queued usage limit",
})
require.NoError(t, err)
const queuedText = "queued message for promote route"
queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{
codersdk.ChatMessageText(queuedText),
})
require.NoError(t, err)
queuedMessage, err := db.InsertChatQueuedMessage(
dbauthz.AsSystemRestricted(ctx),
database.InsertChatQueuedMessageParams{
ChatID: chat.ID,
Content: queuedContent,
},
)
require.NoError(t, err)
insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100)
_, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusWaiting,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
})
require.NoError(t, err)
promoteRes, err := client.Request(
ctx,
http.MethodPost,
fmt.Sprintf("/api/experimental/chats/%s/queue/%d/promote", chat.ID, queuedMessage.ID),
nil,
)
require.NoError(t, err)
defer promoteRes.Body.Close()
require.Equal(t, http.StatusOK, promoteRes.StatusCode)
var promoted codersdk.ChatMessage
err = json.NewDecoder(promoteRes.Body).Decode(&promoted)
require.NoError(t, err)
require.NotZero(t, promoted.ID)
require.Equal(t, chat.ID, promoted.ChatID)
require.Equal(t, codersdk.ChatMessageRoleUser, promoted.Role)
foundPromotedText := false
for _, part := range promoted.Content {
if part.Type == codersdk.ChatMessagePartTypeText && part.Text == queuedText {
foundPromotedText = true
break
}
}
require.True(t, foundPromotedText)
queuedMessages, err := db.GetChatQueuedMessages(dbauthz.AsSystemRestricted(ctx), chat.ID)
require.NoError(t, err)
for _, queued := range queuedMessages {
require.NotEqual(t, queuedMessage.ID, queued.ID)
}
})
t.Run("InvalidQueuedMessageID", func(t *testing.T) {
t.Parallel()
@@ -3261,6 +3644,133 @@ func TestPromoteChatQueuedMessage(t *testing.T) {
})
}
func TestChatUsageLimitOverrideRoutes(t *testing.T) {
t.Parallel()
t.Run("UpsertUserOverrideRequiresPositiveSpendLimit", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, _ := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, client)
_, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
res, err := client.Request(
ctx,
http.MethodPut,
fmt.Sprintf("/api/experimental/chats/usage-limits/overrides/%s", member.ID),
map[string]any{},
)
require.NoError(t, err)
defer res.Body.Close()
err = codersdk.ReadBodyAsError(res)
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Invalid chat usage limit override.", sdkErr.Message)
require.Equal(t, "Spend limit must be greater than 0.", sdkErr.Detail)
})
t.Run("UpsertUserOverrideMissingUser", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client)
_, err := client.UpsertChatUsageLimitOverride(ctx, uuid.New(), codersdk.UpsertChatUsageLimitOverrideRequest{
SpendLimitMicros: 7_000_000,
})
sdkErr := requireSDKError(t, err, http.StatusNotFound)
require.Equal(t, "User not found.", sdkErr.Message)
})
t.Run("DeleteUserOverrideMissingUser", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client)
err := client.DeleteChatUsageLimitOverride(ctx, uuid.New())
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "User not found.", sdkErr.Message)
})
t.Run("DeleteUserOverrideMissingOverride", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, client)
_, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
err := client.DeleteChatUsageLimitOverride(ctx, member.ID)
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Chat usage limit override not found.", sdkErr.Message)
})
t.Run("UpsertGroupOverrideIncludesMemberCount", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, client)
_, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
group := dbgen.Group(t, db, database.Group{OrganizationID: firstUser.OrganizationID})
dbgen.GroupMember(t, db, database.GroupMemberTable{GroupID: group.ID, UserID: member.ID})
dbgen.GroupMember(t, db, database.GroupMemberTable{GroupID: group.ID, UserID: database.PrebuildsSystemUserID})
override, err := client.UpsertChatUsageLimitGroupOverride(ctx, group.ID, codersdk.UpsertChatUsageLimitGroupOverrideRequest{
SpendLimitMicros: 7_000_000,
})
require.NoError(t, err)
require.Equal(t, group.ID, override.GroupID)
require.EqualValues(t, 1, override.MemberCount)
require.NotNil(t, override.SpendLimitMicros)
require.EqualValues(t, 7_000_000, *override.SpendLimitMicros)
config, err := client.GetChatUsageLimitConfig(ctx)
require.NoError(t, err)
var listed *codersdk.ChatUsageLimitGroupOverride
for i := range config.GroupOverrides {
if config.GroupOverrides[i].GroupID == group.ID {
listed = &config.GroupOverrides[i]
break
}
}
require.NotNil(t, listed)
require.EqualValues(t, 1, listed.MemberCount)
})
t.Run("UpsertGroupOverrideMissingGroup", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client)
_, err := client.UpsertChatUsageLimitGroupOverride(ctx, uuid.New(), codersdk.UpsertChatUsageLimitGroupOverrideRequest{
SpendLimitMicros: 7_000_000,
})
sdkErr := requireSDKError(t, err, http.StatusNotFound)
require.Equal(t, "Group not found.", sdkErr.Message)
})
t.Run("DeleteGroupOverrideMissingOverride", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, client)
group := dbgen.Group(t, db, database.Group{OrganizationID: firstUser.OrganizationID})
err := client.DeleteChatUsageLimitGroupOverride(ctx, group.ID)
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Chat usage limit group override not found.", sdkErr.Message)
})
}
func TestPostChatFile(t *testing.T) {
t.Parallel()
@@ -4002,7 +4512,7 @@ func TestChatSystemPrompt(t *testing.T) {
t.Run("AdminCanSet", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
SystemPrompt: "You are a helpful coding assistant.",
})
require.NoError(t, err)
@@ -4016,7 +4526,7 @@ func TestChatSystemPrompt(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
// Unset by sending an empty string.
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
SystemPrompt: "",
})
require.NoError(t, err)
@@ -4029,7 +4539,7 @@ func TestChatSystemPrompt(t *testing.T) {
t.Run("NonAdminFails", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
err := memberClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
err := memberClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
SystemPrompt: "This should fail.",
})
requireSDKError(t, err, http.StatusNotFound)
@@ -4050,7 +4560,7 @@ func TestChatSystemPrompt(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
tooLong := strings.Repeat("a", 131073)
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
SystemPrompt: tooLong,
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
@@ -4058,6 +4568,108 @@ func TestChatSystemPrompt(t *testing.T) {
})
}
func TestChatDesktopEnabled(t *testing.T) {
t.Parallel()
t.Run("ReturnsFalseWhenUnset", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
coderdtest.CreateFirstUser(t, adminClient)
resp, err := adminClient.GetChatDesktopEnabled(ctx)
require.NoError(t, err)
require.False(t, resp.EnableDesktop)
})
t.Run("AdminCanSetTrue", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
coderdtest.CreateFirstUser(t, adminClient)
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
EnableDesktop: true,
})
require.NoError(t, err)
resp, err := adminClient.GetChatDesktopEnabled(ctx)
require.NoError(t, err)
require.True(t, resp.EnableDesktop)
})
t.Run("AdminCanSetFalse", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
coderdtest.CreateFirstUser(t, adminClient)
// Set true first, then set false.
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
EnableDesktop: true,
})
require.NoError(t, err)
err = adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
EnableDesktop: false,
})
require.NoError(t, err)
resp, err := adminClient.GetChatDesktopEnabled(ctx)
require.NoError(t, err)
require.False(t, resp.EnableDesktop)
})
t.Run("NonAdminCanRead", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient)
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
EnableDesktop: true,
})
require.NoError(t, err)
resp, err := memberClient.GetChatDesktopEnabled(ctx)
require.NoError(t, err)
require.True(t, resp.EnableDesktop)
})
t.Run("NonAdminWriteFails", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient)
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
err := memberClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
EnableDesktop: true,
})
requireSDKError(t, err, http.StatusForbidden)
})
t.Run("UnauthenticatedFails", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
coderdtest.CreateFirstUser(t, adminClient)
anonClient := codersdk.New(adminClient.URL)
_, err := anonClient.GetChatDesktopEnabled(ctx)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
})
}
func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error {
t.Helper()
+45 -12
View File
@@ -10,6 +10,7 @@ import (
"flag"
"fmt"
"io"
"math"
"net/http"
httppprof "net/http/pprof"
"net/url"
@@ -44,6 +45,7 @@ import (
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
"github.com/coder/coder/v2/coderd/aiseats"
_ "github.com/coder/coder/v2/coderd/apidoc" // Used for swagger docs.
"github.com/coder/coder/v2/coderd/appearance"
"github.com/coder/coder/v2/coderd/audit"
@@ -629,7 +631,9 @@ func New(options *Options) *API {
),
dbRolluper: options.DatabaseRolluper,
ProfileCollector: defaultProfileCollector{},
AISeatTracker: aiseats.Noop{},
}
api.WorkspaceAppsProvider = workspaceapps.NewDBTokenProvider(
ctx,
options.Logger.Named("workspaceapps"),
@@ -763,17 +767,26 @@ func New(options *Options) *API {
}
api.agentProvider = stn
maxChatsPerAcquire := options.DeploymentValues.AI.Chat.AcquireBatchSize.Value()
if maxChatsPerAcquire > math.MaxInt32 {
maxChatsPerAcquire = math.MaxInt32
}
if maxChatsPerAcquire < math.MinInt32 {
maxChatsPerAcquire = math.MinInt32
}
api.chatDaemon = chatd.New(chatd.Config{
Logger: options.Logger.Named("chats"),
Database: options.Database,
ReplicaID: api.ID,
SubscribeFn: options.ChatSubscribeFn,
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
AgentConn: api.agentProvider.AgentConn,
CreateWorkspace: api.chatCreateWorkspace,
StartWorkspace: api.chatStartWorkspace,
Pubsub: options.Pubsub,
WebpushDispatcher: options.WebPushDispatcher,
Logger: options.Logger.Named("chats"),
Database: options.Database,
ReplicaID: api.ID,
SubscribeFn: options.ChatSubscribeFn,
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
AgentConn: api.agentProvider.AgentConn,
CreateWorkspace: api.chatCreateWorkspace,
StartWorkspace: api.chatStartWorkspace,
Pubsub: options.Pubsub,
WebpushDispatcher: options.WebPushDispatcher,
})
gitSyncLogger := options.Logger.Named("gitsync")
refresher := gitsync.NewRefresher(
@@ -1146,6 +1159,9 @@ func New(options *Options) *API {
r.Get("/summary", api.chatCostSummary)
})
})
r.Route("/insights", func(r chi.Router) {
r.Get("/pull-requests", api.prInsights)
})
r.Route("/files", func(r chi.Router) {
r.Use(httpmw.RateLimit(options.FilesRateLimit, time.Minute))
r.Post("/", api.postChatFile)
@@ -1154,6 +1170,8 @@ func New(options *Options) *API {
r.Route("/config", func(r chi.Router) {
r.Get("/system-prompt", api.getChatSystemPrompt)
r.Put("/system-prompt", api.putChatSystemPrompt)
r.Get("/desktop-enabled", api.getChatDesktopEnabled)
r.Put("/desktop-enabled", api.putChatDesktopEnabled)
r.Get("/user-prompt", api.getUserChatCustomPrompt)
r.Put("/user-prompt", api.putUserChatCustomPrompt)
})
@@ -1175,13 +1193,25 @@ func New(options *Options) *API {
r.Delete("/", api.deleteChatModelConfig)
})
})
r.Route("/usage-limits", func(r chi.Router) {
r.Get("/", api.getChatUsageLimitConfig)
r.Put("/", api.updateChatUsageLimitConfig)
r.Get("/status", api.getMyChatUsageLimitStatus)
r.Route("/overrides/{user}", func(r chi.Router) {
r.Put("/", api.upsertChatUsageLimitOverride)
r.Delete("/", api.deleteChatUsageLimitOverride)
})
r.Route("/group-overrides/{group}", func(r chi.Router) {
r.Put("/", api.upsertChatUsageLimitGroupOverride)
r.Delete("/", api.deleteChatUsageLimitGroupOverride)
})
})
r.Route("/{chat}", func(r chi.Router) {
r.Use(httpmw.ExtractChatParam(options.Database))
r.Get("/", api.getChat)
r.Get("/git/watch", api.watchChatGit)
r.Get("/desktop", api.watchChatDesktop)
r.Post("/archive", api.archiveChat)
r.Post("/unarchive", api.unarchiveChat)
r.Patch("/", api.patchChat)
r.Get("/messages", api.getChatMessages)
r.Post("/messages", api.postChatMessages)
r.Patch("/messages/{message}", api.patchChatMessage)
@@ -2033,6 +2063,8 @@ type API struct {
dbRolluper *dbrollup.Rolluper
// chatDaemon handles background processing of pending chats.
chatDaemon *chatd.Server
// AISeatTracker records AI seat usage.
AISeatTracker aiseats.SeatTracker
// gitSyncWorker refreshes stale chat diff statuses in the
// background.
gitSyncWorker *gitsync.Worker
@@ -2245,6 +2277,7 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n
provisionerdserver.Options{
OIDCConfig: api.OIDCConfig,
ExternalAuthConfigs: api.ExternalAuthConfigs,
AISeatTracker: api.AISeatTracker,
Clock: api.Clock,
HeartbeatFn: options.heartbeatFn,
},
+9
View File
@@ -879,6 +879,15 @@ func createAnotherUserRetry(t testing.TB, client *codersdk.Client, organizationI
m(&req)
}
// Service accounts cannot have a password or email and must
// use login_type=none. Enforce this after mutators so callers
// only need to set ServiceAccount=true.
if req.ServiceAccount {
req.Password = ""
req.Email = ""
req.UserLoginType = codersdk.LoginTypeNone
}
user, err := client.CreateUserWithOrgs(context.Background(), req)
var apiError *codersdk.Error
// If the user already exists by username or email conflict, try again up to "retries" times.
+39 -7
View File
@@ -13,32 +13,64 @@ var _ usage.Inserter = (*UsageInserter)(nil)
type UsageInserter struct {
sync.Mutex
events []usagetypes.DiscreteEvent
discreteEvents []usagetypes.DiscreteEvent
heartbeatEvents []usagetypes.HeartbeatEvent
seenHeartbeats map[string]struct{}
}
func NewUsageInserter() *UsageInserter {
return &UsageInserter{
events: []usagetypes.DiscreteEvent{},
discreteEvents: []usagetypes.DiscreteEvent{},
seenHeartbeats: map[string]struct{}{},
heartbeatEvents: []usagetypes.HeartbeatEvent{},
}
}
func (u *UsageInserter) InsertDiscreteUsageEvent(_ context.Context, _ database.Store, event usagetypes.DiscreteEvent) error {
u.Lock()
defer u.Unlock()
u.events = append(u.events, event)
u.discreteEvents = append(u.discreteEvents, event)
return nil
}
func (u *UsageInserter) GetEvents() []usagetypes.DiscreteEvent {
func (u *UsageInserter) InsertHeartbeatUsageEvent(_ context.Context, _ database.Store, id string, event usagetypes.HeartbeatEvent) error {
u.Lock()
defer u.Unlock()
eventsCopy := make([]usagetypes.DiscreteEvent, len(u.events))
copy(eventsCopy, u.events)
if _, seen := u.seenHeartbeats[id]; seen {
return nil
}
u.seenHeartbeats[id] = struct{}{}
u.heartbeatEvents = append(u.heartbeatEvents, event)
return nil
}
func (u *UsageInserter) GetHeartbeatEvents() []usagetypes.HeartbeatEvent {
u.Lock()
defer u.Unlock()
eventsCopy := make([]usagetypes.HeartbeatEvent, len(u.heartbeatEvents))
copy(eventsCopy, u.heartbeatEvents)
return eventsCopy
}
func (u *UsageInserter) GetDiscreteEvents() []usagetypes.DiscreteEvent {
u.Lock()
defer u.Unlock()
eventsCopy := make([]usagetypes.DiscreteEvent, len(u.discreteEvents))
copy(eventsCopy, u.discreteEvents)
return eventsCopy
}
func (u *UsageInserter) TotalEventCount() int {
u.Lock()
defer u.Unlock()
return len(u.discreteEvents) + len(u.heartbeatEvents)
}
func (u *UsageInserter) Reset() {
u.Lock()
defer u.Unlock()
u.events = []usagetypes.DiscreteEvent{}
u.seenHeartbeats = map[string]struct{}{}
u.discreteEvents = []usagetypes.DiscreteEvent{}
u.heartbeatEvents = []usagetypes.HeartbeatEvent{}
}
+23 -18
View File
@@ -6,22 +6,27 @@ type CheckConstraint string
// CheckConstraint enums.
const (
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users
CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config
CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config
CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
CheckGroupsChatSpendLimitMicrosCheck CheckConstraint = "groups_chat_spend_limit_micros_check" // groups
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
CheckUsersChatSpendLimitMicrosCheck CheckConstraint = "users_chat_spend_limit_micros_check" // users
CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users
CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
)
+92 -7
View File
@@ -21,6 +21,7 @@ import (
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/render"
@@ -194,13 +195,14 @@ func MinimalUserFromVisibleUser(user database.VisibleUser) codersdk.MinimalUser
func ReducedUser(user database.User) codersdk.ReducedUser {
return codersdk.ReducedUser{
MinimalUser: MinimalUser(user),
Email: user.Email,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
LastSeenAt: user.LastSeenAt,
Status: codersdk.UserStatus(user.Status),
LoginType: codersdk.LoginType(user.LoginType),
MinimalUser: MinimalUser(user),
Email: user.Email,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
LastSeenAt: user.LastSeenAt,
Status: codersdk.UserStatus(user.Status),
LoginType: codersdk.LoginType(user.LoginType),
IsServiceAccount: user.IsServiceAccount,
}
}
@@ -1164,3 +1166,86 @@ func nullInt64Ptr(v sql.NullInt64) *int64 {
value := v.Int64
return &value
}
// ChatDiffStatus converts a database.ChatDiffStatus to a
// codersdk.ChatDiffStatus. When status is nil an empty value
// containing only the chatID is returned.
func ChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) codersdk.ChatDiffStatus {
result := codersdk.ChatDiffStatus{
ChatID: chatID,
}
if status == nil {
return result
}
result.ChatID = status.ChatID
if status.Url.Valid {
u := strings.TrimSpace(status.Url.String)
if u != "" {
result.URL = &u
}
}
if result.URL == nil {
// Try to build a branch URL from the stored origin.
// Since this function does not have access to the API
// instance, we construct a GitHub provider directly as
// a best-effort fallback.
// TODO: This uses the default github.com API base URL,
// so branch URLs for GitHub Enterprise instances will
// be incorrect. To fix this, this function would need
// access to the external auth configs.
gp := gitprovider.New("github", "", nil)
if gp != nil {
if owner, repo, _, ok := gp.ParseRepositoryOrigin(status.GitRemoteOrigin); ok {
branchURL := gp.BuildBranchURL(owner, repo, status.GitBranch)
if branchURL != "" {
result.URL = &branchURL
}
}
}
}
if status.PullRequestState.Valid {
pullRequestState := strings.TrimSpace(status.PullRequestState.String)
if pullRequestState != "" {
result.PullRequestState = &pullRequestState
}
}
result.PullRequestTitle = status.PullRequestTitle
result.PullRequestDraft = status.PullRequestDraft
result.ChangesRequested = status.ChangesRequested
result.Additions = status.Additions
result.Deletions = status.Deletions
result.ChangedFiles = status.ChangedFiles
if status.AuthorLogin.Valid {
result.AuthorLogin = &status.AuthorLogin.String
}
if status.AuthorAvatarUrl.Valid {
result.AuthorAvatarURL = &status.AuthorAvatarUrl.String
}
if status.BaseBranch.Valid {
result.BaseBranch = &status.BaseBranch.String
}
if status.HeadBranch.Valid {
result.HeadBranch = &status.HeadBranch.String
}
if status.PrNumber.Valid {
result.PRNumber = &status.PrNumber.Int32
}
if status.Commits.Valid {
result.Commits = &status.Commits.Int32
}
if status.Approved.Valid {
result.Approved = &status.Approved.Bool
}
if status.ReviewerCount.Valid {
result.ReviewerCount = &status.ReviewerCount.Int32
}
if status.RefreshedAt.Valid {
refreshedAt := status.RefreshedAt.Time
result.RefreshedAt = &refreshedAt
}
staleAt := status.StaleAt
result.StaleAt = &staleAt
return result
}
+193 -5
View File
@@ -1264,7 +1264,7 @@ func (q *querier) canAssignRoles(ctx context.Context, orgID uuid.UUID, added, re
// System roles are stored in the database but have a fixed, code-defined
// meaning. Do not rewrite the name for them so the static "who can assign
// what" mapping applies.
if !rbac.SystemRoleName(roleName.Name) {
if !rolestore.IsSystemRoleName(roleName.Name) {
// To support a dynamic mapping of what roles can assign what, we need
// to store this in the database. For now, just use a static role so
// owners and org admins can assign roles.
@@ -1726,6 +1726,13 @@ func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountCon
return q.db.CountAuthorizedConnectionLogs(ctx, arg, prep)
}
func (q *querier) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return 0, err
}
return q.db.CountEnabledModelsWithoutPricing(ctx)
}
func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil {
return nil, err
@@ -1854,6 +1861,20 @@ func (q *querier) DeleteChatQueuedMessage(ctx context.Context, arg database.Dele
return q.db.DeleteChatQueuedMessage(ctx, arg)
}
func (q *querier) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteChatUsageLimitGroupOverride(ctx, groupID)
}
func (q *querier) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteChatUsageLimitUserOverride(ctx, userID)
}
func (q *querier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceCryptoKey); err != nil {
return database.CryptoKey{}, err
@@ -2124,12 +2145,12 @@ func (q *querier) DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) erro
return fetchAndExec(q.log, q.auth, policy.ActionShare, fetch, q.db.DeleteWorkspaceACLByID)(ctx, id)
}
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, params database.DeleteWorkspaceACLsByOrganizationParams) error {
// This is a system-only function.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.DeleteWorkspaceACLsByOrganization(ctx, organizationID)
return q.db.DeleteWorkspaceACLsByOrganization(ctx, params)
}
func (q *querier) DeleteWorkspaceAgentPortShare(ctx context.Context, arg database.DeleteWorkspaceAgentPortShareParams) error {
@@ -2327,6 +2348,13 @@ func (q *querier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Tim
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysLastUsedAfter)(ctx, lastUsed)
}
func (q *querier) GetActiveAISeatCount(ctx context.Context) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceLicense); err != nil {
return 0, err
}
return q.db.GetActiveAISeatCount(ctx)
}
func (q *querier) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTemplate.All()); err != nil {
return nil, err
@@ -2454,6 +2482,17 @@ func (q *querier) GetChatCostSummary(ctx context.Context, arg database.GetChatCo
return q.db.GetChatCostSummary(ctx, arg)
}
func (q *querier) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
// The desktop-enabled flag is a deployment-wide setting read by any
// authenticated chat user and by chatd when deciding whether to expose
// computer-use tooling. We only require that an explicit actor is present
// in the context so unauthenticated calls fail closed.
if _, ok := ActorFromContext(ctx); !ok {
return false, ErrNoActor
}
return q.db.GetChatDesktopEnabled(ctx)
}
func (q *querier) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, chatID)
@@ -2532,6 +2571,14 @@ func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetC
return q.db.GetChatMessagesByChatID(ctx, arg)
}
func (q *querier) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
_, err := q.GetChatByID(ctx, arg.ChatID)
if err != nil {
return nil, err
}
return q.db.GetChatMessagesByChatIDDescPaginated(ctx, arg)
}
func (q *querier) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, chatID)
@@ -2596,8 +2643,33 @@ func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) {
return q.db.GetChatSystemPrompt(ctx)
}
func (q *querier) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByOwnerID)(ctx, ownerID)
func (q *querier) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatUsageLimitConfig{}, err
}
return q.db.GetChatUsageLimitConfig(ctx)
}
func (q *querier) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.GetChatUsageLimitGroupOverrideRow{}, err
}
return q.db.GetChatUsageLimitGroupOverride(ctx, groupID)
}
func (q *querier) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.GetChatUsageLimitUserOverrideRow{}, err
}
return q.db.GetChatUsageLimitUserOverride(ctx, userID)
}
func (q *querier) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceChat.Type)
if err != nil {
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
}
return q.db.GetAuthorizedChats(ctx, arg, prep)
}
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
@@ -3087,6 +3159,34 @@ func (q *querier) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg da
return q.db.GetOrganizationsWithPrebuildStatus(ctx, arg)
}
func (q *querier) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRInsightsPerModelParams) ([]database.GetPRInsightsPerModelRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetPRInsightsPerModel(ctx, arg)
}
func (q *querier) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetPRInsightsRecentPRs(ctx, arg)
}
func (q *querier) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.GetPRInsightsSummaryRow{}, err
}
return q.db.GetPRInsightsSummary(ctx, arg)
}
func (q *querier) GetPRInsightsTimeSeries(ctx context.Context, arg database.GetPRInsightsTimeSeriesParams) ([]database.GetPRInsightsTimeSeriesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetPRInsightsTimeSeries(ctx, arg)
}
func (q *querier) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
version, err := q.db.GetTemplateVersionByJobID(ctx, jobID)
if err != nil {
@@ -3750,6 +3850,13 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID)
return q.db.GetUserChatCustomPrompt(ctx, userID)
}
func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil {
return 0, err
}
return q.db.GetUserChatSpendInPeriod(ctx, arg)
}
func (q *querier) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return 0, err
@@ -3757,6 +3864,13 @@ func (q *querier) GetUserCount(ctx context.Context, includeSystem bool) (int64,
return q.db.GetUserCount(ctx, includeSystem)
}
func (q *querier) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(userID.String())); err != nil {
return 0, err
}
return q.db.GetUserGroupSpendLimit(ctx, userID)
}
func (q *querier) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
// Used by insights endpoints. Need to check both for auditors and for regular users with template acl perms.
if err := q.authorizeContext(ctx, policy.ActionViewInsights, rbac.ResourceTemplate); err != nil {
@@ -4426,6 +4540,13 @@ func (q *querier) InsertAIBridgeInterception(ctx context.Context, arg database.I
return insert(q.log, q.auth, rbac.ResourceAibridgeInterception.WithOwner(arg.InitiatorID.String()), q.db.InsertAIBridgeInterception)(ctx, arg)
}
func (q *querier) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
return database.AIBridgeModelThought{}, err
}
return q.db.InsertAIBridgeModelThought(ctx, arg)
}
func (q *querier) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
// All aibridge_token_usages records belong to the initiator of their associated interception.
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
@@ -5115,6 +5236,20 @@ func (q *querier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context,
return q.db.ListAIBridgeUserPromptsByInterceptionIDs(ctx, interceptionIDs)
}
func (q *querier) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.ListChatUsageLimitGroupOverrides(ctx)
}
func (q *querier) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.ListChatUsageLimitOverrides(ctx)
}
func (q *querier) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListProvisionerKeysByOrganization)(ctx, organizationID)
}
@@ -5234,6 +5369,13 @@ func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveU
return q.db.RemoveUserFromGroups(ctx, arg)
}
func (q *querier) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(userID.String())); err != nil {
return 0, err
}
return q.db.ResolveUserChatSpendLimit(ctx, userID)
}
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
@@ -6417,6 +6559,13 @@ func (q *querier) UpdateWorkspacesTTLByTemplateID(ctx context.Context, arg datab
return q.db.UpdateWorkspacesTTLByTemplateID(ctx, arg)
}
func (q *querier) UpsertAISeatState(ctx context.Context, arg database.UpsertAISeatStateParams) (bool, error) {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceSystem); err != nil {
return false, err
}
return q.db.UpsertAISeatState(ctx, arg)
}
func (q *querier) UpsertAnnouncementBanners(ctx context.Context, value string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
@@ -6438,6 +6587,13 @@ func (q *querier) UpsertBoundaryUsageStats(ctx context.Context, arg database.Ups
return q.db.UpsertBoundaryUsageStats(ctx, arg)
}
func (q *querier) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.UpsertChatDesktopEnabled(ctx, enableDesktop)
}
func (q *querier) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
// Authorize update on the parent chat.
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
@@ -6469,6 +6625,27 @@ func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) erro
return q.db.UpsertChatSystemPrompt(ctx, value)
}
func (q *querier) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatUsageLimitConfig{}, err
}
return q.db.UpsertChatUsageLimitConfig(ctx, arg)
}
func (q *querier) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.UpsertChatUsageLimitGroupOverrideRow{}, err
}
return q.db.UpsertChatUsageLimitGroupOverride(ctx, arg)
}
func (q *querier) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.UpsertChatUsageLimitUserOverrideRow{}, err
}
return q.db.UpsertChatUsageLimitUserOverride(ctx, arg)
}
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return database.ConnectionLog{}, err
@@ -6656,6 +6833,13 @@ func (q *querier) UpsertWorkspaceAppAuditSession(ctx context.Context, arg databa
return q.db.UpsertWorkspaceAppAuditSession(ctx, arg)
}
func (q *querier) UsageEventExistsByID(ctx context.Context, id string) (bool, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceUsageEvent); err != nil {
return false, err
}
return q.db.UsageEventExistsByID(ctx, id)
}
func (q *querier) ValidateGroupIDs(ctx context.Context, groupIDs []uuid.UUID) (database.ValidateGroupIDsRow, error) {
// This check is probably overly restrictive, but the "correct" check isn't
// necessarily obvious. It's only used as a verification check for ACLs right
@@ -6751,3 +6935,7 @@ func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database
// database.Store interface, so dbauthz needs to implement it.
return q.ListAIBridgeModels(ctx, arg)
}
func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.Chat, error) {
return q.GetChats(ctx, arg)
}
+219 -10
View File
@@ -513,6 +513,10 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatCostSummary(gomock.Any(), arg).Return(row, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(row)
}))
s.Run("CountEnabledModelsWithoutPricing", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().CountEnabledModelsWithoutPricing(gomock.Any()).Return(int64(3), nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(int64(3))
}))
s.Run("GetChatDiffStatusByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
@@ -558,6 +562,14 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
}))
s.Run("GetChatMessagesByChatIDDescPaginated", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
arg := database.GetChatMessagesByChatIDDescPaginatedParams{ChatID: chat.ID, BeforeID: 0, LimitVal: 50}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetChatMessagesByChatIDDescPaginated(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
}))
s.Run("GetLastChatMessageByRole", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
@@ -606,12 +618,17 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
}))
s.Run("GetChatsByOwnerID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
c1 := testutil.Fake(s.T(), faker, database.Chat{})
c2 := testutil.Fake(s.T(), faker, database.Chat{})
params := database.GetChatsByOwnerIDParams{OwnerID: c1.OwnerID}
dbm.EXPECT().GetChatsByOwnerID(gomock.Any(), params).Return([]database.Chat{c1, c2}, nil).AnyTimes()
check.Args(params).Asserts(c1, policy.ActionRead, c2, policy.ActionRead).Returns([]database.Chat{c1, c2})
s.Run("GetChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
params := database.GetChatsParams{}
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
// No asserts here because SQLFilter.
check.Args(params).Asserts()
}))
s.Run("GetAuthorizedChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
params := database.GetChatsParams{}
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
// No asserts here because it re-routes through GetChats which uses SQLFilter.
check.Args(params, emptyPreparedAuthorized{}).Asserts()
}))
s.Run("GetChatQueuedMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
@@ -624,6 +641,10 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatSystemPrompt(gomock.Any()).Return("prompt", nil).AnyTimes()
check.Args().Asserts()
}))
s.Run("GetChatDesktopEnabled", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetChatDesktopEnabled(gomock.Any()).Return(false, nil).AnyTimes()
check.Args().Asserts()
}))
s.Run("GetEnabledChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
@@ -833,6 +854,146 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UpsertChatSystemPrompt(gomock.Any(), "").Return(nil).AnyTimes()
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("UpsertChatDesktopEnabled", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertChatDesktopEnabled(gomock.Any(), false).Return(nil).AnyTimes()
check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("GetUserChatSpendInPeriod", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetUserChatSpendInPeriodParams{
UserID: uuid.New(),
StartTime: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
EndTime: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
}
spend := int64(123)
dbm.EXPECT().GetUserChatSpendInPeriod(gomock.Any(), arg).Return(spend, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.UserID.String()), policy.ActionRead).Returns(spend)
}))
s.Run("GetUserGroupSpendLimit", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
userID := uuid.New()
limit := int64(456)
dbm.EXPECT().GetUserGroupSpendLimit(gomock.Any(), userID).Return(limit, nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceChat.WithOwner(userID.String()), policy.ActionRead).Returns(limit)
}))
s.Run("ResolveUserChatSpendLimit", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
userID := uuid.New()
limit := int64(789)
dbm.EXPECT().ResolveUserChatSpendLimit(gomock.Any(), userID).Return(limit, nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceChat.WithOwner(userID.String()), policy.ActionRead).Returns(limit)
}))
s.Run("GetChatUsageLimitConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
now := dbtime.Now()
config := database.ChatUsageLimitConfig{
ID: 1,
Singleton: true,
Enabled: true,
DefaultLimitMicros: 1_000_000,
Period: "monthly",
CreatedAt: now,
UpdatedAt: now,
}
dbm.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(config, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
groupID := uuid.New()
override := database.GetChatUsageLimitGroupOverrideRow{
GroupID: groupID,
SpendLimitMicros: sql.NullInt64{Int64: 2_000_000, Valid: true},
}
dbm.EXPECT().GetChatUsageLimitGroupOverride(gomock.Any(), groupID).Return(override, nil).AnyTimes()
check.Args(groupID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(override)
}))
s.Run("GetChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
userID := uuid.New()
override := database.GetChatUsageLimitUserOverrideRow{
UserID: userID,
SpendLimitMicros: sql.NullInt64{Int64: 3_000_000, Valid: true},
}
dbm.EXPECT().GetChatUsageLimitUserOverride(gomock.Any(), userID).Return(override, nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(override)
}))
s.Run("ListChatUsageLimitGroupOverrides", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
overrides := []database.ListChatUsageLimitGroupOverridesRow{{
GroupID: uuid.New(),
GroupName: "group-name",
GroupDisplayName: "Group Name",
GroupAvatarUrl: "https://example.com/group.png",
SpendLimitMicros: sql.NullInt64{Int64: 4_000_000, Valid: true},
MemberCount: 5,
}}
dbm.EXPECT().ListChatUsageLimitGroupOverrides(gomock.Any()).Return(overrides, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(overrides)
}))
s.Run("ListChatUsageLimitOverrides", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
overrides := []database.ListChatUsageLimitOverridesRow{{
UserID: uuid.New(),
Username: "usage-limit-user",
Name: "Usage Limit User",
AvatarURL: "https://example.com/avatar.png",
SpendLimitMicros: sql.NullInt64{Int64: 5_000_000, Valid: true},
}}
dbm.EXPECT().ListChatUsageLimitOverrides(gomock.Any()).Return(overrides, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(overrides)
}))
s.Run("UpsertChatUsageLimitConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
now := dbtime.Now()
arg := database.UpsertChatUsageLimitConfigParams{
Enabled: true,
DefaultLimitMicros: 6_000_000,
Period: "monthly",
}
config := database.ChatUsageLimitConfig{
ID: 1,
Singleton: true,
Enabled: arg.Enabled,
DefaultLimitMicros: arg.DefaultLimitMicros,
Period: arg.Period,
CreatedAt: now,
UpdatedAt: now,
}
dbm.EXPECT().UpsertChatUsageLimitConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
}))
s.Run("UpsertChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.UpsertChatUsageLimitGroupOverrideParams{
SpendLimitMicros: 7_000_000,
GroupID: uuid.New(),
}
override := database.UpsertChatUsageLimitGroupOverrideRow{
GroupID: arg.GroupID,
Name: "group",
DisplayName: "Group",
AvatarURL: "",
SpendLimitMicros: sql.NullInt64{Int64: arg.SpendLimitMicros, Valid: true},
}
dbm.EXPECT().UpsertChatUsageLimitGroupOverride(gomock.Any(), arg).Return(override, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(override)
}))
s.Run("UpsertChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.UpsertChatUsageLimitUserOverrideParams{
SpendLimitMicros: 8_000_000,
UserID: uuid.New(),
}
override := database.UpsertChatUsageLimitUserOverrideRow{
UserID: arg.UserID,
Username: "user",
Name: "User",
AvatarURL: "",
SpendLimitMicros: sql.NullInt64{Int64: arg.SpendLimitMicros, Valid: true},
}
dbm.EXPECT().UpsertChatUsageLimitUserOverride(gomock.Any(), arg).Return(override, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(override)
}))
s.Run("DeleteChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
groupID := uuid.New()
dbm.EXPECT().DeleteChatUsageLimitGroupOverride(gomock.Any(), groupID).Return(nil).AnyTimes()
check.Args(groupID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("DeleteChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
userID := uuid.New()
dbm.EXPECT().DeleteChatUsageLimitUserOverride(gomock.Any(), userID).Return(nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
}
func (s *MethodTestSuite) TestFile() {
@@ -1155,6 +1316,14 @@ func (s *MethodTestSuite) TestProvisionerJob() {
}
func (s *MethodTestSuite) TestLicense() {
s.Run("GetActiveAISeatCount", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetActiveAISeatCount(gomock.Any()).Return(int64(100), nil).AnyTimes()
check.Args().Asserts(rbac.ResourceLicense, policy.ActionRead).Returns(int64(100))
}))
s.Run("UpsertAISeatState", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertAISeatState(gomock.Any(), gomock.Any()).Return(true, nil).AnyTimes()
check.Args(database.UpsertAISeatStateParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate)
}))
s.Run("GetLicenses", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
a := database.License{ID: 1}
b := database.License{ID: 2}
@@ -1324,7 +1493,7 @@ func (s *MethodTestSuite) TestOrganization() {
org := testutil.Fake(s.T(), faker, database.Organization{})
arg := database.UpdateOrganizationWorkspaceSharingSettingsParams{
ID: org.ID,
WorkspaceSharingDisabled: true,
ShareableWorkspaceOwners: database.ShareableWorkspaceOwnersNone,
}
dbm.EXPECT().GetOrganizationByID(gomock.Any(), org.ID).Return(org, nil).AnyTimes()
dbm.EXPECT().UpdateOrganizationWorkspaceSharingSettings(gomock.Any(), arg).Return(org, nil).AnyTimes()
@@ -1755,6 +1924,26 @@ func (s *MethodTestSuite) TestTemplate() {
dbm.EXPECT().GetTemplateInsightsByTemplate(gomock.Any(), arg).Return([]database.GetTemplateInsightsByTemplateRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceTemplate, policy.ActionViewInsights)
}))
s.Run("GetPRInsightsSummary", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetPRInsightsSummaryParams{}
dbm.EXPECT().GetPRInsightsSummary(gomock.Any(), arg).Return(database.GetPRInsightsSummaryRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetPRInsightsTimeSeries", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetPRInsightsTimeSeriesParams{}
dbm.EXPECT().GetPRInsightsTimeSeries(gomock.Any(), arg).Return([]database.GetPRInsightsTimeSeriesRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetPRInsightsPerModel", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetPRInsightsPerModelParams{}
dbm.EXPECT().GetPRInsightsPerModel(gomock.Any(), arg).Return([]database.GetPRInsightsPerModelRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetPRInsightsRecentPRs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetPRInsightsRecentPRsParams{}
dbm.EXPECT().GetPRInsightsRecentPRs(gomock.Any(), arg).Return([]database.GetPRInsightsRecentPRsRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetTelemetryTaskEventsParams{}
dbm.EXPECT().GetTelemetryTaskEvents(gomock.Any(), arg).Return([]database.GetTelemetryTaskEventsRow{}, nil).AnyTimes()
@@ -2243,9 +2432,12 @@ func (s *MethodTestSuite) TestWorkspace() {
check.Args(w.ID).Asserts(w, policy.ActionShare)
}))
s.Run("DeleteWorkspaceACLsByOrganization", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
orgID := uuid.New()
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), orgID).Return(nil).AnyTimes()
check.Args(orgID).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
arg := database.DeleteWorkspaceACLsByOrganizationParams{
OrganizationID: uuid.New(),
ExcludeServiceAccounts: false,
}
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
}))
s.Run("GetLatestWorkspaceBuildByWorkspaceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
w := testutil.Fake(s.T(), faker, database.Workspace{})
@@ -4951,6 +5143,12 @@ func (s *MethodTestSuite) TestUsageEvents() {
check.Args(params).Asserts(rbac.ResourceUsageEvent, policy.ActionCreate)
}))
s.Run("UsageEventExistsByID", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
id := uuid.NewString()
db.EXPECT().UsageEventExistsByID(gomock.Any(), id).Return(true, nil)
check.Args(id).Asserts(rbac.ResourceUsageEvent, policy.ActionRead)
}))
s.Run("SelectUsageEventsForPublishing", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
now := dbtime.Now()
db.EXPECT().SelectUsageEventsForPublishing(gomock.Any(), now).Return([]database.UsageEvent{}, nil)
@@ -5011,6 +5209,17 @@ func (s *MethodTestSuite) TestAIBridge() {
check.Args(params).Asserts(intc, policy.ActionCreate)
}))
s.Run("InsertAIBridgeModelThought", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intID := uuid.UUID{2}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intID).Return(intc, nil).AnyTimes() // Validation.
params := database.InsertAIBridgeModelThoughtParams{InterceptionID: intc.ID}
expected := testutil.Fake(s.T(), faker, database.AIBridgeModelThought{InterceptionID: intc.ID})
db.EXPECT().InsertAIBridgeModelThought(gomock.Any(), params).Return(expected, nil).AnyTimes()
check.Args(params).Asserts(intc, policy.ActionUpdate)
}))
s.Run("InsertAIBridgeTokenUsage", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intID := uuid.UUID{2}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
+2 -1
View File
@@ -29,6 +29,7 @@ import (
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/rbac/regosql"
"github.com/coder/coder/v2/coderd/rbac/rolestore"
"github.com/coder/coder/v2/coderd/util/slice"
)
@@ -143,7 +144,7 @@ func (s *MethodTestSuite) Mocked(testCaseF func(dmb *dbmock.MockStore, faker *go
UUID: pair.OrganizationID,
Valid: pair.OrganizationID != uuid.Nil,
},
IsSystem: rbac.SystemRoleName(pair.Name),
IsSystem: rolestore.IsSystemRoleName(pair.Name),
ID: uuid.New(),
})
}
+17 -25
View File
@@ -650,34 +650,26 @@ func Organization(t testing.TB, db database.Store, orig database.Organization) d
})
require.NoError(t, err, "insert organization")
// Populate the placeholder organization-member system role (created by
// DB trigger/migration) so org members have expected permissions.
//nolint:gocritic // ReconcileOrgMemberRole needs the system:update
// Populate the placeholder system roles (created by DB
// trigger/migration) so org members have expected permissions.
//nolint:gocritic // ReconcileSystemRole needs the system:update
// permission that `genCtx` does not have.
sysCtx := dbauthz.AsSystemRestricted(genCtx)
_, _, err = rolestore.ReconcileOrgMemberRole(sysCtx, db, database.CustomRole{
Name: rbac.RoleOrgMember(),
OrganizationID: uuid.NullUUID{
UUID: org.ID,
Valid: true,
},
}, org.WorkspaceSharingDisabled)
if errors.Is(err, sql.ErrNoRows) {
// The trigger that creates the placeholder role didn't run (e.g.,
// triggers were disabled in the test). Create the role manually.
err = rolestore.CreateOrgMemberRole(sysCtx, db, org)
require.NoError(t, err, "create organization-member role")
_, _, err = rolestore.ReconcileOrgMemberRole(sysCtx, db, database.CustomRole{
Name: rbac.RoleOrgMember(),
OrganizationID: uuid.NullUUID{
UUID: org.ID,
Valid: true,
},
}, org.WorkspaceSharingDisabled)
for roleName := range rolestore.SystemRoleNames {
role := database.CustomRole{
Name: roleName,
OrganizationID: uuid.NullUUID{UUID: org.ID, Valid: true},
}
_, _, err = rolestore.ReconcileSystemRole(sysCtx, db, role, org)
if errors.Is(err, sql.ErrNoRows) {
// The trigger that creates the placeholder role didn't run (e.g.,
// triggers were disabled in the test). Create the role manually.
err = rolestore.CreateSystemRole(sysCtx, db, org, roleName)
require.NoError(t, err, "create role "+roleName)
_, _, err = rolestore.ReconcileSystemRole(sysCtx, db, role, org)
}
require.NoError(t, err, "reconcile role "+roleName)
}
require.NoError(t, err, "reconcile organization-member role")
return org
}
+216 -6
View File
@@ -288,6 +288,14 @@ func (m queryMetricsStore) CountConnectionLogs(ctx context.Context, arg database
return r0, r1
}
func (m queryMetricsStore) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
start := time.Now()
r0, r1 := m.s.CountEnabledModelsWithoutPricing(ctx)
m.queryLatencies.WithLabelValues("CountEnabledModelsWithoutPricing").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountEnabledModelsWithoutPricing").Inc()
return r0, r1
}
func (m queryMetricsStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
start := time.Now()
r0, r1 := m.s.CountInProgressPrebuilds(ctx)
@@ -408,6 +416,22 @@ func (m queryMetricsStore) DeleteChatQueuedMessage(ctx context.Context, arg data
return r0
}
func (m queryMetricsStore) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatUsageLimitGroupOverride(ctx, groupID)
m.queryLatencies.WithLabelValues("DeleteChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatUsageLimitGroupOverride").Inc()
return r0
}
func (m queryMetricsStore) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatUsageLimitUserOverride(ctx, userID)
m.queryLatencies.WithLabelValues("DeleteChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatUsageLimitUserOverride").Inc()
return r0
}
func (m queryMetricsStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
start := time.Now()
r0, r1 := m.s.DeleteCryptoKey(ctx, arg)
@@ -672,10 +696,11 @@ func (m queryMetricsStore) DeleteWorkspaceACLByID(ctx context.Context, id uuid.U
return r0
}
func (m queryMetricsStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
func (m queryMetricsStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, arg database.DeleteWorkspaceACLsByOrganizationParams) error {
start := time.Now()
r0 := m.s.DeleteWorkspaceACLsByOrganization(ctx, organizationID)
r0 := m.s.DeleteWorkspaceACLsByOrganization(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteWorkspaceACLsByOrganization").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteWorkspaceACLsByOrganization").Inc()
return r0
}
@@ -871,6 +896,14 @@ func (m queryMetricsStore) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed
return r0, r1
}
func (m queryMetricsStore) GetActiveAISeatCount(ctx context.Context) (int64, error) {
start := time.Now()
r0, r1 := m.s.GetActiveAISeatCount(ctx)
m.queryLatencies.WithLabelValues("GetActiveAISeatCount").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetActiveAISeatCount").Inc()
return r0, r1
}
func (m queryMetricsStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
start := time.Now()
r0, r1 := m.s.GetActivePresetPrebuildSchedules(ctx)
@@ -1015,6 +1048,14 @@ func (m queryMetricsStore) GetChatCostSummary(ctx context.Context, arg database.
return r0, r1
}
func (m queryMetricsStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
start := time.Now()
r0, r1 := m.s.GetChatDesktopEnabled(ctx)
m.queryLatencies.WithLabelValues("GetChatDesktopEnabled").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDesktopEnabled").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
start := time.Now()
r0, r1 := m.s.GetChatDiffStatusByChatID(ctx, chatID)
@@ -1063,6 +1104,14 @@ func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID d
return r0, r1
}
func (m queryMetricsStore) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessagesByChatIDDescPaginated(ctx, arg)
m.queryLatencies.WithLabelValues("GetChatMessagesByChatIDDescPaginated").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesByChatIDDescPaginated").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessagesForPromptByChatID(ctx, chatID)
@@ -1127,11 +1176,35 @@ func (m queryMetricsStore) GetChatSystemPrompt(ctx context.Context) (string, err
return r0, r1
}
func (m queryMetricsStore) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
func (m queryMetricsStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
start := time.Now()
r0, r1 := m.s.GetChatsByOwnerID(ctx, ownerID)
m.queryLatencies.WithLabelValues("GetChatsByOwnerID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsByOwnerID").Inc()
r0, r1 := m.s.GetChatUsageLimitConfig(ctx)
m.queryLatencies.WithLabelValues("GetChatUsageLimitConfig").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitConfig").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatUsageLimitGroupOverride(ctx, groupID)
m.queryLatencies.WithLabelValues("GetChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitGroupOverride").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatUsageLimitUserOverride(ctx, userID)
m.queryLatencies.WithLabelValues("GetChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitUserOverride").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetChats(ctx, arg)
m.queryLatencies.WithLabelValues("GetChats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChats").Inc()
return r0, r1
}
@@ -1671,6 +1744,38 @@ func (m queryMetricsStore) GetOrganizationsWithPrebuildStatus(ctx context.Contex
return r0, r1
}
func (m queryMetricsStore) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRInsightsPerModelParams) ([]database.GetPRInsightsPerModelRow, error) {
start := time.Now()
r0, r1 := m.s.GetPRInsightsPerModel(ctx, arg)
m.queryLatencies.WithLabelValues("GetPRInsightsPerModel").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsPerModel").Inc()
return r0, r1
}
func (m queryMetricsStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
start := time.Now()
r0, r1 := m.s.GetPRInsightsRecentPRs(ctx, arg)
m.queryLatencies.WithLabelValues("GetPRInsightsRecentPRs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsRecentPRs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
start := time.Now()
r0, r1 := m.s.GetPRInsightsSummary(ctx, arg)
m.queryLatencies.WithLabelValues("GetPRInsightsSummary").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsSummary").Inc()
return r0, r1
}
func (m queryMetricsStore) GetPRInsightsTimeSeries(ctx context.Context, arg database.GetPRInsightsTimeSeriesParams) ([]database.GetPRInsightsTimeSeriesRow, error) {
start := time.Now()
r0, r1 := m.s.GetPRInsightsTimeSeries(ctx, arg)
m.queryLatencies.WithLabelValues("GetPRInsightsTimeSeries").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsTimeSeries").Inc()
return r0, r1
}
func (m queryMetricsStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
start := time.Now()
r0, r1 := m.s.GetParameterSchemasByJobID(ctx, jobID)
@@ -2255,6 +2360,14 @@ func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID u
return r0, r1
}
func (m queryMetricsStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.GetUserChatSpendInPeriod(ctx, arg)
m.queryLatencies.WithLabelValues("GetUserChatSpendInPeriod").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatSpendInPeriod").Inc()
return r0, r1
}
func (m queryMetricsStore) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
start := time.Now()
r0, r1 := m.s.GetUserCount(ctx, includeSystem)
@@ -2263,6 +2376,14 @@ func (m queryMetricsStore) GetUserCount(ctx context.Context, includeSystem bool)
return r0, r1
}
func (m queryMetricsStore) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
start := time.Now()
r0, r1 := m.s.GetUserGroupSpendLimit(ctx, userID)
m.queryLatencies.WithLabelValues("GetUserGroupSpendLimit").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserGroupSpendLimit").Inc()
return r0, r1
}
func (m queryMetricsStore) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
start := time.Now()
r0, r1 := m.s.GetUserLatencyInsights(ctx, arg)
@@ -2871,6 +2992,14 @@ func (m queryMetricsStore) InsertAIBridgeInterception(ctx context.Context, arg d
return r0, r1
}
func (m queryMetricsStore) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
start := time.Now()
r0, r1 := m.s.InsertAIBridgeModelThought(ctx, arg)
m.queryLatencies.WithLabelValues("InsertAIBridgeModelThought").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertAIBridgeModelThought").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
start := time.Now()
r0, r1 := m.s.InsertAIBridgeTokenUsage(ctx, arg)
@@ -3495,6 +3624,22 @@ func (m queryMetricsStore) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.
return r0, r1
}
func (m queryMetricsStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
start := time.Now()
r0, r1 := m.s.ListChatUsageLimitGroupOverrides(ctx)
m.queryLatencies.WithLabelValues("ListChatUsageLimitGroupOverrides").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatUsageLimitGroupOverrides").Inc()
return r0, r1
}
func (m queryMetricsStore) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
start := time.Now()
r0, r1 := m.s.ListChatUsageLimitOverrides(ctx)
m.queryLatencies.WithLabelValues("ListChatUsageLimitOverrides").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatUsageLimitOverrides").Inc()
return r0, r1
}
func (m queryMetricsStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
start := time.Now()
r0, r1 := m.s.ListProvisionerKeysByOrganization(ctx, organizationID)
@@ -3607,6 +3752,14 @@ func (m queryMetricsStore) RemoveUserFromGroups(ctx context.Context, arg databas
return r0, r1
}
func (m queryMetricsStore) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
start := time.Now()
r0, r1 := m.s.ResolveUserChatSpendLimit(ctx, userID)
m.queryLatencies.WithLabelValues("ResolveUserChatSpendLimit").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ResolveUserChatSpendLimit").Inc()
return r0, r1
}
func (m queryMetricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
start := time.Now()
r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest)
@@ -3859,6 +4012,7 @@ func (m queryMetricsStore) UpdateOrganizationWorkspaceSharingSettings(ctx contex
start := time.Now()
r0, r1 := m.s.UpdateOrganizationWorkspaceSharingSettings(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateOrganizationWorkspaceSharingSettings").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateOrganizationWorkspaceSharingSettings").Inc()
return r0, r1
}
@@ -4406,6 +4560,14 @@ func (m queryMetricsStore) UpdateWorkspacesTTLByTemplateID(ctx context.Context,
return r0
}
func (m queryMetricsStore) UpsertAISeatState(ctx context.Context, arg database.UpsertAISeatStateParams) (bool, error) {
start := time.Now()
r0, r1 := m.s.UpsertAISeatState(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertAISeatState").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertAISeatState").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertAnnouncementBanners(ctx context.Context, value string) error {
start := time.Now()
r0 := m.s.UpsertAnnouncementBanners(ctx, value)
@@ -4430,6 +4592,14 @@ func (m queryMetricsStore) UpsertBoundaryUsageStats(ctx context.Context, arg dat
return r0, r1
}
func (m queryMetricsStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
start := time.Now()
r0 := m.s.UpsertChatDesktopEnabled(ctx, enableDesktop)
m.queryLatencies.WithLabelValues("UpsertChatDesktopEnabled").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDesktopEnabled").Inc()
return r0
}
func (m queryMetricsStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
start := time.Now()
r0, r1 := m.s.UpsertChatDiffStatus(ctx, arg)
@@ -4454,6 +4624,30 @@ func (m queryMetricsStore) UpsertChatSystemPrompt(ctx context.Context, value str
return r0
}
func (m queryMetricsStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
start := time.Now()
r0, r1 := m.s.UpsertChatUsageLimitConfig(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitConfig").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitConfig").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
start := time.Now()
r0, r1 := m.s.UpsertChatUsageLimitGroupOverride(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitGroupOverride").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
start := time.Now()
r0, r1 := m.s.UpsertChatUsageLimitUserOverride(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitUserOverride").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
start := time.Now()
r0, r1 := m.s.UpsertConnectionLog(ctx, arg)
@@ -4630,6 +4824,14 @@ func (m queryMetricsStore) UpsertWorkspaceAppAuditSession(ctx context.Context, a
return r0, r1
}
func (m queryMetricsStore) UsageEventExistsByID(ctx context.Context, id string) (bool, error) {
start := time.Now()
r0, r1 := m.s.UsageEventExistsByID(ctx, id)
m.queryLatencies.WithLabelValues("UsageEventExistsByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UsageEventExistsByID").Inc()
return r0, r1
}
func (m queryMetricsStore) ValidateGroupIDs(ctx context.Context, groupIds []uuid.UUID) (database.ValidateGroupIDsRow, error) {
start := time.Now()
r0, r1 := m.s.ValidateGroupIDs(ctx, groupIds)
@@ -4749,3 +4951,11 @@ func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeModels").Inc()
return r0, r1
}
func (m queryMetricsStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetAuthorizedChats(ctx, arg, prepared)
m.queryLatencies.WithLabelValues("GetAuthorizedChats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAuthorizedChats").Inc()
return r0, r1
}
+397 -10
View File
@@ -424,6 +424,21 @@ func (mr *MockStoreMockRecorder) CountConnectionLogs(ctx, arg any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountConnectionLogs", reflect.TypeOf((*MockStore)(nil).CountConnectionLogs), ctx, arg)
}
// CountEnabledModelsWithoutPricing mocks base method.
func (m *MockStore) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountEnabledModelsWithoutPricing", ctx)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountEnabledModelsWithoutPricing indicates an expected call of CountEnabledModelsWithoutPricing.
func (mr *MockStoreMockRecorder) CountEnabledModelsWithoutPricing(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEnabledModelsWithoutPricing", reflect.TypeOf((*MockStore)(nil).CountEnabledModelsWithoutPricing), ctx)
}
// CountInProgressPrebuilds mocks base method.
func (m *MockStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
m.ctrl.T.Helper()
@@ -639,6 +654,34 @@ func (mr *MockStoreMockRecorder) DeleteChatQueuedMessage(ctx, arg any) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatQueuedMessage", reflect.TypeOf((*MockStore)(nil).DeleteChatQueuedMessage), ctx, arg)
}
// DeleteChatUsageLimitGroupOverride mocks base method.
func (m *MockStore) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatUsageLimitGroupOverride", ctx, groupID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteChatUsageLimitGroupOverride indicates an expected call of DeleteChatUsageLimitGroupOverride.
func (mr *MockStoreMockRecorder) DeleteChatUsageLimitGroupOverride(ctx, groupID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).DeleteChatUsageLimitGroupOverride), ctx, groupID)
}
// DeleteChatUsageLimitUserOverride mocks base method.
func (m *MockStore) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatUsageLimitUserOverride", ctx, userID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteChatUsageLimitUserOverride indicates an expected call of DeleteChatUsageLimitUserOverride.
func (mr *MockStoreMockRecorder) DeleteChatUsageLimitUserOverride(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).DeleteChatUsageLimitUserOverride), ctx, userID)
}
// DeleteCryptoKey mocks base method.
func (m *MockStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
m.ctrl.T.Helper()
@@ -1112,17 +1155,17 @@ func (mr *MockStoreMockRecorder) DeleteWorkspaceACLByID(ctx, id any) *gomock.Cal
}
// DeleteWorkspaceACLsByOrganization mocks base method.
func (m *MockStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
func (m *MockStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, arg database.DeleteWorkspaceACLsByOrganizationParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteWorkspaceACLsByOrganization", ctx, organizationID)
ret := m.ctrl.Call(m, "DeleteWorkspaceACLsByOrganization", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteWorkspaceACLsByOrganization indicates an expected call of DeleteWorkspaceACLsByOrganization.
func (mr *MockStoreMockRecorder) DeleteWorkspaceACLsByOrganization(ctx, organizationID any) *gomock.Call {
func (mr *MockStoreMockRecorder) DeleteWorkspaceACLsByOrganization(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLsByOrganization", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLsByOrganization), ctx, organizationID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLsByOrganization", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLsByOrganization), ctx, arg)
}
// DeleteWorkspaceAgentPortShare mocks base method.
@@ -1478,6 +1521,21 @@ func (mr *MockStoreMockRecorder) GetAPIKeysLastUsedAfter(ctx, lastUsed any) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeysLastUsedAfter", reflect.TypeOf((*MockStore)(nil).GetAPIKeysLastUsedAfter), ctx, lastUsed)
}
// GetActiveAISeatCount mocks base method.
func (m *MockStore) GetActiveAISeatCount(ctx context.Context) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveAISeatCount", ctx)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveAISeatCount indicates an expected call of GetActiveAISeatCount.
func (mr *MockStoreMockRecorder) GetActiveAISeatCount(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveAISeatCount", reflect.TypeOf((*MockStore)(nil).GetActiveAISeatCount), ctx)
}
// GetActivePresetPrebuildSchedules mocks base method.
func (m *MockStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
m.ctrl.T.Helper()
@@ -1673,6 +1731,21 @@ func (mr *MockStoreMockRecorder) GetAuthorizedAuditLogsOffset(ctx, arg, prepared
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedAuditLogsOffset", reflect.TypeOf((*MockStore)(nil).GetAuthorizedAuditLogsOffset), ctx, arg, prepared)
}
// GetAuthorizedChats mocks base method.
func (m *MockStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthorizedChats", ctx, arg, prepared)
ret0, _ := ret[0].([]database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAuthorizedChats indicates an expected call of GetAuthorizedChats.
func (mr *MockStoreMockRecorder) GetAuthorizedChats(ctx, arg, prepared any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedChats", reflect.TypeOf((*MockStore)(nil).GetAuthorizedChats), ctx, arg, prepared)
}
// GetAuthorizedConnectionLogsOffset mocks base method.
func (m *MockStore) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) {
m.ctrl.T.Helper()
@@ -1838,6 +1911,21 @@ func (mr *MockStoreMockRecorder) GetChatCostSummary(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostSummary", reflect.TypeOf((*MockStore)(nil).GetChatCostSummary), ctx, arg)
}
// GetChatDesktopEnabled mocks base method.
func (m *MockStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatDesktopEnabled", ctx)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatDesktopEnabled indicates an expected call of GetChatDesktopEnabled.
func (mr *MockStoreMockRecorder) GetChatDesktopEnabled(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDesktopEnabled", reflect.TypeOf((*MockStore)(nil).GetChatDesktopEnabled), ctx)
}
// GetChatDiffStatusByChatID mocks base method.
func (m *MockStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
m.ctrl.T.Helper()
@@ -1928,6 +2016,21 @@ func (mr *MockStoreMockRecorder) GetChatMessagesByChatID(ctx, arg any) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatID), ctx, arg)
}
// GetChatMessagesByChatIDDescPaginated mocks base method.
func (m *MockStore) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatMessagesByChatIDDescPaginated", ctx, arg)
ret0, _ := ret[0].([]database.ChatMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatMessagesByChatIDDescPaginated indicates an expected call of GetChatMessagesByChatIDDescPaginated.
func (mr *MockStoreMockRecorder) GetChatMessagesByChatIDDescPaginated(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatIDDescPaginated", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatIDDescPaginated), ctx, arg)
}
// GetChatMessagesForPromptByChatID mocks base method.
func (m *MockStore) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
m.ctrl.T.Helper()
@@ -2048,19 +2151,64 @@ func (mr *MockStoreMockRecorder) GetChatSystemPrompt(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).GetChatSystemPrompt), ctx)
}
// GetChatsByOwnerID mocks base method.
func (m *MockStore) GetChatsByOwnerID(ctx context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
// GetChatUsageLimitConfig mocks base method.
func (m *MockStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatsByOwnerID", ctx, arg)
ret := m.ctrl.Call(m, "GetChatUsageLimitConfig", ctx)
ret0, _ := ret[0].(database.ChatUsageLimitConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatUsageLimitConfig indicates an expected call of GetChatUsageLimitConfig.
func (mr *MockStoreMockRecorder) GetChatUsageLimitConfig(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitConfig", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitConfig), ctx)
}
// GetChatUsageLimitGroupOverride mocks base method.
func (m *MockStore) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatUsageLimitGroupOverride", ctx, groupID)
ret0, _ := ret[0].(database.GetChatUsageLimitGroupOverrideRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatUsageLimitGroupOverride indicates an expected call of GetChatUsageLimitGroupOverride.
func (mr *MockStoreMockRecorder) GetChatUsageLimitGroupOverride(ctx, groupID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitGroupOverride), ctx, groupID)
}
// GetChatUsageLimitUserOverride mocks base method.
func (m *MockStore) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatUsageLimitUserOverride", ctx, userID)
ret0, _ := ret[0].(database.GetChatUsageLimitUserOverrideRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatUsageLimitUserOverride indicates an expected call of GetChatUsageLimitUserOverride.
func (mr *MockStoreMockRecorder) GetChatUsageLimitUserOverride(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitUserOverride), ctx, userID)
}
// GetChats mocks base method.
func (m *MockStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChats", ctx, arg)
ret0, _ := ret[0].([]database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatsByOwnerID indicates an expected call of GetChatsByOwnerID.
func (mr *MockStoreMockRecorder) GetChatsByOwnerID(ctx, arg any) *gomock.Call {
// GetChats indicates an expected call of GetChats.
func (mr *MockStoreMockRecorder) GetChats(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByOwnerID", reflect.TypeOf((*MockStore)(nil).GetChatsByOwnerID), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChats", reflect.TypeOf((*MockStore)(nil).GetChats), ctx, arg)
}
// GetConnectionLogsOffset mocks base method.
@@ -3068,6 +3216,66 @@ func (mr *MockStoreMockRecorder) GetOrganizationsWithPrebuildStatus(ctx, arg any
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationsWithPrebuildStatus", reflect.TypeOf((*MockStore)(nil).GetOrganizationsWithPrebuildStatus), ctx, arg)
}
// GetPRInsightsPerModel mocks base method.
func (m *MockStore) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRInsightsPerModelParams) ([]database.GetPRInsightsPerModelRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPRInsightsPerModel", ctx, arg)
ret0, _ := ret[0].([]database.GetPRInsightsPerModelRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPRInsightsPerModel indicates an expected call of GetPRInsightsPerModel.
func (mr *MockStoreMockRecorder) GetPRInsightsPerModel(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsPerModel", reflect.TypeOf((*MockStore)(nil).GetPRInsightsPerModel), ctx, arg)
}
// GetPRInsightsRecentPRs mocks base method.
func (m *MockStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPRInsightsRecentPRs", ctx, arg)
ret0, _ := ret[0].([]database.GetPRInsightsRecentPRsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPRInsightsRecentPRs indicates an expected call of GetPRInsightsRecentPRs.
func (mr *MockStoreMockRecorder) GetPRInsightsRecentPRs(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsRecentPRs", reflect.TypeOf((*MockStore)(nil).GetPRInsightsRecentPRs), ctx, arg)
}
// GetPRInsightsSummary mocks base method.
func (m *MockStore) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPRInsightsSummary", ctx, arg)
ret0, _ := ret[0].(database.GetPRInsightsSummaryRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPRInsightsSummary indicates an expected call of GetPRInsightsSummary.
func (mr *MockStoreMockRecorder) GetPRInsightsSummary(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsSummary", reflect.TypeOf((*MockStore)(nil).GetPRInsightsSummary), ctx, arg)
}
// GetPRInsightsTimeSeries mocks base method.
func (m *MockStore) GetPRInsightsTimeSeries(ctx context.Context, arg database.GetPRInsightsTimeSeriesParams) ([]database.GetPRInsightsTimeSeriesRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPRInsightsTimeSeries", ctx, arg)
ret0, _ := ret[0].([]database.GetPRInsightsTimeSeriesRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPRInsightsTimeSeries indicates an expected call of GetPRInsightsTimeSeries.
func (mr *MockStoreMockRecorder) GetPRInsightsTimeSeries(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsTimeSeries", reflect.TypeOf((*MockStore)(nil).GetPRInsightsTimeSeries), ctx, arg)
}
// GetParameterSchemasByJobID mocks base method.
func (m *MockStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
m.ctrl.T.Helper()
@@ -4193,6 +4401,21 @@ func (mr *MockStoreMockRecorder) GetUserChatCustomPrompt(ctx, userID any) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).GetUserChatCustomPrompt), ctx, userID)
}
// GetUserChatSpendInPeriod mocks base method.
func (m *MockStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserChatSpendInPeriod", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserChatSpendInPeriod indicates an expected call of GetUserChatSpendInPeriod.
func (mr *MockStoreMockRecorder) GetUserChatSpendInPeriod(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatSpendInPeriod", reflect.TypeOf((*MockStore)(nil).GetUserChatSpendInPeriod), ctx, arg)
}
// GetUserCount mocks base method.
func (m *MockStore) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
m.ctrl.T.Helper()
@@ -4208,6 +4431,21 @@ func (mr *MockStoreMockRecorder) GetUserCount(ctx, includeSystem any) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserCount", reflect.TypeOf((*MockStore)(nil).GetUserCount), ctx, includeSystem)
}
// GetUserGroupSpendLimit mocks base method.
func (m *MockStore) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserGroupSpendLimit", ctx, userID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserGroupSpendLimit indicates an expected call of GetUserGroupSpendLimit.
func (mr *MockStoreMockRecorder) GetUserGroupSpendLimit(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserGroupSpendLimit", reflect.TypeOf((*MockStore)(nil).GetUserGroupSpendLimit), ctx, userID)
}
// GetUserLatencyInsights mocks base method.
func (m *MockStore) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
m.ctrl.T.Helper()
@@ -5362,6 +5600,21 @@ func (mr *MockStoreMockRecorder) InsertAIBridgeInterception(ctx, arg any) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeInterception", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeInterception), ctx, arg)
}
// InsertAIBridgeModelThought mocks base method.
func (m *MockStore) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertAIBridgeModelThought", ctx, arg)
ret0, _ := ret[0].(database.AIBridgeModelThought)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertAIBridgeModelThought indicates an expected call of InsertAIBridgeModelThought.
func (mr *MockStoreMockRecorder) InsertAIBridgeModelThought(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeModelThought", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeModelThought), ctx, arg)
}
// InsertAIBridgeTokenUsage mocks base method.
func (m *MockStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
m.ctrl.T.Helper()
@@ -6547,6 +6800,36 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared)
}
// ListChatUsageLimitGroupOverrides mocks base method.
func (m *MockStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListChatUsageLimitGroupOverrides", ctx)
ret0, _ := ret[0].([]database.ListChatUsageLimitGroupOverridesRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListChatUsageLimitGroupOverrides indicates an expected call of ListChatUsageLimitGroupOverrides.
func (mr *MockStoreMockRecorder) ListChatUsageLimitGroupOverrides(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChatUsageLimitGroupOverrides", reflect.TypeOf((*MockStore)(nil).ListChatUsageLimitGroupOverrides), ctx)
}
// ListChatUsageLimitOverrides mocks base method.
func (m *MockStore) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListChatUsageLimitOverrides", ctx)
ret0, _ := ret[0].([]database.ListChatUsageLimitOverridesRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListChatUsageLimitOverrides indicates an expected call of ListChatUsageLimitOverrides.
func (mr *MockStoreMockRecorder) ListChatUsageLimitOverrides(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChatUsageLimitOverrides", reflect.TypeOf((*MockStore)(nil).ListChatUsageLimitOverrides), ctx)
}
// ListProvisionerKeysByOrganization mocks base method.
func (m *MockStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
m.ctrl.T.Helper()
@@ -6785,6 +7068,21 @@ func (mr *MockStoreMockRecorder) RemoveUserFromGroups(ctx, arg any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromGroups), ctx, arg)
}
// ResolveUserChatSpendLimit mocks base method.
func (m *MockStore) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ResolveUserChatSpendLimit", ctx, userID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ResolveUserChatSpendLimit indicates an expected call of ResolveUserChatSpendLimit.
func (mr *MockStoreMockRecorder) ResolveUserChatSpendLimit(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolveUserChatSpendLimit", reflect.TypeOf((*MockStore)(nil).ResolveUserChatSpendLimit), ctx, userID)
}
// RevokeDBCryptKey mocks base method.
func (m *MockStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
m.ctrl.T.Helper()
@@ -8229,6 +8527,21 @@ func (mr *MockStoreMockRecorder) UpdateWorkspacesTTLByTemplateID(ctx, arg any) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkspacesTTLByTemplateID", reflect.TypeOf((*MockStore)(nil).UpdateWorkspacesTTLByTemplateID), ctx, arg)
}
// UpsertAISeatState mocks base method.
func (m *MockStore) UpsertAISeatState(ctx context.Context, arg database.UpsertAISeatStateParams) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertAISeatState", ctx, arg)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertAISeatState indicates an expected call of UpsertAISeatState.
func (mr *MockStoreMockRecorder) UpsertAISeatState(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertAISeatState", reflect.TypeOf((*MockStore)(nil).UpsertAISeatState), ctx, arg)
}
// UpsertAnnouncementBanners mocks base method.
func (m *MockStore) UpsertAnnouncementBanners(ctx context.Context, value string) error {
m.ctrl.T.Helper()
@@ -8272,6 +8585,20 @@ func (mr *MockStoreMockRecorder) UpsertBoundaryUsageStats(ctx, arg any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertBoundaryUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertBoundaryUsageStats), ctx, arg)
}
// UpsertChatDesktopEnabled mocks base method.
func (m *MockStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatDesktopEnabled", ctx, enableDesktop)
ret0, _ := ret[0].(error)
return ret0
}
// UpsertChatDesktopEnabled indicates an expected call of UpsertChatDesktopEnabled.
func (mr *MockStoreMockRecorder) UpsertChatDesktopEnabled(ctx, enableDesktop any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDesktopEnabled", reflect.TypeOf((*MockStore)(nil).UpsertChatDesktopEnabled), ctx, enableDesktop)
}
// UpsertChatDiffStatus mocks base method.
func (m *MockStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
m.ctrl.T.Helper()
@@ -8316,6 +8643,51 @@ func (mr *MockStoreMockRecorder) UpsertChatSystemPrompt(ctx, value any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatSystemPrompt), ctx, value)
}
// UpsertChatUsageLimitConfig mocks base method.
func (m *MockStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatUsageLimitConfig", ctx, arg)
ret0, _ := ret[0].(database.ChatUsageLimitConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertChatUsageLimitConfig indicates an expected call of UpsertChatUsageLimitConfig.
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitConfig(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitConfig", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitConfig), ctx, arg)
}
// UpsertChatUsageLimitGroupOverride mocks base method.
func (m *MockStore) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatUsageLimitGroupOverride", ctx, arg)
ret0, _ := ret[0].(database.UpsertChatUsageLimitGroupOverrideRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertChatUsageLimitGroupOverride indicates an expected call of UpsertChatUsageLimitGroupOverride.
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitGroupOverride(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitGroupOverride), ctx, arg)
}
// UpsertChatUsageLimitUserOverride mocks base method.
func (m *MockStore) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatUsageLimitUserOverride", ctx, arg)
ret0, _ := ret[0].(database.UpsertChatUsageLimitUserOverrideRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertChatUsageLimitUserOverride indicates an expected call of UpsertChatUsageLimitUserOverride.
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitUserOverride(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitUserOverride), ctx, arg)
}
// UpsertConnectionLog mocks base method.
func (m *MockStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
m.ctrl.T.Helper()
@@ -8633,6 +9005,21 @@ func (mr *MockStoreMockRecorder) UpsertWorkspaceAppAuditSession(ctx, arg any) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertWorkspaceAppAuditSession", reflect.TypeOf((*MockStore)(nil).UpsertWorkspaceAppAuditSession), ctx, arg)
}
// UsageEventExistsByID mocks base method.
func (m *MockStore) UsageEventExistsByID(ctx context.Context, id string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UsageEventExistsByID", ctx, id)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UsageEventExistsByID indicates an expected call of UsageEventExistsByID.
func (mr *MockStoreMockRecorder) UsageEventExistsByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UsageEventExistsByID", reflect.TypeOf((*MockStore)(nil).UsageEventExistsByID), ctx, id)
}
// ValidateGroupIDs mocks base method.
func (m *MockStore) ValidateGroupIDs(ctx context.Context, groupIds []uuid.UUID) (database.ValidateGroupIDsRow, error) {
m.ctrl.T.Helper()
+112 -13
View File
@@ -10,6 +10,11 @@ CREATE TYPE agent_key_scope_enum AS ENUM (
'no_user_data'
);
CREATE TYPE ai_seat_usage_reason AS ENUM (
'aibridge',
'task'
);
CREATE TYPE api_key_scope AS ENUM (
'coder:all',
'coder:application_connect',
@@ -503,7 +508,14 @@ CREATE TYPE resource_type AS ENUM (
'workspace_agent',
'workspace_app',
'prebuilds_settings',
'task'
'task',
'ai_seat'
);
CREATE TYPE shareable_workspace_owners AS ENUM (
'none',
'everyone',
'service_accounts'
);
CREATE TYPE startup_script_behavior AS ENUM (
@@ -608,28 +620,35 @@ CREATE FUNCTION aggregate_usage_event() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
-- Check for supported event types and throw error for unknown types
IF NEW.event_type NOT IN ('dc_managed_agents_v1') THEN
-- Check for supported event types and throw error for unknown types.
IF NEW.event_type NOT IN ('dc_managed_agents_v1', 'hb_ai_seats_v1') THEN
RAISE EXCEPTION 'Unhandled usage event type in aggregate_usage_event: %', NEW.event_type;
END IF;
INSERT INTO usage_events_daily (day, event_type, usage_data)
VALUES (
-- Extract the date from the created_at timestamp, always using UTC for
-- consistency
date_trunc('day', NEW.created_at AT TIME ZONE 'UTC')::date,
NEW.event_type,
NEW.event_data
)
ON CONFLICT (day, event_type) DO UPDATE SET
usage_data = CASE
-- Handle simple counter events by summing the count
-- Handle simple counter events by summing the count.
WHEN NEW.event_type IN ('dc_managed_agents_v1') THEN
jsonb_build_object(
'count',
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0) +
COALESCE((NEW.event_data->>'count')::bigint, 0)
)
-- Heartbeat events: keep the max value seen that day
WHEN NEW.event_type IN ('hb_ai_seats_v1') THEN
jsonb_build_object(
'count',
GREATEST(
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0),
COALESCE((NEW.event_data->>'count')::bigint, 0)
)
)
END;
RETURN NEW;
@@ -786,7 +805,7 @@ BEGIN
END;
$$;
CREATE FUNCTION insert_org_member_system_role() RETURNS trigger
CREATE FUNCTION insert_organization_system_roles() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
@@ -801,7 +820,8 @@ BEGIN
is_system,
created_at,
updated_at
) VALUES (
) VALUES
(
'organization-member',
'',
NEW.id,
@@ -812,6 +832,18 @@ BEGIN
true,
NOW(),
NOW()
),
(
'organization-service-account',
'',
NEW.id,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
true,
NOW(),
NOW()
);
RETURN NEW;
END;
@@ -1046,6 +1078,15 @@ BEGIN
END;
$$;
CREATE TABLE ai_seat_state (
user_id uuid NOT NULL,
first_used_at timestamp with time zone NOT NULL,
last_used_at timestamp with time zone NOT NULL,
last_event_type ai_seat_usage_reason NOT NULL,
last_event_description text NOT NULL,
updated_at timestamp with time zone NOT NULL
);
CREATE TABLE aibridge_interceptions (
id uuid NOT NULL,
initiator_id uuid NOT NULL,
@@ -1071,6 +1112,15 @@ COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interceptio
COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID supplied by the client (optional and not universally supported).';
CREATE TABLE aibridge_model_thoughts (
interception_id uuid NOT NULL,
content text NOT NULL,
metadata jsonb,
created_at timestamp with time zone NOT NULL
);
COMMENT ON TABLE aibridge_model_thoughts IS 'Audit log of model thinking in intercepted requests in AI Bridge';
CREATE TABLE aibridge_token_usages (
id uuid NOT NULL,
interception_id uuid NOT NULL,
@@ -1239,7 +1289,8 @@ CREATE TABLE chat_messages (
compressed boolean DEFAULT false NOT NULL,
created_by uuid,
content_version smallint NOT NULL,
total_cost_micros bigint
total_cost_micros bigint,
runtime_ms bigint
);
CREATE SEQUENCE chat_messages_id_seq
@@ -1303,6 +1354,28 @@ CREATE SEQUENCE chat_queued_messages_id_seq
ALTER SEQUENCE chat_queued_messages_id_seq OWNED BY chat_queued_messages.id;
CREATE TABLE chat_usage_limit_config (
id bigint NOT NULL,
singleton boolean DEFAULT true NOT NULL,
enabled boolean DEFAULT false NOT NULL,
default_limit_micros bigint DEFAULT 0 NOT NULL,
period text DEFAULT 'month'::text NOT NULL,
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
CONSTRAINT chat_usage_limit_config_default_limit_micros_check CHECK ((default_limit_micros >= 0)),
CONSTRAINT chat_usage_limit_config_period_check CHECK ((period = ANY (ARRAY['day'::text, 'week'::text, 'month'::text]))),
CONSTRAINT chat_usage_limit_config_singleton_check CHECK (singleton)
);
CREATE SEQUENCE chat_usage_limit_config_id_seq
START WITH 1
INCREMENT BY 1
NO MINVALUE
NO MAXVALUE
CACHE 1;
ALTER SEQUENCE chat_usage_limit_config_id_seq OWNED BY chat_usage_limit_config.id;
CREATE TABLE chats (
id uuid DEFAULT gen_random_uuid() NOT NULL,
owner_id uuid NOT NULL,
@@ -1459,7 +1532,9 @@ CREATE TABLE groups (
avatar_url text DEFAULT ''::text NOT NULL,
quota_allowance integer DEFAULT 0 NOT NULL,
display_name text DEFAULT ''::text NOT NULL,
source group_source DEFAULT 'user'::group_source NOT NULL
source group_source DEFAULT 'user'::group_source NOT NULL,
chat_spend_limit_micros bigint,
CONSTRAINT groups_chat_spend_limit_micros_check CHECK (((chat_spend_limit_micros IS NULL) OR (chat_spend_limit_micros > 0)))
);
COMMENT ON COLUMN groups.display_name IS 'Display name is a custom, human-friendly group name that user can set. This is not required to be unique and can be the empty string.';
@@ -1494,7 +1569,9 @@ CREATE TABLE users (
one_time_passcode_expires_at timestamp with time zone,
is_system boolean DEFAULT false NOT NULL,
is_service_account boolean DEFAULT false NOT NULL,
chat_spend_limit_micros bigint,
CONSTRAINT one_time_passcode_set CHECK ((((hashed_one_time_passcode IS NULL) AND (one_time_passcode_expires_at IS NULL)) OR ((hashed_one_time_passcode IS NOT NULL) AND (one_time_passcode_expires_at IS NOT NULL)))),
CONSTRAINT users_chat_spend_limit_micros_check CHECK (((chat_spend_limit_micros IS NULL) OR (chat_spend_limit_micros > 0))),
CONSTRAINT users_email_not_empty CHECK (((is_service_account = true) = (email = ''::text))),
CONSTRAINT users_service_account_login_type CHECK (((is_service_account = false) OR (login_type = 'none'::login_type))),
CONSTRAINT users_username_min_length CHECK ((length(username) >= 1))
@@ -1782,9 +1859,11 @@ CREATE TABLE organizations (
display_name text NOT NULL,
icon text DEFAULT ''::text NOT NULL,
deleted boolean DEFAULT false NOT NULL,
workspace_sharing_disabled boolean DEFAULT false NOT NULL
shareable_workspace_owners shareable_workspace_owners DEFAULT 'everyone'::shareable_workspace_owners NOT NULL
);
COMMENT ON COLUMN organizations.shareable_workspace_owners IS 'Controls whose workspaces can be shared: none, everyone, or service_accounts.';
CREATE TABLE parameter_schemas (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
@@ -2584,7 +2663,7 @@ CREATE TABLE usage_events (
publish_started_at timestamp with time zone,
published_at timestamp with time zone,
failure_message text,
CONSTRAINT usage_event_type_check CHECK ((event_type = 'dc_managed_agents_v1'::text))
CONSTRAINT usage_event_type_check CHECK ((event_type = ANY (ARRAY['dc_managed_agents_v1'::text, 'hb_ai_seats_v1'::text])))
);
COMMENT ON TABLE usage_events IS 'usage_events contains usage data that is collected from the product and potentially shipped to the usage collector service.';
@@ -3141,6 +3220,8 @@ ALTER TABLE ONLY chat_messages ALTER COLUMN id SET DEFAULT nextval('chat_message
ALTER TABLE ONLY chat_queued_messages ALTER COLUMN id SET DEFAULT nextval('chat_queued_messages_id_seq'::regclass);
ALTER TABLE ONLY chat_usage_limit_config ALTER COLUMN id SET DEFAULT nextval('chat_usage_limit_config_id_seq'::regclass);
ALTER TABLE ONLY licenses ALTER COLUMN id SET DEFAULT nextval('licenses_id_seq'::regclass);
ALTER TABLE ONLY provisioner_job_logs ALTER COLUMN id SET DEFAULT nextval('provisioner_job_logs_id_seq'::regclass);
@@ -3156,6 +3237,9 @@ ALTER TABLE ONLY workspace_resource_metadata ALTER COLUMN id SET DEFAULT nextval
ALTER TABLE ONLY workspace_agent_stats
ADD CONSTRAINT agent_stats_pkey PRIMARY KEY (id);
ALTER TABLE ONLY ai_seat_state
ADD CONSTRAINT ai_seat_state_pkey PRIMARY KEY (user_id);
ALTER TABLE ONLY aibridge_interceptions
ADD CONSTRAINT aibridge_interceptions_pkey PRIMARY KEY (id);
@@ -3198,6 +3282,12 @@ ALTER TABLE ONLY chat_providers
ALTER TABLE ONLY chat_queued_messages
ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
ALTER TABLE ONLY chat_usage_limit_config
ADD CONSTRAINT chat_usage_limit_config_pkey PRIMARY KEY (id);
ALTER TABLE ONLY chat_usage_limit_config
ADD CONSTRAINT chat_usage_limit_config_singleton_key UNIQUE (singleton);
ALTER TABLE ONLY chats
ADD CONSTRAINT chats_pkey PRIMARY KEY (id);
@@ -3510,6 +3600,8 @@ CREATE INDEX idx_aibridge_interceptions_thread_parent_id ON aibridge_interceptio
CREATE INDEX idx_aibridge_interceptions_thread_root_id ON aibridge_interceptions USING btree (thread_root_id);
CREATE INDEX idx_aibridge_model_thoughts_interception_id ON aibridge_model_thoughts USING btree (interception_id);
CREATE INDEX idx_aibridge_token_usages_interception_id ON aibridge_token_usages USING btree (interception_id);
CREATE INDEX idx_aibridge_token_usages_provider_response_id ON aibridge_token_usages USING btree (provider_response_id);
@@ -3550,6 +3642,8 @@ CREATE INDEX idx_chat_messages_compressed_summary_boundary ON chat_messages USIN
CREATE INDEX idx_chat_messages_created_at ON chat_messages USING btree (created_at);
CREATE INDEX idx_chat_messages_owner_spend ON chat_messages USING btree (chat_id, created_at) WHERE (total_cost_micros IS NOT NULL);
CREATE INDEX idx_chat_model_configs_enabled ON chat_model_configs USING btree (enabled);
CREATE INDEX idx_chat_model_configs_provider ON chat_model_configs USING btree (provider);
@@ -3624,6 +3718,8 @@ CREATE INDEX idx_template_versions_has_ai_task ON template_versions USING btree
CREATE UNIQUE INDEX idx_unique_preset_name ON template_version_presets USING btree (name, template_version_id);
CREATE INDEX idx_usage_events_ai_seats ON usage_events USING btree (event_type, created_at) WHERE (event_type = 'hb_ai_seats_v1'::text);
CREATE INDEX idx_usage_events_select_for_publishing ON usage_events USING btree (published_at, publish_started_at, created_at);
CREATE INDEX idx_user_deleted_deleted_at ON user_deleted USING btree (deleted_at);
@@ -3798,7 +3894,7 @@ CREATE TRIGGER trigger_delete_oauth2_provider_app_token AFTER DELETE ON oauth2_p
CREATE TRIGGER trigger_insert_apikeys BEFORE INSERT ON api_keys FOR EACH ROW EXECUTE FUNCTION insert_apikey_fail_if_user_deleted();
CREATE TRIGGER trigger_insert_org_member_system_role AFTER INSERT ON organizations FOR EACH ROW EXECUTE FUNCTION insert_org_member_system_role();
CREATE TRIGGER trigger_insert_organization_system_roles AFTER INSERT ON organizations FOR EACH ROW EXECUTE FUNCTION insert_organization_system_roles();
CREATE TRIGGER trigger_nullify_next_start_at_on_workspace_autostart_modificati AFTER UPDATE ON workspaces FOR EACH ROW EXECUTE FUNCTION nullify_next_start_at_on_workspace_autostart_modification();
@@ -3816,6 +3912,9 @@ COMMENT ON TRIGGER workspace_agent_name_unique_trigger ON workspace_agents IS 'U
the uniqueness requirement. A trigger allows us to enforce uniqueness going
forward without requiring a migration to clean up historical data.';
ALTER TABLE ONLY ai_seat_state
ADD CONSTRAINT ai_seat_state_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY aibridge_interceptions
ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
@@ -6,6 +6,7 @@ type ForeignKeyConstraint string
// ForeignKeyConstraint enums.
const (
ForeignKeyAiSeatStateUserID ForeignKeyConstraint = "ai_seat_state_user_id_fkey" // ALTER TABLE ONLY ai_seat_state ADD CONSTRAINT ai_seat_state_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
@@ -26,6 +26,7 @@ func TestCustomQueriesSyncedRowScan(t *testing.T) {
"GetTemplatesWithFilter": "GetAuthorizedTemplates",
"GetWorkspaces": "GetAuthorizedWorkspaces",
"GetUsers": "GetAuthorizedUsers",
"GetChats": "GetAuthorizedChats",
}
// Scan custom
@@ -0,0 +1,3 @@
DROP TABLE ai_seat_state;
DROP TYPE ai_seat_usage_reason;
@@ -0,0 +1,13 @@
CREATE TYPE ai_seat_usage_reason AS ENUM (
'aibridge',
'task'
);
CREATE TABLE ai_seat_state (
user_id uuid NOT NULL PRIMARY KEY REFERENCES users (id) ON DELETE CASCADE,
first_used_at timestamptz NOT NULL,
last_used_at timestamptz NOT NULL,
last_event_type ai_seat_usage_reason NOT NULL,
last_event_description text NOT NULL,
updated_at timestamptz NOT NULL
);
@@ -0,0 +1 @@
-- resource_type enum values cannot be removed safely; no-op.
@@ -0,0 +1 @@
ALTER TYPE resource_type ADD VALUE IF NOT EXISTS 'ai_seat';
@@ -0,0 +1,4 @@
DROP INDEX IF EXISTS idx_chat_messages_owner_spend;
ALTER TABLE groups DROP COLUMN IF EXISTS chat_spend_limit_micros;
ALTER TABLE users DROP COLUMN IF EXISTS chat_spend_limit_micros;
DROP TABLE IF EXISTS chat_usage_limit_config;
@@ -0,0 +1,32 @@
-- 1. Singleton config table
CREATE TABLE chat_usage_limit_config (
id BIGSERIAL PRIMARY KEY,
-- Only one row allowed (enforced by CHECK).
singleton BOOLEAN NOT NULL DEFAULT TRUE CHECK (singleton),
UNIQUE (singleton),
enabled BOOLEAN NOT NULL DEFAULT FALSE,
-- Limit per user per period, in micro-dollars (1 USD = 1,000,000).
default_limit_micros BIGINT NOT NULL DEFAULT 0
CHECK (default_limit_micros >= 0),
-- Period length: 'day', 'week', or 'month'.
period TEXT NOT NULL DEFAULT 'month'
CHECK (period IN ('day', 'week', 'month')),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- Seed a single disabled row so reads never return empty.
INSERT INTO chat_usage_limit_config (singleton) VALUES (TRUE);
-- 2. Per-user overrides (inline on users table).
ALTER TABLE users ADD COLUMN chat_spend_limit_micros BIGINT DEFAULT NULL
CHECK (chat_spend_limit_micros IS NULL OR chat_spend_limit_micros > 0);
-- 3. Per-group overrides (inline on groups table).
ALTER TABLE groups ADD COLUMN chat_spend_limit_micros BIGINT DEFAULT NULL
CHECK (chat_spend_limit_micros IS NULL OR chat_spend_limit_micros > 0);
-- Speed up per-user spend aggregation in the usage-limit hot path.
CREATE INDEX idx_chat_messages_owner_spend
ON chat_messages (chat_id, created_at)
WHERE total_cost_micros IS NOT NULL;
@@ -0,0 +1,3 @@
DROP INDEX idx_aibridge_model_thoughts_interception_id;
DROP TABLE aibridge_model_thoughts;
@@ -0,0 +1,10 @@
CREATE TABLE aibridge_model_thoughts (
interception_id UUID NOT NULL,
content TEXT NOT NULL,
metadata jsonb,
created_at TIMESTAMPTZ NOT NULL
);
COMMENT ON TABLE aibridge_model_thoughts IS 'Audit log of model thinking in intercepted requests in AI Bridge';
CREATE INDEX idx_aibridge_model_thoughts_interception_id ON aibridge_model_thoughts(interception_id);
@@ -0,0 +1,52 @@
DELETE FROM custom_roles
WHERE name = 'organization-service-account' AND is_system = true;
ALTER TABLE organizations
ADD COLUMN workspace_sharing_disabled boolean NOT NULL DEFAULT false;
-- Migrate back: 'none' -> disabled, everything else -> enabled.
UPDATE organizations
SET workspace_sharing_disabled = true
WHERE shareable_workspace_owners = 'none';
ALTER TABLE organizations DROP COLUMN shareable_workspace_owners;
DROP TYPE shareable_workspace_owners;
-- Restore the original single-role trigger from migration 408.
DROP TRIGGER IF EXISTS trigger_insert_organization_system_roles ON organizations;
DROP FUNCTION IF EXISTS insert_organization_system_roles;
CREATE OR REPLACE FUNCTION insert_org_member_system_role() RETURNS trigger AS $$
BEGIN
INSERT INTO custom_roles (
name,
display_name,
organization_id,
site_permissions,
org_permissions,
user_permissions,
member_permissions,
is_system,
created_at,
updated_at
) VALUES (
'organization-member',
'',
NEW.id,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
true,
NOW(),
NOW()
);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_insert_org_member_system_role
AFTER INSERT ON organizations
FOR EACH ROW
EXECUTE FUNCTION insert_org_member_system_role();
@@ -0,0 +1,101 @@
CREATE TYPE shareable_workspace_owners AS ENUM ('none', 'everyone', 'service_accounts');
ALTER TABLE organizations
ADD COLUMN shareable_workspace_owners shareable_workspace_owners NOT NULL DEFAULT 'everyone';
COMMENT ON COLUMN organizations.shareable_workspace_owners IS 'Controls whose workspaces can be shared: none, everyone, or service_accounts.';
-- Migrate existing data from the boolean column.
UPDATE organizations
SET shareable_workspace_owners = 'none'
WHERE workspace_sharing_disabled = true;
ALTER TABLE organizations DROP COLUMN workspace_sharing_disabled;
-- Defensively rename any existing 'organization-service-account' roles
-- so they don't collide with the new system role.
UPDATE custom_roles
SET name = name || '-' || id::text
-- lower(name) is part of the existing unique index
WHERE lower(name) = 'organization-service-account';
-- Create skeleton organization-service-account system roles for all
-- existing organizations, mirroring what migration 408 did for
-- organization-member.
INSERT INTO custom_roles (
name,
display_name,
organization_id,
site_permissions,
org_permissions,
user_permissions,
member_permissions,
is_system,
created_at,
updated_at
)
SELECT
'organization-service-account',
'',
id,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
true,
NOW(),
NOW()
FROM
organizations;
-- Replace the single-role trigger with one that creates both system
-- roles when a new organization is inserted.
DROP TRIGGER IF EXISTS trigger_insert_org_member_system_role ON organizations;
DROP FUNCTION IF EXISTS insert_org_member_system_role;
CREATE OR REPLACE FUNCTION insert_organization_system_roles() RETURNS trigger AS $$
BEGIN
INSERT INTO custom_roles (
name,
display_name,
organization_id,
site_permissions,
org_permissions,
user_permissions,
member_permissions,
is_system,
created_at,
updated_at
) VALUES
(
'organization-member',
'',
NEW.id,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
true,
NOW(),
NOW()
),
(
'organization-service-account',
'',
NEW.id,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
true,
NOW(),
NOW()
);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_insert_organization_system_roles
AFTER INSERT ON organizations
FOR EACH ROW
EXECUTE FUNCTION insert_organization_system_roles();
@@ -0,0 +1,38 @@
DROP INDEX IF EXISTS idx_usage_events_ai_seats;
-- Remove hb_ai_seats_v1 rows so the original constraint can be restored.
DELETE FROM usage_events WHERE event_type = 'hb_ai_seats_v1';
DELETE FROM usage_events_daily WHERE event_type = 'hb_ai_seats_v1';
-- Restore original constraint.
ALTER TABLE usage_events
DROP CONSTRAINT usage_event_type_check,
ADD CONSTRAINT usage_event_type_check CHECK (event_type IN ('dc_managed_agents_v1'));
-- Restore the original aggregate function without hb_ai_seats_v1 support.
CREATE OR REPLACE FUNCTION aggregate_usage_event()
RETURNS TRIGGER AS $$
BEGIN
IF NEW.event_type NOT IN ('dc_managed_agents_v1') THEN
RAISE EXCEPTION 'Unhandled usage event type in aggregate_usage_event: %', NEW.event_type;
END IF;
INSERT INTO usage_events_daily (day, event_type, usage_data)
VALUES (
date_trunc('day', NEW.created_at AT TIME ZONE 'UTC')::date,
NEW.event_type,
NEW.event_data
)
ON CONFLICT (day, event_type) DO UPDATE SET
usage_data = CASE
WHEN NEW.event_type IN ('dc_managed_agents_v1') THEN
jsonb_build_object(
'count',
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0) +
COALESCE((NEW.event_data->>'count')::bigint, 0)
)
END;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
@@ -0,0 +1,50 @@
-- Expand the CHECK constraint to allow hb_ai_seats_v1.
ALTER TABLE usage_events
DROP CONSTRAINT usage_event_type_check,
ADD CONSTRAINT usage_event_type_check CHECK (event_type IN ('dc_managed_agents_v1', 'hb_ai_seats_v1'));
-- Partial index for efficient lookups of AI seat heartbeat events by time.
-- This will be used for the admin dashboard to see seat count over time.
CREATE INDEX idx_usage_events_ai_seats
ON usage_events (event_type, created_at)
WHERE event_type = 'hb_ai_seats_v1';
-- Update the aggregate function to handle hb_ai_seats_v1 events.
-- Heartbeat events replace the previous value for the same time period.
CREATE OR REPLACE FUNCTION aggregate_usage_event()
RETURNS TRIGGER AS $$
BEGIN
-- Check for supported event types and throw error for unknown types.
IF NEW.event_type NOT IN ('dc_managed_agents_v1', 'hb_ai_seats_v1') THEN
RAISE EXCEPTION 'Unhandled usage event type in aggregate_usage_event: %', NEW.event_type;
END IF;
INSERT INTO usage_events_daily (day, event_type, usage_data)
VALUES (
date_trunc('day', NEW.created_at AT TIME ZONE 'UTC')::date,
NEW.event_type,
NEW.event_data
)
ON CONFLICT (day, event_type) DO UPDATE SET
usage_data = CASE
-- Handle simple counter events by summing the count.
WHEN NEW.event_type IN ('dc_managed_agents_v1') THEN
jsonb_build_object(
'count',
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0) +
COALESCE((NEW.event_data->>'count')::bigint, 0)
)
-- Heartbeat events: keep the max value seen that day
WHEN NEW.event_type IN ('hb_ai_seats_v1') THEN
jsonb_build_object(
'count',
GREATEST(
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0),
COALESCE((NEW.event_data->>'count')::bigint, 0)
)
)
END;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
@@ -0,0 +1 @@
ALTER TABLE chat_messages DROP COLUMN runtime_ms;
@@ -0,0 +1 @@
ALTER TABLE chat_messages ADD COLUMN runtime_ms bigint;
@@ -0,0 +1,28 @@
-- Fixture for migration 000443_three_options_for_allowed_workspace_sharing.
-- Inserts a custom role named 'Organization-Service-Account' (mixed case)
-- to ensure the migration's case-insensitive rename catches it.
INSERT INTO custom_roles (
name,
display_name,
organization_id,
site_permissions,
org_permissions,
user_permissions,
member_permissions,
is_system,
created_at,
updated_at
)
VALUES (
'Organization-Service-Account',
'User-created role',
'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1',
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
false,
NOW(),
NOW()
)
ON CONFLICT DO NOTHING;
@@ -0,0 +1,11 @@
INSERT INTO
ai_seat_state (
user_id,
first_used_at,
last_used_at,
last_event_type,
last_event_description,
updated_at
)
VALUES
('30095c71-380b-457a-8995-97b8ee6e5307', NOW(), NOW(), 'task'::ai_seat_usage_reason, 'Used for AI task', NOW());
@@ -0,0 +1,5 @@
UPDATE users SET chat_spend_limit_micros = 5000000
WHERE id = 'fc1511ef-4fcf-4a3b-98a1-8df64160e35a';
UPDATE groups SET chat_spend_limit_micros = 10000000
WHERE id = 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1';
@@ -0,0 +1,13 @@
INSERT INTO
aibridge_model_thoughts (
interception_id,
content,
metadata,
created_at
)
VALUES (
'be003e1e-b38f-43bf-847d-928074dd0aa8', -- from 000370_aibridge.up.sql
'The user is asking about their workspaces. I should use the coder_list_workspaces tool to retrieve this information.',
'{"source": "commentary"}',
'2025-09-15 12:45:19.123456+00'
);
@@ -0,0 +1,20 @@
INSERT INTO usage_events (
id,
event_type,
event_data,
created_at,
publish_started_at,
published_at,
failure_message
)
VALUES
-- Unpublished hb_ai_seats_v1 event.
(
'ai-seats-event1',
'hb_ai_seats_v1',
'{"count":3}',
'2023-06-01 00:00:00+00',
NULL,
NULL,
NULL
);
+64
View File
@@ -52,6 +52,7 @@ type customQuerier interface {
auditLogQuerier
connectionLogQuerier
aibridgeQuerier
chatQuerier
}
type templateQuerier interface {
@@ -451,6 +452,7 @@ func (q *sqlQuerier) GetAuthorizedUsers(ctx context.Context, arg GetUsersParams,
&i.OneTimePasscodeExpiresAt,
&i.IsSystem,
&i.IsServiceAccount,
&i.ChatSpendLimitMicros,
&i.Count,
); err != nil {
return nil, err
@@ -737,6 +739,68 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
return count, nil
}
type chatQuerier interface {
GetAuthorizedChats(ctx context.Context, arg GetChatsParams, prepared rbac.PreparedAuthorized) ([]Chat, error)
}
func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams, prepared rbac.PreparedAuthorized) ([]Chat, error) {
authorizedFilter, err := prepared.CompileToSQL(ctx, rbac.ConfigChats())
if err != nil {
return nil, xerrors.Errorf("compile authorized filter: %w", err)
}
filtered, err := insertAuthorizedFilter(getChats, fmt.Sprintf(" AND %s", authorizedFilter))
if err != nil {
return nil, xerrors.Errorf("insert authorized filter: %w", err)
}
// The name comment is for metric tracking
query := fmt.Sprintf("-- name: GetAuthorizedChats :many\n%s", filtered)
rows, err := q.db.QueryContext(ctx, query,
arg.OwnerID,
arg.Archived,
arg.AfterID,
arg.OffsetOpt,
arg.LimitOpt,
)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Chat
for rows.Next() {
var i Chat
if err := rows.Scan(
&i.ID,
&i.OwnerID,
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
&i.CreatedAt,
&i.UpdatedAt,
&i.ParentChatID,
&i.RootChatID,
&i.LastModelConfigID,
&i.Archived,
&i.LastError,
&i.Mode,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
type aibridgeQuerier interface {
ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error)
CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error)
+166 -13
View File
@@ -741,6 +741,64 @@ func AllAgentKeyScopeEnumValues() []AgentKeyScopeEnum {
}
}
type AiSeatUsageReason string
const (
AiSeatUsageReasonAibridge AiSeatUsageReason = "aibridge"
AiSeatUsageReasonTask AiSeatUsageReason = "task"
)
func (e *AiSeatUsageReason) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = AiSeatUsageReason(s)
case string:
*e = AiSeatUsageReason(s)
default:
return fmt.Errorf("unsupported scan type for AiSeatUsageReason: %T", src)
}
return nil
}
type NullAiSeatUsageReason struct {
AiSeatUsageReason AiSeatUsageReason `json:"ai_seat_usage_reason"`
Valid bool `json:"valid"` // Valid is true if AiSeatUsageReason is not NULL
}
// Scan implements the Scanner interface.
func (ns *NullAiSeatUsageReason) Scan(value interface{}) error {
if value == nil {
ns.AiSeatUsageReason, ns.Valid = "", false
return nil
}
ns.Valid = true
return ns.AiSeatUsageReason.Scan(value)
}
// Value implements the driver Valuer interface.
func (ns NullAiSeatUsageReason) Value() (driver.Value, error) {
if !ns.Valid {
return nil, nil
}
return string(ns.AiSeatUsageReason), nil
}
func (e AiSeatUsageReason) Valid() bool {
switch e {
case AiSeatUsageReasonAibridge,
AiSeatUsageReasonTask:
return true
}
return false
}
func AllAiSeatUsageReasonValues() []AiSeatUsageReason {
return []AiSeatUsageReason{
AiSeatUsageReasonAibridge,
AiSeatUsageReasonTask,
}
}
type AppSharingLevel string
const (
@@ -2969,6 +3027,7 @@ const (
ResourceTypeWorkspaceApp ResourceType = "workspace_app"
ResourceTypePrebuildsSettings ResourceType = "prebuilds_settings"
ResourceTypeTask ResourceType = "task"
ResourceTypeAiSeat ResourceType = "ai_seat"
)
func (e *ResourceType) Scan(src interface{}) error {
@@ -3033,7 +3092,8 @@ func (e ResourceType) Valid() bool {
ResourceTypeWorkspaceAgent,
ResourceTypeWorkspaceApp,
ResourceTypePrebuildsSettings,
ResourceTypeTask:
ResourceTypeTask,
ResourceTypeAiSeat:
return true
}
return false
@@ -3067,6 +3127,68 @@ func AllResourceTypeValues() []ResourceType {
ResourceTypeWorkspaceApp,
ResourceTypePrebuildsSettings,
ResourceTypeTask,
ResourceTypeAiSeat,
}
}
type ShareableWorkspaceOwners string
const (
ShareableWorkspaceOwnersNone ShareableWorkspaceOwners = "none"
ShareableWorkspaceOwnersEveryone ShareableWorkspaceOwners = "everyone"
ShareableWorkspaceOwnersServiceAccounts ShareableWorkspaceOwners = "service_accounts"
)
func (e *ShareableWorkspaceOwners) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = ShareableWorkspaceOwners(s)
case string:
*e = ShareableWorkspaceOwners(s)
default:
return fmt.Errorf("unsupported scan type for ShareableWorkspaceOwners: %T", src)
}
return nil
}
type NullShareableWorkspaceOwners struct {
ShareableWorkspaceOwners ShareableWorkspaceOwners `json:"shareable_workspace_owners"`
Valid bool `json:"valid"` // Valid is true if ShareableWorkspaceOwners is not NULL
}
// Scan implements the Scanner interface.
func (ns *NullShareableWorkspaceOwners) Scan(value interface{}) error {
if value == nil {
ns.ShareableWorkspaceOwners, ns.Valid = "", false
return nil
}
ns.Valid = true
return ns.ShareableWorkspaceOwners.Scan(value)
}
// Value implements the driver Valuer interface.
func (ns NullShareableWorkspaceOwners) Value() (driver.Value, error) {
if !ns.Valid {
return nil, nil
}
return string(ns.ShareableWorkspaceOwners), nil
}
func (e ShareableWorkspaceOwners) Valid() bool {
switch e {
case ShareableWorkspaceOwnersNone,
ShareableWorkspaceOwnersEveryone,
ShareableWorkspaceOwnersServiceAccounts:
return true
}
return false
}
func AllShareableWorkspaceOwnersValues() []ShareableWorkspaceOwners {
return []ShareableWorkspaceOwners{
ShareableWorkspaceOwnersNone,
ShareableWorkspaceOwnersEveryone,
ShareableWorkspaceOwnersServiceAccounts,
}
}
@@ -3916,6 +4038,14 @@ type AIBridgeInterception struct {
ClientSessionID sql.NullString `db:"client_session_id" json:"client_session_id"`
}
// Audit log of model thinking in intercepted requests in AI Bridge
type AIBridgeModelThought struct {
InterceptionID uuid.UUID `db:"interception_id" json:"interception_id"`
Content string `db:"content" json:"content"`
Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
}
// Audit log of tokens used by intercepted requests in AI Bridge
type AIBridgeTokenUsage struct {
ID uuid.UUID `db:"id" json:"id"`
@@ -3975,6 +4105,15 @@ type APIKey struct {
AllowList AllowList `db:"allow_list" json:"allow_list"`
}
type AiSeatState struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
FirstUsedAt time.Time `db:"first_used_at" json:"first_used_at"`
LastUsedAt time.Time `db:"last_used_at" json:"last_used_at"`
LastEventType AiSeatUsageReason `db:"last_event_type" json:"last_event_type"`
LastEventDescription string `db:"last_event_description" json:"last_event_description"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
type AuditLog struct {
ID uuid.UUID `db:"id" json:"id"`
Time time.Time `db:"time" json:"time"`
@@ -4085,6 +4224,7 @@ type ChatMessage struct {
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
ContentVersion int16 `db:"content_version" json:"content_version"`
TotalCostMicros sql.NullInt64 `db:"total_cost_micros" json:"total_cost_micros"`
RuntimeMs sql.NullInt64 `db:"runtime_ms" json:"runtime_ms"`
}
type ChatModelConfig struct {
@@ -4126,6 +4266,16 @@ type ChatQueuedMessage struct {
CreatedAt time.Time `db:"created_at" json:"created_at"`
}
type ChatUsageLimitConfig struct {
ID int64 `db:"id" json:"id"`
Singleton bool `db:"singleton" json:"singleton"`
Enabled bool `db:"enabled" json:"enabled"`
DefaultLimitMicros int64 `db:"default_limit_micros" json:"default_limit_micros"`
Period string `db:"period" json:"period"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
type ConnectionLog struct {
ID uuid.UUID `db:"id" json:"id"`
ConnectTime time.Time `db:"connect_time" json:"connect_time"`
@@ -4238,7 +4388,8 @@ type Group struct {
// Display name is a custom, human-friendly group name that user can set. This is not required to be unique and can be the empty string.
DisplayName string `db:"display_name" json:"display_name"`
// Source indicates how the group was created. It can be created by a user manually, or through some system process like OIDC group sync.
Source GroupSource `db:"source" json:"source"`
Source GroupSource `db:"source" json:"source"`
ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"`
}
// Joins group members with user information, organization ID, group name. Includes both regular group members and organization members (as part of the "Everyone" group).
@@ -4446,16 +4597,17 @@ type OAuth2ProviderAppToken struct {
}
type Organization struct {
ID uuid.UUID `db:"id" json:"id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
IsDefault bool `db:"is_default" json:"is_default"`
DisplayName string `db:"display_name" json:"display_name"`
Icon string `db:"icon" json:"icon"`
Deleted bool `db:"deleted" json:"deleted"`
WorkspaceSharingDisabled bool `db:"workspace_sharing_disabled" json:"workspace_sharing_disabled"`
ID uuid.UUID `db:"id" json:"id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
IsDefault bool `db:"is_default" json:"is_default"`
DisplayName string `db:"display_name" json:"display_name"`
Icon string `db:"icon" json:"icon"`
Deleted bool `db:"deleted" json:"deleted"`
// Controls whose workspaces can be shared: none, everyone, or service_accounts.
ShareableWorkspaceOwners ShareableWorkspaceOwners `db:"shareable_workspace_owners" json:"shareable_workspace_owners"`
}
type OrganizationMember struct {
@@ -5008,7 +5160,8 @@ type User struct {
// Determines if a user is a system user, and therefore cannot login or perform normal actions
IsSystem bool `db:"is_system" json:"is_system"`
// Determines if a user is an admin-managed account that cannot login
IsServiceAccount bool `db:"is_service_account" json:"is_service_account"`
IsServiceAccount bool `db:"is_service_account" json:"is_service_account"`
ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"`
}
type UserConfig struct {
+45 -2
View File
@@ -77,6 +77,9 @@ type sqlcQuerier interface {
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error)
// Counts enabled, non-deleted model configs that lack both input and
// output pricing in their JSONB options.cost configuration.
CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error)
// CountInProgressPrebuilds returns the number of in-progress prebuilds, grouped by preset ID and transition.
// Prebuild considered in-progress if it's in the "pending", "starting", "stopping", or "deleting" state.
CountInProgressPrebuilds(ctx context.Context) ([]CountInProgressPrebuildsRow, error)
@@ -99,6 +102,8 @@ type sqlcQuerier interface {
DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error
DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error
DeleteChatQueuedMessage(ctx context.Context, arg DeleteChatQueuedMessageParams) error
DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error
DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error
DeleteCryptoKey(ctx context.Context, arg DeleteCryptoKeyParams) (CryptoKey, error)
DeleteCustomRole(ctx context.Context, arg DeleteCustomRoleParams) error
DeleteExpiredAPIKeys(ctx context.Context, arg DeleteExpiredAPIKeysParams) (int64, error)
@@ -145,7 +150,7 @@ type sqlcQuerier interface {
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error
DeleteWorkspaceACLsByOrganization(ctx context.Context, arg DeleteWorkspaceACLsByOrganizationParams) error
DeleteWorkspaceAgentPortShare(ctx context.Context, arg DeleteWorkspaceAgentPortShareParams) error
DeleteWorkspaceAgentPortSharesByTemplate(ctx context.Context, templateID uuid.UUID) error
DeleteWorkspaceSubAgentByID(ctx context.Context, id uuid.UUID) error
@@ -187,6 +192,7 @@ type sqlcQuerier interface {
GetAPIKeysByLoginType(ctx context.Context, arg GetAPIKeysByLoginTypeParams) ([]APIKey, error)
GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error)
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
GetActiveAISeatCount(ctx context.Context) (int64, error)
GetActivePresetPrebuildSchedules(ctx context.Context) ([]TemplateVersionPresetPrebuildSchedule, error)
GetActiveUserCount(ctx context.Context, includeSystem bool) (int64, error)
GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceBuild, error)
@@ -228,12 +234,14 @@ type sqlcQuerier interface {
// Aggregate cost summary for a single user within a date range.
// Only counts assistant-role messages.
GetChatCostSummary(ctx context.Context, arg GetChatCostSummaryParams) (GetChatCostSummaryRow, error)
GetChatDesktopEnabled(ctx context.Context) (bool, error)
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error)
GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error)
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error)
GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg GetChatMessagesByChatIDDescPaginatedParams) ([]ChatMessage, error)
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
GetChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
@@ -242,7 +250,10 @@ type sqlcQuerier interface {
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
GetChatSystemPrompt(ctx context.Context) (string, error)
GetChatsByOwnerID(ctx context.Context, arg GetChatsByOwnerIDParams) ([]Chat, error)
GetChatUsageLimitConfig(ctx context.Context) (ChatUsageLimitConfig, error)
GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (GetChatUsageLimitGroupOverrideRow, error)
GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (GetChatUsageLimitUserOverrideRow, error)
GetChats(ctx context.Context, arg GetChatsParams) ([]Chat, error)
GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error)
GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error)
GetCryptoKeys(ctx context.Context) ([]CryptoKey, error)
@@ -330,6 +341,18 @@ type sqlcQuerier interface {
// GetOrganizationsWithPrebuildStatus returns organizations with prebuilds configured and their
// membership status for the prebuilds system user (org membership, group existence, group membership).
GetOrganizationsWithPrebuildStatus(ctx context.Context, arg GetOrganizationsWithPrebuildStatusParams) ([]GetOrganizationsWithPrebuildStatusRow, error)
// Returns PR metrics grouped by the model used for each chat.
GetPRInsightsPerModel(ctx context.Context, arg GetPRInsightsPerModelParams) ([]GetPRInsightsPerModelRow, error)
// Returns individual PR rows with cost for the recent PRs table.
GetPRInsightsRecentPRs(ctx context.Context, arg GetPRInsightsRecentPRsParams) ([]GetPRInsightsRecentPRsRow, error)
// PR Insights queries for the /agents analytics dashboard.
// These aggregate data from chat_diff_statuses (PR metadata) joined
// with chats and chat_messages (cost) to power the PR Insights view.
// Returns aggregate PR metrics for the given date range.
// The handler calls this twice (current + previous period) for trends.
GetPRInsightsSummary(ctx context.Context, arg GetPRInsightsSummaryParams) (GetPRInsightsSummaryRow, error)
// Returns daily PR counts grouped by state for the chart.
GetPRInsightsTimeSeries(ctx context.Context, arg GetPRInsightsTimeSeriesParams) ([]GetPRInsightsTimeSeriesRow, error)
GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error)
GetPrebuildMetrics(ctx context.Context) ([]GetPrebuildMetricsRow, error)
GetPrebuildsSettings(ctx context.Context) (string, error)
@@ -494,7 +517,11 @@ type sqlcQuerier interface {
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
GetUserChatSpendInPeriod(ctx context.Context, arg GetUserChatSpendInPeriodParams) (int64, error)
GetUserCount(ctx context.Context, includeSystem bool) (int64, error)
// Returns the minimum (most restrictive) group limit for a user.
// Returns -1 if the user has no group limits applied.
GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error)
// GetUserLatencyInsights returns the median and 95th percentile connection
// latency that users have experienced. The result can be filtered on
// template_ids, meaning only user data from workspaces based on those templates
@@ -598,6 +625,7 @@ type sqlcQuerier interface {
GetWorkspacesEligibleForTransition(ctx context.Context, now time.Time) ([]GetWorkspacesEligibleForTransitionRow, error)
GetWorkspacesForWorkspaceMetrics(ctx context.Context) ([]GetWorkspacesForWorkspaceMetricsRow, error)
InsertAIBridgeInterception(ctx context.Context, arg InsertAIBridgeInterceptionParams) (AIBridgeInterception, error)
InsertAIBridgeModelThought(ctx context.Context, arg InsertAIBridgeModelThoughtParams) (AIBridgeModelThought, error)
InsertAIBridgeTokenUsage(ctx context.Context, arg InsertAIBridgeTokenUsageParams) (AIBridgeTokenUsage, error)
InsertAIBridgeToolUsage(ctx context.Context, arg InsertAIBridgeToolUsageParams) (AIBridgeToolUsage, error)
InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIBridgeUserPromptParams) (AIBridgeUserPrompt, error)
@@ -694,6 +722,8 @@ type sqlcQuerier interface {
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
ListChatUsageLimitGroupOverrides(ctx context.Context) ([]ListChatUsageLimitGroupOverridesRow, error)
ListChatUsageLimitOverrides(ctx context.Context) ([]ListChatUsageLimitOverridesRow, error)
ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
@@ -714,6 +744,12 @@ type sqlcQuerier interface {
ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error)
// Resolves the effective spend limit for a user using the hierarchy:
// 1. Individual user override (highest priority)
// 2. Minimum group limit across all user's groups
// 3. Global default from config
// Returns -1 if limits are not enabled.
ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error)
RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error
// Note that this selects from the CTE, not the original table. The CTE is named
// the same as the original table to trick sqlc into reusing the existing struct
@@ -833,6 +869,8 @@ type sqlcQuerier interface {
UpdateWorkspaceTTL(ctx context.Context, arg UpdateWorkspaceTTLParams) error
UpdateWorkspacesDormantDeletingAtByTemplateID(ctx context.Context, arg UpdateWorkspacesDormantDeletingAtByTemplateIDParams) ([]WorkspaceTable, error)
UpdateWorkspacesTTLByTemplateID(ctx context.Context, arg UpdateWorkspacesTTLByTemplateIDParams) error
// Returns true if a new rows was inserted, false otherwise.
UpsertAISeatState(ctx context.Context, arg UpsertAISeatStateParams) (bool, error)
UpsertAnnouncementBanners(ctx context.Context, value string) error
UpsertApplicationName(ctx context.Context, value string) error
// Upserts boundary usage statistics for a replica. On INSERT (new period), uses
@@ -840,9 +878,13 @@ type sqlcQuerier interface {
// cumulative values for unique counts (accurate period totals). Request counts
// are always deltas, accumulated in DB. Returns true if insert, false if update.
UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBoundaryUsageStatsParams) (bool, error)
UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
UpsertChatSystemPrompt(ctx context.Context, value string) error
UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error)
UpsertChatUsageLimitGroupOverride(ctx context.Context, arg UpsertChatUsageLimitGroupOverrideParams) (UpsertChatUsageLimitGroupOverrideRow, error)
UpsertChatUsageLimitUserOverride(ctx context.Context, arg UpsertChatUsageLimitUserOverrideParams) (UpsertChatUsageLimitUserOverrideRow, error)
UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error)
// The default proxy is implied and not actually stored in the database.
// So we need to store it's configuration here for display purposes.
@@ -877,6 +919,7 @@ type sqlcQuerier interface {
// was started. This means that a new row was inserted (no previous session) or
// the updated_at is older than stale interval.
UpsertWorkspaceAppAuditSession(ctx context.Context, arg UpsertWorkspaceAppAuditSessionParams) (bool, error)
UsageEventExistsByID(ctx context.Context, id string) (bool, error)
ValidateGroupIDs(ctx context.Context, groupIds []uuid.UUID) (ValidateGroupIDsRow, error)
ValidateUserIDs(ctx context.Context, userIds []uuid.UUID) (ValidateUserIDsRow, error)
}
+504 -58
View File
@@ -1235,6 +1235,230 @@ func TestGetAuthorizedWorkspacesAndAgentsByOwnerID(t *testing.T) {
})
}
func TestGetAuthorizedChats(t *testing.T) {
t.Parallel()
if testing.Short() {
t.SkipNow()
}
sqlDB := testSQLDB(t)
err := migrations.Up(sqlDB)
require.NoError(t, err)
db := database.New(sqlDB)
authorizer := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
// Create users with different roles.
owner := dbgen.User(t, db, database.User{
RBACRoles: []string{rbac.RoleOwner().String()},
})
member := dbgen.User(t, db, database.User{})
secondMember := dbgen.User(t, db, database.User{})
// Create FK dependencies: a chat provider and model config.
ctx := testutil.Context(t, testutil.WaitMedium)
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
})
require.NoError(t, err)
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "test-model",
DisplayName: "Test Model",
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 80,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
// Create 3 chats owned by owner.
for i := range 3 {
_, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: owner.ID,
LastModelConfigID: modelCfg.ID,
Title: fmt.Sprintf("owner chat %d", i+1),
})
require.NoError(t, err)
}
// Create 2 chats owned by member.
for i := range 2 {
_, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: member.ID,
LastModelConfigID: modelCfg.ID,
Title: fmt.Sprintf("member chat %d", i+1),
})
require.NoError(t, err)
}
t.Run("sqlQuerier", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
// Member should only see their own 2 chats.
memberSubject, _, err := httpmw.UserRBACSubject(ctx, db, member.ID, rbac.ExpandableScope(rbac.ScopeAll))
require.NoError(t, err)
preparedMember, err := authorizer.Prepare(ctx, memberSubject, policy.ActionRead, rbac.ResourceChat.Type)
require.NoError(t, err)
memberRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedMember)
require.NoError(t, err)
require.Len(t, memberRows, 2)
for _, row := range memberRows {
require.Equal(t, member.ID, row.OwnerID, "member should only see own chats")
}
// Owner should see at least the 5 pre-created chats (site-wide
// access). Parallel subtests may add more, so use GreaterOrEqual.
ownerSubject, _, err := httpmw.UserRBACSubject(ctx, db, owner.ID, rbac.ExpandableScope(rbac.ScopeAll))
require.NoError(t, err)
preparedOwner, err := authorizer.Prepare(ctx, ownerSubject, policy.ActionRead, rbac.ResourceChat.Type)
require.NoError(t, err)
ownerRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedOwner)
require.NoError(t, err)
require.GreaterOrEqual(t, len(ownerRows), 5)
// secondMember has no chats and should see 0.
secondSubject, _, err := httpmw.UserRBACSubject(ctx, db, secondMember.ID, rbac.ExpandableScope(rbac.ScopeAll))
require.NoError(t, err)
preparedSecond, err := authorizer.Prepare(ctx, secondSubject, policy.ActionRead, rbac.ResourceChat.Type)
require.NoError(t, err)
secondRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedSecond)
require.NoError(t, err)
require.Len(t, secondRows, 0)
// Org admin should NOT see other users' chats — chats are
// not org-scoped resources.
orgs, err := db.GetOrganizations(ctx, database.GetOrganizationsParams{})
require.NoError(t, err)
require.NotEmpty(t, orgs)
orgAdmin := dbgen.User(t, db, database.User{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: orgAdmin.ID,
OrganizationID: orgs[0].ID,
Roles: []string{rbac.RoleOrgAdmin()},
})
orgAdminSubject, _, err := httpmw.UserRBACSubject(ctx, db, orgAdmin.ID, rbac.ExpandableScope(rbac.ScopeAll))
require.NoError(t, err)
preparedOrgAdmin, err := authorizer.Prepare(ctx, orgAdminSubject, policy.ActionRead, rbac.ResourceChat.Type)
require.NoError(t, err)
orgAdminRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedOrgAdmin)
require.NoError(t, err)
require.Len(t, orgAdminRows, 0, "org admin with no chats should see 0 chats")
// OwnerID filter: member queries their own chats.
memberFilterSelf, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
OwnerID: member.ID,
}, preparedMember)
require.NoError(t, err)
require.Len(t, memberFilterSelf, 2)
// OwnerID filter: member queries owner's chats → sees 0.
memberFilterOwner, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
OwnerID: owner.ID,
}, preparedMember)
require.NoError(t, err)
require.Len(t, memberFilterOwner, 0)
})
t.Run("dbauthz", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
authzdb := dbauthz.New(db, authorizer, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
// As member: should see only own 2 chats.
memberSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, member.ID, rbac.ExpandableScope(rbac.ScopeAll))
require.NoError(t, err)
memberCtx := dbauthz.As(ctx, memberSubject)
memberRows, err := authzdb.GetChats(memberCtx, database.GetChatsParams{})
require.NoError(t, err)
require.Len(t, memberRows, 2)
for _, row := range memberRows {
require.Equal(t, member.ID, row.OwnerID, "member should only see own chats")
}
// As owner: should see at least the 5 pre-created chats.
ownerSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, owner.ID, rbac.ExpandableScope(rbac.ScopeAll))
require.NoError(t, err)
ownerCtx := dbauthz.As(ctx, ownerSubject)
ownerRows, err := authzdb.GetChats(ownerCtx, database.GetChatsParams{})
require.NoError(t, err)
require.GreaterOrEqual(t, len(ownerRows), 5)
// As secondMember: should see 0 chats.
secondSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, secondMember.ID, rbac.ExpandableScope(rbac.ScopeAll))
require.NoError(t, err)
secondCtx := dbauthz.As(ctx, secondSubject)
secondRows, err := authzdb.GetChats(secondCtx, database.GetChatsParams{})
require.NoError(t, err)
require.Len(t, secondRows, 0)
})
t.Run("pagination", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
// Use a dedicated user for pagination to avoid interference
// with the other parallel subtests.
paginationUser := dbgen.User(t, db, database.User{})
for i := range 7 {
_, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: paginationUser.ID,
LastModelConfigID: modelCfg.ID,
Title: fmt.Sprintf("pagination chat %d", i+1),
})
require.NoError(t, err)
}
pagUserSubject, _, err := httpmw.UserRBACSubject(ctx, db, paginationUser.ID, rbac.ExpandableScope(rbac.ScopeAll))
require.NoError(t, err)
preparedMember, err := authorizer.Prepare(ctx, pagUserSubject, policy.ActionRead, rbac.ResourceChat.Type)
require.NoError(t, err)
// Fetch first page with limit=2.
page1, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
LimitOpt: 2,
}, preparedMember)
require.NoError(t, err)
require.Len(t, page1, 2)
for _, row := range page1 {
require.Equal(t, paginationUser.ID, row.OwnerID, "paginated results must belong to pagination user")
}
// Fetch remaining pages and collect all chat IDs.
allIDs := make(map[uuid.UUID]struct{})
for _, row := range page1 {
allIDs[row.ID] = struct{}{}
}
offset := int32(2)
for {
page, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
LimitOpt: 2,
OffsetOpt: offset,
}, preparedMember)
require.NoError(t, err)
for _, row := range page {
require.Equal(t, paginationUser.ID, row.OwnerID, "paginated results must belong to pagination user")
allIDs[row.ID] = struct{}{}
}
if len(page) < 2 {
break
}
offset += int32(len(page)) //nolint:gosec // Test code, pagination values are small.
}
// All 7 member chats should be accounted for with no leakage.
require.Len(t, allIDs, 7, "pagination should return all member chats exactly once")
})
}
func TestInsertWorkspaceAgentLogs(t *testing.T) {
t.Parallel()
if testing.Short() {
@@ -2431,6 +2655,42 @@ func TestDeleteCustomRoleDoesNotDeleteSystemRole(t *testing.T) {
require.True(t, roles[0].IsSystem)
}
func TestGetAuthorizationUserRolesImpliedOrgRole(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
org := dbgen.Organization(t, db, database.Organization{})
regularUser := dbgen.User(t, db, database.User{})
saUser := dbgen.User(t, db, database.User{IsServiceAccount: true})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org.ID,
UserID: regularUser.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org.ID,
UserID: saUser.ID,
})
ctx := testutil.Context(t, testutil.WaitShort)
wantMember := rbac.RoleOrgMember() + ":" + org.ID.String()
wantSA := rbac.RoleOrgServiceAccount() + ":" + org.ID.String()
// Regular users get the implied organization-member role.
regularRoles, err := db.GetAuthorizationUserRoles(ctx, regularUser.ID)
require.NoError(t, err)
require.Contains(t, regularRoles.Roles, wantMember)
require.NotContains(t, regularRoles.Roles, wantSA)
// Service accounts get the implied organization-service-account role.
saRoles, err := db.GetAuthorizationUserRoles(ctx, saUser.ID)
require.NoError(t, err)
require.Contains(t, saRoles.Roles, wantSA)
require.NotContains(t, saRoles.Roles, wantMember)
}
func TestUpdateOrganizationWorkspaceSharingSettings(t *testing.T) {
t.Parallel()
@@ -2441,82 +2701,155 @@ func TestUpdateOrganizationWorkspaceSharingSettings(t *testing.T) {
updated, err := db.UpdateOrganizationWorkspaceSharingSettings(ctx, database.UpdateOrganizationWorkspaceSharingSettingsParams{
ID: org.ID,
WorkspaceSharingDisabled: true,
ShareableWorkspaceOwners: database.ShareableWorkspaceOwnersNone,
UpdatedAt: dbtime.Now(),
})
require.NoError(t, err)
require.True(t, updated.WorkspaceSharingDisabled)
require.Equal(t, database.ShareableWorkspaceOwnersNone, updated.ShareableWorkspaceOwners)
got, err := db.GetOrganizationByID(ctx, org.ID)
require.NoError(t, err)
require.True(t, got.WorkspaceSharingDisabled)
require.Equal(t, database.ShareableWorkspaceOwnersNone, got.ShareableWorkspaceOwners)
}
func TestDeleteWorkspaceACLsByOrganization(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
org1 := dbgen.Organization(t, db, database.Organization{})
org2 := dbgen.Organization(t, db, database.Organization{})
t.Run("DeletesAll", func(t *testing.T) {
t.Parallel()
owner1 := dbgen.User(t, db, database.User{})
owner2 := dbgen.User(t, db, database.User{})
sharedUser := dbgen.User(t, db, database.User{})
sharedGroup := dbgen.Group(t, db, database.Group{
OrganizationID: org1.ID,
db, _ := dbtestutil.NewDB(t)
org1 := dbgen.Organization(t, db, database.Organization{})
org2 := dbgen.Organization(t, db, database.Organization{})
owner1 := dbgen.User(t, db, database.User{})
owner2 := dbgen.User(t, db, database.User{})
sharedUser := dbgen.User(t, db, database.User{})
sharedGroup := dbgen.Group(t, db, database.Group{
OrganizationID: org1.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org1.ID,
UserID: owner1.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org2.ID,
UserID: owner2.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org1.ID,
UserID: sharedUser.ID,
})
ws1 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: owner1.ID,
OrganizationID: org1.ID,
UserACL: database.WorkspaceACL{
sharedUser.ID.String(): {
Permissions: []policy.Action{policy.ActionRead},
},
},
GroupACL: database.WorkspaceACL{
sharedGroup.ID.String(): {
Permissions: []policy.Action{policy.ActionRead},
},
},
}).Do().Workspace
ws2 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: owner2.ID,
OrganizationID: org2.ID,
UserACL: database.WorkspaceACL{
uuid.NewString(): {
Permissions: []policy.Action{policy.ActionRead},
},
},
}).Do().Workspace
ctx := testutil.Context(t, testutil.WaitShort)
err := db.DeleteWorkspaceACLsByOrganization(ctx, database.DeleteWorkspaceACLsByOrganizationParams{
OrganizationID: org1.ID,
ExcludeServiceAccounts: false,
})
require.NoError(t, err)
got1, err := db.GetWorkspaceByID(ctx, ws1.ID)
require.NoError(t, err)
require.Empty(t, got1.UserACL)
require.Empty(t, got1.GroupACL)
got2, err := db.GetWorkspaceByID(ctx, ws2.ID)
require.NoError(t, err)
require.NotEmpty(t, got2.UserACL)
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org1.ID,
UserID: owner1.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org2.ID,
UserID: owner2.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org1.ID,
UserID: sharedUser.ID,
})
t.Run("ExcludesServiceAccounts", func(t *testing.T) {
t.Parallel()
ws1 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: owner1.ID,
OrganizationID: org1.ID,
UserACL: database.WorkspaceACL{
db, _ := dbtestutil.NewDB(t)
org := dbgen.Organization(t, db, database.Organization{})
regularUser := dbgen.User(t, db, database.User{})
saUser := dbgen.User(t, db, database.User{IsServiceAccount: true})
sharedUser := dbgen.User(t, db, database.User{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org.ID,
UserID: regularUser.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org.ID,
UserID: saUser.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: org.ID,
UserID: sharedUser.ID,
})
regularWS := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: regularUser.ID,
OrganizationID: org.ID,
UserACL: database.WorkspaceACL{
sharedUser.ID.String(): {
Permissions: []policy.Action{policy.ActionRead},
},
},
}).Do().Workspace
saWS := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: saUser.ID,
OrganizationID: org.ID,
UserACL: database.WorkspaceACL{
sharedUser.ID.String(): {
Permissions: []policy.Action{policy.ActionRead},
},
},
}).Do().Workspace
ctx := testutil.Context(t, testutil.WaitShort)
err := db.DeleteWorkspaceACLsByOrganization(ctx, database.DeleteWorkspaceACLsByOrganizationParams{
OrganizationID: org.ID,
ExcludeServiceAccounts: true,
})
require.NoError(t, err)
// Regular user workspace ACLs should be cleared.
gotRegular, err := db.GetWorkspaceByID(ctx, regularWS.ID)
require.NoError(t, err)
require.Empty(t, gotRegular.UserACL)
// Service account workspace ACLs should be preserved.
gotSA, err := db.GetWorkspaceByID(ctx, saWS.ID)
require.NoError(t, err)
require.Equal(t, database.WorkspaceACL{
sharedUser.ID.String(): {
Permissions: []policy.Action{policy.ActionRead},
},
},
GroupACL: database.WorkspaceACL{
sharedGroup.ID.String(): {
Permissions: []policy.Action{policy.ActionRead},
},
},
}).Do().Workspace
ws2 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: owner2.ID,
OrganizationID: org2.ID,
UserACL: database.WorkspaceACL{
uuid.NewString(): {
Permissions: []policy.Action{policy.ActionRead},
},
},
}).Do().Workspace
ctx := testutil.Context(t, testutil.WaitShort)
err := db.DeleteWorkspaceACLsByOrganization(ctx, org1.ID)
require.NoError(t, err)
got1, err := db.GetWorkspaceByID(ctx, ws1.ID)
require.NoError(t, err)
require.Empty(t, got1.UserACL)
require.Empty(t, got1.GroupACL)
got2, err := db.GetWorkspaceByID(ctx, ws2.ID)
require.NoError(t, err)
require.NotEmpty(t, got2.UserACL)
}, gotSA.UserACL)
})
}
func TestAuthorizedAuditLogs(t *testing.T) {
@@ -7982,6 +8315,80 @@ func TestUsageEventsTrigger(t *testing.T) {
require.WithinDuration(t, time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC), rows[1].Day, time.Second)
})
t.Run("HeartbeatAISeats", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t)
// Insert a heartbeat event.
err := db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
ID: "hb-1",
EventType: "hb_ai_seats_v1",
EventData: []byte(`{"count": 10}`),
CreatedAt: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
})
require.NoError(t, err)
rows := getDailyRows(ctx, sqlDB)
require.Len(t, rows, 1)
require.Equal(t, "hb_ai_seats_v1", rows[0].EventType)
require.JSONEq(t, `{"count": 10}`, string(rows[0].UsageData))
// Insert a higher count on the same day — should take the max.
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
ID: "hb-2",
EventType: "hb_ai_seats_v1",
EventData: []byte(`{"count": 50}`),
CreatedAt: time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC),
})
require.NoError(t, err)
rows = getDailyRows(ctx, sqlDB)
require.Len(t, rows, 1)
require.JSONEq(t, `{"count": 50}`, string(rows[0].UsageData))
// Insert a lower count on the same day — should keep the max (50).
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
ID: "hb-3",
EventType: "hb_ai_seats_v1",
EventData: []byte(`{"count": 25}`),
CreatedAt: time.Date(2025, 1, 1, 18, 0, 0, 0, time.UTC),
})
require.NoError(t, err)
rows = getDailyRows(ctx, sqlDB)
require.Len(t, rows, 1)
require.JSONEq(t, `{"count": 50}`, string(rows[0].UsageData))
// Insert on a different day.
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
ID: "hb-4",
EventType: "hb_ai_seats_v1",
EventData: []byte(`{"count": 5}`),
CreatedAt: time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC),
})
require.NoError(t, err)
rows = getDailyRows(ctx, sqlDB)
require.Len(t, rows, 2)
require.JSONEq(t, `{"count": 50}`, string(rows[0].UsageData))
require.JSONEq(t, `{"count": 5}`, string(rows[1].UsageData))
// Also insert a dc_managed_agents_v1 on the same first day to
// verify different event types get separate daily rows.
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
ID: "dc-1",
EventType: "dc_managed_agents_v1",
EventData: []byte(`{"count": 7}`),
CreatedAt: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
})
require.NoError(t, err)
rows = getDailyRows(ctx, sqlDB)
require.Len(t, rows, 3)
})
t.Run("UnknownEventType", func(t *testing.T) {
t.Parallel()
@@ -9441,3 +9848,42 @@ func TestGetWorkspaceBuildMetricsByResourceID(t *testing.T) {
require.Equal(t, "success", row.WorstStatus)
})
}
// TestUpsertAISeats verifies 'UpsertAISeatState' only returns true when a new
// row is inserted.
func TestUpsertAISeats(t *testing.T) {
t.Parallel()
sqlDB := testSQLDB(t)
err := migrations.Up(sqlDB)
require.NoError(t, err)
db := database.New(sqlDB)
ctx := testutil.Context(t, testutil.WaitShort)
now := dbtime.Now()
user := dbgen.User(t, db, database.User{})
newRow, err := db.UpsertAISeatState(ctx, database.UpsertAISeatStateParams{
UserID: user.ID,
FirstUsedAt: now.Add(time.Hour * -24),
LastEventType: database.AiSeatUsageReasonTask,
})
require.NoError(t, err)
require.True(t, newRow)
alreadyExists, err := db.UpsertAISeatState(ctx, database.UpsertAISeatStateParams{
UserID: user.ID,
FirstUsedAt: now.Add(time.Hour * -23),
LastEventType: database.AiSeatUsageReasonTask,
})
require.NoError(t, err)
require.False(t, alreadyExists)
alreadyExists, err = db.UpsertAISeatState(ctx, database.UpsertAISeatStateParams{
UserID: user.ID,
FirstUsedAt: now,
LastEventType: database.AiSeatUsageReasonTask,
})
require.NoError(t, err)
require.False(t, alreadyExists)
}
File diff suppressed because it is too large Load Diff
+14
View File
@@ -53,6 +53,14 @@ INSERT INTO aibridge_tool_usages (
)
RETURNING *;
-- name: InsertAIBridgeModelThought :one
INSERT INTO aibridge_model_thoughts (
interception_id, content, metadata, created_at
) VALUES (
@interception_id, @content, COALESCE(@metadata::jsonb, '{}'::jsonb), @created_at
)
RETURNING *;
-- name: GetAIBridgeInterceptionByID :one
SELECT
*
@@ -362,6 +370,11 @@ WITH
WHERE started_at < @before_time::timestamp with time zone
),
-- CTEs are executed in order.
model_thoughts AS (
DELETE FROM aibridge_model_thoughts
WHERE interception_id IN (SELECT id FROM to_delete)
RETURNING 1
),
tool_usages AS (
DELETE FROM aibridge_tool_usages
WHERE interception_id IN (SELECT id FROM to_delete)
@@ -384,6 +397,7 @@ WITH
)
-- Cumulative count.
SELECT (
(SELECT COUNT(*) FROM model_thoughts) +
(SELECT COUNT(*) FROM tool_usages) +
(SELECT COUNT(*) FROM token_usages) +
(SELECT COUNT(*) FROM user_prompts) +
+35
View File
@@ -0,0 +1,35 @@
-- name: UpsertAISeatState :one
-- Returns true if a new rows was inserted, false otherwise.
INSERT INTO ai_seat_state (
user_id,
first_used_at,
last_used_at,
last_event_type,
last_event_description,
updated_at
)
VALUES
($1, $2, $2, $3, $4, $2)
ON CONFLICT (user_id) DO UPDATE
SET
last_used_at = EXCLUDED.last_used_at,
last_event_type = EXCLUDED.last_event_type,
last_event_description = EXCLUDED.last_event_description,
updated_at = EXCLUDED.updated_at
RETURNING
-- Postgres vodoo to know if a row was inserted.
(xmax = 0)::boolean AS is_new;
-- name: GetActiveAISeatCount :one
SELECT
COUNT(*)
FROM
ai_seat_state ais
JOIN
users u
ON
ais.user_id = u.id
WHERE
u.status = 'active'::user_status
AND u.deleted = false
AND u.is_system = false;
+118
View File
@@ -0,0 +1,118 @@
-- PR Insights queries for the /agents analytics dashboard.
-- These aggregate data from chat_diff_statuses (PR metadata) joined
-- with chats and chat_messages (cost) to power the PR Insights view.
-- name: GetPRInsightsSummary :one
-- Returns aggregate PR metrics for the given date range.
-- The handler calls this twice (current + previous period) for trends.
SELECT
COUNT(*)::bigint AS total_prs_created,
COUNT(*) FILTER (WHERE cds.pull_request_state = 'merged')::bigint AS total_prs_merged,
COUNT(*) FILTER (WHERE cds.pull_request_state = 'closed')::bigint AS total_prs_closed,
COALESCE(SUM(cds.additions), 0)::bigint AS total_additions,
COALESCE(SUM(cds.deletions), 0)::bigint AS total_deletions,
COALESCE(SUM(cc.cost_micros), 0)::bigint AS total_cost_micros,
COALESCE(SUM(cc.cost_micros) FILTER (WHERE cds.pull_request_state = 'merged'), 0)::bigint AS merged_cost_micros
FROM chat_diff_statuses cds
JOIN chats c ON c.id = cds.chat_id
LEFT JOIN (
SELECT
COALESCE(ch.root_chat_id, ch.id) AS root_id,
COALESCE(SUM(cm.total_cost_micros), 0) AS cost_micros
FROM chat_messages cm
JOIN chats ch ON ch.id = cm.chat_id
WHERE cm.total_cost_micros IS NOT NULL
GROUP BY COALESCE(ch.root_chat_id, ch.id)
) cc ON cc.root_id = COALESCE(c.root_chat_id, c.id)
WHERE cds.pull_request_state IS NOT NULL
AND c.created_at >= @start_date::timestamptz
AND c.created_at < @end_date::timestamptz
AND (sqlc.narg('owner_id')::uuid IS NULL OR c.owner_id = sqlc.narg('owner_id')::uuid);
-- name: GetPRInsightsTimeSeries :many
-- Returns daily PR counts grouped by state for the chart.
SELECT
date_trunc('day', c.created_at)::timestamptz AS date,
COUNT(*)::bigint AS prs_created,
COUNT(*) FILTER (WHERE cds.pull_request_state = 'merged')::bigint AS prs_merged,
COUNT(*) FILTER (WHERE cds.pull_request_state = 'closed')::bigint AS prs_closed
FROM chat_diff_statuses cds
JOIN chats c ON c.id = cds.chat_id
WHERE cds.pull_request_state IS NOT NULL
AND c.created_at >= @start_date::timestamptz
AND c.created_at < @end_date::timestamptz
AND (sqlc.narg('owner_id')::uuid IS NULL OR c.owner_id = sqlc.narg('owner_id')::uuid)
GROUP BY date_trunc('day', c.created_at)
ORDER BY date_trunc('day', c.created_at);
-- name: GetPRInsightsPerModel :many
-- Returns PR metrics grouped by the model used for each chat.
SELECT
cmc.id AS model_config_id,
cmc.display_name,
cmc.provider,
COUNT(*)::bigint AS total_prs,
COUNT(*) FILTER (WHERE cds.pull_request_state = 'merged')::bigint AS merged_prs,
COALESCE(SUM(cds.additions), 0)::bigint AS total_additions,
COALESCE(SUM(cds.deletions), 0)::bigint AS total_deletions,
COALESCE(SUM(cc.cost_micros), 0)::bigint AS total_cost_micros,
COALESCE(SUM(cc.cost_micros) FILTER (WHERE cds.pull_request_state = 'merged'), 0)::bigint AS merged_cost_micros
FROM chat_diff_statuses cds
JOIN chats c ON c.id = cds.chat_id
JOIN chat_model_configs cmc ON cmc.id = c.last_model_config_id
LEFT JOIN (
SELECT
COALESCE(ch.root_chat_id, ch.id) AS root_id,
COALESCE(SUM(cm.total_cost_micros), 0) AS cost_micros
FROM chat_messages cm
JOIN chats ch ON ch.id = cm.chat_id
WHERE cm.total_cost_micros IS NOT NULL
GROUP BY COALESCE(ch.root_chat_id, ch.id)
) cc ON cc.root_id = COALESCE(c.root_chat_id, c.id)
WHERE cds.pull_request_state IS NOT NULL
AND c.created_at >= @start_date::timestamptz
AND c.created_at < @end_date::timestamptz
AND (sqlc.narg('owner_id')::uuid IS NULL OR c.owner_id = sqlc.narg('owner_id')::uuid)
GROUP BY cmc.id, cmc.display_name, cmc.provider
ORDER BY total_prs DESC;
-- name: GetPRInsightsRecentPRs :many
-- Returns individual PR rows with cost for the recent PRs table.
SELECT
c.id AS chat_id,
cds.pull_request_title AS pr_title,
cds.url AS pr_url,
cds.pr_number,
cds.pull_request_state AS state,
cds.pull_request_draft AS draft,
cds.additions,
cds.deletions,
cds.changed_files,
cds.commits,
cds.approved,
cds.changes_requested,
cds.reviewer_count,
cds.author_login,
cds.author_avatar_url,
COALESCE(cds.base_branch, '')::text AS base_branch,
COALESCE(cmc.display_name, cmc.model)::text AS model_display_name,
COALESCE(cc.cost_micros, 0)::bigint AS cost_micros,
c.created_at
FROM chat_diff_statuses cds
JOIN chats c ON c.id = cds.chat_id
JOIN chat_model_configs cmc ON cmc.id = c.last_model_config_id
LEFT JOIN (
SELECT
COALESCE(ch.root_chat_id, ch.id) AS root_id,
COALESCE(SUM(cm.total_cost_micros), 0) AS cost_micros
FROM chat_messages cm
JOIN chats ch ON ch.id = cm.chat_id
WHERE cm.total_cost_micros IS NOT NULL
GROUP BY COALESCE(ch.root_chat_id, ch.id)
) cc ON cc.root_id = COALESCE(c.root_chat_id, c.id)
WHERE cds.pull_request_state IS NOT NULL
AND c.created_at >= @start_date::timestamptz
AND c.created_at < @end_date::timestamptz
AND (sqlc.narg('owner_id')::uuid IS NULL OR c.owner_id = sqlc.narg('owner_id')::uuid)
ORDER BY c.created_at DESC
LIMIT @limit_val::int;
+153 -4
View File
@@ -40,6 +40,23 @@ WHERE
ORDER BY
created_at ASC;
-- name: GetChatMessagesByChatIDDescPaginated :many
SELECT
*
FROM
chat_messages
WHERE
chat_id = @chat_id::uuid
AND CASE
WHEN @before_id::bigint > 0 THEN id < @before_id::bigint
ELSE true
END
AND visibility IN ('user', 'both')
ORDER BY
id DESC
LIMIT
COALESCE(NULLIF(@limit_val::int, 0), 50);
-- name: GetChatMessagesForPromptByChatID :many
WITH latest_compressed_summary AS (
SELECT
@@ -96,13 +113,16 @@ ORDER BY
created_at ASC,
id ASC;
-- name: GetChatsByOwnerID :many
-- name: GetChats :many
SELECT
*
FROM
chats
WHERE
owner_id = @owner_id::uuid
CASE
WHEN @owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN chats.owner_id = @owner_id
ELSE true
END
AND CASE
WHEN sqlc.narg('archived') :: boolean IS NULL THEN true
ELSE chats.archived = sqlc.narg('archived') :: boolean
@@ -126,6 +146,8 @@ WHERE
)
ELSE true
END
-- Authorize Filter clause will be injected below in GetAuthorizedChats
-- @authorize_filter
ORDER BY
-- Deterministic and consistent ordering of all rows, even if they share
-- a timestamp. This is to ensure consistent pagination.
@@ -183,7 +205,8 @@ INSERT INTO chat_messages (
cache_read_tokens,
context_limit,
compressed,
total_cost_micros
total_cost_micros,
runtime_ms
) VALUES (
@chat_id::uuid,
sqlc.narg('created_by')::uuid,
@@ -200,7 +223,8 @@ INSERT INTO chat_messages (
sqlc.narg('cache_read_tokens')::bigint,
sqlc.narg('context_limit')::bigint,
COALESCE(sqlc.narg('compressed')::boolean, FALSE),
sqlc.narg('total_cost_micros')::bigint
sqlc.narg('total_cost_micros')::bigint,
sqlc.narg('runtime_ms')::bigint
)
RETURNING
*;
@@ -683,3 +707,128 @@ LIMIT
sqlc.arg('page_limit')::int
OFFSET
sqlc.arg('page_offset')::int;
-- name: GetChatUsageLimitConfig :one
SELECT * FROM chat_usage_limit_config WHERE singleton = TRUE LIMIT 1;
-- name: UpsertChatUsageLimitConfig :one
INSERT INTO chat_usage_limit_config (singleton, enabled, default_limit_micros, period, updated_at)
VALUES (TRUE, @enabled::boolean, @default_limit_micros::bigint, @period::text, NOW())
ON CONFLICT (singleton) DO UPDATE SET
enabled = EXCLUDED.enabled,
default_limit_micros = EXCLUDED.default_limit_micros,
period = EXCLUDED.period,
updated_at = NOW()
RETURNING *;
-- name: ListChatUsageLimitOverrides :many
SELECT u.id AS user_id, u.username, u.name, u.avatar_url,
u.chat_spend_limit_micros AS spend_limit_micros
FROM users u
WHERE u.chat_spend_limit_micros IS NOT NULL
ORDER BY u.username ASC;
-- name: UpsertChatUsageLimitUserOverride :one
UPDATE users
SET chat_spend_limit_micros = @spend_limit_micros::bigint
WHERE id = @user_id::uuid
RETURNING id AS user_id, username, name, avatar_url, chat_spend_limit_micros AS spend_limit_micros;
-- name: DeleteChatUsageLimitUserOverride :exec
UPDATE users SET chat_spend_limit_micros = NULL WHERE id = @user_id::uuid;
-- name: GetChatUsageLimitUserOverride :one
SELECT id AS user_id, chat_spend_limit_micros AS spend_limit_micros
FROM users
WHERE id = @user_id::uuid AND chat_spend_limit_micros IS NOT NULL;
-- name: GetUserChatSpendInPeriod :one
SELECT COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_spend_micros
FROM chat_messages cm
JOIN chats c ON c.id = cm.chat_id
WHERE c.owner_id = @user_id::uuid
AND cm.created_at >= @start_time::timestamptz
AND cm.created_at < @end_time::timestamptz
AND cm.total_cost_micros IS NOT NULL;
-- name: CountEnabledModelsWithoutPricing :one
-- Counts enabled, non-deleted model configs that lack both input and
-- output pricing in their JSONB options.cost configuration.
SELECT COUNT(*)::bigint AS count
FROM chat_model_configs
WHERE enabled = TRUE
AND deleted = FALSE
AND (
options->'cost' IS NULL
OR options->'cost' = 'null'::jsonb
OR (
(options->'cost'->>'input_price_per_million_tokens' IS NULL)
AND (options->'cost'->>'output_price_per_million_tokens' IS NULL)
)
);
-- name: ListChatUsageLimitGroupOverrides :many
SELECT
g.id AS group_id,
g.name AS group_name,
g.display_name AS group_display_name,
g.avatar_url AS group_avatar_url,
g.chat_spend_limit_micros AS spend_limit_micros,
(SELECT COUNT(*)
FROM group_members_expanded gme
WHERE gme.group_id = g.id
AND gme.user_is_system = FALSE) AS member_count
FROM groups g
WHERE g.chat_spend_limit_micros IS NOT NULL
ORDER BY g.name ASC;
-- name: UpsertChatUsageLimitGroupOverride :one
UPDATE groups
SET chat_spend_limit_micros = @spend_limit_micros::bigint
WHERE id = @group_id::uuid
RETURNING id AS group_id, name, display_name, avatar_url, chat_spend_limit_micros AS spend_limit_micros;
-- name: DeleteChatUsageLimitGroupOverride :exec
UPDATE groups SET chat_spend_limit_micros = NULL WHERE id = @group_id::uuid;
-- name: GetChatUsageLimitGroupOverride :one
SELECT id AS group_id, chat_spend_limit_micros AS spend_limit_micros
FROM groups
WHERE id = @group_id::uuid AND chat_spend_limit_micros IS NOT NULL;
-- name: GetUserGroupSpendLimit :one
-- Returns the minimum (most restrictive) group limit for a user.
-- Returns -1 if the user has no group limits applied.
SELECT COALESCE(MIN(g.chat_spend_limit_micros), -1)::bigint AS limit_micros
FROM groups g
JOIN group_members_expanded gme ON gme.group_id = g.id
WHERE gme.user_id = @user_id::uuid
AND g.chat_spend_limit_micros IS NOT NULL;
-- name: ResolveUserChatSpendLimit :one
-- Resolves the effective spend limit for a user using the hierarchy:
-- 1. Individual user override (highest priority)
-- 2. Minimum group limit across all user's groups
-- 3. Global default from config
-- Returns -1 if limits are not enabled.
SELECT CASE
-- If limits are disabled, return -1.
WHEN NOT cfg.enabled THEN -1
-- Individual override takes priority.
WHEN u.chat_spend_limit_micros IS NOT NULL THEN u.chat_spend_limit_micros
-- Group limit (minimum across all user's groups) is next.
WHEN gl.limit_micros IS NOT NULL THEN gl.limit_micros
-- Fall back to global default.
ELSE cfg.default_limit_micros
END::bigint AS effective_limit_micros
FROM chat_usage_limit_config cfg
CROSS JOIN users u
LEFT JOIN LATERAL (
SELECT MIN(g.chat_spend_limit_micros) AS limit_micros
FROM groups g
JOIN group_members_expanded gme ON gme.group_id = g.id
WHERE gme.user_id = @user_id::uuid
AND g.chat_spend_limit_micros IS NOT NULL
) gl ON TRUE
WHERE u.id = @user_id::uuid
LIMIT 1;
+1 -1
View File
@@ -147,7 +147,7 @@ WHERE
UPDATE
organizations
SET
workspace_sharing_disabled = @workspace_sharing_disabled,
shareable_workspace_owners = @shareable_workspace_owners,
updated_at = @updated_at
WHERE
id = @id
+20
View File
@@ -140,3 +140,23 @@ SELECT
-- name: UpsertChatSystemPrompt :exec
INSERT INTO site_configs (key, value) VALUES ('agents_chat_system_prompt', $1)
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_system_prompt';
-- name: GetChatDesktopEnabled :one
SELECT
COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_desktop_enabled'), false) :: boolean AS enable_desktop;
-- name: UpsertChatDesktopEnabled :exec
INSERT INTO site_configs (key, value)
VALUES (
'agents_desktop_enabled',
CASE
WHEN sqlc.arg(enable_desktop)::bool THEN 'true'
ELSE 'false'
END
)
ON CONFLICT (key) DO UPDATE
SET value = CASE
WHEN sqlc.arg(enable_desktop)::bool THEN 'true'
ELSE 'false'
END
WHERE site_configs.key = 'agents_desktop_enabled';
+5
View File
@@ -15,6 +15,11 @@ VALUES
(@id, @event_type, @event_data, @created_at, NULL, NULL, NULL)
ON CONFLICT (id) DO NOTHING;
-- name: UsageEventExistsByID :one
SELECT EXISTS(
SELECT 1 FROM usage_events WHERE id = @id
)::bool;
-- name: SelectUsageEventsForPublishing :many
WITH usage_events AS (
UPDATE
+14 -2
View File
@@ -391,9 +391,21 @@ SELECT
array_agg(org_roles || ':' || organization_members.organization_id::text)
FROM
organization_members,
-- All org_members get the organization-member role for their orgs
-- All org members get an implied role for their orgs. Most members
-- get organization-member, but service accounts will get
-- organization-service-account instead. They're largely the same,
-- but having them be distinct means we can allow configuring
-- service-accounts to have slightly broader permissionssuch as
-- for workspace sharing.
unnest(
array_append(roles, 'organization-member')
array_append(
roles,
CASE WHEN users.is_service_account THEN
'organization-service-account'
ELSE
'organization-member'
END
)
) AS org_roles
WHERE
user_id = users.id
+7 -1
View File
@@ -955,7 +955,13 @@ SET
group_acl = '{}'::jsonb,
user_acl = '{}'::jsonb
WHERE
organization_id = @organization_id;
organization_id = @organization_id
AND (
NOT @exclude_service_accounts::boolean
OR owner_id NOT IN (
SELECT id FROM users WHERE is_service_account = true
)
);
-- name: GetRegularWorkspaceCreateMetrics :many
-- Count regular workspaces: only those whose first successful 'start' build
+1
View File
@@ -235,6 +235,7 @@ sql:
aibridge_tool_usage: AIBridgeToolUsage
aibridge_token_usage: AIBridgeTokenUsage
aibridge_user_prompt: AIBridgeUserPrompt
aibridge_model_thought: AIBridgeModelThought
rules:
- name: do-not-use-public-schema-in-queries
message: "do not use public schema in queries"
+3
View File
@@ -7,6 +7,7 @@ type UniqueConstraint string
// UniqueConstraint enums.
const (
UniqueAgentStatsPkey UniqueConstraint = "agent_stats_pkey" // ALTER TABLE ONLY workspace_agent_stats ADD CONSTRAINT agent_stats_pkey PRIMARY KEY (id);
UniqueAiSeatStatePkey UniqueConstraint = "ai_seat_state_pkey" // ALTER TABLE ONLY ai_seat_state ADD CONSTRAINT ai_seat_state_pkey PRIMARY KEY (user_id);
UniqueAibridgeInterceptionsPkey UniqueConstraint = "aibridge_interceptions_pkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_pkey PRIMARY KEY (id);
UniqueAibridgeTokenUsagesPkey UniqueConstraint = "aibridge_token_usages_pkey" // ALTER TABLE ONLY aibridge_token_usages ADD CONSTRAINT aibridge_token_usages_pkey PRIMARY KEY (id);
UniqueAibridgeToolUsagesPkey UniqueConstraint = "aibridge_tool_usages_pkey" // ALTER TABLE ONLY aibridge_tool_usages ADD CONSTRAINT aibridge_tool_usages_pkey PRIMARY KEY (id);
@@ -21,6 +22,8 @@ const (
UniqueChatProvidersPkey UniqueConstraint = "chat_providers_pkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id);
UniqueChatProvidersProviderKey UniqueConstraint = "chat_providers_provider_key" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider);
UniqueChatQueuedMessagesPkey UniqueConstraint = "chat_queued_messages_pkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
UniqueChatUsageLimitConfigPkey UniqueConstraint = "chat_usage_limit_config_pkey" // ALTER TABLE ONLY chat_usage_limit_config ADD CONSTRAINT chat_usage_limit_config_pkey PRIMARY KEY (id);
UniqueChatUsageLimitConfigSingletonKey UniqueConstraint = "chat_usage_limit_config_singleton_key" // ALTER TABLE ONLY chat_usage_limit_config ADD CONSTRAINT chat_usage_limit_config_singleton_key UNIQUE (singleton);
UniqueChatsPkey UniqueConstraint = "chats_pkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_pkey PRIMARY KEY (id);
UniqueConnectionLogsPkey UniqueConstraint = "connection_logs_pkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_pkey PRIMARY KEY (id);
UniqueCryptoKeysPkey UniqueConstraint = "crypto_keys_pkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_pkey PRIMARY KEY (feature, sequence);
+3 -3
View File
@@ -48,8 +48,8 @@ type Store interface {
UpsertChatDiffStatusReference(
ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams,
) (database.ChatDiffStatus, error)
GetChatsByOwnerID(
ctx context.Context, arg database.GetChatsByOwnerIDParams,
GetChats(
ctx context.Context, arg database.GetChatsParams,
) ([]database.Chat, error)
}
@@ -250,7 +250,7 @@ func (w *Worker) MarkStale(
return
}
chats, err := w.store.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
chats, err := w.store.GetChats(ctx, database.GetChatsParams{
OwnerID: ownerID,
})
if err != nil {
+11 -12
View File
@@ -469,8 +469,8 @@ func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) {
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
require.Equal(t, ownerID, arg.OwnerID)
return []database.Chat{
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
@@ -478,13 +478,12 @@ func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) {
{ID: chatOther, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
}, nil
})
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
mu.Lock()
upsertRefCalls = append(upsertRefCalls, arg)
mu.Unlock()
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
}).Times(2)
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
mu.Lock()
upsertRefCalls = append(upsertRefCalls, arg)
mu.Unlock()
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
}).Times(2)
pub := func(_ context.Context, chatID uuid.UUID) error {
mu.Lock()
@@ -527,7 +526,7 @@ func TestWorker_MarkStale_NoMatchingChats(t *testing.T) {
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
Return([]database.Chat{
{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
@@ -555,7 +554,7 @@ func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) {
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
Return([]database.Chat{
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
{ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
@@ -590,7 +589,7 @@ func TestWorker_MarkStale_GetChatsFails(t *testing.T) {
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
Return(nil, fmt.Errorf("db error"))
mClock := quartz.NewMock(t)
+6
View File
@@ -136,6 +136,12 @@ func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool
Properties: sdkTool.Schema.Properties,
Required: sdkTool.Schema.Required,
},
Annotations: mcp.ToolAnnotation{
ReadOnlyHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.ReadOnlyHint),
DestructiveHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.DestructiveHint),
IdempotentHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.IdempotentHint),
OpenWorldHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.OpenWorldHint),
},
},
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
var buf bytes.Buffer
+29 -9
View File
@@ -91,21 +91,41 @@ func TestMCPHTTP_E2E_ClientIntegration(t *testing.T) {
// Verify we have some expected Coder tools
var foundTools []string
for _, tool := range tools.Tools {
var userTool *mcp.Tool
var writeFileTool *mcp.Tool
for i := range tools.Tools {
tool := tools.Tools[i]
foundTools = append(foundTools, tool.Name)
switch tool.Name {
case toolsdk.ToolNameGetAuthenticatedUser:
userTool = &tools.Tools[i]
case toolsdk.ToolNameWorkspaceWriteFile:
writeFileTool = &tools.Tools[i]
}
}
// Check for some basic tools that should be available
assert.Contains(t, foundTools, toolsdk.ToolNameGetAuthenticatedUser, "Should have authenticated user tool")
require.NotNil(t, userTool)
require.NotNil(t, writeFileTool)
require.NotNil(t, userTool.Annotations.ReadOnlyHint)
require.NotNil(t, userTool.Annotations.DestructiveHint)
require.NotNil(t, userTool.Annotations.IdempotentHint)
require.NotNil(t, userTool.Annotations.OpenWorldHint)
assert.True(t, *userTool.Annotations.ReadOnlyHint)
assert.False(t, *userTool.Annotations.DestructiveHint)
assert.True(t, *userTool.Annotations.IdempotentHint)
assert.False(t, *userTool.Annotations.OpenWorldHint)
require.NotNil(t, writeFileTool.Annotations.ReadOnlyHint)
require.NotNil(t, writeFileTool.Annotations.DestructiveHint)
require.NotNil(t, writeFileTool.Annotations.IdempotentHint)
require.NotNil(t, writeFileTool.Annotations.OpenWorldHint)
assert.False(t, *writeFileTool.Annotations.ReadOnlyHint)
assert.True(t, *writeFileTool.Annotations.DestructiveHint)
assert.False(t, *writeFileTool.Annotations.IdempotentHint)
assert.False(t, *writeFileTool.Annotations.OpenWorldHint)
// Find and execute the authenticated user tool
var userTool *mcp.Tool
for _, tool := range tools.Tools {
if tool.Name == toolsdk.ToolNameGetAuthenticatedUser {
userTool = &tool
break
}
}
// Execute the authenticated user tool.
require.NotNil(t, userTool, "Expected to find "+toolsdk.ToolNameGetAuthenticatedUser+" tool")
// Execute the tool
+1
View File
@@ -34,6 +34,7 @@ const (
ServiceAgentMetricAggregator = "agent-metrics-aggregator"
// ServiceTallymanPublisher publishes usage events to coder/tallyman.
ServiceTallymanPublisher = "tallyman-publisher"
ServiceUsageEventCron = "usage-event-cron"
RequestTypeTag = "coder_request_type"
)
+166
View File
@@ -0,0 +1,166 @@
package provisionerdserver_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/provisionerdserver"
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
)
func TestMergeExtraEnvs(t *testing.T) {
t.Parallel()
tests := []struct {
name string
initial map[string]string
envs []*sdkproto.Env
expected map[string]string
expectErr string
}{
{
name: "empty",
initial: map[string]string{},
envs: nil,
expected: map[string]string{},
},
{
name: "default_replace",
initial: map[string]string{},
envs: []*sdkproto.Env{
{Name: "FOO", Value: "bar"},
},
expected: map[string]string{"FOO": "bar"},
},
{
name: "explicit_replace",
initial: map[string]string{"FOO": "old"},
envs: []*sdkproto.Env{
{Name: "FOO", Value: "new", MergeStrategy: "replace"},
},
expected: map[string]string{"FOO": "new"},
},
{
name: "empty_strategy_defaults_to_replace",
initial: map[string]string{"FOO": "old"},
envs: []*sdkproto.Env{
{Name: "FOO", Value: "new", MergeStrategy: ""},
},
expected: map[string]string{"FOO": "new"},
},
{
name: "append_to_existing",
initial: map[string]string{"PATH": "/usr/bin"},
envs: []*sdkproto.Env{
{Name: "PATH", Value: "/custom/bin", MergeStrategy: "append"},
},
expected: map[string]string{"PATH": "/usr/bin:/custom/bin"},
},
{
name: "append_no_existing",
initial: map[string]string{},
envs: []*sdkproto.Env{
{Name: "PATH", Value: "/custom/bin", MergeStrategy: "append"},
},
expected: map[string]string{"PATH": "/custom/bin"},
},
{
name: "append_to_empty_value",
initial: map[string]string{"PATH": ""},
envs: []*sdkproto.Env{
{Name: "PATH", Value: "/custom/bin", MergeStrategy: "append"},
},
expected: map[string]string{"PATH": "/custom/bin"},
},
{
name: "prepend_to_existing",
initial: map[string]string{"PATH": "/usr/bin"},
envs: []*sdkproto.Env{
{Name: "PATH", Value: "/custom/bin", MergeStrategy: "prepend"},
},
expected: map[string]string{"PATH": "/custom/bin:/usr/bin"},
},
{
name: "prepend_no_existing",
initial: map[string]string{},
envs: []*sdkproto.Env{
{Name: "PATH", Value: "/custom/bin", MergeStrategy: "prepend"},
},
expected: map[string]string{"PATH": "/custom/bin"},
},
{
name: "error_no_duplicate",
initial: map[string]string{},
envs: []*sdkproto.Env{
{Name: "FOO", Value: "bar", MergeStrategy: "error"},
},
expected: map[string]string{"FOO": "bar"},
},
{
name: "error_with_duplicate",
initial: map[string]string{"FOO": "existing"},
envs: []*sdkproto.Env{
{Name: "FOO", Value: "new", MergeStrategy: "error"},
},
expectErr: "duplicate env var",
},
{
name: "multiple_appends_same_key",
initial: map[string]string{},
envs: []*sdkproto.Env{
{Name: "PATH", Value: "/a/bin", MergeStrategy: "append"},
{Name: "PATH", Value: "/b/bin", MergeStrategy: "append"},
},
expected: map[string]string{"PATH": "/a/bin:/b/bin"},
},
{
name: "multiple_prepends_same_key",
initial: map[string]string{},
envs: []*sdkproto.Env{
{Name: "PATH", Value: "/a/bin", MergeStrategy: "prepend"},
{Name: "PATH", Value: "/b/bin", MergeStrategy: "prepend"},
},
expected: map[string]string{"PATH": "/b/bin:/a/bin"},
},
{
name: "mixed_strategies",
initial: map[string]string{},
envs: []*sdkproto.Env{
{Name: "PATH", Value: "/first", MergeStrategy: "append"},
{Name: "PATH", Value: "/override", MergeStrategy: "replace"},
},
expected: map[string]string{"PATH": "/override"},
},
{
name: "mixed_keys",
initial: map[string]string{},
envs: []*sdkproto.Env{
{Name: "PATH", Value: "/a", MergeStrategy: "append"},
{Name: "HOME", Value: "/home/user"},
{Name: "PATH", Value: "/b", MergeStrategy: "append"},
},
expected: map[string]string{
"PATH": "/a:/b",
"HOME": "/home/user",
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
env := make(map[string]string)
for k, v := range tc.initial {
env[k] = v
}
err := provisionerdserver.MergeExtraEnvs(env, tc.envs)
if tc.expectErr != "" {
require.ErrorContains(t, err, tc.expectErr)
return
}
require.NoError(t, err)
require.Equal(t, tc.expected, env)
})
}
}
@@ -28,6 +28,7 @@ import (
protobuf "google.golang.org/protobuf/proto"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/aiseats"
"github.com/coder/coder/v2/coderd/apikey"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/database"
@@ -76,6 +77,7 @@ const (
type Options struct {
OIDCConfig promoauth.OAuth2Config
ExternalAuthConfigs []*externalauth.Config
AISeatTracker aiseats.SeatTracker
// Clock for testing
Clock quartz.Clock
@@ -120,6 +122,7 @@ type server struct {
NotificationsEnqueuer notifications.Enqueuer
PrebuildsOrchestrator *atomic.Pointer[prebuilds.ReconciliationOrchestrator]
UsageInserter *atomic.Pointer[usage.Inserter]
AISeatTracker aiseats.SeatTracker
Experiments codersdk.Experiments
OIDCConfig promoauth.OAuth2Config
@@ -215,6 +218,9 @@ func NewServer(
if err := tags.Valid(); err != nil {
return nil, xerrors.Errorf("invalid tags: %w", err)
}
if options.AISeatTracker == nil {
options.AISeatTracker = aiseats.Noop{}
}
if options.AcquireJobLongPollDur == 0 {
options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur
}
@@ -253,6 +259,7 @@ func NewServer(
heartbeatFn: options.HeartbeatFn,
PrebuildsOrchestrator: prebuildsOrchestrator,
UsageInserter: usageInserter,
AISeatTracker: options.AISeatTracker,
metrics: metrics,
Experiments: experiments,
}
@@ -2437,6 +2444,12 @@ func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.Pro
})
}
// Record AI seat usage for successful task workspace builds.
if workspaceBuild.Transition == database.WorkspaceTransitionStart && workspace.TaskID.Valid {
s.AISeatTracker.RecordUsage(ctx, workspace.OwnerID,
aiseats.ReasonTask("task workspace build succeeded"))
}
if s.PrebuildsOrchestrator != nil && input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
// Track resource replacements, if there are any.
orchestrator := s.PrebuildsOrchestrator.Load()
@@ -2821,12 +2834,11 @@ func InsertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
}
env := make(map[string]string)
// For now, we only support adding extra envs, not overriding
// existing ones or performing other manipulations. In future
// we may write these to a separate table so we can perform
// conditional logic on the agent.
for _, e := range prAgent.ExtraEnvs {
env[e.Name] = e.Value
// Apply extra envs with merge strategy support.
// When multiple coder_env resources define the same name,
// the merge_strategy controls how values are combined.
if err := MergeExtraEnvs(env, prAgent.ExtraEnvs); err != nil {
return err
}
// Allow the agent defined envs to override extra envs.
for k, v := range prAgent.Env {
@@ -3422,14 +3434,54 @@ func insertDevcontainerSubagent(
return subAgentID, nil
}
// MergeExtraEnvs applies extra environment variables to the given map,
// respecting the merge_strategy field on each env. When merge_strategy
// is empty or "replace", the value overwrites any existing entry.
// "append" and "prepend" join values with a ":" separator (PATH-style).
// "error" causes a failure if the key already exists.
func MergeExtraEnvs(env map[string]string, extraEnvs []*sdkproto.Env) error {
for _, e := range extraEnvs {
strategy := e.GetMergeStrategy()
if strategy == "" {
strategy = "replace"
}
existing, exists := env[e.GetName()]
switch strategy {
case "error":
if exists {
return xerrors.Errorf(
"duplicate env var %q: merge_strategy is %q but variable is already defined",
e.GetName(), strategy,
)
}
env[e.GetName()] = e.GetValue()
case "append":
if exists && existing != "" {
env[e.GetName()] = existing + ":" + e.GetValue()
} else {
env[e.GetName()] = e.GetValue()
}
case "prepend":
if exists && existing != "" {
env[e.GetName()] = e.GetValue() + ":" + existing
} else {
env[e.GetName()] = e.GetValue()
}
default: // "replace"
env[e.GetName()] = e.GetValue()
}
}
return nil
}
func encodeSubagentEnvs(envs []*sdkproto.Env) (pqtype.NullRawMessage, error) {
if len(envs) == 0 {
return pqtype.NullRawMessage{}, nil
}
subAgentEnvs := make(map[string]string, len(envs))
for _, env := range envs {
subAgentEnvs[env.GetName()] = env.GetValue()
if err := MergeExtraEnvs(subAgentEnvs, envs); err != nil {
return pqtype.NullRawMessage{}, err
}
data, err := json.Marshal(subAgentEnvs)

Some files were not shown because too many files have changed in this diff Show More