Compare commits

..

5 Commits

Author SHA1 Message Date
Jon Ayers aefc75133a fix: use separate http.Transports for wsproxy tests 2026-02-24 23:02:37 +00:00
Jon Ayers b9181c3934 feat(wsproxy): add /debug/expvar endpoint for DERP server stats 2026-02-21 00:19:41 +00:00
Jon Ayers a90471db53 feat(monitoring): add wsproxy DERP section to Grafana dashboard
Adds a new 'Workspace Proxy - DERP' row with 6 panels:
- DERP Connections (current connections and home connections)
- DERP Client Breakdown (local, remote, total)
- DERP Throughput (bytes received/sent rate)
- DERP Packets (received/sent/forwarded rate)
- DERP Packet Drops (by reason label)
- DERP Queue Duration (average queue duration)
2026-02-20 23:44:24 +00:00
Jon Ayers cb71f5e789 feat(wsproxy): add DERP websocket throughput metrics
Add Prometheus metrics tracking active DERP websocket connections and
bytes relayed through the wsproxy:

- coder_wsproxy_derp_websocket_active_connections (gauge)
- coder_wsproxy_derp_websocket_bytes_total (counter, direction=read|write)

Implementation adds a DERPWebsocketMetrics hook struct and countingConn
wrapper in tailnet/, and a new WithWebsocketSupportAndMetrics function
that instruments the websocket connection lifecycle. The existing
WithWebsocketSupport function delegates to the new one with nil metrics.
2026-02-20 23:44:21 +00:00
Jon Ayers f50707bc3e feat(wsproxy): add Prometheus collector for DERP server expvar metrics
Create a prometheus.Collector that bridges the tailscale derp.Server's
expvar-based stats to Prometheus metrics with namespace coder, subsystem
wsproxy_derp. Handles counters, gauges, labeled metrics (nested
metrics.Set for drop reasons, packet types, etc.), and the average
queue duration (converted from ms to seconds).

Register the collector in the wsproxy server after derpServer creation.
2026-02-20 23:40:03 +00:00
316 changed files with 5257 additions and 7872 deletions
+1 -1
View File
@@ -23,7 +23,7 @@ jobs:
egress-policy: audit
- name: stale
uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
with:
stale-issue-label: "stale"
stale-pr-label: "stale"
-4
View File
@@ -235,10 +235,6 @@ type FakeAgentAPI struct {
pushResourcesMonitoringUsageFunc func(*agentproto.PushResourcesMonitoringUsageRequest) (*agentproto.PushResourcesMonitoringUsageResponse, error)
}
func (*FakeAgentAPI) UpdateAppStatus(context.Context, *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
panic("unimplemented")
}
func (f *FakeAgentAPI) GetManifest(context.Context, *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
return f.manifest, nil
}
+330 -544
View File
File diff suppressed because it is too large Load Diff
+1 -20
View File
@@ -436,7 +436,7 @@ message CreateSubAgentRequest {
}
repeated DisplayApp display_apps = 6;
optional bytes id = 7;
}
@@ -494,24 +494,6 @@ message ReportBoundaryLogsRequest {
message ReportBoundaryLogsResponse {}
// UpdateAppStatusRequest updates the given Workspace App's status. c.f. agentsdk.PatchAppStatus
message UpdateAppStatusRequest {
string slug = 1;
enum AppStatusState {
WORKING = 0;
IDLE = 1;
COMPLETE = 2;
FAILURE = 3;
}
AppStatusState state = 2;
string message = 3;
string uri = 4;
}
message UpdateAppStatusResponse {}
service Agent {
rpc GetManifest(GetManifestRequest) returns (Manifest);
rpc GetServiceBanner(GetServiceBannerRequest) returns (ServiceBanner);
@@ -530,5 +512,4 @@ service Agent {
rpc DeleteSubAgent(DeleteSubAgentRequest) returns (DeleteSubAgentResponse);
rpc ListSubAgents(ListSubAgentsRequest) returns (ListSubAgentsResponse);
rpc ReportBoundaryLogs(ReportBoundaryLogsRequest) returns (ReportBoundaryLogsResponse);
rpc UpdateAppStatus(UpdateAppStatusRequest) returns (UpdateAppStatusResponse);
}
+1 -41
View File
@@ -56,7 +56,6 @@ type DRPCAgentClient interface {
DeleteSubAgent(ctx context.Context, in *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
ListSubAgents(ctx context.Context, in *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
type drpcAgentClient struct {
@@ -222,15 +221,6 @@ func (c *drpcAgentClient) ReportBoundaryLogs(ctx context.Context, in *ReportBoun
return out, nil
}
func (c *drpcAgentClient) UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
out := new(UpdateAppStatusResponse)
err := c.cc.Invoke(ctx, "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
type DRPCAgentServer interface {
GetManifest(context.Context, *GetManifestRequest) (*Manifest, error)
GetServiceBanner(context.Context, *GetServiceBannerRequest) (*ServiceBanner, error)
@@ -249,7 +239,6 @@ type DRPCAgentServer interface {
DeleteSubAgent(context.Context, *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
ListSubAgents(context.Context, *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
ReportBoundaryLogs(context.Context, *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
type DRPCAgentUnimplementedServer struct{}
@@ -322,13 +311,9 @@ func (s *DRPCAgentUnimplementedServer) ReportBoundaryLogs(context.Context, *Repo
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentUnimplementedServer) UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
type DRPCAgentDescription struct{}
func (DRPCAgentDescription) NumMethods() int { return 18 }
func (DRPCAgentDescription) NumMethods() int { return 17 }
func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
switch n {
@@ -485,15 +470,6 @@ func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver,
in1.(*ReportBoundaryLogsRequest),
)
}, DRPCAgentServer.ReportBoundaryLogs, true
case 17:
return "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentServer).
UpdateAppStatus(
ctx,
in1.(*UpdateAppStatusRequest),
)
}, DRPCAgentServer.UpdateAppStatus, true
default:
return "", nil, nil, nil, false
}
@@ -774,19 +750,3 @@ func (x *drpcAgent_ReportBoundaryLogsStream) SendAndClose(m *ReportBoundaryLogsR
}
return x.CloseSend()
}
type DRPCAgent_UpdateAppStatusStream interface {
drpc.Stream
SendAndClose(*UpdateAppStatusResponse) error
}
type drpcAgent_UpdateAppStatusStream struct {
drpc.Stream
}
func (x *drpcAgent_UpdateAppStatusStream) SendAndClose(m *UpdateAppStatusResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_proto_agent_proto{}); err != nil {
return err
}
return x.CloseSend()
}
+3 -7
View File
@@ -73,13 +73,9 @@ type DRPCAgentClient27 interface {
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
}
// DRPCAgentClient28 is the Agent API at v2.8. It adds
// - a SubagentId field to the WorkspaceAgentDevcontainer message
// - an Id field to the CreateSubAgentRequest message.
// - UpdateAppStatus RPC.
//
// Compatible with Coder v2.31+
// DRPCAgentClient28 is the Agent API at v2.8. It adds a SubagentId field to the
// WorkspaceAgentDevcontainer message, and a Id field to the CreateSubAgentRequest
// message. Compatible with Coder v2.31+
type DRPCAgentClient28 interface {
DRPCAgentClient27
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
+45 -50
View File
@@ -10,7 +10,6 @@ import (
"path/filepath"
"slices"
"strings"
"time"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
@@ -24,7 +23,6 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/toolsdk"
"github.com/coder/retry"
"github.com/coder/serpent"
)
@@ -541,6 +539,7 @@ func (r *RootCmd) mcpServer() *serpent.Command {
defer cancel()
defer srv.queue.Close()
cliui.Infof(inv.Stderr, "Failed to watch screen events")
// Start the reporter, watcher, and server. These are all tied to the
// lifetime of the MCP server, which is itself tied to the lifetime of the
// AI agent.
@@ -614,51 +613,48 @@ func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation)
}
func (s *mcpServer) startWatcher(ctx context.Context, inv *serpent.Invocation) {
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
return
}
go func() {
for retrier := retry.New(time.Second, 30*time.Second); retrier.Wait(ctx); {
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
if err == nil {
retrier.Reset()
loop:
for {
select {
case <-ctx.Done():
for {
select {
case <-ctx.Done():
return
case event := <-eventsCh:
switch ev := event.(type) {
case agentapi.EventStatusChange:
// If the screen is stable, report idle.
state := codersdk.WorkspaceAppStatusStateWorking
if ev.Status == agentapi.StatusStable {
state = codersdk.WorkspaceAppStatusStateIdle
}
err := s.queue.Push(taskReport{
state: state,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
case event := <-eventsCh:
switch ev := event.(type) {
case agentapi.EventStatusChange:
state := codersdk.WorkspaceAppStatusStateWorking
if ev.Status == agentapi.StatusStable {
state = codersdk.WorkspaceAppStatusStateIdle
}
err := s.queue.Push(taskReport{
state: state,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
}
case agentapi.EventMessageUpdate:
if ev.Role == agentapi.RoleUser {
err := s.queue.Push(taskReport{
messageID: &ev.Id,
state: codersdk.WorkspaceAppStatusStateWorking,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
}
}
}
case agentapi.EventMessageUpdate:
if ev.Role == agentapi.RoleUser {
err := s.queue.Push(taskReport{
messageID: &ev.Id,
state: codersdk.WorkspaceAppStatusStateWorking,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
}
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
}
break loop
}
}
} else {
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
}
return
}
}
}()
@@ -696,14 +692,13 @@ func (s *mcpServer) startServer(ctx context.Context, inv *serpent.Invocation, in
// Add tool dependencies.
toolOpts := []func(*toolsdk.Deps){
toolsdk.WithTaskReporter(func(args toolsdk.ReportTaskArgs) error {
state := codersdk.WorkspaceAppStatusState(args.State)
// The agent does not reliably report idle, so when AgentAPI is
// enabled we override idle to working and let the screen watcher
// detect the real idle via StatusStable. Final states (failure,
// complete) are trusted from the agent since the screen watcher
// cannot produce them.
if s.aiAgentAPIClient != nil && state == codersdk.WorkspaceAppStatusStateIdle {
state = codersdk.WorkspaceAppStatusStateWorking
// The agent does not reliably report its status correctly. If AgentAPI
// is enabled, we will always set the status to "working" when we get an
// MCP message, and rely on the screen watcher to eventually catch the
// idle state.
state := codersdk.WorkspaceAppStatusStateWorking
if s.aiAgentAPIClient == nil {
state = codersdk.WorkspaceAppStatusState(args.State)
}
return s.queue.Push(taskReport{
link: args.Link,
+1 -185
View File
@@ -921,7 +921,7 @@ func TestExpMcpReporter(t *testing.T) {
},
},
},
// We override idle from the agent to working, but trust final states.
// We ignore the state from the agent and assume "working".
{
name: "IgnoreAgentState",
// AI agent reports that it is finished but the summary says it is doing
@@ -953,46 +953,6 @@ func TestExpMcpReporter(t *testing.T) {
Message: "finished",
},
},
// Agent reports failure; trusted even with AgentAPI enabled.
{
state: codersdk.WorkspaceAppStatusStateFailure,
summary: "something broke",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateFailure,
Message: "something broke",
},
},
// After failure, watcher reports stable -> idle.
{
event: makeStatusEvent(agentapi.StatusStable),
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateIdle,
Message: "something broke",
},
},
},
},
// Final states pass through with AgentAPI enabled.
{
name: "AllowFinalStates",
tests: []test{
{
state: codersdk.WorkspaceAppStatusStateWorking,
summary: "doing work",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateWorking,
Message: "doing work",
},
},
// Agent reports complete; not overridden.
{
state: codersdk.WorkspaceAppStatusStateComplete,
summary: "all done",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateComplete,
Message: "all done",
},
},
},
},
// When AgentAPI is not being used, we accept agent state updates as-is.
@@ -1150,148 +1110,4 @@ func TestExpMcpReporter(t *testing.T) {
<-cmdDone
})
}
t.Run("Reconnect", func(t *testing.T) {
t.Parallel()
// Create a test deployment and workspace.
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user2.ID,
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
a[0].Apps = []*proto.App{
{
Slug: "vscode",
},
}
return a
}).Do()
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitLong))
// Watch the workspace for changes.
watcher, err := client.WatchWorkspace(ctx, r.Workspace.ID)
require.NoError(t, err)
var lastAppStatus codersdk.WorkspaceAppStatus
nextUpdate := func() codersdk.WorkspaceAppStatus {
for {
select {
case <-ctx.Done():
require.FailNow(t, "timed out waiting for status update")
case w, ok := <-watcher:
require.True(t, ok, "watch channel closed")
if w.LatestAppStatus != nil && w.LatestAppStatus.ID != lastAppStatus.ID {
t.Logf("Got status update: %s > %s", lastAppStatus.State, w.LatestAppStatus.State)
lastAppStatus = *w.LatestAppStatus
return lastAppStatus
}
}
}
}
// Mock AI AgentAPI server that supports disconnect/reconnect.
disconnect := make(chan struct{})
listening := make(chan func(sse codersdk.ServerSentEvent) error)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Create a cancelable context so we can stop the SSE sender
// goroutine on disconnect without waiting for the HTTP
// serve loop to cancel r.Context().
sseCtx, sseCancel := context.WithCancel(r.Context())
defer sseCancel()
r = r.WithContext(sseCtx)
send, closed, err := httpapi.ServerSentEventSender(w, r)
if err != nil {
httpapi.Write(sseCtx, w, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error setting up server-sent events.",
Detail: err.Error(),
})
return
}
// Send initial message so the watcher knows the agent is active.
send(*makeMessageEvent(0, agentapi.RoleAgent))
select {
case listening <- send:
case <-r.Context().Done():
return
}
select {
case <-closed:
case <-disconnect:
sseCancel()
<-closed
}
}))
t.Cleanup(srv.Close)
inv, _ := clitest.New(t,
"exp", "mcp", "server",
"--agent-url", client.URL.String(),
"--agent-token", r.AgentToken,
"--app-status-slug", "vscode",
"--allowed-tools=coder_report_task",
"--ai-agentapi-url", srv.URL,
)
inv = inv.WithContext(ctx)
pty := ptytest.New(t)
inv.Stdin = pty.Input()
inv.Stdout = pty.Output()
stderr := ptytest.New(t)
inv.Stderr = stderr.Output()
// Run the MCP server.
clitest.Start(t, inv)
// Initialize.
payload := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
pty.WriteLine(payload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore init response
// Get first sender from the initial SSE connection.
sender := testutil.RequireReceive(ctx, t, listening)
// Self-report a working status via tool call.
toolPayload := `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"doing work","link":""}}}`
pty.WriteLine(toolPayload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore response
got := nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
require.Equal(t, "doing work", got.Message)
// Watcher sends stable, verify idle is reported.
err = sender(*makeStatusEvent(agentapi.StatusStable))
require.NoError(t, err)
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
// Disconnect the SSE connection by signaling the handler to return.
testutil.RequireSend(ctx, t, disconnect, struct{}{})
// Wait for the watcher to reconnect and get the new sender.
sender = testutil.RequireReceive(ctx, t, listening)
// After reconnect, self-report a working status again.
toolPayload = `{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"reconnected","link":""}}}`
pty.WriteLine(toolPayload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore response
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
require.Equal(t, "reconnected", got.Message)
// Verify the watcher still processes events after reconnect.
err = sender(*makeStatusEvent(agentapi.StatusStable))
require.NoError(t, err)
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
cancel()
})
}
-12
View File
@@ -29,7 +29,6 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
templateVersionJobTimeout time.Duration
prebuildWorkspaceTimeout time.Duration
noCleanup bool
provisionerTags []string
tracingFlags = &scaletestTracingFlags{}
timeoutStrategy = &timeoutFlags{}
@@ -112,16 +111,10 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
tags, err := ParseProvisionerTags(provisionerTags)
if err != nil {
return err
}
for i := range numTemplates {
id := strconv.Itoa(int(i))
cfg := prebuilds.Config{
OrganizationID: me.OrganizationIDs[0],
ProvisionerTags: tags,
NumPresets: int(numPresets),
NumPresetPrebuilds: int(numPresetPrebuilds),
TemplateVersionJobTimeout: templateVersionJobTimeout,
@@ -290,11 +283,6 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
Description: "Skip cleanup (deletion test) and leave resources intact.",
Value: serpent.BoolOf(&noCleanup),
},
{
Flag: "provisioner-tag",
Description: "Specify a set of tags to target provisioner daemons.",
Value: serpent.StringArrayOf(&provisionerTags),
},
}
tracingFlags.attach(&cmd.Options)
+5 -1
View File
@@ -106,7 +106,11 @@ func TestList(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
memberClient, member = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
sharedWorkspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
+1 -1
View File
@@ -297,7 +297,7 @@ func (pr *ParameterResolver) verifyConstraints(resolved []codersdk.WorkspaceBuil
return xerrors.Errorf("ephemeral parameter %q can be used only with --prompt-ephemeral-parameters or --ephemeral-parameter flag", r.Name)
}
if !tvp.Mutable && action != WorkspaceCreate && !pr.isFirstTimeUse(r.Name) {
if !tvp.Mutable && action != WorkspaceCreate {
return xerrors.Errorf("parameter %q is immutable and cannot be updated", r.Name)
}
}
+31 -7
View File
@@ -25,7 +25,11 @@ func TestSharingShare(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -64,8 +68,12 @@ func TestSharingShare(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
orgOwner = coderdtest.CreateFirstUser(t, client)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -119,7 +127,11 @@ func TestSharingShare(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -170,7 +182,11 @@ func TestSharingStatus(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -214,7 +230,11 @@ func TestSharingRemove(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -271,7 +291,11 @@ func TestSharingRemove(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
+1 -1
View File
@@ -120,7 +120,7 @@ func (r *RootCmd) start() *serpent.Command {
func buildWorkspaceStartRequest(inv *serpent.Invocation, client *codersdk.Client, workspace codersdk.Workspace, parameterFlags workspaceParameterFlags, buildFlags buildFlags, action WorkspaceCLIAction) (codersdk.CreateWorkspaceBuildRequest, error) {
version := workspace.LatestBuild.TemplateVersionID
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || workspace.TemplateRequireActiveVersion || action == WorkspaceUpdate {
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || action == WorkspaceUpdate {
version = workspace.TemplateActiveVersionID
if version != workspace.LatestBuild.TemplateVersionID {
action = WorkspaceUpdate
+4 -4
View File
@@ -33,7 +33,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Do()
statefilePath := filepath.Join(t.TempDir(), "state")
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name, statefilePath)
@@ -54,7 +54,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Do()
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name)
var gotState bytes.Buffer
@@ -74,7 +74,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Do()
inv, root := clitest.New(t, "state", "pull", taUser.Username+"/"+r.Workspace.Name,
"--build", fmt.Sprintf("%d", r.Build.BuildNumber))
@@ -170,7 +170,7 @@ func TestStatePush(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(initialState).
Seed(database.WorkspaceBuild{ProvisionerState: initialState}).
Do()
wantState := []byte("updated state")
stateFile, err := os.CreateTemp(t.TempDir(), "")
+4 -3
View File
@@ -49,9 +49,10 @@ OPTIONS:
security purposes if a --wildcard-access-url is configured.
--disable-workspace-sharing bool, $CODER_DISABLE_WORKSPACE_SHARING
Disable workspace sharing. Workspace ACL checking is disabled and only
owners can have ssh, apps and terminal access to workspaces. Access
based on the 'owner' role is also allowed unless disabled via
Disable workspace sharing (requires the "workspace-sharing" experiment
to be enabled). Workspace ACL checking is disabled and only owners can
have ssh, apps and terminal access to workspaces. Access based on the
'owner' role is also allowed unless disabled via
--disable-owner-workspace-access.
--swagger-enable bool, $CODER_SWAGGER_ENABLE
+4 -4
View File
@@ -523,10 +523,10 @@ disablePathApps: false
# workspaces.
# (default: <unset>, type: bool)
disableOwnerWorkspaceAccess: false
# Disable workspace sharing. Workspace ACL checking is disabled and only owners
# can have ssh, apps and terminal access to workspaces. Access based on the
# 'owner' role is also allowed unless disabled via
# --disable-owner-workspace-access.
# Disable workspace sharing (requires the "workspace-sharing" experiment to be
# enabled). Workspace ACL checking is disabled and only owners can have ssh, apps
# and terminal access to workspaces. Access based on the 'owner' role is also
# allowed unless disabled via --disable-owner-workspace-access.
# (default: <unset>, type: bool)
disableWorkspaceSharing: false
# These options change the behavior of how clients interact with the Coder.
+15 -2
View File
@@ -241,13 +241,26 @@ func (r *RootCmd) listTokens() *serpent.Command {
}
tokens, err := client.Tokens(inv.Context(), codersdk.Me, codersdk.TokensFilter{
IncludeAll: all,
IncludeExpired: includeExpired,
IncludeAll: all,
})
if err != nil {
return xerrors.Errorf("list tokens: %w", err)
}
// Filter out expired tokens unless --include-expired is set
// TODO(Cian): This _could_ get too big for client-side filtering.
// If it causes issues, we can filter server-side.
if !includeExpired {
now := time.Now()
filtered := make([]codersdk.APIKeyWithOwner, 0, len(tokens))
for _, token := range tokens {
if token.ExpiresAt.After(now) {
filtered = append(filtered, token)
}
}
tokens = filtered
}
displayTokens = make([]tokenListRow, len(tokens))
for i, token := range tokens {
-70
View File
@@ -990,74 +990,4 @@ func TestUpdateValidateRichParameters(t *testing.T) {
_ = testutil.TryReceive(ctx, t, doneChan)
})
t.Run("NewImmutableParameterViaFlag", func(t *testing.T) {
t.Parallel()
// Create template and workspace with only a mutable parameter.
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
templateParameters := []*proto.RichParameter{
{Name: stringParameterName, Type: "string", Mutable: true, Required: true, Options: []*proto.RichParameterOption{
{Name: "First option", Description: "This is first option", Value: "1st"},
{Name: "Second option", Description: "This is second option", Value: "2nd"},
}},
}
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses(templateParameters))
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
inv, root := clitest.New(t, "create", "my-workspace", "--yes", "--template", template.Name, "--parameter", fmt.Sprintf("%s=%s", stringParameterName, "1st"))
clitest.SetupConfig(t, member, root)
err := inv.Run()
require.NoError(t, err)
// Update template: add a new immutable parameter.
updatedTemplateParameters := []*proto.RichParameter{
templateParameters[0],
{Name: immutableParameterName, Type: "string", Mutable: false, Required: true, Options: []*proto.RichParameterOption{
{Name: "fir", Description: "First option for immutable parameter", Value: "I"},
{Name: "sec", Description: "Second option for immutable parameter", Value: "II"},
}},
}
updatedVersion := coderdtest.UpdateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses(updatedTemplateParameters), template.ID)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, updatedVersion.ID)
err = client.UpdateActiveTemplateVersion(context.Background(), template.ID, codersdk.UpdateActiveTemplateVersion{
ID: updatedVersion.ID,
})
require.NoError(t, err)
// Update workspace, supplying the new immutable parameter via
// the --parameter flag. This should succeed because it's the
// first time this parameter is being set.
inv, root = clitest.New(t, "update", "my-workspace",
"--parameter", fmt.Sprintf("%s=%s", immutableParameterName, "II"))
clitest.SetupConfig(t, member, root)
pty := ptytest.New(t).Attach(inv)
doneChan := make(chan struct{})
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
pty.ExpectMatch("Planning workspace")
ctx := testutil.Context(t, testutil.WaitLong)
_ = testutil.TryReceive(ctx, t, doneChan)
// Verify the immutable parameter was set correctly.
workspace, err := client.WorkspaceByOwnerAndName(ctx, memberUser.ID.String(), "my-workspace", codersdk.WorkspaceOptions{})
require.NoError(t, err)
actualParameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID)
require.NoError(t, err)
require.Contains(t, actualParameters, codersdk.WorkspaceBuildParameter{
Name: immutableParameterName,
Value: "II",
})
})
}
-2
View File
@@ -179,8 +179,6 @@ func New(opts Options, workspace database.Workspace) *API {
Database: opts.Database,
Log: opts.Log,
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
Clock: opts.Clock,
NotificationsEnqueuer: opts.NotificationsEnqueuer,
}
api.MetadataAPI = &MetadataAPI{
-240
View File
@@ -2,10 +2,6 @@ package agentapi
import (
"context"
"database/sql"
"fmt"
"net/http"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
@@ -13,14 +9,7 @@ import (
"cdr.dev/slog/v3"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/notifications"
strutil "github.com/coder/coder/v2/coderd/util/strings"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
type AppsAPI struct {
@@ -28,8 +17,6 @@ type AppsAPI struct {
Database database.Store
Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
NotificationsEnqueuer notifications.Enqueuer
Clock quartz.Clock
}
func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) {
@@ -117,230 +104,3 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
}
return &agentproto.BatchUpdateAppHealthResponse{}, nil
}
func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
if len(req.Message) > 160 {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Message is too long.",
Detail: "Message must be less than 160 characters.",
Validations: []codersdk.ValidationError{
{Field: "message", Detail: "Message must be less than 160 characters."},
},
})
}
var dbState database.WorkspaceAppStatusState
switch req.State {
case agentproto.UpdateAppStatusRequest_COMPLETE:
dbState = database.WorkspaceAppStatusStateComplete
case agentproto.UpdateAppStatusRequest_FAILURE:
dbState = database.WorkspaceAppStatusStateFailure
case agentproto.UpdateAppStatusRequest_WORKING:
dbState = database.WorkspaceAppStatusStateWorking
case agentproto.UpdateAppStatusRequest_IDLE:
dbState = database.WorkspaceAppStatusStateIdle
default:
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Invalid state provided.",
Detail: fmt.Sprintf("invalid state: %q", req.State),
Validations: []codersdk.ValidationError{
{Field: "state", Detail: "State must be one of: complete, failure, working, idle."},
},
})
}
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, err
}
app, err := a.Database.GetWorkspaceAppByAgentIDAndSlug(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{
AgentID: workspaceAgent.ID,
Slug: req.Slug,
})
if err != nil {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace app.",
Detail: fmt.Sprintf("No app found with slug %q", req.Slug),
})
}
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
if err != nil {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace.",
Detail: err.Error(),
})
}
// Treat the message as untrusted input.
cleaned := strutil.UISanitize(req.Message)
// Get the latest status for the workspace app to detect no-op updates
// nolint:gocritic // This is a system restricted operation.
latestAppStatus, err := a.Database.GetLatestWorkspaceAppStatusByAppID(dbauthz.AsSystemRestricted(ctx), app.ID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get latest workspace app status.",
Detail: err.Error(),
})
}
// If no rows found, latestAppStatus will be a zero-value struct (ID == uuid.Nil)
// nolint:gocritic // This is a system restricted operation.
_, err = a.Database.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
ID: uuid.New(),
CreatedAt: dbtime.Now(),
WorkspaceID: workspace.ID,
AgentID: workspaceAgent.ID,
AppID: app.ID,
State: dbState,
Message: cleaned,
Uri: sql.NullString{
String: req.Uri,
Valid: req.Uri != "",
},
})
if err != nil {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to insert workspace app status.",
Detail: err.Error(),
})
}
if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentAppStatusUpdate)
if err != nil {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to publish workspace update.",
Detail: err.Error(),
})
}
}
// Notify on state change to Working/Idle for AI tasks
a.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, dbState, workspace, workspaceAgent)
if shouldBump(dbState, latestAppStatus) {
// We pass time.Time{} for nextAutostart since we don't have access to
// TemplateScheduleStore here. The activity bump logic handles this by
// defaulting to the template's activity_bump duration (typically 1 hour).
workspacestats.ActivityBumpWorkspace(ctx, a.Log, a.Database, workspace.ID, time.Time{})
}
// just return a blank response because it doesn't contain any settable fields at present.
return new(agentproto.UpdateAppStatusResponse), nil
}
func shouldBump(dbState database.WorkspaceAppStatusState, latestAppStatus database.WorkspaceAppStatus) bool {
// Bump deadline when agent reports working or transitions away from working.
// This prevents auto-pause during active work and gives users time to interact
// after work completes.
// Bump if reporting working state.
if dbState == database.WorkspaceAppStatusStateWorking {
return true
}
// Bump if transitioning away from working state.
if latestAppStatus.ID != uuid.Nil {
prevState := latestAppStatus.State
if prevState == database.WorkspaceAppStatusStateWorking {
return true
}
}
return false
}
// enqueueAITaskStateNotification enqueues a notification when an AI task's app
// transitions to Working or Idle.
// No-op if:
// - the workspace agent app isn't configured as an AI task,
// - the new state equals the latest persisted state,
// - the workspace agent is not ready (still starting up).
func (a *AppsAPI) enqueueAITaskStateNotification(
ctx context.Context,
appID uuid.UUID,
latestAppStatus database.WorkspaceAppStatus,
newAppStatus database.WorkspaceAppStatusState,
workspace database.Workspace,
agent database.WorkspaceAgent,
) {
var notificationTemplate uuid.UUID
switch newAppStatus {
case database.WorkspaceAppStatusStateWorking:
notificationTemplate = notifications.TemplateTaskWorking
case database.WorkspaceAppStatusStateIdle:
notificationTemplate = notifications.TemplateTaskIdle
case database.WorkspaceAppStatusStateComplete:
notificationTemplate = notifications.TemplateTaskCompleted
case database.WorkspaceAppStatusStateFailure:
notificationTemplate = notifications.TemplateTaskFailed
default:
// Not a notifiable state, do nothing
return
}
if !workspace.TaskID.Valid {
// Workspace has no task ID, do nothing.
return
}
// Only send notifications when the agent is ready. We want to skip
// any state transitions that occur whilst the workspace is starting
// up as it doesn't make sense to receive them.
if agent.LifecycleState != database.WorkspaceAgentLifecycleStateReady {
a.Log.Debug(ctx, "skipping AI task notification because agent is not ready",
slog.F("agent_id", agent.ID),
slog.F("lifecycle_state", agent.LifecycleState),
slog.F("new_app_status", newAppStatus),
)
return
}
task, err := a.Database.GetTaskByID(ctx, workspace.TaskID.UUID)
if err != nil {
a.Log.Warn(ctx, "failed to get task", slog.Error(err))
return
}
if !task.WorkspaceAppID.Valid || task.WorkspaceAppID.UUID != appID {
// Non-task app, do nothing.
return
}
// Skip if the latest persisted state equals the new state (no new transition)
// Note: uuid.Nil check is valid here. If no previous status exists,
// GetLatestWorkspaceAppStatusByAppID returns sql.ErrNoRows and we get a zero-value struct.
if latestAppStatus.ID != uuid.Nil && latestAppStatus.State == newAppStatus {
return
}
// Skip the initial "Working" notification when the task first starts.
// This is obvious to the user since they just created the task.
// We still notify on the first "Idle" status and all subsequent transitions.
if latestAppStatus.ID == uuid.Nil && newAppStatus == database.WorkspaceAppStatusStateWorking {
return
}
if _, err := a.NotificationsEnqueuer.EnqueueWithData(
// nolint:gocritic // Need notifier actor to enqueue notifications
dbauthz.AsNotifier(ctx),
workspace.OwnerID,
notificationTemplate,
map[string]string{
"task": task.Name,
"workspace": workspace.Name,
},
map[string]any{
// Use a 1-minute bucketed timestamp to bypass per-day dedupe,
// allowing identical content to resend within the same day
// (but not more than once every 10s).
"dedupe_bypass_ts": a.Clock.Now().UTC().Truncate(time.Minute),
},
"api-workspace-agent-app-status",
// Associate this notification with related entities
workspace.ID, workspace.OwnerID, workspace.OrganizationID, appID,
); err != nil {
a.Log.Warn(ctx, "failed to notify of task state", slog.Error(err))
return
}
}
-115
View File
@@ -1,115 +0,0 @@
package agentapi
import (
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/util/ptr"
)
func TestShouldBump(t *testing.T) {
t.Parallel()
tests := []struct {
name string
prevState *database.WorkspaceAppStatusState // nil means no previous state
newState database.WorkspaceAppStatusState
shouldBump bool
}{
{
name: "FirstStatusBumps",
prevState: nil,
newState: database.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
{
name: "WorkingToIdleBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: true,
},
{
name: "WorkingToCompleteBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
newState: database.WorkspaceAppStatusStateComplete,
shouldBump: true,
},
{
name: "CompleteToIdleNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "CompleteToCompleteNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
newState: database.WorkspaceAppStatusStateComplete,
shouldBump: false,
},
{
name: "FailureToIdleNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "FailureToFailureNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
newState: database.WorkspaceAppStatusStateFailure,
shouldBump: false,
},
{
name: "CompleteToWorkingBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
newState: database.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
{
name: "FailureToCompleteNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
newState: database.WorkspaceAppStatusStateComplete,
shouldBump: false,
},
{
name: "WorkingToFailureBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
newState: database.WorkspaceAppStatusStateFailure,
shouldBump: true,
},
{
name: "IdleToIdleNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateIdle),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "IdleToWorkingBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateIdle),
newState: database.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var prevAppStatus database.WorkspaceAppStatus
// If there's a previous state, report it first.
if tt.prevState != nil {
prevAppStatus.ID = uuid.UUID{1}
prevAppStatus.State = *tt.prevState
}
didBump := shouldBump(tt.newState, prevAppStatus)
if tt.shouldBump {
require.True(t, didBump, "wanted deadline to bump but it didn't")
} else {
require.False(t, didBump, "wanted deadline not to bump but it did")
}
})
}
}
-188
View File
@@ -2,13 +2,9 @@ package agentapi_test
import (
"context"
"database/sql"
"net/http"
"strings"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
@@ -16,12 +12,8 @@ import (
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestBatchUpdateAppHealths(t *testing.T) {
@@ -261,183 +253,3 @@ func TestBatchUpdateAppHealths(t *testing.T) {
require.Nil(t, resp)
})
}
func TestWorkspaceAgentAppStatus(t *testing.T) {
t.Parallel()
t.Run("Success", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
fEnq := &notificationstest.FakeEnqueuer{}
mClock := quartz.NewMock(t)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
workspaceUpdates := make(chan wspubsub.WorkspaceEventKind, 100)
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(_ context.Context, agnt *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
assert.Equal(t, *agnt, agent)
testutil.AssertSend(ctx, t, workspaceUpdates, kind)
return nil
},
NotificationsEnqueuer: fEnq,
Clock: mClock,
}
app := database.WorkspaceApp{
ID: uuid.UUID{8},
}
mDB.EXPECT().GetWorkspaceAppByAgentIDAndSlug(gomock.Any(), database.GetWorkspaceAppByAgentIDAndSlugParams{
AgentID: agent.ID,
Slug: "vscode",
}).Times(1).Return(app, nil)
task := database.Task{
ID: uuid.UUID{7},
WorkspaceAppID: uuid.NullUUID{
Valid: true,
UUID: app.ID,
},
}
mDB.EXPECT().GetTaskByID(gomock.Any(), task.ID).Times(1).Return(task, nil)
workspace := database.Workspace{
ID: uuid.UUID{9},
TaskID: uuid.NullUUID{
Valid: true,
UUID: task.ID,
},
}
mDB.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Times(1).Return(workspace, nil)
appStatus := database.WorkspaceAppStatus{
ID: uuid.UUID{6},
}
mDB.EXPECT().GetLatestWorkspaceAppStatusByAppID(gomock.Any(), app.ID).Times(1).Return(appStatus, nil)
mDB.EXPECT().InsertWorkspaceAppStatus(
gomock.Any(),
gomock.Cond(func(params database.InsertWorkspaceAppStatusParams) bool {
if params.AgentID == agent.ID && params.AppID == app.ID {
assert.Equal(t, "testing", params.Message)
assert.Equal(t, database.WorkspaceAppStatusStateComplete, params.State)
assert.True(t, params.Uri.Valid)
assert.Equal(t, "https://example.com", params.Uri.String)
return true
}
return false
})).Times(1).Return(database.WorkspaceAppStatus{}, nil)
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "vscode",
Message: "testing",
Uri: "https://example.com",
State: agentproto.UpdateAppStatusRequest_COMPLETE,
})
require.NoError(t, err)
kind := testutil.RequireReceive(ctx, t, workspaceUpdates)
require.Equal(t, wspubsub.WorkspaceEventKindAgentAppStatusUpdate, kind)
sent := fEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTaskCompleted))
require.Len(t, sent, 1)
})
t.Run("FailUnknownApp", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
mDB.EXPECT().GetWorkspaceAppByAgentIDAndSlug(gomock.Any(), gomock.Any()).
Times(1).
Return(database.WorkspaceApp{}, sql.ErrNoRows)
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "unknown",
Message: "testing",
Uri: "https://example.com",
State: agentproto.UpdateAppStatusRequest_COMPLETE,
})
require.ErrorContains(t, err, "No app found with slug")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("FailUnknownState", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "vscode",
Message: "testing",
Uri: "https://example.com",
State: 77,
})
require.ErrorContains(t, err, "Invalid state")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("FailTooLong", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "vscode",
Message: strings.Repeat("a", 161),
Uri: "https://example.com",
State: agentproto.UpdateAppStatusRequest_COMPLETE,
})
require.ErrorContains(t, err, "Message is too long")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
}
+2
View File
@@ -466,6 +466,7 @@ func (api *API) convertTasks(ctx context.Context, requesterID uuid.UUID, dbTasks
apiWorkspaces, err := convertWorkspaces(
ctx,
api.Experiments,
api.Logger,
requesterID,
workspaces,
@@ -545,6 +546,7 @@ func (api *API) taskGet(rw http.ResponseWriter, r *http.Request) {
ws, err := convertWorkspace(
ctx,
api.Experiments,
api.Logger,
apiKey.UserID,
workspace,
+1 -1
View File
@@ -832,7 +832,7 @@ func TestTasks(t *testing.T) {
t.Run("SendToNonActiveStates", func(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{})
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitMedium)
+7 -38
View File
@@ -135,34 +135,6 @@ const docTemplate = `{
}
}
},
"/aibridge/models": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": [
"application/json"
],
"tags": [
"AI Bridge"
],
"summary": "List AI Bridge models",
"operationId": "list-ai-bridge-models",
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"type": "string"
}
}
}
}
}
},
"/appearance": {
"get": {
"security": [
@@ -8266,12 +8238,6 @@ const docTemplate = `{
"name": "user",
"in": "path",
"required": true
},
{
"type": "boolean",
"description": "Include expired tokens in the list",
"name": "include_expired",
"in": "query"
}
],
"responses": {
@@ -9579,7 +9545,6 @@ const docTemplate = `{
],
"summary": "Patch workspace agent app status",
"operationId": "patch-workspace-agent-app-status",
"deprecated": true,
"parameters": [
{
"description": "app status",
@@ -15132,7 +15097,8 @@ const docTemplate = `{
"workspace-usage",
"web-push",
"oauth2",
"mcp-server-http"
"mcp-server-http",
"workspace-sharing"
],
"x-enum-comments": {
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
@@ -15141,6 +15107,7 @@ const docTemplate = `{
"ExperimentNotifications": "Sends notifications via SMTP and webhooks following certain events.",
"ExperimentOAuth2": "Enables OAuth2 provider functionality.",
"ExperimentWebPush": "Enables web push notifications through the browser.",
"ExperimentWorkspaceSharing": "Enables updating workspace ACLs for sharing with users and groups.",
"ExperimentWorkspaceUsage": "Enables the new workspace usage tracking."
},
"x-enum-descriptions": [
@@ -15150,7 +15117,8 @@ const docTemplate = `{
"Enables the new workspace usage tracking.",
"Enables web push notifications through the browser.",
"Enables OAuth2 provider functionality.",
"Enables the MCP HTTP server functionality."
"Enables the MCP HTTP server functionality.",
"Enables updating workspace ACLs for sharing with users and groups."
],
"x-enum-varnames": [
"ExperimentExample",
@@ -15159,7 +15127,8 @@ const docTemplate = `{
"ExperimentWorkspaceUsage",
"ExperimentWebPush",
"ExperimentOAuth2",
"ExperimentMCPServerHTTP"
"ExperimentMCPServerHTTP",
"ExperimentWorkspaceSharing"
]
},
"codersdk.ExternalAPIKeyScopes": {
+7 -34
View File
@@ -112,30 +112,6 @@
}
}
},
"/aibridge/models": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": ["application/json"],
"tags": ["AI Bridge"],
"summary": "List AI Bridge models",
"operationId": "list-ai-bridge-models",
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"type": "string"
}
}
}
}
}
},
"/appearance": {
"get": {
"security": [
@@ -7309,12 +7285,6 @@
"name": "user",
"in": "path",
"required": true
},
{
"type": "boolean",
"description": "Include expired tokens in the list",
"name": "include_expired",
"in": "query"
}
],
"responses": {
@@ -8474,7 +8444,6 @@
"tags": ["Agents"],
"summary": "Patch workspace agent app status",
"operationId": "patch-workspace-agent-app-status",
"deprecated": true,
"parameters": [
{
"description": "app status",
@@ -13655,7 +13624,8 @@
"workspace-usage",
"web-push",
"oauth2",
"mcp-server-http"
"mcp-server-http",
"workspace-sharing"
],
"x-enum-comments": {
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
@@ -13664,6 +13634,7 @@
"ExperimentNotifications": "Sends notifications via SMTP and webhooks following certain events.",
"ExperimentOAuth2": "Enables OAuth2 provider functionality.",
"ExperimentWebPush": "Enables web push notifications through the browser.",
"ExperimentWorkspaceSharing": "Enables updating workspace ACLs for sharing with users and groups.",
"ExperimentWorkspaceUsage": "Enables the new workspace usage tracking."
},
"x-enum-descriptions": [
@@ -13673,7 +13644,8 @@
"Enables the new workspace usage tracking.",
"Enables web push notifications through the browser.",
"Enables OAuth2 provider functionality.",
"Enables the MCP HTTP server functionality."
"Enables the MCP HTTP server functionality.",
"Enables updating workspace ACLs for sharing with users and groups."
],
"x-enum-varnames": [
"ExperimentExample",
@@ -13682,7 +13654,8 @@
"ExperimentWorkspaceUsage",
"ExperimentWebPush",
"ExperimentOAuth2",
"ExperimentMCPServerHTTP"
"ExperimentMCPServerHTTP",
"ExperimentWorkspaceSharing"
]
},
"codersdk.ExternalAPIKeyScopes": {
+8 -14
View File
@@ -307,26 +307,20 @@ func (api *API) apiKeyByName(rw http.ResponseWriter, r *http.Request) {
// @Tags Users
// @Param user path string true "User ID, name, or me"
// @Success 200 {array} codersdk.APIKey
// @Param include_expired query bool false "Include expired tokens in the list"
// @Router /users/{user}/keys/tokens [get]
func (api *API) tokens(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
user = httpmw.UserParam(r)
keys []database.APIKey
err error
queryStr = r.URL.Query().Get("include_all")
includeAll, _ = strconv.ParseBool(queryStr)
expiredStr = r.URL.Query().Get("include_expired")
includeExpired, _ = strconv.ParseBool(expiredStr)
ctx = r.Context()
user = httpmw.UserParam(r)
keys []database.APIKey
err error
queryStr = r.URL.Query().Get("include_all")
includeAll, _ = strconv.ParseBool(queryStr)
)
if includeAll {
// get tokens for all users
keys, err = api.Database.GetAPIKeysByLoginType(ctx, database.GetAPIKeysByLoginTypeParams{
LoginType: database.LoginTypeToken,
IncludeExpired: includeExpired,
})
keys, err = api.Database.GetAPIKeysByLoginType(ctx, database.LoginTypeToken)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching API keys.",
@@ -336,7 +330,7 @@ func (api *API) tokens(rw http.ResponseWriter, r *http.Request) {
}
} else {
// get user's tokens only
keys, err = api.Database.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: user.ID, IncludeExpired: includeExpired})
keys, err = api.Database.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: user.ID})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching API keys.",
-38
View File
@@ -69,44 +69,6 @@ func TestTokenCRUD(t *testing.T) {
require.Equal(t, database.AuditActionDelete, auditor.AuditLogs()[numLogs-1].Action)
}
func TestTokensFilterExpired(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
adminClient := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, adminClient)
// Create a token.
res, err := adminClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
Lifetime: time.Hour * 24 * 7,
})
require.NoError(t, err)
keyID := strings.Split(res.Key, "-")[0]
// List tokens without including expired - should see the token.
keys, err := adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
require.NoError(t, err)
require.Len(t, keys, 1)
// Expire the token.
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
require.NoError(t, err)
// List tokens without including expired - should NOT see expired token.
keys, err = adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
require.NoError(t, err)
require.Empty(t, keys)
// List tokens WITH including expired - should see expired token.
keys, err = adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{
IncludeExpired: true,
})
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, keyID, keys[0].ID)
}
func TestTokenScoped(t *testing.T) {
t.Parallel()
+1 -1
View File
@@ -500,7 +500,7 @@ func (e *Executor) runOnce(t time.Time) Stats {
"task": task.Name,
"task_id": task.ID.String(),
"workspace": ws.Name,
"pause_reason": "idle timeout",
"pause_reason": "inactivity exceeded the dormancy threshold",
},
"lifecycle_executor",
ws.ID, ws.OwnerID, ws.OrganizationID,
+1 -1
View File
@@ -2082,6 +2082,6 @@ func TestExecutorTaskWorkspace(t *testing.T) {
require.Equal(t, task.Name, sent[0].Labels["task"])
require.Equal(t, task.ID.String(), sent[0].Labels["task_id"])
require.Equal(t, workspace.Name, sent[0].Labels["workspace"])
require.Equal(t, "idle timeout", sent[0].Labels["pause_reason"])
require.Equal(t, "inactivity exceeded the dormancy threshold", sent[0].Labels["pause_reason"])
})
}
+4
View File
@@ -1526,6 +1526,10 @@ func New(options *Options) *API {
})
r.Get("/timings", api.workspaceTimings)
r.Route("/acl", func(r chi.Router) {
r.Use(
httpmw.RequireExperiment(api.Experiments, codersdk.ExperimentWorkspaceSharing),
)
r.Get("/", api.workspaceACL)
r.Patch("/", api.patchWorkspaceACL)
r.Delete("/", api.deleteWorkspaceACL)
+3 -70
View File
@@ -668,31 +668,6 @@ var (
}),
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
subjectWorkspaceBuilder = rbac.Subject{
Type: rbac.SubjectTypeWorkspaceBuilder,
FriendlyName: "Workspace Builder",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
{
Identifier: rbac.RoleIdentifier{Name: "workspace-builder"},
DisplayName: "Workspace Builder",
Site: rbac.Permissions(map[string][]policy.Action{
// Reading provisioner daemons to check eligibility.
rbac.ResourceProvisionerDaemon.Type: {policy.ActionRead},
// Updating provisioner jobs (e.g. marking prebuild
// jobs complete).
rbac.ResourceProvisionerJobs.Type: {policy.ActionUpdate},
// Reading provisioner state requires template update
// permission.
rbac.ResourceTemplate.Type: {policy.ActionUpdate},
}),
User: []rbac.Permission{},
ByOrgID: map[string]rbac.OrgPermissions{},
},
}),
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
)
// AsProvisionerd returns a context with an actor that has permissions required
@@ -799,14 +774,6 @@ func AsBoundaryUsageTracker(ctx context.Context) context.Context {
return As(ctx, subjectBoundaryUsageTracker)
}
// AsWorkspaceBuilder returns a context with an actor that has permissions
// required for the workspace builder to prepare workspace builds. This
// includes reading provisioner daemons, updating provisioner jobs, and
// reading provisioner state (which requires template update permission).
func AsWorkspaceBuilder(ctx context.Context) context.Context {
return As(ctx, subjectWorkspaceBuilder)
}
var AsRemoveActor = rbac.Subject{
ID: "remove-actor",
}
@@ -2155,13 +2122,6 @@ func (q *querier) GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UUID)
return fetch(q.log, q.auth, q.db.GetAIBridgeInterceptionByID)(ctx, id)
}
func (q *querier) GetAIBridgeInterceptionLineageByToolCallID(ctx context.Context, toolCallID string) (database.GetAIBridgeInterceptionLineageByToolCallIDRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
return database.GetAIBridgeInterceptionLineageByToolCallIDRow{}, err
}
return q.db.GetAIBridgeInterceptionLineageByToolCallID(ctx, toolCallID)
}
func (q *querier) GetAIBridgeInterceptions(ctx context.Context) ([]database.AIBridgeInterception, error) {
fetch := func(ctx context.Context, _ any) ([]database.AIBridgeInterception, error) {
return q.db.GetAIBridgeInterceptions(ctx)
@@ -2201,12 +2161,12 @@ func (q *querier) GetAPIKeyByName(ctx context.Context, arg database.GetAPIKeyByN
return fetch(q.log, q.auth, q.db.GetAPIKeyByName)(ctx, arg)
}
func (q *querier) GetAPIKeysByLoginType(ctx context.Context, loginType database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
func (q *querier) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByLoginType)(ctx, loginType)
}
func (q *querier) GetAPIKeysByUserID(ctx context.Context, params database.GetAPIKeysByUserIDParams) ([]database.APIKey, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByUserID)(ctx, params)
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByUserID)(ctx, database.GetAPIKeysByUserIDParams{LoginType: params.LoginType, UserID: params.UserID})
}
func (q *querier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]database.APIKey, error) {
@@ -2297,7 +2257,7 @@ func (q *querier) GetAuditLogsOffset(ctx context.Context, arg database.GetAuditL
}
func (q *querier) GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx context.Context, authToken uuid.UUID) (database.GetAuthenticatedWorkspaceAgentAndBuildByAuthTokenRow, error) {
// This is a system function.
// This is a system function
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return database.GetAuthenticatedWorkspaceAgentAndBuildByAuthTokenRow{}, err
}
@@ -3173,13 +3133,6 @@ func (q *querier) GetTelemetryItems(ctx context.Context) ([]database.TelemetryIt
return q.db.GetTelemetryItems(ctx)
}
func (q *querier) GetTelemetryTaskEvents(ctx context.Context, arg database.GetTelemetryTaskEventsParams) ([]database.GetTelemetryTaskEventsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTask.All()); err != nil {
return nil, err
}
return q.db.GetTelemetryTaskEvents(ctx, arg)
}
func (q *querier) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) {
if err := q.authorizeTemplateInsights(ctx, arg.TemplateIDs); err != nil {
return nil, err
@@ -3961,11 +3914,6 @@ func (q *querier) GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, wor
return q.db.GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx, workspaceBuildIDs, prep)
}
func (q *querier) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, buildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
// Fetching the provisioner state requires Update permission on the template.
return fetchWithAction(q.log, q.auth, policy.ActionUpdate, q.db.GetWorkspaceBuildProvisionerStateByID)(ctx, buildID)
}
func (q *querier) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
@@ -4798,14 +4746,6 @@ func (q *querier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Contex
return q.db.ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg)
}
func (q *querier) ListAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams) ([]string, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
if err != nil {
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
}
return q.db.ListAuthorizedAIBridgeModels(ctx, arg, prep)
}
func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
// This function is a system function until we implement a join for aibridge interceptions.
// Matches the behavior of the workspaces listing endpoint.
@@ -6367,10 +6307,3 @@ func (q *querier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg
// database.Store interface, so dbauthz needs to implement it.
return q.CountAIBridgeInterceptions(ctx, arg)
}
func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams, _ rbac.PreparedAuthorized) ([]string, error) {
// TODO: Delete this function, all ListAIBridgeModels should be authorized. For now just call ListAIBridgeModels on the authz querier.
// This cannot be deleted for now because it's included in the
// database.Store interface, so dbauthz needs to implement it.
return q.ListAIBridgeModels(ctx, arg)
}
+2 -40
View File
@@ -237,8 +237,8 @@ func (s *MethodTestSuite) TestAPIKey() {
s.Run("GetAPIKeysByLoginType", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
a := testutil.Fake(s.T(), faker, database.APIKey{LoginType: database.LoginTypePassword})
b := testutil.Fake(s.T(), faker, database.APIKey{LoginType: database.LoginTypePassword})
dbm.EXPECT().GetAPIKeysByLoginType(gomock.Any(), database.GetAPIKeysByLoginTypeParams{LoginType: database.LoginTypePassword}).Return([]database.APIKey{a, b}, nil).AnyTimes()
check.Args(database.GetAPIKeysByLoginTypeParams{LoginType: database.LoginTypePassword}).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
dbm.EXPECT().GetAPIKeysByLoginType(gomock.Any(), database.LoginTypePassword).Return([]database.APIKey{a, b}, nil).AnyTimes()
check.Args(database.LoginTypePassword).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
}))
s.Run("GetAPIKeysByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u1 := testutil.Fake(s.T(), faker, database.User{})
@@ -1326,11 +1326,6 @@ func (s *MethodTestSuite) TestTemplate() {
dbm.EXPECT().GetTemplateInsightsByTemplate(gomock.Any(), arg).Return([]database.GetTemplateInsightsByTemplateRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceTemplate, policy.ActionViewInsights)
}))
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetTelemetryTaskEventsParams{}
dbm.EXPECT().GetTelemetryTaskEvents(gomock.Any(), arg).Return([]database.GetTelemetryTaskEventsRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceTask.All(), policy.ActionRead)
}))
s.Run("GetTemplateAppInsights", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetTemplateAppInsightsParams{}
dbm.EXPECT().GetTemplateAppInsights(gomock.Any(), arg).Return([]database.GetTemplateAppInsightsRow{}, nil).AnyTimes()
@@ -1974,15 +1969,6 @@ func (s *MethodTestSuite) TestWorkspace() {
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
check.Args(build.ID).Asserts(ws, policy.ActionRead).Returns(build)
}))
s.Run("GetWorkspaceBuildProvisionerStateByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
row := database.GetWorkspaceBuildProvisionerStateByIDRow{
ProvisionerState: []byte("state"),
TemplateID: uuid.New(),
TemplateOrganizationID: uuid.New(),
}
dbm.EXPECT().GetWorkspaceBuildProvisionerStateByID(gomock.Any(), gomock.Any()).Return(row, nil).AnyTimes()
check.Args(uuid.New()).Asserts(row, policy.ActionUpdate).Returns(row)
}))
s.Run("GetWorkspaceBuildByJobID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ws := testutil.Fake(s.T(), faker, database.Workspace{})
build := testutil.Fake(s.T(), faker, database.WorkspaceBuild{WorkspaceID: ws.ID})
@@ -4695,16 +4681,6 @@ func (s *MethodTestSuite) TestAIBridge() {
check.Args(intID).Asserts(intc, policy.ActionRead).Returns(intc)
}))
s.Run("GetAIBridgeInterceptionLineageByToolCallID", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
toolCallID := "call_123"
row := database.GetAIBridgeInterceptionLineageByToolCallIDRow{
ThreadParentID: uuid.UUID{1},
ThreadRootID: uuid.UUID{2},
}
db.EXPECT().GetAIBridgeInterceptionLineageByToolCallID(gomock.Any(), toolCallID).Return(row, nil).AnyTimes()
check.Args(toolCallID).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead).Returns(row)
}))
s.Run("GetAIBridgeInterceptions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
a := testutil.Fake(s.T(), faker, database.AIBridgeInterception{})
b := testutil.Fake(s.T(), faker, database.AIBridgeInterception{})
@@ -4770,20 +4746,6 @@ func (s *MethodTestSuite) TestAIBridge() {
check.Args(params, emptyPreparedAuthorized{}).Asserts()
}))
s.Run("ListAIBridgeModels", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
params := database.ListAIBridgeModelsParams{}
db.EXPECT().ListAuthorizedAIBridgeModels(gomock.Any(), params, gomock.Any()).Return([]string{}, nil).AnyTimes()
// No asserts here because SQLFilter.
check.Args(params).Asserts()
}))
s.Run("ListAuthorizedAIBridgeModels", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
params := database.ListAIBridgeModelsParams{}
db.EXPECT().ListAuthorizedAIBridgeModels(gomock.Any(), params, gomock.Any()).Return([]string{}, nil).AnyTimes()
// No asserts here because SQLFilter.
check.Args(params, emptyPreparedAuthorized{}).Asserts()
}))
s.Run("ListAIBridgeTokenUsagesByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ids := []uuid.UUID{{1}}
db.EXPECT().ListAIBridgeTokenUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeTokenUsage{}, nil).AnyTimes()
-19
View File
@@ -67,8 +67,6 @@ type WorkspaceBuildBuilder struct {
jobError string // Error message for failed jobs
jobErrorCode string // Error code for failed jobs
provisionerState []byte
}
// BuilderOption is a functional option for customizing job timestamps
@@ -140,15 +138,6 @@ func (b WorkspaceBuildBuilder) Seed(seed database.WorkspaceBuild) WorkspaceBuild
return b
}
// ProvisionerState sets the provisioner state for the workspace build.
// This is stored separately from the seed because ProvisionerState is
// not part of the WorkspaceBuild view struct.
func (b WorkspaceBuildBuilder) ProvisionerState(state []byte) WorkspaceBuildBuilder {
//nolint: revive // returns modified struct
b.provisionerState = state
return b
}
func (b WorkspaceBuildBuilder) Resource(resource ...*sdkproto.Resource) WorkspaceBuildBuilder {
//nolint: revive // returns modified struct
b.resources = append(b.resources, resource...)
@@ -475,14 +464,6 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
}
resp.Build = dbgen.WorkspaceBuild(b.t, b.db, b.seed)
if len(b.provisionerState) > 0 {
err = b.db.UpdateWorkspaceBuildProvisionerStateByID(ownerCtx, database.UpdateWorkspaceBuildProvisionerStateByIDParams{
ID: resp.Build.ID,
UpdatedAt: dbtime.Now(),
ProvisionerState: b.provisionerState,
})
require.NoError(b.t, err, "update provisioner state")
}
b.logger.Debug(context.Background(), "created workspace build",
slog.F("build_id", resp.Build.ID),
slog.F("workspace_id", resp.Workspace.ID),
+9 -14
View File
@@ -504,7 +504,7 @@ func WorkspaceBuild(t testing.TB, db database.Store, orig database.WorkspaceBuil
Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart),
InitiatorID: takeFirst(orig.InitiatorID, uuid.New()),
JobID: jobID,
ProvisionerState: []byte{},
ProvisionerState: takeFirstSlice(orig.ProvisionerState, []byte{}),
Deadline: takeFirst(orig.Deadline, dbtime.Now().Add(time.Hour)),
MaxDeadline: takeFirst(orig.MaxDeadline, time.Time{}),
Reason: takeFirst(orig.Reason, database.BuildReasonInitiator),
@@ -1373,8 +1373,6 @@ func OAuth2ProviderAppCode(t testing.TB, db database.Store, seed database.OAuth2
ResourceUri: seed.ResourceUri,
CodeChallenge: seed.CodeChallenge,
CodeChallengeMethod: seed.CodeChallengeMethod,
StateHash: seed.StateHash,
RedirectUri: seed.RedirectUri,
})
require.NoError(t, err, "insert oauth2 app code")
return code
@@ -1585,16 +1583,14 @@ func ClaimPrebuild(
func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertAIBridgeInterceptionParams, endedAt *time.Time) database.AIBridgeInterception {
interception, err := db.InsertAIBridgeInterception(genCtx, database.InsertAIBridgeInterceptionParams{
ID: takeFirst(seed.ID, uuid.New()),
APIKeyID: seed.APIKeyID,
InitiatorID: takeFirst(seed.InitiatorID, uuid.New()),
Provider: takeFirst(seed.Provider, "provider"),
Model: takeFirst(seed.Model, "model"),
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
StartedAt: takeFirst(seed.StartedAt, dbtime.Now()),
Client: seed.Client,
ThreadParentInterceptionID: seed.ThreadParentInterceptionID,
ThreadRootInterceptionID: seed.ThreadRootInterceptionID,
ID: takeFirst(seed.ID, uuid.New()),
APIKeyID: seed.APIKeyID,
InitiatorID: takeFirst(seed.InitiatorID, uuid.New()),
Provider: takeFirst(seed.Provider, "provider"),
Model: takeFirst(seed.Model, "model"),
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
StartedAt: takeFirst(seed.StartedAt, dbtime.Now()),
Client: seed.Client,
})
if endedAt != nil {
interception, err = db.UpdateAIBridgeInterceptionEnded(genCtx, database.UpdateAIBridgeInterceptionEndedParams{
@@ -1647,7 +1643,6 @@ func AIBridgeToolUsage(t testing.TB, db database.Store, seed database.InsertAIBr
ID: takeFirst(seed.ID, uuid.New()),
InterceptionID: takeFirst(seed.InterceptionID, uuid.New()),
ProviderResponseID: takeFirst(seed.ProviderResponseID, "provider_response_id"),
ProviderToolCallID: sql.NullString{String: takeFirst(seed.ProviderResponseID, testutil.GetRandomName(t)), Valid: true},
Tool: takeFirst(seed.Tool, "tool"),
ServerUrl: serverURL,
Input: takeFirst(seed.Input, "input"),
+1 -41
View File
@@ -726,14 +726,6 @@ func (m queryMetricsStore) GetAIBridgeInterceptionByID(ctx context.Context, id u
return r0, r1
}
func (m queryMetricsStore) GetAIBridgeInterceptionLineageByToolCallID(ctx context.Context, toolCallID string) (database.GetAIBridgeInterceptionLineageByToolCallIDRow, error) {
start := time.Now()
r0, r1 := m.s.GetAIBridgeInterceptionLineageByToolCallID(ctx, toolCallID)
m.queryLatencies.WithLabelValues("GetAIBridgeInterceptionLineageByToolCallID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAIBridgeInterceptionLineageByToolCallID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetAIBridgeInterceptions(ctx context.Context) ([]database.AIBridgeInterception, error) {
start := time.Now()
r0, r1 := m.s.GetAIBridgeInterceptions(ctx)
@@ -782,7 +774,7 @@ func (m queryMetricsStore) GetAPIKeyByName(ctx context.Context, arg database.Get
return r0, r1
}
func (m queryMetricsStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
func (m queryMetricsStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
start := time.Now()
r0, r1 := m.s.GetAPIKeysByLoginType(ctx, loginType)
m.queryLatencies.WithLabelValues("GetAPIKeysByLoginType").Observe(time.Since(start).Seconds())
@@ -1798,14 +1790,6 @@ func (m queryMetricsStore) GetTelemetryItems(ctx context.Context) ([]database.Te
return r0, r1
}
func (m queryMetricsStore) GetTelemetryTaskEvents(ctx context.Context, createdAfter database.GetTelemetryTaskEventsParams) ([]database.GetTelemetryTaskEventsRow, error) {
start := time.Now()
r0, r1 := m.s.GetTelemetryTaskEvents(ctx, createdAfter)
m.queryLatencies.WithLabelValues("GetTelemetryTaskEvents").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTelemetryTaskEvents").Inc()
return r0, r1
}
func (m queryMetricsStore) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) {
start := time.Now()
r0, r1 := m.s.GetTemplateAppInsights(ctx, arg)
@@ -2446,14 +2430,6 @@ func (m queryMetricsStore) GetWorkspaceBuildParametersByBuildIDs(ctx context.Con
return r0, r1
}
func (m queryMetricsStore) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
start := time.Now()
r0, r1 := m.s.GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuildID)
m.queryLatencies.WithLabelValues("GetWorkspaceBuildProvisionerStateByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetWorkspaceBuildProvisionerStateByID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
start := time.Now()
r0, r1 := m.s.GetWorkspaceBuildStatsByTemplates(ctx, since)
@@ -3222,14 +3198,6 @@ func (m queryMetricsStore) ListAIBridgeInterceptionsTelemetrySummaries(ctx conte
return r0, r1
}
func (m queryMetricsStore) ListAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams) ([]string, error) {
start := time.Now()
r0, r1 := m.s.ListAIBridgeModels(ctx, arg)
m.queryLatencies.WithLabelValues("ListAIBridgeModels").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAIBridgeModels").Inc()
return r0, r1
}
func (m queryMetricsStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
start := time.Now()
r0, r1 := m.s.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIds)
@@ -4444,11 +4412,3 @@ func (m queryMetricsStore) CountAuthorizedAIBridgeInterceptions(ctx context.Cont
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountAuthorizedAIBridgeInterceptions").Inc()
return r0, r1
}
func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error) {
start := time.Now()
r0, r1 := m.s.ListAuthorizedAIBridgeModels(ctx, arg, prepared)
m.queryLatencies.WithLabelValues("ListAuthorizedAIBridgeModels").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeModels").Inc()
return r0, r1
}
+4 -79
View File
@@ -1214,21 +1214,6 @@ func (mr *MockStoreMockRecorder) GetAIBridgeInterceptionByID(ctx, id any) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIBridgeInterceptionByID", reflect.TypeOf((*MockStore)(nil).GetAIBridgeInterceptionByID), ctx, id)
}
// GetAIBridgeInterceptionLineageByToolCallID mocks base method.
func (m *MockStore) GetAIBridgeInterceptionLineageByToolCallID(ctx context.Context, toolCallID string) (database.GetAIBridgeInterceptionLineageByToolCallIDRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAIBridgeInterceptionLineageByToolCallID", ctx, toolCallID)
ret0, _ := ret[0].(database.GetAIBridgeInterceptionLineageByToolCallIDRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAIBridgeInterceptionLineageByToolCallID indicates an expected call of GetAIBridgeInterceptionLineageByToolCallID.
func (mr *MockStoreMockRecorder) GetAIBridgeInterceptionLineageByToolCallID(ctx, toolCallID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIBridgeInterceptionLineageByToolCallID", reflect.TypeOf((*MockStore)(nil).GetAIBridgeInterceptionLineageByToolCallID), ctx, toolCallID)
}
// GetAIBridgeInterceptions mocks base method.
func (m *MockStore) GetAIBridgeInterceptions(ctx context.Context) ([]database.AIBridgeInterception, error) {
m.ctrl.T.Helper()
@@ -1320,18 +1305,18 @@ func (mr *MockStoreMockRecorder) GetAPIKeyByName(ctx, arg any) *gomock.Call {
}
// GetAPIKeysByLoginType mocks base method.
func (m *MockStore) GetAPIKeysByLoginType(ctx context.Context, arg database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
func (m *MockStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAPIKeysByLoginType", ctx, arg)
ret := m.ctrl.Call(m, "GetAPIKeysByLoginType", ctx, loginType)
ret0, _ := ret[0].([]database.APIKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAPIKeysByLoginType indicates an expected call of GetAPIKeysByLoginType.
func (mr *MockStoreMockRecorder) GetAPIKeysByLoginType(ctx, arg any) *gomock.Call {
func (mr *MockStoreMockRecorder) GetAPIKeysByLoginType(ctx, loginType any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeysByLoginType", reflect.TypeOf((*MockStore)(nil).GetAPIKeysByLoginType), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeysByLoginType", reflect.TypeOf((*MockStore)(nil).GetAPIKeysByLoginType), ctx, loginType)
}
// GetAPIKeysByUserID mocks base method.
@@ -3329,21 +3314,6 @@ func (mr *MockStoreMockRecorder) GetTelemetryItems(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTelemetryItems", reflect.TypeOf((*MockStore)(nil).GetTelemetryItems), ctx)
}
// GetTelemetryTaskEvents mocks base method.
func (m *MockStore) GetTelemetryTaskEvents(ctx context.Context, arg database.GetTelemetryTaskEventsParams) ([]database.GetTelemetryTaskEventsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTelemetryTaskEvents", ctx, arg)
ret0, _ := ret[0].([]database.GetTelemetryTaskEventsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTelemetryTaskEvents indicates an expected call of GetTelemetryTaskEvents.
func (mr *MockStoreMockRecorder) GetTelemetryTaskEvents(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTelemetryTaskEvents", reflect.TypeOf((*MockStore)(nil).GetTelemetryTaskEvents), ctx, arg)
}
// GetTemplateAppInsights mocks base method.
func (m *MockStore) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) {
m.ctrl.T.Helper()
@@ -4574,21 +4544,6 @@ func (mr *MockStoreMockRecorder) GetWorkspaceBuildParametersByBuildIDs(ctx, work
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceBuildParametersByBuildIDs", reflect.TypeOf((*MockStore)(nil).GetWorkspaceBuildParametersByBuildIDs), ctx, workspaceBuildIds)
}
// GetWorkspaceBuildProvisionerStateByID mocks base method.
func (m *MockStore) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetWorkspaceBuildProvisionerStateByID", ctx, workspaceBuildID)
ret0, _ := ret[0].(database.GetWorkspaceBuildProvisionerStateByIDRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetWorkspaceBuildProvisionerStateByID indicates an expected call of GetWorkspaceBuildProvisionerStateByID.
func (mr *MockStoreMockRecorder) GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuildID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceBuildProvisionerStateByID", reflect.TypeOf((*MockStore)(nil).GetWorkspaceBuildProvisionerStateByID), ctx, workspaceBuildID)
}
// GetWorkspaceBuildStatsByTemplates mocks base method.
func (m *MockStore) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
m.ctrl.T.Helper()
@@ -6027,21 +5982,6 @@ func (mr *MockStoreMockRecorder) ListAIBridgeInterceptionsTelemetrySummaries(ctx
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeInterceptionsTelemetrySummaries", reflect.TypeOf((*MockStore)(nil).ListAIBridgeInterceptionsTelemetrySummaries), ctx, arg)
}
// ListAIBridgeModels mocks base method.
func (m *MockStore) ListAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListAIBridgeModels", ctx, arg)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListAIBridgeModels indicates an expected call of ListAIBridgeModels.
func (mr *MockStoreMockRecorder) ListAIBridgeModels(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAIBridgeModels), ctx, arg)
}
// ListAIBridgeTokenUsagesByInterceptionIDs mocks base method.
func (m *MockStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
m.ctrl.T.Helper()
@@ -6102,21 +6042,6 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeInterceptions(ctx, arg, p
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeInterceptions), ctx, arg, prepared)
}
// ListAuthorizedAIBridgeModels mocks base method.
func (m *MockStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListAuthorizedAIBridgeModels", ctx, arg, prepared)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListAuthorizedAIBridgeModels indicates an expected call of ListAuthorizedAIBridgeModels.
func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared)
}
// ListProvisionerKeysByOrganization mocks base method.
func (m *MockStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
m.ctrl.T.Helper()
+4 -22
View File
@@ -1024,19 +1024,13 @@ CREATE TABLE aibridge_interceptions (
metadata jsonb,
ended_at timestamp with time zone,
api_key_id text,
client character varying(64) DEFAULT 'Unknown'::character varying,
thread_parent_id uuid,
thread_root_id uuid
client character varying(64) DEFAULT 'Unknown'::character varying
);
COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge';
COMMENT ON COLUMN aibridge_interceptions.initiator_id IS 'Relates to a users record, but FK is elided for performance.';
COMMENT ON COLUMN aibridge_interceptions.thread_parent_id IS 'The interception which directly caused this interception to occur, usually through an agentic loop or threaded conversation.';
COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interception of the thread that this interception belongs to.';
CREATE TABLE aibridge_token_usages (
id uuid NOT NULL,
interception_id uuid NOT NULL,
@@ -1061,8 +1055,7 @@ CREATE TABLE aibridge_tool_usages (
injected boolean DEFAULT false NOT NULL,
invocation_error text,
metadata jsonb,
created_at timestamp with time zone NOT NULL,
provider_tool_call_id text
created_at timestamp with time zone NOT NULL
);
COMMENT ON TABLE aibridge_tool_usages IS 'Audit log of tool calls in intercepted requests in AI Bridge';
@@ -1478,9 +1471,7 @@ CREATE TABLE oauth2_provider_app_codes (
app_id uuid NOT NULL,
resource_uri text,
code_challenge text,
code_challenge_method text,
state_hash text,
redirect_uri text
code_challenge_method text
);
COMMENT ON TABLE oauth2_provider_app_codes IS 'Codes are meant to be exchanged for access tokens.';
@@ -1491,10 +1482,6 @@ COMMENT ON COLUMN oauth2_provider_app_codes.code_challenge IS 'PKCE code challen
COMMENT ON COLUMN oauth2_provider_app_codes.code_challenge_method IS 'PKCE challenge method (S256)';
COMMENT ON COLUMN oauth2_provider_app_codes.state_hash IS 'SHA-256 hash of the OAuth2 state parameter, stored to prevent state reflection attacks.';
COMMENT ON COLUMN oauth2_provider_app_codes.redirect_uri IS 'The redirect_uri provided during authorization, to be verified during token exchange (RFC 6749 §4.1.3).';
CREATE TABLE oauth2_provider_app_secrets (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
@@ -2715,6 +2702,7 @@ CREATE VIEW workspace_build_with_user AS
workspace_builds.build_number,
workspace_builds.transition,
workspace_builds.initiator_id,
workspace_builds.provisioner_state,
workspace_builds.job_id,
workspace_builds.deadline,
workspace_builds.reason,
@@ -3297,18 +3285,12 @@ CREATE INDEX idx_aibridge_interceptions_provider ON aibridge_interceptions USING
CREATE INDEX idx_aibridge_interceptions_started_id_desc ON aibridge_interceptions USING btree (started_at DESC, id DESC);
CREATE INDEX idx_aibridge_interceptions_thread_parent_id ON aibridge_interceptions USING btree (thread_parent_id);
CREATE INDEX idx_aibridge_interceptions_thread_root_id ON aibridge_interceptions USING btree (thread_root_id);
CREATE INDEX idx_aibridge_token_usages_interception_id ON aibridge_token_usages USING btree (interception_id);
CREATE INDEX idx_aibridge_token_usages_provider_response_id ON aibridge_token_usages USING btree (provider_response_id);
CREATE INDEX idx_aibridge_tool_usages_interception_id ON aibridge_tool_usages USING btree (interception_id);
CREATE INDEX idx_aibridge_tool_usages_provider_tool_call_id ON aibridge_tool_usages USING btree (provider_tool_call_id);
CREATE INDEX idx_aibridge_tool_usagesprovider_response_id ON aibridge_tool_usages USING btree (provider_response_id);
CREATE INDEX idx_aibridge_user_prompts_interception_id ON aibridge_user_prompts USING btree (interception_id);
+4 -23
View File
@@ -51,34 +51,15 @@ func TestViewSubsetTemplateVersion(t *testing.T) {
}
}
// TestViewSubsetWorkspaceBuild ensures WorkspaceBuildTable is a subset of
// WorkspaceBuild, with the exception of ProvisionerState which is
// intentionally excluded from the workspace_build_with_user view to avoid
// loading the large Terraform state blob on hot paths.
// TestViewSubsetWorkspaceBuild ensures WorkspaceBuildTable is a subset of WorkspaceBuild
func TestViewSubsetWorkspaceBuild(t *testing.T) {
t.Parallel()
table := reflect.TypeOf(database.WorkspaceBuildTable{})
joined := reflect.TypeOf(database.WorkspaceBuild{})
tableFields := fieldNames(allFields(table))
joinedFields := fieldNames(allFields(joined))
// ProvisionerState is intentionally excluded from the
// workspace_build_with_user view to avoid loading multi-MB Terraform
// state blobs on hot paths. Callers that need it use
// GetWorkspaceBuildProvisionerStateByID instead.
excludedFields := map[string]bool{
"ProvisionerState": true,
}
var filtered []string
for _, name := range tableFields {
if !excludedFields[name] {
filtered = append(filtered, name)
}
}
if !assert.Subset(t, joinedFields, filtered, "table is not subset") {
tableFields := allFields(table)
joinedFields := allFields(joined)
if !assert.Subset(t, fieldNames(joinedFields), fieldNames(tableFields), "table is not subset") {
t.Log("Some fields were added to the WorkspaceBuild Table without updating the 'workspace_build_with_user' view.")
t.Log("See migration 000141_join_users_build_version.up.sql to create the view.")
}
@@ -1,3 +0,0 @@
ALTER TABLE oauth2_provider_app_codes
DROP COLUMN state_hash,
DROP COLUMN redirect_uri;
@@ -1,9 +0,0 @@
ALTER TABLE oauth2_provider_app_codes
ADD COLUMN state_hash text,
ADD COLUMN redirect_uri text;
COMMENT ON COLUMN oauth2_provider_app_codes.state_hash IS
'SHA-256 hash of the OAuth2 state parameter, stored to prevent state reflection attacks.';
COMMENT ON COLUMN oauth2_provider_app_codes.redirect_uri IS
'The redirect_uri provided during authorization, to be verified during token exchange (RFC 6749 §4.1.3).';
@@ -1,31 +0,0 @@
-- Restore provisioner_state to workspace_build_with_user view.
DROP VIEW workspace_build_with_user;
CREATE VIEW workspace_build_with_user AS
SELECT
workspace_builds.id,
workspace_builds.created_at,
workspace_builds.updated_at,
workspace_builds.workspace_id,
workspace_builds.template_version_id,
workspace_builds.build_number,
workspace_builds.transition,
workspace_builds.initiator_id,
workspace_builds.provisioner_state,
workspace_builds.job_id,
workspace_builds.deadline,
workspace_builds.reason,
workspace_builds.daily_cost,
workspace_builds.max_deadline,
workspace_builds.template_version_preset_id,
workspace_builds.has_ai_task,
workspace_builds.has_external_agent,
COALESCE(visible_users.avatar_url, ''::text) AS initiator_by_avatar_url,
COALESCE(visible_users.username, ''::text) AS initiator_by_username,
COALESCE(visible_users.name, ''::text) AS initiator_by_name
FROM
workspace_builds
LEFT JOIN
visible_users ON workspace_builds.initiator_id = visible_users.id;
COMMENT ON VIEW workspace_build_with_user IS 'Joins in the username + avatar url of the initiated by user.';
@@ -1,33 +0,0 @@
-- Drop and recreate workspace_build_with_user to exclude provisioner_state.
-- This avoids loading the large Terraform state blob (1-5 MB per workspace)
-- on every query that uses this view. The callers that need provisioner_state
-- now fetch it separately via GetWorkspaceBuildProvisionerStateByID.
DROP VIEW workspace_build_with_user;
CREATE VIEW workspace_build_with_user AS
SELECT
workspace_builds.id,
workspace_builds.created_at,
workspace_builds.updated_at,
workspace_builds.workspace_id,
workspace_builds.template_version_id,
workspace_builds.build_number,
workspace_builds.transition,
workspace_builds.initiator_id,
workspace_builds.job_id,
workspace_builds.deadline,
workspace_builds.reason,
workspace_builds.daily_cost,
workspace_builds.max_deadline,
workspace_builds.template_version_preset_id,
workspace_builds.has_ai_task,
workspace_builds.has_external_agent,
COALESCE(visible_users.avatar_url, ''::text) AS initiator_by_avatar_url,
COALESCE(visible_users.username, ''::text) AS initiator_by_username,
COALESCE(visible_users.name, ''::text) AS initiator_by_name
FROM
workspace_builds
LEFT JOIN
visible_users ON workspace_builds.initiator_id = visible_users.id;
COMMENT ON VIEW workspace_build_with_user IS 'Joins in the username + avatar url of the initiated by user.';
@@ -1,9 +0,0 @@
DROP INDEX IF EXISTS idx_aibridge_tool_usages_provider_tool_call_id;
ALTER TABLE aibridge_tool_usages
DROP COLUMN provider_tool_call_id;
DROP INDEX IF EXISTS idx_aibridge_interceptions_thread_parent_id;
ALTER TABLE aibridge_interceptions
DROP COLUMN thread_parent_id;
@@ -1,11 +0,0 @@
ALTER TABLE aibridge_tool_usages
ADD COLUMN provider_tool_call_id text NULL; -- nullable to allow existing data to be correct
CREATE INDEX idx_aibridge_tool_usages_provider_tool_call_id ON aibridge_tool_usages (provider_tool_call_id);
ALTER TABLE aibridge_interceptions
ADD COLUMN thread_parent_id UUID NULL;
COMMENT ON COLUMN aibridge_interceptions.thread_parent_id IS 'The interception which directly caused this interception to occur, usually through an agentic loop or threaded conversation.';
CREATE INDEX idx_aibridge_interceptions_thread_parent_id ON aibridge_interceptions (thread_parent_id);
@@ -1,4 +0,0 @@
DROP INDEX IF EXISTS idx_aibridge_interceptions_thread_root_id;
ALTER TABLE aibridge_interceptions
DROP COLUMN thread_root_id;
@@ -1,6 +0,0 @@
ALTER TABLE aibridge_interceptions
ADD COLUMN thread_root_id UUID NULL;
COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interception of the thread that this interception belongs to.';
CREATE INDEX idx_aibridge_interceptions_thread_root_id ON aibridge_interceptions (thread_root_id);
-8
View File
@@ -316,14 +316,6 @@ func (t GetFileTemplatesRow) RBACObject() rbac.Object {
WithGroupACL(t.GroupACL)
}
// RBACObject for a workspace build's provisioner state requires Update access of the template.
func (t GetWorkspaceBuildProvisionerStateByIDRow) RBACObject() rbac.Object {
return rbac.ResourceTemplate.WithID(t.TemplateID).
InOrg(t.TemplateOrganizationID).
WithACLUserList(t.UserACL).
WithGroupACL(t.GroupACL)
}
func (t Template) DeepCopy() Template {
cpy := t
cpy.UserACL = maps.Clone(t.UserACL)
-32
View File
@@ -769,7 +769,6 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
type aibridgeQuerier interface {
ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error)
CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error)
ListAuthorizedAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error)
}
func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error) {
@@ -813,8 +812,6 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
&i.AIBridgeInterception.EndedAt,
&i.AIBridgeInterception.APIKeyID,
&i.AIBridgeInterception.Client,
&i.AIBridgeInterception.ThreadParentID,
&i.AIBridgeInterception.ThreadRootID,
&i.VisibleUser.ID,
&i.VisibleUser.Username,
&i.VisibleUser.Name,
@@ -873,35 +870,6 @@ func (q *sqlQuerier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, a
return count, nil
}
func (q *sqlQuerier) ListAuthorizedAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error) {
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
VariableConverter: regosql.AIBridgeInterceptionConverter(),
})
if err != nil {
return nil, xerrors.Errorf("compile authorized filter: %w", err)
}
filtered, err := insertAuthorizedFilter(listAIBridgeModels, fmt.Sprintf(" AND %s", authorizedFilter))
if err != nil {
return nil, xerrors.Errorf("insert authorized filter: %w", err)
}
query := fmt.Sprintf("-- name: ListAIBridgeModels :many\n%s", filtered)
rows, err := q.db.QueryContext(ctx, query, arg.Model, arg.Offset, arg.Limit)
if err != nil {
return nil, err
}
defer rows.Close()
var items []string
for rows.Next() {
var model string
if err := rows.Scan(&model); err != nil {
return nil, err
}
items = append(items, model)
}
return items, nil
}
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
if !strings.Contains(query, authorizedQueryPlaceholder) {
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")
+4 -12
View File
@@ -3643,10 +3643,6 @@ type AIBridgeInterception struct {
EndedAt sql.NullTime `db:"ended_at" json:"ended_at"`
APIKeyID sql.NullString `db:"api_key_id" json:"api_key_id"`
Client sql.NullString `db:"client" json:"client"`
// The interception which directly caused this interception to occur, usually through an agentic loop or threaded conversation.
ThreadParentID uuid.NullUUID `db:"thread_parent_id" json:"thread_parent_id"`
// The root interception of the thread that this interception belongs to.
ThreadRootID uuid.NullUUID `db:"thread_root_id" json:"thread_root_id"`
}
// Audit log of tokens used by intercepted requests in AI Bridge
@@ -3674,10 +3670,9 @@ type AIBridgeToolUsage struct {
// Whether this tool was injected; i.e. Bridge injected these tools into the request from an MCP server. If false it means a tool was defined by the client and already existed in the request (MCP or built-in).
Injected bool `db:"injected" json:"injected"`
// Only injected tools are invoked.
InvocationError sql.NullString `db:"invocation_error" json:"invocation_error"`
Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
ProviderToolCallID sql.NullString `db:"provider_tool_call_id" json:"provider_tool_call_id"`
InvocationError sql.NullString `db:"invocation_error" json:"invocation_error"`
Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
}
// Audit log of prompts used by intercepted requests in AI Bridge
@@ -4031,10 +4026,6 @@ type OAuth2ProviderAppCode struct {
CodeChallenge sql.NullString `db:"code_challenge" json:"code_challenge"`
// PKCE challenge method (S256)
CodeChallengeMethod sql.NullString `db:"code_challenge_method" json:"code_challenge_method"`
// SHA-256 hash of the OAuth2 state parameter, stored to prevent state reflection attacks.
StateHash sql.NullString `db:"state_hash" json:"state_hash"`
// The redirect_uri provided during authorization, to be verified during token exchange (RFC 6749 §4.1.3).
RedirectUri sql.NullString `db:"redirect_uri" json:"redirect_uri"`
}
type OAuth2ProviderAppSecret struct {
@@ -4992,6 +4983,7 @@ type WorkspaceBuild struct {
BuildNumber int32 `db:"build_number" json:"build_number"`
Transition WorkspaceTransition `db:"transition" json:"transition"`
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
ProvisionerState []byte `db:"provisioner_state" json:"provisioner_state"`
JobID uuid.UUID `db:"job_id" json:"job_id"`
Deadline time.Time `db:"deadline" json:"deadline"`
Reason BuildReason `db:"reason" json:"reason"`
+1 -29
View File
@@ -162,11 +162,6 @@ type sqlcQuerier interface {
// and returns the preset with the most parameters (largest subset).
FindMatchingPresetID(ctx context.Context, arg FindMatchingPresetIDParams) (uuid.UUID, error)
GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UUID) (AIBridgeInterception, error)
// Look up the parent interception and the root of the thread by finding
// which interception recorded a tool usage with the given tool call ID.
// COALESCE ensures that if the parent has no thread_root_id (i.e. it IS
// the root), we return its own ID as the root.
GetAIBridgeInterceptionLineageByToolCallID(ctx context.Context, toolCallID string) (GetAIBridgeInterceptionLineageByToolCallIDRow, error)
GetAIBridgeInterceptions(ctx context.Context) ([]AIBridgeInterception, error)
GetAIBridgeTokenUsagesByInterceptionID(ctx context.Context, interceptionID uuid.UUID) ([]AIBridgeTokenUsage, error)
GetAIBridgeToolUsagesByInterceptionID(ctx context.Context, interceptionID uuid.UUID) ([]AIBridgeToolUsage, error)
@@ -174,7 +169,7 @@ type sqlcQuerier interface {
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
// there is no unique constraint on empty token names
GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNameParams) (APIKey, error)
GetAPIKeysByLoginType(ctx context.Context, arg GetAPIKeysByLoginTypeParams) ([]APIKey, error)
GetAPIKeysByLoginType(ctx context.Context, loginType LoginType) ([]APIKey, error)
GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error)
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
GetActivePresetPrebuildSchedules(ctx context.Context) ([]TemplateVersionPresetPrebuildSchedule, error)
@@ -366,23 +361,6 @@ type sqlcQuerier interface {
GetTaskSnapshot(ctx context.Context, taskID uuid.UUID) (TaskSnapshot, error)
GetTelemetryItem(ctx context.Context, key string) (TelemetryItem, error)
GetTelemetryItems(ctx context.Context) ([]TelemetryItem, error)
// Returns all data needed to build task lifecycle events for telemetry
// in a single round-trip. For each task whose workspace is in the
// given set, fetches:
// - the latest workspace app binding (task_workspace_apps)
// - the most recent stop and start builds (workspace_builds)
// - the last "working" app status (workspace_app_statuses)
// - the first app status after resume, for active workspaces
//
// Assumptions:
// - 1:1 relationship between tasks and workspaces. All builds on the
// workspace are considered task-related.
// - Idle duration approximation: If the agent reports "working", does
// work, then reports "done", we miss that working time.
// - lws and active_dur join across all historical app IDs for the task,
// because each resume cycle provisions a new app ID. This ensures
// pre-pause statuses contribute to idle duration and active duration.
GetTelemetryTaskEvents(ctx context.Context, arg GetTelemetryTaskEventsParams) ([]GetTelemetryTaskEventsRow, error)
// GetTemplateAppInsights returns the aggregate usage of each app in a given
// timeframe. The result can be filtered on template_ids, meaning only user data
// from workspaces based on those templates will be included.
@@ -528,11 +506,6 @@ type sqlcQuerier interface {
GetWorkspaceBuildMetricsByResourceID(ctx context.Context, id uuid.UUID) (GetWorkspaceBuildMetricsByResourceIDRow, error)
GetWorkspaceBuildParameters(ctx context.Context, workspaceBuildID uuid.UUID) ([]WorkspaceBuildParameter, error)
GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIds []uuid.UUID) ([]WorkspaceBuildParameter, error)
// Fetches the provisioner state of a workspace build, joined through to the
// template so that dbauthz can enforce policy.ActionUpdate on the template.
// Provisioner state contains sensitive Terraform state and should only be
// accessible to template administrators.
GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (GetWorkspaceBuildProvisionerStateByIDRow, error)
GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]GetWorkspaceBuildStatsByTemplatesRow, error)
GetWorkspaceBuildsByWorkspaceID(ctx context.Context, arg GetWorkspaceBuildsByWorkspaceIDParams) ([]WorkspaceBuild, error)
GetWorkspaceBuildsCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceBuild, error)
@@ -658,7 +631,6 @@ type sqlcQuerier interface {
// Finds all unique AI Bridge interception telemetry summaries combinations
// (provider, model, client) in the given timeframe for telemetry reporting.
ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error)
ListAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams) ([]string, error)
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
+6 -9
View File
@@ -8195,9 +8195,8 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
// All keys are present before deletion
keys, err := db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
LoginType: user.LoginType,
UserID: user.ID,
IncludeExpired: true,
LoginType: user.LoginType,
UserID: user.ID,
})
require.NoError(t, err)
require.Len(t, keys, len(expiredTimes)+len(unexpiredTimes))
@@ -8213,9 +8212,8 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
// Ensure it was deleted
remaining, err := db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
LoginType: user.LoginType,
UserID: user.ID,
IncludeExpired: true,
LoginType: user.LoginType,
UserID: user.ID,
})
require.NoError(t, err)
require.Len(t, remaining, len(expiredTimes)+len(unexpiredTimes)-1)
@@ -8230,9 +8228,8 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
// Ensure only unexpired keys remain
remaining, err = db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
LoginType: user.LoginType,
UserID: user.ID,
IncludeExpired: true,
LoginType: user.LoginType,
UserID: user.ID,
})
require.NoError(t, err)
require.Len(t, remaining, len(unexpiredTimes))
+48 -403
View File
@@ -378,7 +378,7 @@ func (q *sqlQuerier) DeleteOldAIBridgeRecords(ctx context.Context, beforeTime ti
const getAIBridgeInterceptionByID = `-- name: GetAIBridgeInterceptionByID :one
SELECT
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client
FROM
aibridge_interceptions
WHERE
@@ -398,45 +398,13 @@ func (q *sqlQuerier) GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UU
&i.EndedAt,
&i.APIKeyID,
&i.Client,
&i.ThreadParentID,
&i.ThreadRootID,
)
return i, err
}
const getAIBridgeInterceptionLineageByToolCallID = `-- name: GetAIBridgeInterceptionLineageByToolCallID :one
WITH linked AS (
SELECT interception_id FROM aibridge_tool_usages
WHERE provider_tool_call_id = $1::text
ORDER BY created_at DESC
LIMIT 1
)
SELECT linked.interception_id AS thread_parent_id,
COALESCE(aibridge_interceptions.thread_root_id, linked.interception_id) AS thread_root_id
FROM aibridge_interceptions
INNER JOIN linked ON linked.interception_id = aibridge_interceptions.id
WHERE aibridge_interceptions.id = linked.interception_id
`
type GetAIBridgeInterceptionLineageByToolCallIDRow struct {
ThreadParentID uuid.UUID `db:"thread_parent_id" json:"thread_parent_id"`
ThreadRootID uuid.UUID `db:"thread_root_id" json:"thread_root_id"`
}
// Look up the parent interception and the root of the thread by finding
// which interception recorded a tool usage with the given tool call ID.
// COALESCE ensures that if the parent has no thread_root_id (i.e. it IS
// the root), we return its own ID as the root.
func (q *sqlQuerier) GetAIBridgeInterceptionLineageByToolCallID(ctx context.Context, toolCallID string) (GetAIBridgeInterceptionLineageByToolCallIDRow, error) {
row := q.db.QueryRowContext(ctx, getAIBridgeInterceptionLineageByToolCallID, toolCallID)
var i GetAIBridgeInterceptionLineageByToolCallIDRow
err := row.Scan(&i.ThreadParentID, &i.ThreadRootID)
return i, err
}
const getAIBridgeInterceptions = `-- name: GetAIBridgeInterceptions :many
SELECT
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client
FROM
aibridge_interceptions
`
@@ -460,8 +428,6 @@ func (q *sqlQuerier) GetAIBridgeInterceptions(ctx context.Context) ([]AIBridgeIn
&i.EndedAt,
&i.APIKeyID,
&i.Client,
&i.ThreadParentID,
&i.ThreadRootID,
); err != nil {
return nil, err
}
@@ -519,7 +485,7 @@ func (q *sqlQuerier) GetAIBridgeTokenUsagesByInterceptionID(ctx context.Context,
const getAIBridgeToolUsagesByInterceptionID = `-- name: GetAIBridgeToolUsagesByInterceptionID :many
SELECT
id, interception_id, provider_response_id, server_url, tool, input, injected, invocation_error, metadata, created_at, provider_tool_call_id
id, interception_id, provider_response_id, server_url, tool, input, injected, invocation_error, metadata, created_at
FROM
aibridge_tool_usages
WHERE
@@ -549,7 +515,6 @@ func (q *sqlQuerier) GetAIBridgeToolUsagesByInterceptionID(ctx context.Context,
&i.InvocationError,
&i.Metadata,
&i.CreatedAt,
&i.ProviderToolCallID,
); err != nil {
return nil, err
}
@@ -608,24 +573,22 @@ func (q *sqlQuerier) GetAIBridgeUserPromptsByInterceptionID(ctx context.Context,
const insertAIBridgeInterception = `-- name: InsertAIBridgeInterception :one
INSERT INTO aibridge_interceptions (
id, api_key_id, initiator_id, provider, model, metadata, started_at, client, thread_parent_id, thread_root_id
id, api_key_id, initiator_id, provider, model, metadata, started_at, client
) VALUES (
$1, $2, $3, $4, $5, COALESCE($6::jsonb, '{}'::jsonb), $7, $8, $9::uuid, $10::uuid
$1, $2, $3, $4, $5, COALESCE($6::jsonb, '{}'::jsonb), $7, $8
)
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client
`
type InsertAIBridgeInterceptionParams struct {
ID uuid.UUID `db:"id" json:"id"`
APIKeyID sql.NullString `db:"api_key_id" json:"api_key_id"`
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
Provider string `db:"provider" json:"provider"`
Model string `db:"model" json:"model"`
Metadata json.RawMessage `db:"metadata" json:"metadata"`
StartedAt time.Time `db:"started_at" json:"started_at"`
Client sql.NullString `db:"client" json:"client"`
ThreadParentInterceptionID uuid.NullUUID `db:"thread_parent_interception_id" json:"thread_parent_interception_id"`
ThreadRootInterceptionID uuid.NullUUID `db:"thread_root_interception_id" json:"thread_root_interception_id"`
ID uuid.UUID `db:"id" json:"id"`
APIKeyID sql.NullString `db:"api_key_id" json:"api_key_id"`
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
Provider string `db:"provider" json:"provider"`
Model string `db:"model" json:"model"`
Metadata json.RawMessage `db:"metadata" json:"metadata"`
StartedAt time.Time `db:"started_at" json:"started_at"`
Client sql.NullString `db:"client" json:"client"`
}
func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertAIBridgeInterceptionParams) (AIBridgeInterception, error) {
@@ -638,8 +601,6 @@ func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertA
arg.Metadata,
arg.StartedAt,
arg.Client,
arg.ThreadParentInterceptionID,
arg.ThreadRootInterceptionID,
)
var i AIBridgeInterception
err := row.Scan(
@@ -652,8 +613,6 @@ func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertA
&i.EndedAt,
&i.APIKeyID,
&i.Client,
&i.ThreadParentID,
&i.ThreadRootID,
)
return i, err
}
@@ -702,18 +661,17 @@ func (q *sqlQuerier) InsertAIBridgeTokenUsage(ctx context.Context, arg InsertAIB
const insertAIBridgeToolUsage = `-- name: InsertAIBridgeToolUsage :one
INSERT INTO aibridge_tool_usages (
id, interception_id, provider_response_id, provider_tool_call_id, tool, server_url, input, injected, invocation_error, metadata, created_at
id, interception_id, provider_response_id, tool, server_url, input, injected, invocation_error, metadata, created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, COALESCE($10::jsonb, '{}'::jsonb), $11
$1, $2, $3, $4, $5, $6, $7, $8, COALESCE($9::jsonb, '{}'::jsonb), $10
)
RETURNING id, interception_id, provider_response_id, server_url, tool, input, injected, invocation_error, metadata, created_at, provider_tool_call_id
RETURNING id, interception_id, provider_response_id, server_url, tool, input, injected, invocation_error, metadata, created_at
`
type InsertAIBridgeToolUsageParams struct {
ID uuid.UUID `db:"id" json:"id"`
InterceptionID uuid.UUID `db:"interception_id" json:"interception_id"`
ProviderResponseID string `db:"provider_response_id" json:"provider_response_id"`
ProviderToolCallID sql.NullString `db:"provider_tool_call_id" json:"provider_tool_call_id"`
Tool string `db:"tool" json:"tool"`
ServerUrl sql.NullString `db:"server_url" json:"server_url"`
Input string `db:"input" json:"input"`
@@ -728,7 +686,6 @@ func (q *sqlQuerier) InsertAIBridgeToolUsage(ctx context.Context, arg InsertAIBr
arg.ID,
arg.InterceptionID,
arg.ProviderResponseID,
arg.ProviderToolCallID,
arg.Tool,
arg.ServerUrl,
arg.Input,
@@ -749,7 +706,6 @@ func (q *sqlQuerier) InsertAIBridgeToolUsage(ctx context.Context, arg InsertAIBr
&i.InvocationError,
&i.Metadata,
&i.CreatedAt,
&i.ProviderToolCallID,
)
return i, err
}
@@ -795,7 +751,7 @@ func (q *sqlQuerier) InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIB
const listAIBridgeInterceptions = `-- name: ListAIBridgeInterceptions :many
SELECT
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id,
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client,
visible_users.id, visible_users.username, visible_users.name, visible_users.avatar_url
FROM
aibridge_interceptions
@@ -904,8 +860,6 @@ func (q *sqlQuerier) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBr
&i.AIBridgeInterception.EndedAt,
&i.AIBridgeInterception.APIKeyID,
&i.AIBridgeInterception.Client,
&i.AIBridgeInterception.ThreadParentID,
&i.AIBridgeInterception.ThreadRootID,
&i.VisibleUser.ID,
&i.VisibleUser.Username,
&i.VisibleUser.Name,
@@ -974,60 +928,6 @@ func (q *sqlQuerier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Con
return items, nil
}
const listAIBridgeModels = `-- name: ListAIBridgeModels :many
SELECT
model
FROM
aibridge_interceptions
WHERE
-- Remove inflight interceptions (ones which lack an ended_at value).
aibridge_interceptions.ended_at IS NOT NULL
-- Filter model
AND CASE
WHEN $1::text != '' THEN aibridge_interceptions.model LIKE $1::text || '%'
ELSE true
END
-- We use an ` + "`" + `@authorize_filter` + "`" + ` as we are attempting to list models that are relevant
-- to the user and what they are allowed to see.
-- Authorize Filter clause will be injected below in ListAIBridgeModelsAuthorized
-- @authorize_filter
GROUP BY
model
ORDER BY
model ASC
LIMIT COALESCE(NULLIF($3::integer, 0), 100)
OFFSET $2
`
type ListAIBridgeModelsParams struct {
Model string `db:"model" json:"model"`
Offset int32 `db:"offset_" json:"offset_"`
Limit int32 `db:"limit_" json:"limit_"`
}
func (q *sqlQuerier) ListAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams) ([]string, error) {
rows, err := q.db.QueryContext(ctx, listAIBridgeModels, arg.Model, arg.Offset, arg.Limit)
if err != nil {
return nil, err
}
defer rows.Close()
var items []string
for rows.Next() {
var model string
if err := rows.Scan(&model); err != nil {
return nil, err
}
items = append(items, model)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listAIBridgeTokenUsagesByInterceptionIDs = `-- name: ListAIBridgeTokenUsagesByInterceptionIDs :many
SELECT
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
@@ -1073,7 +973,7 @@ func (q *sqlQuerier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Contex
const listAIBridgeToolUsagesByInterceptionIDs = `-- name: ListAIBridgeToolUsagesByInterceptionIDs :many
SELECT
id, interception_id, provider_response_id, server_url, tool, input, injected, invocation_error, metadata, created_at, provider_tool_call_id
id, interception_id, provider_response_id, server_url, tool, input, injected, invocation_error, metadata, created_at
FROM
aibridge_tool_usages
WHERE
@@ -1103,7 +1003,6 @@ func (q *sqlQuerier) ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context
&i.InvocationError,
&i.Metadata,
&i.CreatedAt,
&i.ProviderToolCallID,
); err != nil {
return nil, err
}
@@ -1166,7 +1065,7 @@ UPDATE aibridge_interceptions
WHERE
id = $2::uuid
AND ended_at IS NULL
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client
`
type UpdateAIBridgeInterceptionEndedParams struct {
@@ -1187,8 +1086,6 @@ func (q *sqlQuerier) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg Up
&i.EndedAt,
&i.APIKeyID,
&i.Client,
&i.ThreadParentID,
&i.ThreadRootID,
)
return i, err
}
@@ -1373,16 +1270,10 @@ func (q *sqlQuerier) GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNamePar
const getAPIKeysByLoginType = `-- name: GetAPIKeysByLoginType :many
SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address, token_name, scopes, allow_list FROM api_keys WHERE login_type = $1
AND ($2::bool OR expires_at > now())
`
type GetAPIKeysByLoginTypeParams struct {
LoginType LoginType `db:"login_type" json:"login_type"`
IncludeExpired bool `db:"include_expired" json:"include_expired"`
}
func (q *sqlQuerier) GetAPIKeysByLoginType(ctx context.Context, arg GetAPIKeysByLoginTypeParams) ([]APIKey, error) {
rows, err := q.db.QueryContext(ctx, getAPIKeysByLoginType, arg.LoginType, arg.IncludeExpired)
func (q *sqlQuerier) GetAPIKeysByLoginType(ctx context.Context, loginType LoginType) ([]APIKey, error) {
rows, err := q.db.QueryContext(ctx, getAPIKeysByLoginType, loginType)
if err != nil {
return nil, err
}
@@ -1420,17 +1311,15 @@ func (q *sqlQuerier) GetAPIKeysByLoginType(ctx context.Context, arg GetAPIKeysBy
const getAPIKeysByUserID = `-- name: GetAPIKeysByUserID :many
SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address, token_name, scopes, allow_list FROM api_keys WHERE login_type = $1 AND user_id = $2
AND ($3::bool OR expires_at > now())
`
type GetAPIKeysByUserIDParams struct {
LoginType LoginType `db:"login_type" json:"login_type"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
IncludeExpired bool `db:"include_expired" json:"include_expired"`
LoginType LoginType `db:"login_type" json:"login_type"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
}
func (q *sqlQuerier) GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error) {
rows, err := q.db.QueryContext(ctx, getAPIKeysByUserID, arg.LoginType, arg.UserID, arg.IncludeExpired)
rows, err := q.db.QueryContext(ctx, getAPIKeysByUserID, arg.LoginType, arg.UserID)
if err != nil {
return nil, err
}
@@ -6879,7 +6768,7 @@ func (q *sqlQuerier) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context
}
const getOAuth2ProviderAppCodeByID = `-- name: GetOAuth2ProviderAppCodeByID :one
SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method, state_hash, redirect_uri FROM oauth2_provider_app_codes WHERE id = $1
SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method FROM oauth2_provider_app_codes WHERE id = $1
`
func (q *sqlQuerier) GetOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) (OAuth2ProviderAppCode, error) {
@@ -6896,14 +6785,12 @@ func (q *sqlQuerier) GetOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.U
&i.ResourceUri,
&i.CodeChallenge,
&i.CodeChallengeMethod,
&i.StateHash,
&i.RedirectUri,
)
return i, err
}
const getOAuth2ProviderAppCodeByPrefix = `-- name: GetOAuth2ProviderAppCodeByPrefix :one
SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method, state_hash, redirect_uri FROM oauth2_provider_app_codes WHERE secret_prefix = $1
SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method FROM oauth2_provider_app_codes WHERE secret_prefix = $1
`
func (q *sqlQuerier) GetOAuth2ProviderAppCodeByPrefix(ctx context.Context, secretPrefix []byte) (OAuth2ProviderAppCode, error) {
@@ -6920,8 +6807,6 @@ func (q *sqlQuerier) GetOAuth2ProviderAppCodeByPrefix(ctx context.Context, secre
&i.ResourceUri,
&i.CodeChallenge,
&i.CodeChallengeMethod,
&i.StateHash,
&i.RedirectUri,
)
return i, err
}
@@ -7325,9 +7210,7 @@ INSERT INTO oauth2_provider_app_codes (
user_id,
resource_uri,
code_challenge,
code_challenge_method,
state_hash,
redirect_uri
code_challenge_method
) VALUES(
$1,
$2,
@@ -7338,10 +7221,8 @@ INSERT INTO oauth2_provider_app_codes (
$7,
$8,
$9,
$10,
$11,
$12
) RETURNING id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method, state_hash, redirect_uri
$10
) RETURNING id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method
`
type InsertOAuth2ProviderAppCodeParams struct {
@@ -7355,8 +7236,6 @@ type InsertOAuth2ProviderAppCodeParams struct {
ResourceUri sql.NullString `db:"resource_uri" json:"resource_uri"`
CodeChallenge sql.NullString `db:"code_challenge" json:"code_challenge"`
CodeChallengeMethod sql.NullString `db:"code_challenge_method" json:"code_challenge_method"`
StateHash sql.NullString `db:"state_hash" json:"state_hash"`
RedirectUri sql.NullString `db:"redirect_uri" json:"redirect_uri"`
}
func (q *sqlQuerier) InsertOAuth2ProviderAppCode(ctx context.Context, arg InsertOAuth2ProviderAppCodeParams) (OAuth2ProviderAppCode, error) {
@@ -7371,8 +7250,6 @@ func (q *sqlQuerier) InsertOAuth2ProviderAppCode(ctx context.Context, arg Insert
arg.ResourceUri,
arg.CodeChallenge,
arg.CodeChallengeMethod,
arg.StateHash,
arg.RedirectUri,
)
var i OAuth2ProviderAppCode
err := row.Scan(
@@ -7386,8 +7263,6 @@ func (q *sqlQuerier) InsertOAuth2ProviderAppCode(ctx context.Context, arg Insert
&i.ResourceUri,
&i.CodeChallenge,
&i.CodeChallengeMethod,
&i.StateHash,
&i.RedirectUri,
)
return i, err
}
@@ -13437,203 +13312,6 @@ func (q *sqlQuerier) GetTaskSnapshot(ctx context.Context, taskID uuid.UUID) (Tas
return i, err
}
const getTelemetryTaskEvents = `-- name: GetTelemetryTaskEvents :many
WITH task_app_ids AS (
SELECT task_id, workspace_app_id
FROM task_workspace_apps
),
task_status_timeline AS (
-- All app statuses across every historical app for each task,
-- plus synthetic "boundary" rows at each stop/start build transition.
-- This allows us to correctly take gaps due to pause/resume into account.
SELECT tai.task_id, was.created_at, was.state::text AS state
FROM workspace_app_statuses was
JOIN task_app_ids tai ON tai.workspace_app_id = was.app_id
UNION ALL
SELECT t.id AS task_id, wb.created_at, '_boundary' AS state
FROM tasks t
JOIN workspace_builds wb ON wb.workspace_id = t.workspace_id
WHERE t.deleted_at IS NULL
AND t.workspace_id IS NOT NULL
AND wb.build_number > 1
),
task_event_data AS (
SELECT
t.id AS task_id,
t.workspace_id,
twa.workspace_app_id,
-- Latest stop build.
stop_build.created_at AS stop_build_created_at,
stop_build.reason AS stop_build_reason,
-- Latest start build (task_resume only).
start_build.created_at AS start_build_created_at,
start_build.reason AS start_build_reason,
start_build.build_number AS start_build_number,
-- Last "working" app status (for idle duration).
lws.created_at AS last_working_status_at,
-- First app status after resume (for resume-to-status duration).
-- Only populated for workspaces in an active phase (started more
-- recently than stopped).
fsar.created_at AS first_status_after_resume_at,
-- Cumulative time spent in "working" state.
active_dur.total_working_ms AS active_duration_ms
FROM tasks t
LEFT JOIN LATERAL (
SELECT task_app.workspace_app_id
FROM task_workspace_apps task_app
WHERE task_app.task_id = t.id
ORDER BY task_app.workspace_build_number DESC
LIMIT 1
) twa ON TRUE
LEFT JOIN LATERAL (
SELECT wb.created_at, wb.reason, wb.build_number
FROM workspace_builds wb
WHERE wb.workspace_id = t.workspace_id
AND wb.transition = 'stop'
ORDER BY wb.build_number DESC
LIMIT 1
) stop_build ON TRUE
LEFT JOIN LATERAL (
SELECT wb.created_at, wb.reason, wb.build_number
FROM workspace_builds wb
WHERE wb.workspace_id = t.workspace_id
AND wb.transition = 'start'
ORDER BY wb.build_number DESC
LIMIT 1
) start_build ON TRUE
LEFT JOIN LATERAL (
SELECT tst.created_at
FROM task_status_timeline tst
WHERE tst.task_id = t.id
AND tst.state = 'working'
-- Only consider status before the latest pause so that
-- post-resume statuses don't mask pre-pause idle time.
AND (stop_build.created_at IS NULL
OR tst.created_at <= stop_build.created_at)
ORDER BY tst.created_at DESC
LIMIT 1
) lws ON TRUE
LEFT JOIN LATERAL (
SELECT was.created_at
FROM workspace_app_statuses was
WHERE was.app_id = twa.workspace_app_id
AND was.created_at > start_build.created_at
ORDER BY was.created_at ASC
LIMIT 1
) fsar ON twa.workspace_app_id IS NOT NULL
AND start_build.created_at IS NOT NULL
AND (stop_build.created_at IS NULL
OR start_build.created_at > stop_build.created_at)
-- Active duration: cumulative time spent in "working" state across all
-- historical app IDs for this task. Uses LEAD() to convert point-in-time
-- statuses into intervals, then sums intervals where state='working'. For
-- the last status, falls back to stop_build time (if paused) or @now (if
-- still running).
LEFT JOIN LATERAL (
SELECT COALESCE(
SUM(EXTRACT(EPOCH FROM (interval_end - interval_start)) * 1000)::bigint,
0
)::bigint AS total_working_ms
FROM (
SELECT
tst.created_at AS interval_start,
COALESCE(
LEAD(tst.created_at) OVER (ORDER BY tst.created_at ASC, CASE WHEN tst.state = '_boundary' THEN 1 ELSE 0 END ASC),
CASE WHEN stop_build.created_at IS NOT NULL
AND (start_build.created_at IS NULL
OR stop_build.created_at > start_build.created_at)
THEN stop_build.created_at
ELSE $1::timestamptz
END
) AS interval_end,
tst.state
FROM task_status_timeline tst
WHERE tst.task_id = t.id
) intervals
WHERE intervals.state = 'working'
) active_dur ON TRUE
WHERE t.deleted_at IS NULL
AND t.workspace_id IS NOT NULL
AND EXISTS (
SELECT 1 FROM workspace_builds wb
WHERE wb.workspace_id = t.workspace_id
AND wb.created_at > $2
)
)
SELECT task_id, workspace_id, workspace_app_id, stop_build_created_at, stop_build_reason, start_build_created_at, start_build_reason, start_build_number, last_working_status_at, first_status_after_resume_at, active_duration_ms FROM task_event_data
ORDER BY task_id
`
type GetTelemetryTaskEventsParams struct {
Now time.Time `db:"now" json:"now"`
CreatedAfter time.Time `db:"created_after" json:"created_after"`
}
type GetTelemetryTaskEventsRow struct {
TaskID uuid.UUID `db:"task_id" json:"task_id"`
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
WorkspaceAppID uuid.NullUUID `db:"workspace_app_id" json:"workspace_app_id"`
StopBuildCreatedAt sql.NullTime `db:"stop_build_created_at" json:"stop_build_created_at"`
StopBuildReason NullBuildReason `db:"stop_build_reason" json:"stop_build_reason"`
StartBuildCreatedAt sql.NullTime `db:"start_build_created_at" json:"start_build_created_at"`
StartBuildReason NullBuildReason `db:"start_build_reason" json:"start_build_reason"`
StartBuildNumber sql.NullInt32 `db:"start_build_number" json:"start_build_number"`
LastWorkingStatusAt sql.NullTime `db:"last_working_status_at" json:"last_working_status_at"`
FirstStatusAfterResumeAt sql.NullTime `db:"first_status_after_resume_at" json:"first_status_after_resume_at"`
ActiveDurationMs int64 `db:"active_duration_ms" json:"active_duration_ms"`
}
// Returns all data needed to build task lifecycle events for telemetry
// in a single round-trip. For each task whose workspace is in the
// given set, fetches:
// - the latest workspace app binding (task_workspace_apps)
// - the most recent stop and start builds (workspace_builds)
// - the last "working" app status (workspace_app_statuses)
// - the first app status after resume, for active workspaces
//
// Assumptions:
// - 1:1 relationship between tasks and workspaces. All builds on the
// workspace are considered task-related.
// - Idle duration approximation: If the agent reports "working", does
// work, then reports "done", we miss that working time.
// - lws and active_dur join across all historical app IDs for the task,
// because each resume cycle provisions a new app ID. This ensures
// pre-pause statuses contribute to idle duration and active duration.
func (q *sqlQuerier) GetTelemetryTaskEvents(ctx context.Context, arg GetTelemetryTaskEventsParams) ([]GetTelemetryTaskEventsRow, error) {
rows, err := q.db.QueryContext(ctx, getTelemetryTaskEvents, arg.Now, arg.CreatedAfter)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetTelemetryTaskEventsRow
for rows.Next() {
var i GetTelemetryTaskEventsRow
if err := rows.Scan(
&i.TaskID,
&i.WorkspaceID,
&i.WorkspaceAppID,
&i.StopBuildCreatedAt,
&i.StopBuildReason,
&i.StartBuildCreatedAt,
&i.StartBuildReason,
&i.StartBuildNumber,
&i.LastWorkingStatusAt,
&i.FirstStatusAfterResumeAt,
&i.ActiveDurationMs,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertTask = `-- name: InsertTask :one
INSERT INTO tasks
(id, organization_id, owner_id, name, display_name, workspace_id, template_version_id, template_parameters, prompt, created_at)
@@ -18263,7 +17941,7 @@ const getAuthenticatedWorkspaceAgentAndBuildByAuthToken = `-- name: GetAuthentic
SELECT
workspaces.id, workspaces.created_at, workspaces.updated_at, workspaces.owner_id, workspaces.organization_id, workspaces.template_id, workspaces.deleted, workspaces.name, workspaces.autostart_schedule, workspaces.ttl, workspaces.last_used_at, workspaces.dormant_at, workspaces.deleting_at, workspaces.automatic_updates, workspaces.favorite, workspaces.next_start_at, workspaces.group_acl, workspaces.user_acl,
workspace_agents.id, workspace_agents.created_at, workspace_agents.updated_at, workspace_agents.name, workspace_agents.first_connected_at, workspace_agents.last_connected_at, workspace_agents.disconnected_at, workspace_agents.resource_id, workspace_agents.auth_token, workspace_agents.auth_instance_id, workspace_agents.architecture, workspace_agents.environment_variables, workspace_agents.operating_system, workspace_agents.instance_metadata, workspace_agents.resource_metadata, workspace_agents.directory, workspace_agents.version, workspace_agents.last_connected_replica_id, workspace_agents.connection_timeout_seconds, workspace_agents.troubleshooting_url, workspace_agents.motd_file, workspace_agents.lifecycle_state, workspace_agents.expanded_directory, workspace_agents.logs_length, workspace_agents.logs_overflowed, workspace_agents.started_at, workspace_agents.ready_at, workspace_agents.subsystems, workspace_agents.display_apps, workspace_agents.api_version, workspace_agents.display_order, workspace_agents.parent_id, workspace_agents.api_key_scope, workspace_agents.deleted,
workspace_build_with_user.id, workspace_build_with_user.created_at, workspace_build_with_user.updated_at, workspace_build_with_user.workspace_id, workspace_build_with_user.template_version_id, workspace_build_with_user.build_number, workspace_build_with_user.transition, workspace_build_with_user.initiator_id, workspace_build_with_user.job_id, workspace_build_with_user.deadline, workspace_build_with_user.reason, workspace_build_with_user.daily_cost, workspace_build_with_user.max_deadline, workspace_build_with_user.template_version_preset_id, workspace_build_with_user.has_ai_task, workspace_build_with_user.has_external_agent, workspace_build_with_user.initiator_by_avatar_url, workspace_build_with_user.initiator_by_username, workspace_build_with_user.initiator_by_name,
workspace_build_with_user.id, workspace_build_with_user.created_at, workspace_build_with_user.updated_at, workspace_build_with_user.workspace_id, workspace_build_with_user.template_version_id, workspace_build_with_user.build_number, workspace_build_with_user.transition, workspace_build_with_user.initiator_id, workspace_build_with_user.provisioner_state, workspace_build_with_user.job_id, workspace_build_with_user.deadline, workspace_build_with_user.reason, workspace_build_with_user.daily_cost, workspace_build_with_user.max_deadline, workspace_build_with_user.template_version_preset_id, workspace_build_with_user.has_ai_task, workspace_build_with_user.has_external_agent, workspace_build_with_user.initiator_by_avatar_url, workspace_build_with_user.initiator_by_username, workspace_build_with_user.initiator_by_name,
tasks.id AS task_id
FROM
workspace_agents
@@ -18401,6 +18079,7 @@ func (q *sqlQuerier) GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx conte
&i.WorkspaceBuild.BuildNumber,
&i.WorkspaceBuild.Transition,
&i.WorkspaceBuild.InitiatorID,
&i.WorkspaceBuild.ProvisionerState,
&i.WorkspaceBuild.JobID,
&i.WorkspaceBuild.Deadline,
&i.WorkspaceBuild.Reason,
@@ -21314,7 +20993,7 @@ func (q *sqlQuerier) InsertWorkspaceBuildParameters(ctx context.Context, arg Ins
}
const getActiveWorkspaceBuildsByTemplateID = `-- name: GetActiveWorkspaceBuildsByTemplateID :many
SELECT wb.id, wb.created_at, wb.updated_at, wb.workspace_id, wb.template_version_id, wb.build_number, wb.transition, wb.initiator_id, wb.job_id, wb.deadline, wb.reason, wb.daily_cost, wb.max_deadline, wb.template_version_preset_id, wb.has_ai_task, wb.has_external_agent, wb.initiator_by_avatar_url, wb.initiator_by_username, wb.initiator_by_name
SELECT wb.id, wb.created_at, wb.updated_at, wb.workspace_id, wb.template_version_id, wb.build_number, wb.transition, wb.initiator_id, wb.provisioner_state, wb.job_id, wb.deadline, wb.reason, wb.daily_cost, wb.max_deadline, wb.template_version_preset_id, wb.has_ai_task, wb.has_external_agent, wb.initiator_by_avatar_url, wb.initiator_by_username, wb.initiator_by_name
FROM (
SELECT
workspace_id, MAX(build_number) as max_build_number
@@ -21362,6 +21041,7 @@ func (q *sqlQuerier) GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, t
&i.BuildNumber,
&i.Transition,
&i.InitiatorID,
&i.ProvisionerState,
&i.JobID,
&i.Deadline,
&i.Reason,
@@ -21469,7 +21149,7 @@ func (q *sqlQuerier) GetFailedWorkspaceBuildsByTemplateID(ctx context.Context, a
const getLatestWorkspaceBuildByWorkspaceID = `-- name: GetLatestWorkspaceBuildByWorkspaceID :one
SELECT
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
FROM
workspace_build_with_user AS workspace_builds
WHERE
@@ -21492,6 +21172,7 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, w
&i.BuildNumber,
&i.Transition,
&i.InitiatorID,
&i.ProvisionerState,
&i.JobID,
&i.Deadline,
&i.Reason,
@@ -21510,7 +21191,7 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, w
const getLatestWorkspaceBuildsByWorkspaceIDs = `-- name: GetLatestWorkspaceBuildsByWorkspaceIDs :many
SELECT
DISTINCT ON (workspace_id)
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
FROM
workspace_build_with_user AS workspace_builds
WHERE
@@ -21537,6 +21218,7 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context,
&i.BuildNumber,
&i.Transition,
&i.InitiatorID,
&i.ProvisionerState,
&i.JobID,
&i.Deadline,
&i.Reason,
@@ -21564,7 +21246,7 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context,
const getWorkspaceBuildByID = `-- name: GetWorkspaceBuildByID :one
SELECT
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
FROM
workspace_build_with_user AS workspace_builds
WHERE
@@ -21585,6 +21267,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (W
&i.BuildNumber,
&i.Transition,
&i.InitiatorID,
&i.ProvisionerState,
&i.JobID,
&i.Deadline,
&i.Reason,
@@ -21602,7 +21285,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (W
const getWorkspaceBuildByJobID = `-- name: GetWorkspaceBuildByJobID :one
SELECT
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
FROM
workspace_build_with_user AS workspace_builds
WHERE
@@ -21623,6 +21306,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByJobID(ctx context.Context, jobID uuid.UU
&i.BuildNumber,
&i.Transition,
&i.InitiatorID,
&i.ProvisionerState,
&i.JobID,
&i.Deadline,
&i.Reason,
@@ -21640,7 +21324,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByJobID(ctx context.Context, jobID uuid.UU
const getWorkspaceBuildByWorkspaceIDAndBuildNumber = `-- name: GetWorkspaceBuildByWorkspaceIDAndBuildNumber :one
SELECT
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
FROM
workspace_build_with_user AS workspace_builds
WHERE
@@ -21665,6 +21349,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx context.Co
&i.BuildNumber,
&i.Transition,
&i.InitiatorID,
&i.ProvisionerState,
&i.JobID,
&i.Deadline,
&i.Reason,
@@ -21736,48 +21421,6 @@ func (q *sqlQuerier) GetWorkspaceBuildMetricsByResourceID(ctx context.Context, i
return i, err
}
const getWorkspaceBuildProvisionerStateByID = `-- name: GetWorkspaceBuildProvisionerStateByID :one
SELECT
workspace_builds.provisioner_state,
templates.id AS template_id,
templates.organization_id AS template_organization_id,
templates.user_acl,
templates.group_acl
FROM
workspace_builds
INNER JOIN
workspaces ON workspaces.id = workspace_builds.workspace_id
INNER JOIN
templates ON templates.id = workspaces.template_id
WHERE
workspace_builds.id = $1
`
type GetWorkspaceBuildProvisionerStateByIDRow struct {
ProvisionerState []byte `db:"provisioner_state" json:"provisioner_state"`
TemplateID uuid.UUID `db:"template_id" json:"template_id"`
TemplateOrganizationID uuid.UUID `db:"template_organization_id" json:"template_organization_id"`
UserACL TemplateACL `db:"user_acl" json:"user_acl"`
GroupACL TemplateACL `db:"group_acl" json:"group_acl"`
}
// Fetches the provisioner state of a workspace build, joined through to the
// template so that dbauthz can enforce policy.ActionUpdate on the template.
// Provisioner state contains sensitive Terraform state and should only be
// accessible to template administrators.
func (q *sqlQuerier) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (GetWorkspaceBuildProvisionerStateByIDRow, error) {
row := q.db.QueryRowContext(ctx, getWorkspaceBuildProvisionerStateByID, workspaceBuildID)
var i GetWorkspaceBuildProvisionerStateByIDRow
err := row.Scan(
&i.ProvisionerState,
&i.TemplateID,
&i.TemplateOrganizationID,
&i.UserACL,
&i.GroupACL,
)
return i, err
}
const getWorkspaceBuildStatsByTemplates = `-- name: GetWorkspaceBuildStatsByTemplates :many
SELECT
w.template_id,
@@ -21847,7 +21490,7 @@ func (q *sqlQuerier) GetWorkspaceBuildStatsByTemplates(ctx context.Context, sinc
const getWorkspaceBuildsByWorkspaceID = `-- name: GetWorkspaceBuildsByWorkspaceID :many
SELECT
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
FROM
workspace_build_with_user AS workspace_builds
WHERE
@@ -21911,6 +21554,7 @@ func (q *sqlQuerier) GetWorkspaceBuildsByWorkspaceID(ctx context.Context, arg Ge
&i.BuildNumber,
&i.Transition,
&i.InitiatorID,
&i.ProvisionerState,
&i.JobID,
&i.Deadline,
&i.Reason,
@@ -21937,7 +21581,7 @@ func (q *sqlQuerier) GetWorkspaceBuildsByWorkspaceID(ctx context.Context, arg Ge
}
const getWorkspaceBuildsCreatedAfter = `-- name: GetWorkspaceBuildsCreatedAfter :many
SELECT id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name FROM workspace_build_with_user WHERE created_at > $1
SELECT id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name FROM workspace_build_with_user WHERE created_at > $1
`
func (q *sqlQuerier) GetWorkspaceBuildsCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceBuild, error) {
@@ -21958,6 +21602,7 @@ func (q *sqlQuerier) GetWorkspaceBuildsCreatedAfter(ctx context.Context, created
&i.BuildNumber,
&i.Transition,
&i.InitiatorID,
&i.ProvisionerState,
&i.JobID,
&i.Deadline,
&i.Reason,
+4 -46
View File
@@ -1,8 +1,8 @@
-- name: InsertAIBridgeInterception :one
INSERT INTO aibridge_interceptions (
id, api_key_id, initiator_id, provider, model, metadata, started_at, client, thread_parent_id, thread_root_id
id, api_key_id, initiator_id, provider, model, metadata, started_at, client
) VALUES (
@id, @api_key_id, @initiator_id, @provider, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid
@id, @api_key_id, @initiator_id, @provider, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client
)
RETURNING *;
@@ -14,23 +14,6 @@ WHERE
AND ended_at IS NULL
RETURNING *;
-- name: GetAIBridgeInterceptionLineageByToolCallID :one
-- Look up the parent interception and the root of the thread by finding
-- which interception recorded a tool usage with the given tool call ID.
-- COALESCE ensures that if the parent has no thread_root_id (i.e. it IS
-- the root), we return its own ID as the root.
WITH linked AS (
SELECT interception_id FROM aibridge_tool_usages
WHERE provider_tool_call_id = @tool_call_id::text
ORDER BY created_at DESC
LIMIT 1
)
SELECT linked.interception_id AS thread_parent_id,
COALESCE(aibridge_interceptions.thread_root_id, linked.interception_id) AS thread_root_id
FROM aibridge_interceptions
INNER JOIN linked ON linked.interception_id = aibridge_interceptions.id
WHERE aibridge_interceptions.id = linked.interception_id;
-- name: InsertAIBridgeTokenUsage :one
INSERT INTO aibridge_token_usages (
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
@@ -49,9 +32,9 @@ RETURNING *;
-- name: InsertAIBridgeToolUsage :one
INSERT INTO aibridge_tool_usages (
id, interception_id, provider_response_id, provider_tool_call_id, tool, server_url, input, injected, invocation_error, metadata, created_at
id, interception_id, provider_response_id, tool, server_url, input, injected, invocation_error, metadata, created_at
) VALUES (
@id, @interception_id, @provider_response_id, @provider_tool_call_id, @tool, @server_url, @input, @injected, @invocation_error, COALESCE(@metadata::jsonb, '{}'::jsonb), @created_at
@id, @interception_id, @provider_response_id, @tool, @server_url, @input, @injected, @invocation_error, COALESCE(@metadata::jsonb, '{}'::jsonb), @created_at
)
RETURNING *;
@@ -391,28 +374,3 @@ SELECT (
(SELECT COUNT(*) FROM user_prompts) +
(SELECT COUNT(*) FROM interceptions)
)::bigint as total_deleted;
-- name: ListAIBridgeModels :many
SELECT
model
FROM
aibridge_interceptions
WHERE
-- Remove inflight interceptions (ones which lack an ended_at value).
aibridge_interceptions.ended_at IS NOT NULL
-- Filter model
AND CASE
WHEN @model::text != '' THEN aibridge_interceptions.model LIKE @model::text || '%'
ELSE true
END
-- We use an `@authorize_filter` as we are attempting to list models that are relevant
-- to the user and what they are allowed to see.
-- Authorize Filter clause will be injected below in ListAIBridgeModelsAuthorized
-- @authorize_filter
GROUP BY
model
ORDER BY
model ASC
LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100)
OFFSET @offset_
;
+2 -4
View File
@@ -25,12 +25,10 @@ LIMIT
SELECT * FROM api_keys WHERE last_used > $1;
-- name: GetAPIKeysByLoginType :many
SELECT * FROM api_keys WHERE login_type = $1
AND (@include_expired::bool OR expires_at > now());
SELECT * FROM api_keys WHERE login_type = $1;
-- name: GetAPIKeysByUserID :many
SELECT * FROM api_keys WHERE login_type = $1 AND user_id = $2
AND (@include_expired::bool OR expires_at > now());
SELECT * FROM api_keys WHERE login_type = $1 AND user_id = $2;
-- name: InsertAPIKey :one
INSERT INTO
+2 -6
View File
@@ -140,9 +140,7 @@ INSERT INTO oauth2_provider_app_codes (
user_id,
resource_uri,
code_challenge,
code_challenge_method,
state_hash,
redirect_uri
code_challenge_method
) VALUES(
$1,
$2,
@@ -153,9 +151,7 @@ INSERT INTO oauth2_provider_app_codes (
$7,
$8,
$9,
$10,
$11,
$12
$10
) RETURNING *;
-- name: DeleteOAuth2ProviderAppCodeByID :exec
-143
View File
@@ -100,146 +100,3 @@ FROM
task_snapshots
WHERE
task_id = $1;
-- name: GetTelemetryTaskEvents :many
-- Returns all data needed to build task lifecycle events for telemetry
-- in a single round-trip. For each task whose workspace is in the
-- given set, fetches:
-- - the latest workspace app binding (task_workspace_apps)
-- - the most recent stop and start builds (workspace_builds)
-- - the last "working" app status (workspace_app_statuses)
-- - the first app status after resume, for active workspaces
--
-- Assumptions:
-- - 1:1 relationship between tasks and workspaces. All builds on the
-- workspace are considered task-related.
-- - Idle duration approximation: If the agent reports "working", does
-- work, then reports "done", we miss that working time.
-- - lws and active_dur join across all historical app IDs for the task,
-- because each resume cycle provisions a new app ID. This ensures
-- pre-pause statuses contribute to idle duration and active duration.
WITH task_app_ids AS (
SELECT task_id, workspace_app_id
FROM task_workspace_apps
),
task_status_timeline AS (
-- All app statuses across every historical app for each task,
-- plus synthetic "boundary" rows at each stop/start build transition.
-- This allows us to correctly take gaps due to pause/resume into account.
SELECT tai.task_id, was.created_at, was.state::text AS state
FROM workspace_app_statuses was
JOIN task_app_ids tai ON tai.workspace_app_id = was.app_id
UNION ALL
SELECT t.id AS task_id, wb.created_at, '_boundary' AS state
FROM tasks t
JOIN workspace_builds wb ON wb.workspace_id = t.workspace_id
WHERE t.deleted_at IS NULL
AND t.workspace_id IS NOT NULL
AND wb.build_number > 1
),
task_event_data AS (
SELECT
t.id AS task_id,
t.workspace_id,
twa.workspace_app_id,
-- Latest stop build.
stop_build.created_at AS stop_build_created_at,
stop_build.reason AS stop_build_reason,
-- Latest start build (task_resume only).
start_build.created_at AS start_build_created_at,
start_build.reason AS start_build_reason,
start_build.build_number AS start_build_number,
-- Last "working" app status (for idle duration).
lws.created_at AS last_working_status_at,
-- First app status after resume (for resume-to-status duration).
-- Only populated for workspaces in an active phase (started more
-- recently than stopped).
fsar.created_at AS first_status_after_resume_at,
-- Cumulative time spent in "working" state.
active_dur.total_working_ms AS active_duration_ms
FROM tasks t
LEFT JOIN LATERAL (
SELECT task_app.workspace_app_id
FROM task_workspace_apps task_app
WHERE task_app.task_id = t.id
ORDER BY task_app.workspace_build_number DESC
LIMIT 1
) twa ON TRUE
LEFT JOIN LATERAL (
SELECT wb.created_at, wb.reason, wb.build_number
FROM workspace_builds wb
WHERE wb.workspace_id = t.workspace_id
AND wb.transition = 'stop'
ORDER BY wb.build_number DESC
LIMIT 1
) stop_build ON TRUE
LEFT JOIN LATERAL (
SELECT wb.created_at, wb.reason, wb.build_number
FROM workspace_builds wb
WHERE wb.workspace_id = t.workspace_id
AND wb.transition = 'start'
ORDER BY wb.build_number DESC
LIMIT 1
) start_build ON TRUE
LEFT JOIN LATERAL (
SELECT tst.created_at
FROM task_status_timeline tst
WHERE tst.task_id = t.id
AND tst.state = 'working'
-- Only consider status before the latest pause so that
-- post-resume statuses don't mask pre-pause idle time.
AND (stop_build.created_at IS NULL
OR tst.created_at <= stop_build.created_at)
ORDER BY tst.created_at DESC
LIMIT 1
) lws ON TRUE
LEFT JOIN LATERAL (
SELECT was.created_at
FROM workspace_app_statuses was
WHERE was.app_id = twa.workspace_app_id
AND was.created_at > start_build.created_at
ORDER BY was.created_at ASC
LIMIT 1
) fsar ON twa.workspace_app_id IS NOT NULL
AND start_build.created_at IS NOT NULL
AND (stop_build.created_at IS NULL
OR start_build.created_at > stop_build.created_at)
-- Active duration: cumulative time spent in "working" state across all
-- historical app IDs for this task. Uses LEAD() to convert point-in-time
-- statuses into intervals, then sums intervals where state='working'. For
-- the last status, falls back to stop_build time (if paused) or @now (if
-- still running).
LEFT JOIN LATERAL (
SELECT COALESCE(
SUM(EXTRACT(EPOCH FROM (interval_end - interval_start)) * 1000)::bigint,
0
)::bigint AS total_working_ms
FROM (
SELECT
tst.created_at AS interval_start,
COALESCE(
LEAD(tst.created_at) OVER (ORDER BY tst.created_at ASC, CASE WHEN tst.state = '_boundary' THEN 1 ELSE 0 END ASC),
CASE WHEN stop_build.created_at IS NOT NULL
AND (start_build.created_at IS NULL
OR stop_build.created_at > start_build.created_at)
THEN stop_build.created_at
ELSE @now::timestamptz
END
) AS interval_end,
tst.state
FROM task_status_timeline tst
WHERE tst.task_id = t.id
) intervals
WHERE intervals.state = 'working'
) active_dur ON TRUE
WHERE t.deleted_at IS NULL
AND t.workspace_id IS NOT NULL
AND EXISTS (
SELECT 1 FROM workspace_builds wb
WHERE wb.workspace_id = t.workspace_id
AND wb.created_at > @created_after
)
)
SELECT * FROM task_event_data
ORDER BY task_id;
@@ -87,4 +87,3 @@ SELECT DISTINCT ON (workspace_id)
FROM workspace_app_statuses
WHERE workspace_id = ANY(@ids :: uuid[])
ORDER BY workspace_id, created_at DESC;
@@ -271,23 +271,3 @@ JOIN workspace_resources wr ON wr.job_id = wb.job_id
JOIN workspace_agents wa ON wa.resource_id = wr.id
WHERE wb.job_id = (SELECT job_id FROM workspace_resources WHERE workspace_resources.id = $1)
GROUP BY wb.created_at, wb.transition, t.name, o.name, w.owner_id;
-- name: GetWorkspaceBuildProvisionerStateByID :one
-- Fetches the provisioner state of a workspace build, joined through to the
-- template so that dbauthz can enforce policy.ActionUpdate on the template.
-- Provisioner state contains sensitive Terraform state and should only be
-- accessible to template administrators.
SELECT
workspace_builds.provisioner_state,
templates.id AS template_id,
templates.organization_id AS template_organization_id,
templates.user_acl,
templates.group_acl
FROM
workspace_builds
INNER JOIN
workspaces ON workspaces.id = workspace_builds.workspace_id
INNER JOIN
templates ON templates.id = workspaces.template_id
WHERE
workspace_builds.id = @workspace_build_id;
-18
View File
@@ -124,24 +124,6 @@ sql:
- column: "tasks_with_status.workspace_app_health"
go_type:
type: "NullWorkspaceAppHealth"
# Workaround for sqlc not interpreting the left join correctly
# in the combined telemetry query.
- column: "task_event_data.start_build_number"
go_type: "database/sql.NullInt32"
- column: "task_event_data.stop_build_created_at"
go_type: "database/sql.NullTime"
- column: "task_event_data.stop_build_reason"
go_type:
type: "NullBuildReason"
- column: "task_event_data.start_build_created_at"
go_type: "database/sql.NullTime"
- column: "task_event_data.start_build_reason"
go_type:
type: "NullBuildReason"
- column: "task_event_data.last_working_status_at"
go_type: "database/sql.NullTime"
- column: "task_event_data.first_status_after_resume_at"
go_type: "database/sql.NullTime"
rename:
group_member: GroupMemberTable
group_members_expanded: GroupMember
+4 -3
View File
@@ -228,11 +228,12 @@ func (p *QueryParamParser) RedirectURL(vals url.Values, base *url.URL, queryPara
})
}
// OAuth 2.1 requires exact redirect URI matching.
if v.String() != base.String() {
// It can be a sub-directory but not a sub-domain, as we have apps on
// sub-domains and that seems too dangerous.
if v.Host != base.Host || !strings.HasPrefix(v.Path, base.Path) {
p.Errors = append(p.Errors, codersdk.ValidationError{
Field: queryParam,
Detail: fmt.Sprintf("Query param %q must exactly match %s", queryParam, base),
Detail: fmt.Sprintf("Query param %q must be a subset of %s", queryParam, base),
})
}
+2 -6
View File
@@ -62,12 +62,8 @@ func CSRF(cookieCfg codersdk.HTTPCookieConfig) func(next http.Handler) http.Hand
mw.ExemptRegexp(regexp.MustCompile("/organizations/[^/]+/provisionerdaemons/*"))
mw.ExemptFunc(func(r *http.Request) bool {
// Enforce CSRF on API routes and the OAuth2 authorize
// endpoint. The authorize endpoint serves a browser consent
// form whose POST must be CSRF-protected to prevent
// cross-site authorization code theft (coder/security#121).
if !strings.HasPrefix(r.URL.Path, "/api") &&
!strings.HasPrefix(r.URL.Path, "/oauth2/authorize") {
// Only enforce CSRF on API routes.
if !strings.HasPrefix(r.URL.Path, "/api") {
return true
}
-20
View File
@@ -51,26 +51,6 @@ func TestCSRFExemptList(t *testing.T) {
URL: "https://coder.com/api/v2/me",
Exempt: false,
},
{
Name: "OAuth2Authorize",
URL: "https://coder.com/oauth2/authorize",
Exempt: false,
},
{
Name: "OAuth2AuthorizeQuery",
URL: "https://coder.com/oauth2/authorize?client_id=test",
Exempt: false,
},
{
Name: "OAuth2Tokens",
URL: "https://coder.com/oauth2/tokens",
Exempt: true,
},
{
Name: "OAuth2Register",
URL: "https://coder.com/oauth2/register",
Exempt: true,
},
}
mw := httpmw.CSRF(codersdk.HTTPCookieConfig{})
+3 -11
View File
@@ -348,12 +348,8 @@ func reapJob(ctx context.Context, log slog.Logger, db database.Store, pub pubsub
// Only copy the provisioner state if there's no state in
// the current build.
currentStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
if err != nil {
return xerrors.Errorf("get workspace build provisioner state: %w", err)
}
if len(currentStateRow.ProvisionerState) == 0 {
// Get the previous build's state if it exists.
if len(build.ProvisionerState) == 0 {
// Get the previous build if it exists.
prevBuild, err := db.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
WorkspaceID: build.WorkspaceID,
BuildNumber: build.BuildNumber - 1,
@@ -362,14 +358,10 @@ func reapJob(ctx context.Context, log slog.Logger, db database.Store, pub pubsub
return xerrors.Errorf("get previous workspace build: %w", err)
}
if err == nil {
prevStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, prevBuild.ID)
if err != nil {
return xerrors.Errorf("get previous workspace build provisioner state: %w", err)
}
err = db.UpdateWorkspaceBuildProvisionerStateByID(ctx, database.UpdateWorkspaceBuildProvisionerStateByIDParams{
ID: build.ID,
UpdatedAt: dbtime.Now(),
ProvisionerState: prevStateRow.ProvisionerState,
ProvisionerState: prevBuild.ProvisionerState,
})
if err != nil {
return xerrors.Errorf("update workspace build by id: %w", err)
+22 -29
View File
@@ -126,9 +126,9 @@ func TestDetectorHungWorkspaceBuild(t *testing.T) {
previousBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: org.ID,
OwnerID: user.ID,
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
ProvisionerState(expectedWorkspaceBuildState).
Succeeded(dbfake.WithJobCompletedAt(twentyMinAgo)).
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
ProvisionerState: expectedWorkspaceBuildState,
}).Succeeded(dbfake.WithJobCompletedAt(twentyMinAgo)).
Do()
// Current build (hung - running job with UpdatedAt > 5 min ago).
@@ -163,9 +163,7 @@ func TestDetectorHungWorkspaceBuild(t *testing.T) {
// Check that the provisioner state was copied.
build, err := db.GetWorkspaceBuildByID(ctx, currentBuild.Build.ID)
require.NoError(t, err)
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
require.NoError(t, err)
require.Equal(t, expectedWorkspaceBuildState, provisionerStateRow.ProvisionerState)
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
detector.Close()
detector.Wait()
@@ -196,9 +194,9 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
previousBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: org.ID,
OwnerID: user.ID,
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
ProvisionerState([]byte(`{"dean":"NOT cool","colin":"also NOT cool"}`)).
Succeeded(dbfake.WithJobCompletedAt(twentyMinAgo)).
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
ProvisionerState: []byte(`{"dean":"NOT cool","colin":"also NOT cool"}`),
}).Succeeded(dbfake.WithJobCompletedAt(twentyMinAgo)).
Do()
// Current build (hung - running job with UpdatedAt > 5 min ago).
@@ -206,8 +204,9 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
currentBuild := dbfake.WorkspaceBuild(t, db, previousBuild.Workspace).
Pubsub(pubsub).
Seed(database.WorkspaceBuild{
BuildNumber: 2,
}).ProvisionerState(expectedWorkspaceBuildState).
BuildNumber: 2,
ProvisionerState: expectedWorkspaceBuildState,
}).
Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
Do()
@@ -236,9 +235,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
// Check that the provisioner state was NOT copied.
build, err := db.GetWorkspaceBuildByID(ctx, currentBuild.Build.ID)
require.NoError(t, err)
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
require.NoError(t, err)
require.Equal(t, expectedWorkspaceBuildState, provisionerStateRow.ProvisionerState)
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
detector.Close()
detector.Wait()
@@ -269,9 +266,9 @@ func TestDetectorHungWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testing.T
currentBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: org.ID,
OwnerID: user.ID,
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
ProvisionerState(expectedWorkspaceBuildState).
Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
ProvisionerState: expectedWorkspaceBuildState,
}).Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
Do()
t.Log("current job ID: ", currentBuild.Build.JobID)
@@ -298,9 +295,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testing.T
// Check that the provisioner state was NOT updated.
build, err := db.GetWorkspaceBuildByID(ctx, currentBuild.Build.ID)
require.NoError(t, err)
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
require.NoError(t, err)
require.Equal(t, expectedWorkspaceBuildState, provisionerStateRow.ProvisionerState)
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
detector.Close()
detector.Wait()
@@ -330,9 +325,9 @@ func TestDetectorPendingWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testin
currentBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: org.ID,
OwnerID: user.ID,
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
ProvisionerState(expectedWorkspaceBuildState).
Pending(dbfake.WithJobCreatedAt(thirtyFiveMinAgo), dbfake.WithJobUpdatedAt(thirtyFiveMinAgo)).
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
ProvisionerState: expectedWorkspaceBuildState,
}).Pending(dbfake.WithJobCreatedAt(thirtyFiveMinAgo), dbfake.WithJobUpdatedAt(thirtyFiveMinAgo)).
Do()
t.Log("current job ID: ", currentBuild.Build.JobID)
@@ -361,9 +356,7 @@ func TestDetectorPendingWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testin
// Check that the provisioner state was NOT updated.
build, err := db.GetWorkspaceBuildByID(ctx, currentBuild.Build.ID)
require.NoError(t, err)
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
require.NoError(t, err)
require.Equal(t, expectedWorkspaceBuildState, provisionerStateRow.ProvisionerState)
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
detector.Close()
detector.Wait()
@@ -405,9 +398,9 @@ func TestDetectorWorkspaceBuildForDormantWorkspace(t *testing.T) {
Time: now.Add(-time.Hour),
Valid: true,
},
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
ProvisionerState(expectedWorkspaceBuildState).
Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
ProvisionerState: expectedWorkspaceBuildState,
}).Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
Do()
t.Log("current job ID: ", currentBuild.Build.JobID)
+22 -40
View File
@@ -2,8 +2,6 @@ package mcp_test
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
@@ -12,7 +10,6 @@ import (
"strings"
"testing"
"github.com/google/uuid"
mcpclient "github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
@@ -27,15 +24,6 @@ import (
"github.com/coder/coder/v2/testutil"
)
// mcpGeneratePKCE creates a PKCE verifier and S256 challenge for MCP
// e2e tests.
func mcpGeneratePKCE() (verifier, challenge string) {
verifier = uuid.NewString() + uuid.NewString()
h := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(h[:])
return verifier, challenge
}
func TestMCPHTTP_E2E_ClientIntegration(t *testing.T) {
t.Parallel()
@@ -565,32 +553,31 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
// In a real flow, this would be done through the browser consent page
// For testing, we'll create the code directly using the internal API
// First, we need to authorize the app (simulating user consent).
staticVerifier, staticChallenge := mcpGeneratePKCE()
authURL := fmt.Sprintf("%s/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&state=test_state&code_challenge=%s&code_challenge_method=S256",
api.AccessURL.String(), app.ID, "http://localhost:3000/callback", staticChallenge)
// First, we need to authorize the app (simulating user consent)
authURL := fmt.Sprintf("%s/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&state=test_state",
api.AccessURL.String(), app.ID, "http://localhost:3000/callback")
// Create an HTTP client that follows redirects but captures the final redirect.
// Create an HTTP client that follows redirects but captures the final redirect
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse // Stop following redirects
},
}
// Make the authorization request (this would normally be done in a browser).
// Make the authorization request (this would normally be done in a browser)
req, err := http.NewRequestWithContext(ctx, "GET", authURL, nil)
require.NoError(t, err)
// Use RFC 6750 Bearer token for authentication.
// Use RFC 6750 Bearer token for authentication
req.Header.Set("Authorization", "Bearer "+coderClient.SessionToken())
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
// The response should be a redirect to the consent page or directly to callback.
// For testing purposes, let's simulate the POST consent approval.
// The response should be a redirect to the consent page or directly to callback
// For testing purposes, let's simulate the POST consent approval
if resp.StatusCode == http.StatusOK {
// This means we got the consent page, now we need to POST consent.
// This means we got the consent page, now we need to POST consent
consentReq, err := http.NewRequestWithContext(ctx, "POST", authURL, nil)
require.NoError(t, err)
consentReq.Header.Set("Authorization", "Bearer "+coderClient.SessionToken())
@@ -601,7 +588,7 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
defer resp.Body.Close()
}
// Extract authorization code from redirect URL.
// Extract authorization code from redirect URL
require.True(t, resp.StatusCode >= 300 && resp.StatusCode < 400, "Expected redirect response")
location := resp.Header.Get("Location")
require.NotEmpty(t, location, "Expected Location header in redirect")
@@ -613,14 +600,13 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
t.Logf("Successfully obtained authorization code: %s", authCode[:10]+"...")
// Step 2: Exchange authorization code for access token and refresh token.
// Step 2: Exchange authorization code for access token and refresh token
tokenRequestBody := url.Values{
"grant_type": {"authorization_code"},
"client_id": {app.ID.String()},
"client_secret": {secret.ClientSecretFull},
"code": {authCode},
"redirect_uri": {"http://localhost:3000/callback"},
"code_verifier": {staticVerifier},
}
tokenReq, err := http.NewRequestWithContext(ctx, "POST", api.AccessURL.String()+"/oauth2/tokens",
@@ -882,44 +868,41 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
t.Logf("Successfully registered dynamic client: %s", clientID)
// Step 3: Perform OAuth2 authorization code flow with dynamically registered client.
dynamicVerifier, dynamicChallenge := mcpGeneratePKCE()
authURL := fmt.Sprintf("%s/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&state=dynamic_state&code_challenge=%s&code_challenge_method=S256",
api.AccessURL.String(), clientID, "http://localhost:3000/callback", dynamicChallenge)
// Step 3: Perform OAuth2 authorization code flow with dynamically registered client
authURL := fmt.Sprintf("%s/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&state=dynamic_state",
api.AccessURL.String(), clientID, "http://localhost:3000/callback")
// Create an HTTP client that captures redirects.
// Create an HTTP client that captures redirects
authClient := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse // Stop following redirects
},
}
// Make the authorization request with authentication.
// Make the authorization request with authentication
authReq, err := http.NewRequestWithContext(ctx, "GET", authURL, nil)
require.NoError(t, err)
authReq.Header.Set("Cookie", fmt.Sprintf("coder_session_token=%s", coderClient.SessionToken()))
authReq.Header.Set("Authorization", "Bearer "+coderClient.SessionToken())
authResp, err := authClient.Do(authReq)
require.NoError(t, err)
defer authResp.Body.Close()
// Handle the response - check for error first.
// Handle the response - check for error first
if authResp.StatusCode == http.StatusBadRequest {
// Read error response for debugging.
// Read error response for debugging
bodyBytes, err := io.ReadAll(authResp.Body)
require.NoError(t, err)
t.Logf("OAuth2 authorization error: %s", string(bodyBytes))
t.FailNow()
}
// Handle consent flow if needed.
// Handle consent flow if needed
if authResp.StatusCode == http.StatusOK {
// This means we got the consent page, now we need to POST consent.
// This means we got the consent page, now we need to POST consent
consentReq, err := http.NewRequestWithContext(ctx, "POST", authURL, nil)
require.NoError(t, err)
consentReq.Header.Set("Cookie", fmt.Sprintf("coder_session_token=%s", coderClient.SessionToken()))
consentReq.Header.Set("Authorization", "Bearer "+coderClient.SessionToken())
consentReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
authResp, err = authClient.Do(consentReq)
@@ -927,7 +910,7 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
defer authResp.Body.Close()
}
// Extract authorization code from redirect.
// Extract authorization code from redirect
require.True(t, authResp.StatusCode >= 300 && authResp.StatusCode < 400,
"Expected redirect response, got %d", authResp.StatusCode)
location := authResp.Header.Get("Location")
@@ -940,14 +923,13 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
t.Logf("Successfully obtained authorization code: %s", authCode[:10]+"...")
// Step 4: Exchange authorization code for access token.
// Step 4: Exchange authorization code for access token
tokenRequestBody := url.Values{
"grant_type": {"authorization_code"},
"client_id": {clientID},
"client_secret": {clientSecret},
"code": {authCode},
"redirect_uri": {"http://localhost:3000/callback"},
"code_verifier": {dynamicVerifier},
}
tokenReq, err := http.NewRequestWithContext(ctx, "POST", api.AccessURL.String()+"/oauth2/tokens",
+41 -92
View File
@@ -2,8 +2,6 @@ package coderd_test
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
@@ -291,6 +289,7 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
authError: "Invalid query params:",
},
{
// TODO: This is valid for now, but should it be?
name: "DifferentProtocol",
app: apps.Default,
preAuth: func(valid *oauth2.Config) {
@@ -298,7 +297,6 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
newURL.Scheme = "https"
valid.RedirectURL = newURL.String()
},
authError: "Invalid query params:",
},
{
name: "NestedPath",
@@ -308,7 +306,6 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
newURL.Path = path.Join(newURL.Path, "nested")
valid.RedirectURL = newURL.String()
},
authError: "Invalid query params:",
},
{
// Some oauth implementations allow this, but our users can host
@@ -484,12 +481,11 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
}
var code string
var verifier string
if test.defaultCode != nil {
code = *test.defaultCode
} else {
var err error
code, verifier, err = authorizationFlow(ctx, userClient, valid)
code, err = authorizationFlow(ctx, userClient, valid)
if test.authError != "" {
require.Error(t, err)
require.ErrorContains(t, err, test.authError)
@@ -504,12 +500,8 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
test.preToken(valid)
}
// Do the actual exchange. Include PKCE code_verifier when
// we obtained a code through the authorization flow.
exchangeOpts := append([]oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", verifier),
}, test.exchangeMutate...)
token, err := valid.Exchange(ctx, code, exchangeOpts...)
// Do the actual exchange.
token, err := valid.Exchange(ctx, code, test.exchangeMutate...)
if test.tokenError != "" {
require.Error(t, err)
require.ErrorContains(t, err, test.tokenError)
@@ -691,11 +683,10 @@ func TestOAuth2ProviderTokenRefresh(t *testing.T) {
}
type exchangeSetup struct {
cfg *oauth2.Config
app codersdk.OAuth2ProviderApp
secret codersdk.OAuth2ProviderAppSecretFull
code string
verifier string
cfg *oauth2.Config
app codersdk.OAuth2ProviderApp
secret codersdk.OAuth2ProviderAppSecretFull
code string
}
func TestOAuth2ProviderRevoke(t *testing.T) {
@@ -739,13 +730,11 @@ func TestOAuth2ProviderRevoke(t *testing.T) {
name: "OverrideCodeAndToken",
fn: func(ctx context.Context, client *codersdk.Client, s exchangeSetup) {
// Generating a new code should wipe out the old code.
code, verifier, err := authorizationFlow(ctx, client, s.cfg)
code, err := authorizationFlow(ctx, client, s.cfg)
require.NoError(t, err)
// Generating a new token should wipe out the old token.
_, err = s.cfg.Exchange(ctx, code,
oauth2.SetAuthURLParam("code_verifier", verifier),
)
_, err = s.cfg.Exchange(ctx, code)
require.NoError(t, err)
},
replacesToken: true,
@@ -781,15 +770,14 @@ func TestOAuth2ProviderRevoke(t *testing.T) {
}
// Go through the auth flow to get a code.
code, verifier, err := authorizationFlow(ctx, testClient, cfg)
code, err := authorizationFlow(ctx, testClient, cfg)
require.NoError(t, err)
return exchangeSetup{
cfg: cfg,
app: app,
secret: secret,
code: code,
verifier: verifier,
cfg: cfg,
app: app,
secret: secret,
code: code,
}
}
@@ -806,16 +794,12 @@ func TestOAuth2ProviderRevoke(t *testing.T) {
test.fn(ctx, testClient, testEntities)
// Exchange should fail because the code should be gone.
_, err := testEntities.cfg.Exchange(ctx, testEntities.code,
oauth2.SetAuthURLParam("code_verifier", testEntities.verifier),
)
_, err := testEntities.cfg.Exchange(ctx, testEntities.code)
require.Error(t, err)
// Try again, this time letting the exchange complete first.
testEntities = setup(ctx, testClient, test.name+"-2")
token, err := testEntities.cfg.Exchange(ctx, testEntities.code,
oauth2.SetAuthURLParam("code_verifier", testEntities.verifier),
)
token, err := testEntities.cfg.Exchange(ctx, testEntities.code)
require.NoError(t, err)
// Validate the returned access token and that the app is listed.
@@ -888,38 +872,25 @@ func generateApps(ctx context.Context, t *testing.T, client *codersdk.Client, su
}
}
// generatePKCE creates a PKCE verifier and S256 challenge for testing.
func generatePKCE() (verifier, challenge string) {
verifier = uuid.NewString() + uuid.NewString()
h := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(h[:])
return verifier, challenge
}
func authorizationFlow(ctx context.Context, client *codersdk.Client, cfg *oauth2.Config) (code, codeVerifier string, err error) {
func authorizationFlow(ctx context.Context, client *codersdk.Client, cfg *oauth2.Config) (string, error) {
state := uuid.NewString()
codeVerifier, challenge := generatePKCE()
authURL := cfg.AuthCodeURL(state,
oauth2.SetAuthURLParam("code_challenge", challenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
)
authURL := cfg.AuthCodeURL(state)
// Make a POST request to simulate clicking "Allow" on the authorization page.
// This bypasses the HTML consent page and directly processes the authorization.
code, err = oidctest.OAuth2GetCode(
// Make a POST request to simulate clicking "Allow" on the authorization page
// This bypasses the HTML consent page and directly processes the authorization
return oidctest.OAuth2GetCode(
authURL,
func(req *http.Request) (*http.Response, error) {
// Change to POST to simulate the form submission.
// Change to POST to simulate the form submission
req.Method = http.MethodPost
// Prevent automatic redirect following.
// Prevent automatic redirect following
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
return client.Request(ctx, req.Method, req.URL.String(), nil)
},
)
return code, codeVerifier, err
}
func must[T any](value T, err error) T {
@@ -1026,15 +997,11 @@ func TestOAuth2ProviderResourceIndicators(t *testing.T) {
Scopes: []string{},
}
// Step 1: Authorization with resource parameter and PKCE.
// Step 1: Authorization with resource parameter
state := uuid.NewString()
verifier, challenge := generatePKCE()
authURL := cfg.AuthCodeURL(state,
oauth2.SetAuthURLParam("code_challenge", challenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
)
authURL := cfg.AuthCodeURL(state)
if test.authResource != "" {
// Add resource parameter to auth URL.
// Add resource parameter to auth URL
parsedURL, err := url.Parse(authURL)
require.NoError(t, err)
query := parsedURL.Query()
@@ -1063,7 +1030,7 @@ func TestOAuth2ProviderResourceIndicators(t *testing.T) {
// Step 2: Token exchange with resource parameter
// Use custom token exchange since golang.org/x/oauth2 doesn't support resource parameter in token requests
token, err := customTokenExchange(ctx, ownerClient.URL.String(), apps.Default.ID.String(), secret.ClientSecretFull, code, apps.Default.CallbackURL, test.tokenResource, verifier)
token, err := customTokenExchange(ctx, ownerClient.URL.String(), apps.Default.ID.String(), secret.ClientSecretFull, code, apps.Default.CallbackURL, test.tokenResource)
if test.expectTokenError {
require.Error(t, err)
require.Contains(t, err.Error(), "invalid_target")
@@ -1160,13 +1127,9 @@ func TestOAuth2ProviderCrossResourceAudienceValidation(t *testing.T) {
Scopes: []string{},
}
// Authorization with resource parameter for server1 and PKCE.
// Authorization with resource parameter for server1
state := uuid.NewString()
verifier, challenge := generatePKCE()
authURL := cfg.AuthCodeURL(state,
oauth2.SetAuthURLParam("code_challenge", challenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
)
authURL := cfg.AuthCodeURL(state)
parsedURL, err := url.Parse(authURL)
require.NoError(t, err)
query := parsedURL.Query()
@@ -1186,11 +1149,8 @@ func TestOAuth2ProviderCrossResourceAudienceValidation(t *testing.T) {
)
require.NoError(t, err)
// Exchange code for token with resource parameter and PKCE verifier.
token, err := cfg.Exchange(ctx, code,
oauth2.SetAuthURLParam("resource", resource1),
oauth2.SetAuthURLParam("code_verifier", verifier),
)
// Exchange code for token with resource parameter
token, err := cfg.Exchange(ctx, code, oauth2.SetAuthURLParam("resource", resource1))
require.NoError(t, err)
require.NotEmpty(t, token.AccessToken)
@@ -1266,11 +1226,9 @@ func TestOAuth2RefreshExpiryOutlivesAccess(t *testing.T) {
}
// Authorization and token exchange
code, verifier, err := authorizationFlow(ctx, ownerClient, cfg)
code, err := authorizationFlow(ctx, ownerClient, cfg)
require.NoError(t, err)
tok, err := cfg.Exchange(ctx, code,
oauth2.SetAuthURLParam("code_verifier", verifier),
)
tok, err := cfg.Exchange(ctx, code)
require.NoError(t, err)
require.NotEmpty(t, tok.AccessToken)
require.NotEmpty(t, tok.RefreshToken)
@@ -1295,7 +1253,7 @@ func TestOAuth2RefreshExpiryOutlivesAccess(t *testing.T) {
// customTokenExchange performs a custom OAuth2 token exchange with support for resource parameter
// This is needed because golang.org/x/oauth2 doesn't support custom parameters in token requests
func customTokenExchange(ctx context.Context, baseURL, clientID, clientSecret, code, redirectURI, resource, codeVerifier string) (*oauth2.Token, error) {
func customTokenExchange(ctx context.Context, baseURL, clientID, clientSecret, code, redirectURI, resource string) (*oauth2.Token, error) {
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
@@ -1305,9 +1263,6 @@ func customTokenExchange(ctx context.Context, baseURL, clientID, clientSecret, c
if resource != "" {
data.Set("resource", resource)
}
if codeVerifier != "" {
data.Set("code_verifier", codeVerifier)
}
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/oauth2/tokens", strings.NewReader(data.Encode()))
if err != nil {
@@ -1682,21 +1637,17 @@ func TestOAuth2CoderClient(t *testing.T) {
// Make a new user
client, user := coderdtest.CreateAnotherUser(t, owner, first.OrganizationID)
// Do an OAuth2 token exchange and get a new client with an oauth token.
// Do an OAuth2 token exchange and get a new client with an oauth token
state := uuid.NewString()
verifier, challenge := generatePKCE()
// Get an OAuth2 code for a token exchange.
// Get an OAuth2 code for a token exchange
code, err := oidctest.OAuth2GetCode(
cfg.AuthCodeURL(state,
oauth2.SetAuthURLParam("code_challenge", challenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
),
cfg.AuthCodeURL(state),
func(req *http.Request) (*http.Response, error) {
// Change to POST to simulate the form submission.
// Change to POST to simulate the form submission
req.Method = http.MethodPost
// Prevent automatic redirect following.
// Prevent automatic redirect following
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
@@ -1705,9 +1656,7 @@ func TestOAuth2CoderClient(t *testing.T) {
)
require.NoError(t, err)
token, err := cfg.Exchange(ctx, code,
oauth2.SetAuthURLParam("code_verifier", verifier),
)
token, err := cfg.Exchange(ctx, code)
require.NoError(t, err)
// Use the oauth client's authentication
+10 -64
View File
@@ -1,9 +1,7 @@
package oauth2provider
import (
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"net/http"
"net/url"
@@ -11,7 +9,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/justinas/nosurf"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
@@ -25,7 +22,6 @@ import (
type authorizeParams struct {
clientID string
redirectURL *url.URL
redirectURIProvided bool
responseType codersdk.OAuth2ProviderResponseType
scope []string
state string
@@ -38,13 +34,11 @@ func extractAuthorizeParams(r *http.Request, callbackURL *url.URL) (authorizePar
p := httpapi.NewQueryParamParser()
vals := r.URL.Query()
// response_type and client_id are always required.
p.RequiredNotEmpty("response_type", "client_id")
params := authorizeParams{
clientID: p.String(vals, "", "client_id"),
redirectURL: p.RedirectURL(vals, callbackURL, "redirect_uri"),
redirectURIProvided: vals.Get("redirect_uri") != "",
responseType: httpapi.ParseCustom(p, vals, "", "response_type", httpapi.ParseEnum[codersdk.OAuth2ProviderResponseType]),
scope: strings.Fields(strings.TrimSpace(p.String(vals, "", "scope"))),
state: p.String(vals, "", "state"),
@@ -52,15 +46,6 @@ func extractAuthorizeParams(r *http.Request, callbackURL *url.URL) (authorizePar
codeChallenge: p.String(vals, "", "code_challenge"),
codeChallengeMethod: p.String(vals, "", "code_challenge_method"),
}
// PKCE is required for authorization code flow requests.
if params.responseType == codersdk.OAuth2ProviderResponseTypeCode && params.codeChallenge == "" {
p.Errors = append(p.Errors, codersdk.ValidationError{
Field: "code_challenge",
Detail: `Query param "code_challenge" is required and cannot be empty`,
})
}
// Validate resource indicator syntax (RFC 8707): must be absolute URI without fragment
if err := validateResourceParameter(params.resource); err != nil {
p.Errors = append(p.Errors, codersdk.ValidationError{
@@ -127,22 +112,6 @@ func ShowAuthorizePage(accessURL *url.URL) http.HandlerFunc {
return
}
if params.responseType != codersdk.OAuth2ProviderResponseTypeCode {
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
Status: http.StatusBadRequest,
HideStatus: false,
Title: "Unsupported Response Type",
Description: "Only response_type=code is supported.",
Actions: []site.Action{
{
URL: accessURL.String(),
Text: "Back to site",
},
},
})
return
}
cancel := params.redirectURL
cancelQuery := params.redirectURL.Query()
cancelQuery.Add("error", "access_denied")
@@ -153,7 +122,6 @@ func ShowAuthorizePage(accessURL *url.URL) http.HandlerFunc {
AppName: app.Name,
CancelURI: cancel.String(),
RedirectURI: r.URL.String(),
CSRFToken: nosurf.Token(r),
Username: ua.FriendlyName,
})
}
@@ -179,23 +147,16 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
return
}
// OAuth 2.1 removes the implicit grant. Only
// authorization code flow is supported.
if params.responseType != codersdk.OAuth2ProviderResponseTypeCode {
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest,
codersdk.OAuth2ErrorCodeUnsupportedResponseType,
"Only response_type=code is supported")
return
}
// code_challenge is required (enforced by RequiredNotEmpty above),
// but default the method to S256 if omitted.
if params.codeChallengeMethod == "" {
params.codeChallengeMethod = string(codersdk.OAuth2PKCECodeChallengeMethodS256)
}
if err := codersdk.ValidatePKCECodeChallengeMethod(params.codeChallengeMethod); err != nil {
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, err.Error())
return
// Validate PKCE for public clients (MCP requirement)
if params.codeChallenge != "" {
// If code_challenge is provided but method is not, default to S256
if params.codeChallengeMethod == "" {
params.codeChallengeMethod = string(codersdk.OAuth2PKCECodeChallengeMethodS256)
}
if err := codersdk.ValidatePKCECodeChallengeMethod(params.codeChallengeMethod); err != nil {
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, err.Error())
return
}
}
// TODO: Ignoring scope for now, but should look into implementing.
@@ -233,8 +194,6 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
ResourceUri: sql.NullString{String: params.resource, Valid: params.resource != ""},
CodeChallenge: sql.NullString{String: params.codeChallenge, Valid: params.codeChallenge != ""},
CodeChallengeMethod: sql.NullString{String: params.codeChallengeMethod, Valid: params.codeChallengeMethod != ""},
StateHash: hashOAuth2State(params.state),
RedirectUri: sql.NullString{String: params.redirectURL.String(), Valid: params.redirectURIProvided},
})
if err != nil {
return xerrors.Errorf("insert oauth2 authorization code: %w", err)
@@ -259,16 +218,3 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
http.Redirect(rw, r, params.redirectURL.String(), http.StatusFound)
}
}
// hashOAuth2State returns a SHA-256 hash of the OAuth2 state parameter. If
// the state is empty, it returns a null string.
func hashOAuth2State(state string) sql.NullString {
if state == "" {
return sql.NullString{}
}
hash := sha256.Sum256([]byte(state))
return sql.NullString{
String: hex.EncodeToString(hash[:]),
Valid: true,
}
}
@@ -1,53 +0,0 @@
//nolint:testpackage // Internal test for unexported hashOAuth2State helper.
package oauth2provider
import (
"crypto/sha256"
"encoding/hex"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHashOAuth2State(t *testing.T) {
t.Parallel()
t.Run("EmptyState", func(t *testing.T) {
t.Parallel()
result := hashOAuth2State("")
assert.False(t, result.Valid, "empty state should return invalid NullString")
assert.Empty(t, result.String, "empty state should return empty string")
})
t.Run("NonEmptyState", func(t *testing.T) {
t.Parallel()
state := "test-state-value"
result := hashOAuth2State(state)
require.True(t, result.Valid, "non-empty state should return valid NullString")
// Verify it's a proper SHA-256 hash.
expected := sha256.Sum256([]byte(state))
assert.Equal(t, hex.EncodeToString(expected[:]), result.String,
"state hash should be SHA-256 hex digest")
})
t.Run("DifferentStatesProduceDifferentHashes", func(t *testing.T) {
t.Parallel()
hash1 := hashOAuth2State("state-a")
hash2 := hashOAuth2State("state-b")
require.True(t, hash1.Valid)
require.True(t, hash2.Valid)
assert.NotEqual(t, hash1.String, hash2.String,
"different states should produce different hashes")
})
t.Run("SameStateProducesSameHash", func(t *testing.T) {
t.Parallel()
hash1 := hashOAuth2State("deterministic")
hash2 := hashOAuth2State("deterministic")
require.True(t, hash1.Valid)
assert.Equal(t, hash1.String, hash2.String,
"same state should produce identical hash")
})
}
-32
View File
@@ -1,32 +0,0 @@
package oauth2provider_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/site"
)
func TestOAuthConsentFormIncludesCSRFToken(t *testing.T) {
t.Parallel()
const csrfFieldValue = "csrf-field-value"
req := httptest.NewRequest(http.MethodGet, "https://coder.com/oauth2/authorize", nil)
rec := httptest.NewRecorder()
site.RenderOAuthAllowPage(rec, req, site.RenderOAuthAllowData{
AppName: "Test OAuth App",
CancelURI: "https://coder.com/cancel",
RedirectURI: "https://coder.com/oauth2/authorize?client_id=test",
CSRFToken: csrfFieldValue,
Username: "test-user",
})
require.Equal(t, http.StatusOK, rec.Result().StatusCode)
assert.Contains(t, rec.Body.String(), `name="csrf_token"`)
assert.Contains(t, rec.Body.String(), `value="`+csrfFieldValue+`"`)
}
@@ -158,9 +158,7 @@ func TestOAuth2InvalidPKCE(t *testing.T) {
)
}
// TestOAuth2WithoutPKCEIsRejected verifies that authorization requests without
// a code_challenge are rejected now that PKCE is mandatory.
func TestOAuth2WithoutPKCEIsRejected(t *testing.T) {
func TestOAuth2WithoutPKCE(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
@@ -168,15 +166,15 @@ func TestOAuth2WithoutPKCEIsRejected(t *testing.T) {
})
_ = coderdtest.CreateFirstUser(t, client)
// Create OAuth2 app.
app, _ := oauth2providertest.CreateTestOAuth2App(t, client)
// Create OAuth2 app
app, clientSecret := oauth2providertest.CreateTestOAuth2App(t, client)
t.Cleanup(func() {
oauth2providertest.CleanupOAuth2App(t, client, app.ID)
})
state := oauth2providertest.GenerateState(t)
// Authorization without code_challenge should be rejected.
// Perform authorization without PKCE
authParams := oauth2providertest.AuthorizeParams{
ClientID: app.ID.String(),
ResponseType: "code",
@@ -184,9 +182,21 @@ func TestOAuth2WithoutPKCEIsRejected(t *testing.T) {
State: state,
}
oauth2providertest.AuthorizeOAuth2AppExpectingError(
t, client, client.URL.String(), authParams, http.StatusBadRequest,
)
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
require.NotEmpty(t, code, "should receive authorization code")
// Exchange code for token without PKCE
tokenParams := oauth2providertest.TokenExchangeParams{
GrantType: "authorization_code",
Code: code,
ClientID: app.ID.String(),
ClientSecret: clientSecret,
RedirectURI: oauth2providertest.TestRedirectURI,
}
token := oauth2providertest.ExchangeCodeForToken(t, client.URL.String(), tokenParams)
require.NotEmpty(t, token.AccessToken, "should receive access token")
require.NotEmpty(t, token.RefreshToken, "should receive refresh token")
}
func TestOAuth2TokenExchangeClientSecretBasic(t *testing.T) {
@@ -202,16 +212,13 @@ func TestOAuth2TokenExchangeClientSecretBasic(t *testing.T) {
oauth2providertest.CleanupOAuth2App(t, client, app.ID)
})
codeVerifier, codeChallenge := oauth2providertest.GeneratePKCE(t)
state := oauth2providertest.GenerateState(t)
authParams := oauth2providertest.AuthorizeParams{
ClientID: app.ID.String(),
ResponseType: "code",
RedirectURI: oauth2providertest.TestRedirectURI,
State: state,
CodeChallenge: codeChallenge,
CodeChallengeMethod: "S256",
ClientID: app.ID.String(),
ResponseType: "code",
RedirectURI: oauth2providertest.TestRedirectURI,
State: state,
}
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
@@ -222,7 +229,6 @@ func TestOAuth2TokenExchangeClientSecretBasic(t *testing.T) {
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", oauth2providertest.TestRedirectURI)
data.Set("code_verifier", codeVerifier)
req, err := http.NewRequestWithContext(ctx, "POST", client.URL.String()+"/oauth2/tokens", strings.NewReader(data.Encode()))
require.NoError(t, err, "failed to create token request")
@@ -259,16 +265,13 @@ func TestOAuth2TokenExchangeClientSecretBasicInvalidSecret(t *testing.T) {
oauth2providertest.CleanupOAuth2App(t, client, app.ID)
})
codeVerifier, codeChallenge := oauth2providertest.GeneratePKCE(t)
state := oauth2providertest.GenerateState(t)
authParams := oauth2providertest.AuthorizeParams{
ClientID: app.ID.String(),
ResponseType: "code",
RedirectURI: oauth2providertest.TestRedirectURI,
State: state,
CodeChallenge: codeChallenge,
CodeChallengeMethod: "S256",
ClientID: app.ID.String(),
ResponseType: "code",
RedirectURI: oauth2providertest.TestRedirectURI,
State: state,
}
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
@@ -279,7 +282,6 @@ func TestOAuth2TokenExchangeClientSecretBasicInvalidSecret(t *testing.T) {
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", oauth2providertest.TestRedirectURI)
data.Set("code_verifier", codeVerifier)
wrongSecret := clientSecret + "x"
@@ -347,30 +349,26 @@ func TestOAuth2ResourceParameter(t *testing.T) {
})
state := oauth2providertest.GenerateState(t)
codeVerifier, codeChallenge := oauth2providertest.GeneratePKCE(t)
// Perform authorization with resource parameter.
// Perform authorization with resource parameter
authParams := oauth2providertest.AuthorizeParams{
ClientID: app.ID.String(),
ResponseType: "code",
RedirectURI: oauth2providertest.TestRedirectURI,
State: state,
CodeChallenge: codeChallenge,
CodeChallengeMethod: "S256",
Resource: oauth2providertest.TestResourceURI,
ClientID: app.ID.String(),
ResponseType: "code",
RedirectURI: oauth2providertest.TestRedirectURI,
State: state,
Resource: oauth2providertest.TestResourceURI,
}
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
require.NotEmpty(t, code, "should receive authorization code")
// Exchange code for token with resource parameter.
// Exchange code for token with resource parameter
tokenParams := oauth2providertest.TokenExchangeParams{
GrantType: "authorization_code",
Code: code,
ClientID: app.ID.String(),
ClientSecret: clientSecret,
RedirectURI: oauth2providertest.TestRedirectURI,
CodeVerifier: codeVerifier,
Resource: oauth2providertest.TestResourceURI,
}
@@ -394,16 +392,13 @@ func TestOAuth2TokenRefresh(t *testing.T) {
})
state := oauth2providertest.GenerateState(t)
codeVerifier, codeChallenge := oauth2providertest.GeneratePKCE(t)
// Get initial token.
// Get initial token
authParams := oauth2providertest.AuthorizeParams{
ClientID: app.ID.String(),
ResponseType: "code",
RedirectURI: oauth2providertest.TestRedirectURI,
State: state,
CodeChallenge: codeChallenge,
CodeChallengeMethod: "S256",
ClientID: app.ID.String(),
ResponseType: "code",
RedirectURI: oauth2providertest.TestRedirectURI,
State: state,
}
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
@@ -414,7 +409,6 @@ func TestOAuth2TokenRefresh(t *testing.T) {
ClientID: app.ID.String(),
ClientSecret: clientSecret,
RedirectURI: oauth2providertest.TestRedirectURI,
CodeVerifier: codeVerifier,
}
initialToken := oauth2providertest.ExchangeCodeForToken(t, client.URL.String(), tokenParams)
+7 -20
View File
@@ -254,27 +254,14 @@ func authorizationCodeGrant(ctx context.Context, db database.Store, app database
return codersdk.OAuth2TokenResponse{}, errBadCode
}
// Verify redirect_uri matches the one used during authorization
// (RFC 6749 §4.1.3).
if dbCode.RedirectUri.Valid && dbCode.RedirectUri.String != "" {
if req.RedirectURI != dbCode.RedirectUri.String {
return codersdk.OAuth2TokenResponse{}, errBadCode
// Verify PKCE challenge if present
if dbCode.CodeChallenge.Valid && dbCode.CodeChallenge.String != "" {
if req.CodeVerifier == "" {
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
}
if !VerifyPKCE(dbCode.CodeChallenge.String, req.CodeVerifier) {
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
}
}
// PKCE is mandatory for all authorization code flows
// (OAuth 2.1). Verify the code verifier against the stored
// challenge.
if req.CodeVerifier == "" {
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
}
if !dbCode.CodeChallenge.Valid || dbCode.CodeChallenge.String == "" {
// Code was issued without a challenge — should not happen
// with authorize endpoint enforcement, but defend in depth.
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
}
if !VerifyPKCE(dbCode.CodeChallenge.String, req.CodeVerifier) {
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
}
// Verify resource parameter consistency (RFC 8707)
@@ -318,7 +318,6 @@ func TestExtractAuthorizeParams_Scopes(t *testing.T) {
query.Set("response_type", "code")
query.Set("client_id", "test-client")
query.Set("redirect_uri", "http://localhost:3000/callback")
query.Set("code_challenge", "test-challenge")
if tc.scopeParam != "" {
query.Set("scope", tc.scopeParam)
}
@@ -342,34 +341,6 @@ func TestExtractAuthorizeParams_Scopes(t *testing.T) {
}
}
// TestExtractAuthorizeParams_TokenResponseTypeDoesNotRequirePKCE ensures
// response_type=token is parsed without requiring PKCE fields so callers can
// return unsupported_response_type instead of invalid_request.
func TestExtractAuthorizeParams_TokenResponseTypeDoesNotRequirePKCE(t *testing.T) {
t.Parallel()
callbackURL, err := url.Parse("http://localhost:3000/callback")
require.NoError(t, err)
query := url.Values{}
query.Set("response_type", string(codersdk.OAuth2ProviderResponseTypeToken))
query.Set("client_id", "test-client")
query.Set("redirect_uri", "http://localhost:3000/callback")
reqURL, err := url.Parse("http://localhost:8080/oauth2/authorize?" + query.Encode())
require.NoError(t, err)
req := &http.Request{
Method: http.MethodGet,
URL: reqURL,
}
params, validationErrs, err := extractAuthorizeParams(req, callbackURL)
require.NoError(t, err)
require.Empty(t, validationErrs)
require.Equal(t, codersdk.OAuth2ProviderResponseTypeToken, params.responseType)
}
// TestRefreshTokenGrant_Scopes tests that scopes can be requested during refresh
func TestRefreshTokenGrant_Scopes(t *testing.T) {
t.Parallel()
@@ -19,9 +19,9 @@ import (
)
var (
templatesActiveUsersDesc = prometheus.NewDesc("coderd_insights_templates_active_users", "The number of active users of the template.", []string{"template_name", "organization_name"}, nil)
applicationsUsageSecondsDesc = prometheus.NewDesc("coderd_insights_applications_usage_seconds", "The application usage per template.", []string{"template_name", "application_name", "slug", "organization_name"}, nil)
parametersDesc = prometheus.NewDesc("coderd_insights_parameters", "The parameter usage per template.", []string{"template_name", "parameter_name", "parameter_type", "parameter_value", "organization_name"}, nil)
templatesActiveUsersDesc = prometheus.NewDesc("coderd_insights_templates_active_users", "The number of active users of the template.", []string{"template_name"}, nil)
applicationsUsageSecondsDesc = prometheus.NewDesc("coderd_insights_applications_usage_seconds", "The application usage per template.", []string{"template_name", "application_name", "slug"}, nil)
parametersDesc = prometheus.NewDesc("coderd_insights_parameters", "The parameter usage per template.", []string{"template_name", "parameter_name", "parameter_type", "parameter_value"}, nil)
)
type MetricsCollector struct {
@@ -38,8 +38,7 @@ type insightsData struct {
apps []database.GetTemplateAppInsightsByTemplateRow
params []parameterRow
templateNames map[uuid.UUID]string
organizationNames map[uuid.UUID]string // template ID → org name
templateNames map[uuid.UUID]string
}
type parameterRow struct {
@@ -138,7 +137,6 @@ func (mc *MetricsCollector) Run(ctx context.Context) (func(), error) {
templateIDs := uniqueTemplateIDs(templateInsights, appInsights, paramInsights)
templateNames := make(map[uuid.UUID]string, len(templateIDs))
organizationNames := make(map[uuid.UUID]string, len(templateIDs))
if len(templateIDs) > 0 {
templates, err := mc.database.GetTemplatesWithFilter(ctx, database.GetTemplatesWithFilterParams{
IDs: templateIDs,
@@ -148,31 +146,6 @@ func (mc *MetricsCollector) Run(ctx context.Context) (func(), error) {
return
}
templateNames = onlyTemplateNames(templates)
// Build org name lookup so that metrics can
// distinguish templates with the same name across
// different organizations.
orgIDs := make([]uuid.UUID, 0, len(templates))
for _, t := range templates {
orgIDs = append(orgIDs, t.OrganizationID)
}
orgIDs = slice.Unique(orgIDs)
orgs, err := mc.database.GetOrganizations(ctx, database.GetOrganizationsParams{
IDs: orgIDs,
})
if err != nil {
mc.logger.Error(ctx, "unable to fetch organizations from database", slog.Error(err))
return
}
orgNameByID := make(map[uuid.UUID]string, len(orgs))
for _, o := range orgs {
orgNameByID[o.ID] = o.Name
}
organizationNames = make(map[uuid.UUID]string, len(templates))
for _, t := range templates {
organizationNames[t.ID] = orgNameByID[t.OrganizationID]
}
}
// Refresh the collector state
@@ -181,8 +154,7 @@ func (mc *MetricsCollector) Run(ctx context.Context) (func(), error) {
apps: appInsights,
params: paramInsights,
templateNames: templateNames,
organizationNames: organizationNames,
templateNames: templateNames,
})
}
@@ -222,46 +194,44 @@ func (mc *MetricsCollector) Collect(metricsCh chan<- prometheus.Metric) {
// Custom apps
for _, appRow := range data.apps {
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue, float64(appRow.UsageSeconds), data.templateNames[appRow.TemplateID],
appRow.DisplayName, appRow.SlugOrPort, data.organizationNames[appRow.TemplateID])
appRow.DisplayName, appRow.SlugOrPort)
}
// Built-in apps
for _, templateRow := range data.templates {
orgName := data.organizationNames[templateRow.TemplateID]
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue,
float64(templateRow.UsageVscodeSeconds),
data.templateNames[templateRow.TemplateID],
codersdk.TemplateBuiltinAppDisplayNameVSCode,
"", orgName)
"")
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue,
float64(templateRow.UsageJetbrainsSeconds),
data.templateNames[templateRow.TemplateID],
codersdk.TemplateBuiltinAppDisplayNameJetBrains,
"", orgName)
"")
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue,
float64(templateRow.UsageReconnectingPtySeconds),
data.templateNames[templateRow.TemplateID],
codersdk.TemplateBuiltinAppDisplayNameWebTerminal,
"", orgName)
"")
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue,
float64(templateRow.UsageSshSeconds),
data.templateNames[templateRow.TemplateID],
codersdk.TemplateBuiltinAppDisplayNameSSH,
"", orgName)
"")
}
// Templates
for _, templateRow := range data.templates {
metricsCh <- prometheus.MustNewConstMetric(templatesActiveUsersDesc, prometheus.GaugeValue, float64(templateRow.ActiveUsers), data.templateNames[templateRow.TemplateID], data.organizationNames[templateRow.TemplateID])
metricsCh <- prometheus.MustNewConstMetric(templatesActiveUsersDesc, prometheus.GaugeValue, float64(templateRow.ActiveUsers), data.templateNames[templateRow.TemplateID])
}
// Parameters
for _, parameterRow := range data.params {
metricsCh <- prometheus.MustNewConstMetric(parametersDesc, prometheus.GaugeValue, float64(parameterRow.count), data.templateNames[parameterRow.templateID], parameterRow.name, parameterRow.aType, parameterRow.value, data.organizationNames[parameterRow.templateID])
metricsCh <- prometheus.MustNewConstMetric(parametersDesc, prometheus.GaugeValue, float64(parameterRow.count), data.templateNames[parameterRow.templateID], parameterRow.name, parameterRow.aType, parameterRow.value)
}
}
@@ -1,13 +1,13 @@
{
"coderd_insights_applications_usage_seconds[application_name=JetBrains,organization_name=coder,slug=,template_name=golden-template]": 60,
"coderd_insights_applications_usage_seconds[application_name=Visual Studio Code,organization_name=coder,slug=,template_name=golden-template]": 60,
"coderd_insights_applications_usage_seconds[application_name=Web Terminal,organization_name=coder,slug=,template_name=golden-template]": 0,
"coderd_insights_applications_usage_seconds[application_name=SSH,organization_name=coder,slug=,template_name=golden-template]": 60,
"coderd_insights_applications_usage_seconds[application_name=Golden Slug,organization_name=coder,slug=golden-slug,template_name=golden-template]": 180,
"coderd_insights_parameters[organization_name=coder,parameter_name=first_parameter,parameter_type=string,parameter_value=Foobar,template_name=golden-template]": 1,
"coderd_insights_parameters[organization_name=coder,parameter_name=first_parameter,parameter_type=string,parameter_value=Baz,template_name=golden-template]": 1,
"coderd_insights_parameters[organization_name=coder,parameter_name=second_parameter,parameter_type=bool,parameter_value=true,template_name=golden-template]": 2,
"coderd_insights_parameters[organization_name=coder,parameter_name=third_parameter,parameter_type=number,parameter_value=789,template_name=golden-template]": 1,
"coderd_insights_parameters[organization_name=coder,parameter_name=third_parameter,parameter_type=number,parameter_value=999,template_name=golden-template]": 1,
"coderd_insights_templates_active_users[organization_name=coder,template_name=golden-template]": 1
"coderd_insights_applications_usage_seconds[application_name=JetBrains,slug=,template_name=golden-template]": 60,
"coderd_insights_applications_usage_seconds[application_name=Visual Studio Code,slug=,template_name=golden-template]": 60,
"coderd_insights_applications_usage_seconds[application_name=Web Terminal,slug=,template_name=golden-template]": 0,
"coderd_insights_applications_usage_seconds[application_name=SSH,slug=,template_name=golden-template]": 60,
"coderd_insights_applications_usage_seconds[application_name=Golden Slug,slug=golden-slug,template_name=golden-template]": 180,
"coderd_insights_parameters[parameter_name=first_parameter,parameter_type=string,parameter_value=Foobar,template_name=golden-template]": 1,
"coderd_insights_parameters[parameter_name=first_parameter,parameter_type=string,parameter_value=Baz,template_name=golden-template]": 1,
"coderd_insights_parameters[parameter_name=second_parameter,parameter_type=bool,parameter_value=true,template_name=golden-template]": 2,
"coderd_insights_parameters[parameter_name=third_parameter,parameter_type=number,parameter_value=789,template_name=golden-template]": 1,
"coderd_insights_parameters[parameter_name=third_parameter,parameter_type=number,parameter_value=999,template_name=golden-template]": 1,
"coderd_insights_templates_active_users[template_name=golden-template]": 1
}
@@ -725,16 +725,11 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
}
}
provisionerStateRow, err := s.Database.GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuild.ID)
if err != nil {
return nil, failJob(fmt.Sprintf("get workspace build provisioner state: %s", err))
}
protoJob.Type = &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
WorkspaceBuildId: workspaceBuild.ID.String(),
WorkspaceName: workspace.Name,
State: provisionerStateRow.ProvisionerState,
State: workspaceBuild.ProvisionerState,
RichParameterValues: convertRichParameterValues(workspaceBuildParameters),
PreviousParameterValues: convertRichParameterValues(lastWorkspaceBuildParameters),
VariableValues: asVariableValues(templateVariables),
@@ -845,11 +840,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
// Record the time the job spent waiting in the queue.
if s.metrics != nil && job.StartedAt.Valid && job.Provisioner.Valid() {
// These timestamps lose their monotonic clock component after a Postgres
// round-trip, so the subtraction is based purely on wall-clock time. Floor at
// 1ms as a defensive measure against clock adjustments producing a negative
// delta while acknowledging there's a non-zero queue time.
queueWaitSeconds := max(job.StartedAt.Time.Sub(job.CreatedAt).Seconds(), 0.001)
queueWaitSeconds := job.StartedAt.Time.Sub(job.CreatedAt).Seconds()
s.metrics.ObserveJobQueueWait(string(job.Provisioner), string(job.Type), jobTransition, jobBuildReason, queueWaitSeconds)
}
@@ -1321,9 +1321,7 @@ func TestFailJob(t *testing.T) {
<-publishedLogs
build, err := db.GetWorkspaceBuildByID(ctx, buildID)
require.NoError(t, err)
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
require.NoError(t, err)
require.Equal(t, "some state", string(provisionerStateRow.ProvisionerState))
require.Equal(t, "some state", string(build.ProvisionerState))
require.Len(t, auditor.AuditLogs(), 1)
// Assert that the workspace_id field get populated
-1
View File
@@ -81,7 +81,6 @@ const (
SubjectAibridged SubjectType = "aibridged"
SubjectTypeDBPurge SubjectType = "dbpurge"
SubjectTypeBoundaryUsageTracker SubjectType = "boundary_usage_tracker"
SubjectTypeWorkspaceBuilder SubjectType = "workspace_builder"
)
const (
-29
View File
@@ -401,35 +401,6 @@ func AIBridgeInterceptions(ctx context.Context, db database.Store, query string,
return filter, parser.Errors
}
func AIBridgeModels(query string, page codersdk.Pagination) (database.ListAIBridgeModelsParams, []codersdk.ValidationError) {
// nolint:exhaustruct // Empty values just means "don't filter by that field".
filter := database.ListAIBridgeModelsParams{
// #nosec G115 - Safe conversion for pagination offset which is expected to be within int32 range
Offset: int32(page.Offset),
// #nosec G115 - Safe conversion for pagination limit which is expected to be within int32 range
Limit: int32(page.Limit),
}
if query == "" {
return filter, nil
}
values, errors := searchTerms(query, func(term string, values url.Values) error {
// Defaults to the `model` if no `key:value` pair is provided.
values.Add("model", term)
return nil
})
if len(errors) > 0 {
return filter, errors
}
parser := httpapi.NewQueryParamParser()
filter.Model = parser.String(values, "", "model")
parser.ErrorExcessParams(values)
return filter, parser.Errors
}
// Tasks parses a search query for tasks.
//
// Supported query parameters:
+1 -1
View File
@@ -22,7 +22,7 @@ import (
)
const (
defaultModel = anthropic.ModelClaudeHaiku4_5
defaultModel = anthropic.ModelClaude3_5HaikuLatest
systemPrompt = `Generate a short task display name and name from this AI task prompt.
Identify the main task (the core action and subject) and base both names on it.
The task display name and name should be as similar as possible so a human can easily associate them.
+25 -163
View File
@@ -416,10 +416,9 @@ func checkIDPOrgSync(ctx context.Context, db database.Store, values *codersdk.De
func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
var (
ctx = r.ctx
now = r.options.Clock.Now()
// For resources that grow in size very quickly (like workspace builds),
// we only report events that occurred within the past hour.
createdAfter = dbtime.Time(now.Add(-1 * time.Hour)).UTC()
createdAfter = dbtime.Time(r.options.Clock.Now().Add(-1 * time.Hour)).UTC()
eg errgroup.Group
snapshot = &Snapshot{
DeploymentID: r.options.DeploymentID,
@@ -741,19 +740,17 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
return nil
})
eg.Go(func() error {
tasks, err := CollectTasks(ctx, r.options.Database)
dbTasks, err := r.options.Database.ListTasks(ctx, database.ListTasksParams{
OwnerID: uuid.Nil,
OrganizationID: uuid.Nil,
Status: "",
})
if err != nil {
return xerrors.Errorf("collect tasks telemetry: %w", err)
return err
}
snapshot.Tasks = tasks
return nil
})
eg.Go(func() error {
events, err := CollectTaskEvents(ctx, r.options.Database, createdAfter, now)
if err != nil {
return xerrors.Errorf("collect task events telemetry: %w", err)
for _, dbTask := range dbTasks {
snapshot.Tasks = append(snapshot.Tasks, ConvertTask(dbTask))
}
snapshot.TaskEvents = events
return nil
})
eg.Go(func() error {
@@ -905,129 +902,6 @@ func (r *remoteReporter) collectBoundaryUsageSummary(ctx context.Context) (*Boun
}, nil
}
func CollectTasks(ctx context.Context, db database.Store) ([]Task, error) {
dbTasks, err := db.ListTasks(ctx, database.ListTasksParams{
OwnerID: uuid.Nil,
OrganizationID: uuid.Nil,
Status: "",
})
if err != nil {
return nil, xerrors.Errorf("list tasks: %w", err)
}
if len(dbTasks) == 0 {
return []Task{}, nil
}
tasks := make([]Task, 0, len(dbTasks))
for _, dbTask := range dbTasks {
tasks = append(tasks, ConvertTask(dbTask))
}
return tasks, nil
}
// buildTaskEvent constructs a TaskEvent from the combined query row.
func buildTaskEvent(
row database.GetTelemetryTaskEventsRow,
createdAfter time.Time,
now time.Time,
) TaskEvent {
event := TaskEvent{
TaskID: row.TaskID.String(),
}
var (
hasStartBuild = row.StartBuildCreatedAt.Valid
isResumed = hasStartBuild && row.StartBuildNumber.Valid && row.StartBuildNumber.Int32 > 1
hasStopBuild = row.StopBuildCreatedAt.Valid
startedAfterStop = hasStartBuild && hasStopBuild && row.StartBuildCreatedAt.Time.After(row.StopBuildCreatedAt.Time)
currentlyPaused = hasStopBuild && !startedAfterStop
)
// Pause-related fields (requires a stop build).
if hasStopBuild {
event.LastPausedAt = &row.StopBuildCreatedAt.Time
switch {
case row.StopBuildReason.Valid && row.StopBuildReason.BuildReason == database.BuildReasonTaskAutoPause:
event.PauseReason = ptr.Ref("auto")
case row.StopBuildReason.Valid && row.StopBuildReason.BuildReason == database.BuildReasonTaskManualPause:
event.PauseReason = ptr.Ref("manual")
default:
event.PauseReason = ptr.Ref("other")
}
// Idle duration: time between last working status and the pause.
if row.LastWorkingStatusAt.Valid &&
row.StopBuildCreatedAt.Time.After(row.LastWorkingStatusAt.Time) {
idle := row.StopBuildCreatedAt.Time.Sub(row.LastWorkingStatusAt.Time)
event.IdleDurationMS = ptr.Ref(idle.Milliseconds())
}
}
// Resume-related fields (requires task_resume start after stop).
if startedAfterStop {
// Paused duration: time between pause and resume.
if row.StartBuildCreatedAt.Time.After(createdAfter) {
paused := row.StartBuildCreatedAt.Time.Sub(row.StopBuildCreatedAt.Time)
event.PausedDurationMS = ptr.Ref(paused.Milliseconds())
}
// Below only relevant for "resumed" tasks, not when initially created.
if isResumed {
event.LastResumedAt = &row.StartBuildCreatedAt.Time
switch {
// TODO(Cian): will this exist? Future readers may know better than I.
// case row.StartBuildReason == database.BuildReasonTaskAutoResume:
// event.ResumeReason = ptr.Ref("auto")
case row.StartBuildReason.BuildReason == database.BuildReasonTaskResume:
event.ResumeReason = ptr.Ref("manual")
default: // Task resumed by starting workspace?
event.ResumeReason = ptr.Ref("other")
}
}
}
// Unresolved pause: report current paused duration.
if currentlyPaused {
paused := now.Sub(row.StopBuildCreatedAt.Time)
event.PausedDurationMS = ptr.Ref(paused.Milliseconds())
}
// Resume-to-status duration.
if row.FirstStatusAfterResumeAt.Valid && isResumed {
delta := row.FirstStatusAfterResumeAt.Time.Sub(row.StartBuildCreatedAt.Time)
event.ResumeToStatusMS = ptr.Ref(delta.Milliseconds())
}
// Active duration: from SQL calculation.
if row.ActiveDurationMs > 0 {
event.ActiveDurationMS = ptr.Ref(row.ActiveDurationMs)
}
return event
}
// CollectTaskEvents collects lifecycle events for tasks with recent activity.
func CollectTaskEvents(ctx context.Context, db database.Store, createdAfter, now time.Time) ([]TaskEvent, error) {
rows, err := db.GetTelemetryTaskEvents(ctx, database.GetTelemetryTaskEventsParams{
CreatedAfter: createdAfter,
Now: now,
})
if err != nil {
return nil, xerrors.Errorf("get telemetry task events: %w", err)
}
events := make([]TaskEvent, 0, len(rows))
for _, row := range rows {
events = append(events, buildTaskEvent(row, createdAfter, now))
}
return events, nil
}
// HashContent returns a SHA256 hash of the content as a hex string.
// This is useful for hashing sensitive content like prompts for telemetry.
func HashContent(content string) string {
return fmt.Sprintf("%x", sha256.Sum256([]byte(content)))
}
// ConvertAPIKey anonymizes an API key.
func ConvertAPIKey(apiKey database.APIKey) APIKey {
a := APIKey{
@@ -1496,7 +1370,6 @@ type Snapshot struct {
NetworkEvents []NetworkEvent `json:"network_events"`
Organizations []Organization `json:"organizations"`
Tasks []Task `json:"tasks"`
TaskEvents []TaskEvent `json:"task_events"`
TelemetryItems []TelemetryItem `json:"telemetry_items"`
UserTailnetConnections []UserTailnetConnection `json:"user_tailnet_connections"`
PrebuiltWorkspaces []PrebuiltWorkspace `json:"prebuilt_workspaces"`
@@ -2058,36 +1931,25 @@ type Task struct {
WorkspaceAppID *string `json:"workspace_app_id"`
TemplateVersionID string `json:"template_version_id"`
PromptHash string `json:"prompt_hash"` // Prompt is hashed for privacy.
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
Status string `json:"status"`
}
// TaskEvent represents lifecycle events for a task (pause/resume
// cycles). The createdAfter parameter gates PausedDurationMS so
// that only recent pause/resume pairs are reported.
type TaskEvent struct {
TaskID string `json:"task_id"`
LastPausedAt *time.Time `json:"last_paused_at"`
LastResumedAt *time.Time `json:"last_resumed_at"`
PauseReason *string `json:"pause_reason"`
ResumeReason *string `json:"resume_reason"`
IdleDurationMS *int64 `json:"idle_duration_ms"`
PausedDurationMS *int64 `json:"paused_duration_ms"`
ResumeToStatusMS *int64 `json:"resume_to_status_ms"`
ActiveDurationMS *int64 `json:"active_duration_ms"`
}
// ConvertTask converts a database Task to a telemetry Task.
// ConvertTask anonymizes a Task.
func ConvertTask(task database.Task) Task {
t := Task{
ID: task.ID.String(),
OrganizationID: task.OrganizationID.String(),
OwnerID: task.OwnerID.String(),
Name: task.Name,
TemplateVersionID: task.TemplateVersionID.String(),
PromptHash: HashContent(task.Prompt),
Status: string(task.Status),
CreatedAt: task.CreatedAt,
t := &Task{
ID: task.ID.String(),
OrganizationID: task.OrganizationID.String(),
OwnerID: task.OwnerID.String(),
Name: task.Name,
WorkspaceID: nil,
WorkspaceBuildNumber: nil,
WorkspaceAgentID: nil,
WorkspaceAppID: nil,
TemplateVersionID: task.TemplateVersionID.String(),
PromptHash: fmt.Sprintf("%x", sha256.Sum256([]byte(task.Prompt))),
CreatedAt: task.CreatedAt,
Status: string(task.Status),
}
if task.WorkspaceID.Valid {
t.WorkspaceID = ptr.Ref(task.WorkspaceID.UUID.String())
@@ -2101,7 +1963,7 @@ func ConvertTask(task database.Task) Task {
if task.WorkspaceAppID.Valid {
t.WorkspaceAppID = ptr.Ref(task.WorkspaceAppID.UUID.String())
}
return t
return *t
}
type telemetryItemKey string
+1 -584
View File
@@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
@@ -14,7 +13,6 @@ import (
"time"
"github.com/go-chi/chi/v5"
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -23,14 +21,12 @@ import (
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/boundaryusage"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/idpsync"
"github.com/coder/coder/v2/coderd/runtimeconfig"
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
@@ -317,17 +313,6 @@ func TestTelemetry(t *testing.T) {
require.Equal(t, string(database.WorkspaceAgentSubsystemEnvbox), wsa.Subsystems[0])
require.Equal(t, string(database.WorkspaceAgentSubsystemExectrace), wsa.Subsystems[1])
require.Len(t, snapshot.Tasks, 1)
require.Len(t, snapshot.TaskEvents, 1)
taskEvent := snapshot.TaskEvents[0]
assert.Equal(t, task.ID.String(), taskEvent.TaskID)
assert.Nil(t, taskEvent.LastResumedAt)
assert.Nil(t, taskEvent.LastPausedAt)
assert.Nil(t, taskEvent.PauseReason)
assert.Nil(t, taskEvent.ResumeReason)
assert.Nil(t, taskEvent.IdleDurationMS)
assert.Nil(t, taskEvent.PausedDurationMS)
assert.Nil(t, taskEvent.ResumeToStatusMS)
assert.Nil(t, taskEvent.ActiveDurationMS)
for _, snapTask := range snapshot.Tasks {
assert.Equal(t, task.ID.String(), snapTask.ID)
assert.Equal(t, task.OrganizationID.String(), snapTask.OrganizationID)
@@ -341,7 +326,6 @@ func TestTelemetry(t *testing.T) {
assert.Equal(t, taskWA.WorkspaceAppID.UUID.String(), *snapTask.WorkspaceAppID)
assert.Equal(t, task.TemplateVersionID.String(), snapTask.TemplateVersionID)
assert.Equal(t, "e196fe22e61cfa32d8c38749e0ce348108bb4cae29e2c36cdcce7e77faa9eb5f", snapTask.PromptHash)
assert.Equal(t, string(task.Status), snapTask.Status)
assert.Equal(t, task.CreatedAt.UTC(), snapTask.CreatedAt.UTC())
}
@@ -691,573 +675,6 @@ func TestPrebuiltWorkspacesTelemetry(t *testing.T) {
}
}
// taskTelemetryHelper is a grab bag of stuff useful in task telemetry test cases
type taskTelemetryHelper struct {
t *testing.T
ctx context.Context
db database.Store
org database.Organization
user database.User
}
// createBuild creates a workspace build with the given parameters,
// handling provisioner job creation automatically.
func (h *taskTelemetryHelper) createBuild(
resp dbfake.WorkspaceResponse,
buildNumber int32,
createdAt time.Time,
transition database.WorkspaceTransition,
reason database.BuildReason,
) (database.WorkspaceBuild, *database.WorkspaceApp) {
job := dbgen.ProvisionerJob(h.t, h.db, nil, database.ProvisionerJob{
Provisioner: database.ProvisionerTypeTerraform,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeWorkspaceBuild,
OrganizationID: h.org.ID,
})
bld := dbgen.WorkspaceBuild(h.t, h.db, database.WorkspaceBuild{
WorkspaceID: resp.Workspace.ID,
TemplateVersionID: resp.TemplateVersion.ID,
JobID: job.ID,
Transition: transition,
Reason: reason,
BuildNumber: buildNumber,
CreatedAt: createdAt,
HasAITask: sql.NullBool{
Bool: true,
Valid: true,
},
})
if transition == database.WorkspaceTransitionStart {
require.NotEmpty(h.t, resp.Agents, "need at least one agent")
agt := resp.Agents[0]
// App IDs are regenerated by provisionerd each build.
app := dbgen.WorkspaceApp(h.t, h.db, database.WorkspaceApp{
AgentID: agt.ID,
})
_, err := h.db.UpsertTaskWorkspaceApp(h.ctx, database.UpsertTaskWorkspaceAppParams{
TaskID: resp.Task.ID,
WorkspaceBuildNumber: buildNumber,
WorkspaceAgentID: uuid.NullUUID{UUID: agt.ID, Valid: true},
WorkspaceAppID: uuid.NullUUID{UUID: app.ID, Valid: true},
})
require.NoError(h.t, err, "failed to upsert task app")
return bld, &app
}
return bld, nil
}
// nolint: dupl // Test code is better WET than DRY.
func TestTasksTelemetry(t *testing.T) {
t.Parallel()
// Define a fixed reference time for deterministic testing.
now := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
createAppStatus := func(ctx context.Context, db database.Store, wsID uuid.UUID, agentID, appID uuid.UUID, state database.WorkspaceAppStatusState, message string, createdAt time.Time) {
_, err := db.InsertWorkspaceAppStatus(ctx, database.InsertWorkspaceAppStatusParams{
ID: uuid.New(),
CreatedAt: createdAt,
WorkspaceID: wsID,
AgentID: agentID,
AppID: appID,
State: state,
Message: message,
})
require.NoError(t, err)
}
getApp := func(ctx context.Context, db database.Store, agentID uuid.UUID) database.WorkspaceApp {
apps, err := db.GetWorkspaceAppsByAgentID(ctx, agentID)
require.NoError(t, err)
require.NotEmpty(t, apps, "expected at least one app")
return apps[0]
}
type statusSpec struct {
state database.WorkspaceAppStatusState
message string
offset time.Duration
}
type buildSpec struct {
buildNumber int32
offset time.Duration
transition database.WorkspaceTransition
reason database.BuildReason
statuses []statusSpec // created after this build, using this build's app
}
tests := []struct {
name string
// Input: DB setup.
skipWorkspace bool
createdOffset time.Duration
buildOffset *time.Duration
extraBuilds []buildSpec
appStatuses []statusSpec
// Expected output.
expectEvent bool
lastPausedOffset *time.Duration
lastResumedOffset *time.Duration
pauseReason *string
resumeReason *string
idleDurationMS *int64
pausedDurationMS *int64
resumeToStatusMS *int64
activeDurationMS *int64
}{
{
name: "no workspace - all lifecycle fields nil",
skipWorkspace: true,
createdOffset: -1 * time.Hour,
},
{
name: "running workspace - no pause/resume events",
createdOffset: -45 * time.Minute,
buildOffset: ptr.Ref(-30 * time.Minute),
expectEvent: true,
},
{
name: "with app status - no lifecycle events",
createdOffset: -90 * time.Minute,
buildOffset: ptr.Ref(-45 * time.Minute),
appStatuses: []statusSpec{
{database.WorkspaceAppStatusStateWorking, "Task started", -40 * time.Minute},
},
expectEvent: true,
// ResumeToStatusMS is nil because initial start (BuildReasonInitiator)
// doesn't count - only task_resume starts are considered.
activeDurationMS: ptr.Ref(int64(40 * time.Minute / time.Millisecond)),
},
{
name: "auto paused - LastPausedAt and PauseReason=auto",
createdOffset: -3 * time.Hour,
extraBuilds: []buildSpec{
{2, -20 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
},
expectEvent: true,
lastPausedOffset: ptr.Ref(-20 * time.Minute),
pauseReason: ptr.Ref("auto"),
pausedDurationMS: ptr.Ref(20 * time.Minute.Milliseconds()), // Ongoing pause.
},
{
name: "manual paused - LastPausedAt and PauseReason=manual",
createdOffset: -4 * time.Hour,
extraBuilds: []buildSpec{
{2, -15 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskManualPause, nil},
},
expectEvent: true,
lastPausedOffset: ptr.Ref(-15 * time.Minute),
pauseReason: ptr.Ref("manual"),
pausedDurationMS: ptr.Ref(15 * time.Minute.Milliseconds()), // Ongoing pause.
},
{
name: "paused with idle time - IdleDurationMS calculated",
createdOffset: -5 * time.Hour,
appStatuses: []statusSpec{
{database.WorkspaceAppStatusStateWorking, "Working on something", -40 * time.Minute},
{database.WorkspaceAppStatusStateIdle, "Idle now", -35 * time.Minute},
},
extraBuilds: []buildSpec{
{2, -25 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
},
expectEvent: true,
lastPausedOffset: ptr.Ref(-25 * time.Minute),
pauseReason: ptr.Ref("auto"),
idleDurationMS: ptr.Ref(15 * time.Minute.Milliseconds()), // Last working (-40) to stop (-25).
activeDurationMS: ptr.Ref(5 * time.Minute.Milliseconds()), // -40 min (working) to -35 min (idle).
pausedDurationMS: ptr.Ref(25 * time.Minute.Milliseconds()), // Ongoing pause: now - (-25min).
},
{
name: "paused with working status after pause - IdleDurationMS nil",
createdOffset: -5 * time.Hour,
appStatuses: []statusSpec{
{database.WorkspaceAppStatusStateWorking, "Working after pause", -20 * time.Minute},
},
extraBuilds: []buildSpec{
{2, -25 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
},
expectEvent: true,
lastPausedOffset: ptr.Ref(-25 * time.Minute),
pauseReason: ptr.Ref("auto"),
pausedDurationMS: ptr.Ref(25 * time.Minute.Milliseconds()), // Ongoing pause.
// IdleDurationMS is nil because "last working" is after pause.
// ActiveDurationMS is nil because working→stop interval is negative.
},
{
name: "recently resumed - PausedDurationMS calculated",
createdOffset: -6 * time.Hour,
extraBuilds: []buildSpec{
{2, -50 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
{3, -10 * time.Minute, database.WorkspaceTransitionStart, database.BuildReasonTaskResume, nil},
},
expectEvent: true,
lastPausedOffset: ptr.Ref(-50 * time.Minute),
lastResumedOffset: ptr.Ref(-10 * time.Minute),
pauseReason: ptr.Ref("auto"),
resumeReason: ptr.Ref("manual"),
pausedDurationMS: ptr.Ref(40 * time.Minute.Milliseconds()),
},
{
// This test verifies that we do not double-report task events outside of the window.
name: "resumed long ago - PausedDurationMS nil",
createdOffset: -10 * time.Hour,
extraBuilds: []buildSpec{
{2, -5 * time.Hour, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
{3, -2 * time.Hour, database.WorkspaceTransitionStart, database.BuildReasonTaskResume, nil},
},
expectEvent: false,
},
{
name: "multiple cycles - captures latest pause/resume",
createdOffset: -8 * time.Hour,
extraBuilds: []buildSpec{
{2, -3 * time.Hour, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
{3, -150 * time.Minute, database.WorkspaceTransitionStart, database.BuildReasonTaskResume, nil},
{4, -30 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskManualPause, nil},
},
expectEvent: true,
lastPausedOffset: ptr.Ref(-30 * time.Minute),
pauseReason: ptr.Ref("manual"),
pausedDurationMS: ptr.Ref(30 * time.Minute.Milliseconds()), // Ongoing pause: now - (-30min).
},
{
name: "currently paused after recent resume - reports ongoing pause",
createdOffset: -6 * time.Hour,
extraBuilds: []buildSpec{
{2, -50 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
{3, -30 * time.Minute, database.WorkspaceTransitionStart, database.BuildReasonTaskResume, nil},
{4, -10 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskManualPause, nil},
},
expectEvent: true,
lastPausedOffset: ptr.Ref(-10 * time.Minute),
pauseReason: ptr.Ref("manual"),
pausedDurationMS: ptr.Ref(10 * time.Minute.Milliseconds()), // Ongoing pause: now - pause time.
},
{
name: "multiple cycles with recent resume - pairs with preceding pause",
createdOffset: -6 * time.Hour,
appStatuses: []statusSpec{
{database.WorkspaceAppStatusStateWorking, "started work", -6 * time.Hour},
},
extraBuilds: []buildSpec{
{2, -50 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
{3, -30 * time.Minute, database.WorkspaceTransitionStart, database.BuildReasonTaskResume, []statusSpec{
{database.WorkspaceAppStatusStateWorking, "resumed work", -25 * time.Minute},
}},
},
expectEvent: true,
lastPausedOffset: ptr.Ref(-50 * time.Minute),
lastResumedOffset: ptr.Ref(-30 * time.Minute),
pauseReason: ptr.Ref("auto"),
resumeReason: ptr.Ref("manual"),
pausedDurationMS: ptr.Ref(20 * time.Minute.Milliseconds()),
resumeToStatusMS: ptr.Ref((5 * time.Minute).Milliseconds()),
// Build 1 ("started work") -> Build 2 (stop) (5h10m) + Build 3 ("resumed work") -> now (25m)
// TODO(cian): We define IdleDurationMS as "the time from the last working status to pause".
// We know that the task has reported working since T-6h and got auto-paused at T-50m.
// We can reasonably assume that it has been 'idle' from when it was stopped (T-30m) to
// its next report at T-25m. This is covered by ResumeToStatusMS.
// But do we consider the time since its last report (T-6h) to its being auto-paused
// as truly "idle"?
idleDurationMS: ptr.Ref(310 * time.Minute.Milliseconds()),
activeDurationMS: ptr.Ref((5*time.Hour + 10*time.Minute + 25*time.Minute).Milliseconds()),
},
{
name: "all fields populated - full lifecycle",
createdOffset: -7 * time.Hour,
appStatuses: []statusSpec{
{database.WorkspaceAppStatusStateWorking, "Started working", -390 * time.Minute},
{database.WorkspaceAppStatusStateWorking, "Still working", -45 * time.Minute},
},
extraBuilds: []buildSpec{
{2, -35 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
{3, -5 * time.Minute, database.WorkspaceTransitionStart, database.BuildReasonTaskResume, []statusSpec{
{database.WorkspaceAppStatusStateWorking, "Resumed work", -3 * time.Minute},
{database.WorkspaceAppStatusStateIdle, "Finished work", -2 * time.Minute},
}},
},
expectEvent: true,
lastPausedOffset: ptr.Ref(-35 * time.Minute),
lastResumedOffset: ptr.Ref(-5 * time.Minute),
pauseReason: ptr.Ref("auto"),
resumeReason: ptr.Ref("manual"),
idleDurationMS: ptr.Ref(10 * time.Minute.Milliseconds()),
pausedDurationMS: ptr.Ref(30 * time.Minute.Milliseconds()),
resumeToStatusMS: ptr.Ref((2 * time.Minute).Milliseconds()),
// Active duration: (-390 to -35) + (-3 to -2) = 355 + 1 = 356 min.
activeDurationMS: ptr.Ref(356 * time.Minute.Milliseconds()),
},
{
name: "non-task_resume builds are tracked as other",
createdOffset: -4 * time.Hour,
extraBuilds: []buildSpec{
{2, -60 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
{3, -30 * time.Minute, database.WorkspaceTransitionStart, database.BuildReasonInitiator, nil},
},
expectEvent: true,
lastPausedOffset: ptr.Ref(-60 * time.Minute),
pauseReason: ptr.Ref("auto"),
resumeReason: ptr.Ref("other"),
// LastResumedAt is set because isResumed is true (build_number > 1)
// even though the start reason isn't task_resume.
lastResumedOffset: ptr.Ref(-30 * time.Minute),
// PausedDurationMS reports ongoing pause: now - (-60min) = 60min.
pausedDurationMS: ptr.Ref(30 * time.Minute.Milliseconds()),
},
{
name: "simple ongoing pause reports duration",
createdOffset: -3 * time.Hour,
extraBuilds: []buildSpec{
{2, -45 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
},
expectEvent: true,
lastPausedOffset: ptr.Ref(-45 * time.Minute),
pauseReason: ptr.Ref("auto"),
// No resume, so ongoing pause: now - (-45min) = 45min.
pausedDurationMS: ptr.Ref(45 * time.Minute.Milliseconds()),
},
{
name: "active duration with paused task",
createdOffset: -2 * time.Hour,
buildOffset: ptr.Ref(-2 * time.Hour),
appStatuses: []statusSpec{
{database.WorkspaceAppStatusStateWorking, "Started", -90 * time.Minute},
{database.WorkspaceAppStatusStateIdle, "Thinking", -60 * time.Minute}, // 30min working
{database.WorkspaceAppStatusStateWorking, "Resumed", -45 * time.Minute},
{database.WorkspaceAppStatusStateComplete, "Done", -30 * time.Minute}, // 15min working
},
extraBuilds: []buildSpec{
{2, -25 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
},
expectEvent: true,
lastPausedOffset: ptr.Ref(-25 * time.Minute),
pauseReason: ptr.Ref("auto"),
idleDurationMS: ptr.Ref(20 * time.Minute.Milliseconds()), // Last working (-45) to stop (-25).
activeDurationMS: ptr.Ref(45 * time.Minute.Milliseconds()), // 30 + 15 = 45min of "working".
pausedDurationMS: ptr.Ref(25 * time.Minute.Milliseconds()), // Ongoing pause.
},
{
// When a workspace_app_status and a workspace_build share
// the exact same created_at timestamp, the ordering inside
// task_status_timeline is ambiguous. The boundary row must
// sort after real statuses so that LEAD() and the lws
// lateral join produce deterministic results.
name: "status and build at same timestamp - deterministic ordering",
createdOffset: -3 * time.Hour,
buildOffset: ptr.Ref(-2 * time.Hour),
appStatuses: []statusSpec{
{database.WorkspaceAppStatusStateWorking, "Started work", -90 * time.Minute},
// This status has the exact same timestamp as the
// stop build below, exercising the tiebreaker.
{database.WorkspaceAppStatusStateWorking, "Last update before pause", -30 * time.Minute},
},
extraBuilds: []buildSpec{
{2, -30 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
},
expectEvent: true,
lastPausedOffset: ptr.Ref(-30 * time.Minute),
pauseReason: ptr.Ref("auto"),
// IdleDurationMS is nil: the Go code requires
// stop.After(lastWorking), which is false when equal.
// Active: -90m (working) → -30m (boundary/stop) = 60 min.
activeDurationMS: ptr.Ref(60 * time.Minute.Milliseconds()),
pausedDurationMS: ptr.Ref(30 * time.Minute.Milliseconds()),
},
{
// SQL filter: EXISTS (workspace_builds.created_at > createdAfter).
// This task has only old builds (7 days ago), so it won't match
// the 1-hour createdAfter filter and should not return an event.
name: "old task with no recent builds - not returned",
createdOffset: -7 * 24 * time.Hour,
buildOffset: ptr.Ref(-7 * 24 * time.Hour),
expectEvent: false,
},
{
// SQL filter: EXISTS (workspace_builds.created_at > createdAfter).
// This task was created 7 days ago, but has a recent stop build,
// so it should match the filter and return an event.
name: "old task with recent build - returned",
createdOffset: -7 * 24 * time.Hour,
buildOffset: ptr.Ref(-7 * 24 * time.Hour),
extraBuilds: []buildSpec{
{2, -30 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
},
expectEvent: true,
lastPausedOffset: ptr.Ref(-30 * time.Minute),
pauseReason: ptr.Ref("auto"),
pausedDurationMS: ptr.Ref(30 * time.Minute.Milliseconds()), // Ongoing pause.
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
db, _ := dbtestutil.NewDB(t)
org, err := db.GetDefaultOrganization(ctx)
require.NoError(t, err)
user := dbgen.User(t, db, database.User{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
h := &taskTelemetryHelper{
t: t,
ctx: ctx,
db: db,
org: org,
user: user,
}
// Create a deleted task. This is a test antagonist that should never show up in results.
deletedTaskResp := dbfake.WorkspaceBuild(h.t, h.db, database.WorkspaceTable{
OrganizationID: h.org.ID,
OwnerID: h.user.ID,
}).WithTask(database.TaskTable{
Prompt: fmt.Sprintf("deleted-task-%s", t.Name()),
CreatedAt: now.Add(-100 * time.Hour),
}, nil).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStart,
Reason: database.BuildReasonInitiator,
BuildNumber: 1,
CreatedAt: now.Add(-100 * time.Hour),
}).Succeeded().Do()
_, err = db.DeleteTask(h.ctx, database.DeleteTaskParams{
DeletedAt: now.Add(-99 * time.Hour),
ID: deletedTaskResp.Task.ID,
})
require.NoError(h.t, err, "creating deleted task antagonist")
var expectedTask telemetry.Task
if tt.skipWorkspace {
tv := dbgen.TemplateVersion(t, h.db, database.TemplateVersion{
OrganizationID: h.org.ID,
CreatedBy: h.user.ID,
HasAITask: sql.NullBool{Bool: true, Valid: true},
})
task := dbgen.Task(h.t, h.db, database.TaskTable{
OwnerID: h.user.ID,
OrganizationID: h.org.ID,
WorkspaceID: uuid.NullUUID{},
TemplateVersionID: tv.ID,
Prompt: fmt.Sprintf("pending-task-%s", t.Name()),
CreatedAt: now.Add(tt.createdOffset),
})
expectedTask = telemetry.Task{
ID: task.ID.String(),
OrganizationID: h.org.ID.String(),
OwnerID: h.user.ID.String(),
Name: task.Name,
TemplateVersionID: tv.ID.String(),
PromptHash: telemetry.HashContent(task.Prompt),
Status: "pending",
CreatedAt: task.CreatedAt,
}
} else {
buildCreatedAt := now.Add(tt.createdOffset)
if tt.buildOffset != nil {
buildCreatedAt = now.Add(*tt.buildOffset)
}
resp := dbfake.WorkspaceBuild(h.t, h.db, database.WorkspaceTable{
OrganizationID: h.org.ID,
OwnerID: h.user.ID,
}).WithTask(database.TaskTable{
Prompt: fmt.Sprintf("task-%s", t.Name()),
CreatedAt: now.Add(tt.createdOffset),
}, nil).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStart,
Reason: database.BuildReasonInitiator,
BuildNumber: 1,
CreatedAt: buildCreatedAt,
}).Succeeded().Do()
app := getApp(h.ctx, h.db, resp.Agents[0].ID)
for _, s := range tt.appStatuses {
createAppStatus(h.ctx, h.db, resp.Workspace.ID, resp.Agents[0].ID, app.ID, s.state, s.message, now.Add(s.offset))
}
for _, b := range tt.extraBuilds {
bld, bldApp := h.createBuild(resp, b.buildNumber, now.Add(b.offset), b.transition, b.reason)
_ = bld
if bldApp != nil {
for _, s := range b.statuses {
createAppStatus(h.ctx, h.db, resp.Workspace.ID, resp.Agents[0].ID, bldApp.ID, s.state, s.message, now.Add(s.offset))
}
}
}
// Refresh the task
updated, err := h.db.GetTaskByID(ctx, resp.Task.ID)
require.NoError(t, err, "fetching updated task")
expectedTask = telemetry.Task{
ID: updated.ID.String(),
OrganizationID: updated.OrganizationID.String(),
OwnerID: updated.OwnerID.String(),
Name: updated.Name,
WorkspaceID: ptr.Ref(updated.WorkspaceID.UUID.String()),
WorkspaceBuildNumber: ptr.Ref(int64(updated.WorkspaceBuildNumber.Int32)),
WorkspaceAgentID: ptr.Ref(updated.WorkspaceAgentID.UUID.String()),
WorkspaceAppID: ptr.Ref(updated.WorkspaceAppID.UUID.String()),
TemplateVersionID: updated.TemplateVersionID.String(),
PromptHash: telemetry.HashContent(updated.Prompt),
Status: string(updated.Status),
CreatedAt: updated.CreatedAt,
}
}
actualTasks, err := telemetry.CollectTasks(h.ctx, h.db)
require.NoError(t, err, "unexpected error collecting tasks telemetry")
// Invariant: deleted tasks should NEVER appear in results.
require.Len(t, actualTasks, 1, "expected exactly one task")
if diff := cmp.Diff(expectedTask, actualTasks[0]); diff != "" {
t.Fatalf("test case %q: task diff (-want +got):\n%s", tt.name, diff)
}
actualEvents, err := telemetry.CollectTaskEvents(h.ctx, h.db, now.Add(-1*time.Hour), now)
require.NoError(t, err)
if !tt.expectEvent {
require.Empty(t, actualEvents)
} else {
expectedEvent := telemetry.TaskEvent{
TaskID: expectedTask.ID,
}
if tt.lastPausedOffset != nil {
t := now.Add(*tt.lastPausedOffset)
expectedEvent.LastPausedAt = &t
}
if tt.lastResumedOffset != nil {
t := now.Add(*tt.lastResumedOffset)
expectedEvent.LastResumedAt = &t
}
expectedEvent.PauseReason = tt.pauseReason
expectedEvent.ResumeReason = tt.resumeReason
expectedEvent.IdleDurationMS = tt.idleDurationMS
expectedEvent.PausedDurationMS = tt.pausedDurationMS
expectedEvent.ResumeToStatusMS = tt.resumeToStatusMS
expectedEvent.ActiveDurationMS = tt.activeDurationMS
// Each test case creates exactly one workspace with lifecycle
// activity, so we expect exactly one event.
require.Len(t, actualEvents, 1)
actual := actualEvents[0]
if diff := cmp.Diff(expectedEvent, actual); diff != "" {
t.Fatalf("test case %q: event diff (-want +got):\n%s", tt.name, diff)
}
}
})
}
}
type mockDB struct {
database.Store
}
@@ -1350,7 +767,7 @@ func TestRecordTelemetryStatus(t *testing.T) {
require.Nil(t, snapshot1)
}
for range 3 {
for i := 0; i < 3; i++ {
// Whatever happens, subsequent calls should not report if telemetryEnabled didn't change
snapshot2, err := telemetry.RecordTelemetryStatus(ctx, logger, db, testCase.telemetryEnabled)
require.NoError(t, err)
+205 -39
View File
@@ -25,7 +25,6 @@ import (
"tailscale.com/tailcfg"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
@@ -36,11 +35,14 @@ import (
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/prebuilds"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/telemetry"
maputil "github.com/coder/coder/v2/coderd/util/maps"
strutil "github.com/coder/coder/v2/coderd/util/strings"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
@@ -293,7 +295,6 @@ func (api *API) patchWorkspaceAgentLogs(rw http.ResponseWriter, r *http.Request)
// @Param request body agentsdk.PatchAppStatus true "app status"
// @Success 200 {object} codersdk.Response
// @Router /workspaceagents/me/app-status [patch]
// @Deprecated Use UpdateAppStatus on the Agent API instead.
func (api *API) patchWorkspaceAgentAppStatus(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceAgent := httpmw.WorkspaceAgent(r)
@@ -303,6 +304,45 @@ func (api *API) patchWorkspaceAgentAppStatus(rw http.ResponseWriter, r *http.Req
return
}
app, err := api.Database.GetWorkspaceAppByAgentIDAndSlug(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{
AgentID: workspaceAgent.ID,
Slug: req.AppSlug,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace app.",
Detail: fmt.Sprintf("No app found with slug %q", req.AppSlug),
})
return
}
if len(req.Message) > 160 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Message is too long.",
Detail: "Message must be less than 160 characters.",
Validations: []codersdk.ValidationError{
{Field: "message", Detail: "Message must be less than 160 characters."},
},
})
return
}
switch req.State {
case codersdk.WorkspaceAppStatusStateComplete,
codersdk.WorkspaceAppStatusStateFailure,
codersdk.WorkspaceAppStatusStateWorking,
codersdk.WorkspaceAppStatusStateIdle: // valid states
default:
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid state provided.",
Detail: fmt.Sprintf("invalid state: %q", req.State),
Validations: []codersdk.ValidationError{
{Field: "state", Detail: "State must be one of: complete, failure, working."},
},
})
return
}
workspace, err := api.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
@@ -312,50 +352,176 @@ func (api *API) patchWorkspaceAgentAppStatus(rw http.ResponseWriter, r *http.Req
return
}
// This functionality has been moved to the AppsAPI in the agentapi. We keep this HTTP handler around for back
// compatibility with older agents. We'll translate the request into the protobuf so there is only one primary
// implementation.
appAPI := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return workspaceAgent, nil
},
Database: api.Database,
Log: api.Logger,
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
api.publishWorkspaceUpdate(ctx, workspace.OwnerID, wspubsub.WorkspaceEvent{
Kind: kind,
WorkspaceID: workspace.ID,
AgentID: &agent.ID,
})
return nil
},
NotificationsEnqueuer: api.NotificationsEnqueuer,
Clock: api.Clock,
}
protoReq, err := agentsdk.ProtoFromPatchAppStatus(req)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to parse request.",
Detail: err.Error(),
})
return
}
_, err = appAPI.UpdateAppStatus(r.Context(), protoReq)
if err != nil {
sdkErr := new(codersdk.Error)
if xerrors.As(err, &sdkErr) {
httpapi.Write(ctx, rw, sdkErr.StatusCode(), sdkErr.Response)
return
}
// Treat the message as untrusted input.
cleaned := strutil.UISanitize(req.Message)
// Get the latest status for the workspace app to detect no-op updates
// nolint:gocritic // This is a system restricted operation.
latestAppStatus, err := api.Database.GetLatestWorkspaceAppStatusByAppID(dbauthz.AsSystemRestricted(ctx), app.ID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to update app status.",
Message: "Failed to get latest workspace app status.",
Detail: err.Error(),
})
return
}
// If no rows found, latestAppStatus will be a zero-value struct (ID == uuid.Nil)
// nolint:gocritic // This is a system restricted operation.
_, err = api.Database.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
ID: uuid.New(),
CreatedAt: dbtime.Now(),
WorkspaceID: workspace.ID,
AgentID: workspaceAgent.ID,
AppID: app.ID,
State: database.WorkspaceAppStatusState(req.State),
Message: cleaned,
Uri: sql.NullString{
String: req.URI,
Valid: req.URI != "",
},
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to insert workspace app status.",
Detail: err.Error(),
})
return
}
api.publishWorkspaceUpdate(ctx, workspace.OwnerID, wspubsub.WorkspaceEvent{
Kind: wspubsub.WorkspaceEventKindAgentAppStatusUpdate,
WorkspaceID: workspace.ID,
AgentID: &workspaceAgent.ID,
})
// Notify on state change to Working/Idle for AI tasks
api.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, req.State, workspace, workspaceAgent)
// Bump deadline when agent reports working or transitions away from working.
// This prevents auto-pause during active work and gives users time to interact
// after work completes.
shouldBump := false
newState := database.WorkspaceAppStatusState(req.State)
// Bump if reporting working state.
if newState == database.WorkspaceAppStatusStateWorking {
shouldBump = true
}
// Bump if transitioning away from working state.
if latestAppStatus.ID != uuid.Nil {
prevState := latestAppStatus.State
if prevState == database.WorkspaceAppStatusStateWorking {
shouldBump = true
}
}
if shouldBump {
// We pass time.Time{} for nextAutostart since we don't have access to
// TemplateScheduleStore here. The activity bump logic handles this by
// defaulting to the template's activity_bump duration (typically 1 hour).
workspacestats.ActivityBumpWorkspace(ctx, api.Logger, api.Database, workspace.ID, time.Time{})
}
httpapi.Write(ctx, rw, http.StatusOK, nil)
}
// enqueueAITaskStateNotification enqueues a notification when an AI task's app
// transitions to Working or Idle.
// No-op if:
// - the workspace agent app isn't configured as an AI task,
// - the new state equals the latest persisted state,
// - the workspace agent is not ready (still starting up).
func (api *API) enqueueAITaskStateNotification(
ctx context.Context,
appID uuid.UUID,
latestAppStatus database.WorkspaceAppStatus,
newAppStatus codersdk.WorkspaceAppStatusState,
workspace database.Workspace,
agent database.WorkspaceAgent,
) {
// Select notification template based on the new state
var notificationTemplate uuid.UUID
switch newAppStatus {
case codersdk.WorkspaceAppStatusStateWorking:
notificationTemplate = notifications.TemplateTaskWorking
case codersdk.WorkspaceAppStatusStateIdle:
notificationTemplate = notifications.TemplateTaskIdle
case codersdk.WorkspaceAppStatusStateComplete:
notificationTemplate = notifications.TemplateTaskCompleted
case codersdk.WorkspaceAppStatusStateFailure:
notificationTemplate = notifications.TemplateTaskFailed
default:
// Not a notifiable state, do nothing
return
}
if !workspace.TaskID.Valid {
// Workspace has no task ID, do nothing.
return
}
// Only send notifications when the agent is ready. We want to skip
// any state transitions that occur whilst the workspace is starting
// up as it doesn't make sense to receive them.
if agent.LifecycleState != database.WorkspaceAgentLifecycleStateReady {
api.Logger.Debug(ctx, "skipping AI task notification because agent is not ready",
slog.F("agent_id", agent.ID),
slog.F("lifecycle_state", agent.LifecycleState),
slog.F("new_app_status", newAppStatus),
)
return
}
task, err := api.Database.GetTaskByID(ctx, workspace.TaskID.UUID)
if err != nil {
api.Logger.Warn(ctx, "failed to get task", slog.Error(err))
return
}
if !task.WorkspaceAppID.Valid || task.WorkspaceAppID.UUID != appID {
// Non-task app, do nothing.
return
}
// Skip if the latest persisted state equals the new state (no new transition)
// Note: uuid.Nil check is valid here. If no previous status exists,
// GetLatestWorkspaceAppStatusByAppID returns sql.ErrNoRows and we get a zero-value struct.
if latestAppStatus.ID != uuid.Nil && latestAppStatus.State == database.WorkspaceAppStatusState(newAppStatus) {
return
}
// Skip the initial "Working" notification when task first starts.
// This is obvious to the user since they just created the task.
// We still notify on first "Idle" status and all subsequent transitions.
if latestAppStatus.ID == uuid.Nil && newAppStatus == codersdk.WorkspaceAppStatusStateWorking {
return
}
if _, err := api.NotificationsEnqueuer.EnqueueWithData(
// nolint:gocritic // Need notifier actor to enqueue notifications
dbauthz.AsNotifier(ctx),
workspace.OwnerID,
notificationTemplate,
map[string]string{
"task": task.Name,
"workspace": workspace.Name,
},
map[string]any{
// Use a 1-minute bucketed timestamp to bypass per-day dedupe,
// allowing identical content to resend within the same day
// (but not more than once every 10s).
"dedupe_bypass_ts": api.Clock.Now().UTC().Truncate(time.Minute),
},
"api-workspace-agent-app-status",
// Associate this notification with related entities
workspace.ID, workspace.OwnerID, workspace.OrganizationID, appID,
); err != nil {
api.Logger.Warn(ctx, "failed to notify of task state", slog.Error(err))
return
}
}
// workspaceAgentLogs returns the logs associated with a workspace agent
//
// @Summary Get logs by workspace agent
@@ -2045,7 +2211,7 @@ func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.R
// No point in trying to validate the same token over and over again.
if previousToken.OAuthAccessToken == externalAuthLink.OAuthAccessToken &&
previousToken.OAuthRefreshToken == externalAuthLink.OAuthRefreshToken &&
previousToken.OAuthExpiry.Equal(externalAuthLink.OAuthExpiry) {
previousToken.OAuthExpiry == externalAuthLink.OAuthExpiry {
continue
}
+3 -3
View File
@@ -2784,12 +2784,12 @@ func TestWorkspaceAgentExternalAuthListen(t *testing.T) {
const providerID = "fake-idp"
// Count all the times we call validate
var validateCalls atomic.Int32
validateCalls := 0
fake := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithMiddlewares(func(handler http.Handler) http.Handler {
return http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Count all the validate calls
if strings.Contains(r.URL.Path, "/external-auth-validate/") {
validateCalls.Add(1)
validateCalls++
}
handler.ServeHTTP(w, r)
}))
@@ -2852,7 +2852,7 @@ func TestWorkspaceAgentExternalAuthListen(t *testing.T) {
// other should be skipped.
// In a failed test, you will likely see 9, as the last one
// gets canceled.
require.EqualValues(t, 1, validateCalls.Load(), "validate calls duplicated on same token")
require.Equal(t, 1, validateCalls, "validate calls duplicated on same token")
})
}
+1 -4
View File
@@ -19,10 +19,7 @@ var (
appURL = regexp.MustCompile(fmt.Sprintf(
`^(?P<AppSlug>%[1]s)(?:--(?P<AgentName>%[1]s))?--(?P<WorkspaceName>%[1]s)--(?P<Username>%[1]s)$`,
nameRegex))
// PortRegex should not be able to be greater than 65535. In usage though, if a
// user tries to use a greater port, the proxy will just block it and not cause
// any issues. This is a good enough regex check.
PortRegex = regexp.MustCompile(`^\d{4,5}s?$`)
PortRegex = regexp.MustCompile(`^\d{4}s?$`)
validHostnameLabelRegex = regexp.MustCompile(`^[a-z0-9]([-a-z0-9]*[a-z0-9])?$`)
)
+3 -21
View File
@@ -193,16 +193,6 @@ func TestParseSubdomainAppURL(t *testing.T) {
Username: "user",
},
},
{
Name: "Port(5)--Agent--Workspace--User",
Subdomain: "12412--agent--workspace--user",
Expected: appurl.ApplicationURL{
AppSlugOrPort: "12412",
AgentName: "agent",
WorkspaceName: "workspace",
Username: "user",
},
},
{
Name: "Port--Agent--Workspace--User",
Subdomain: "8080s--agent--workspace--user",
@@ -235,11 +225,11 @@ func TestParseSubdomainAppURL(t *testing.T) {
},
},
{
Name: "5DigitPort--agent--Workspace--User",
Subdomain: "30000--agent--workspace--user",
Name: "5DigitAppSlug--Workspace--User",
Subdomain: "30000--workspace--user",
Expected: appurl.ApplicationURL{
AppSlugOrPort: "30000",
AgentName: "agent",
AgentName: "",
WorkspaceName: "workspace",
Username: "user",
},
@@ -609,14 +599,6 @@ func TestURLGenerationVsParsing(t *testing.T) {
Name: "5DigitAppSlug_AgentOmittedInParsing",
AppSlugOrPort: "30000",
AgentName: "agent",
ExpectedParsed: "agent",
},
{
// 6 digits is not a valid port, so it is treated as an app slug.
// App slugs do not require the agent name, so it is dropped
Name: "6DigitAppSlug_AgentOmittedInParsing",
AppSlugOrPort: "300000",
AgentName: "agent",
ExpectedParsed: "",
},
}
+17 -9
View File
@@ -856,24 +856,32 @@ func (api *API) workspaceBuildLogs(rw http.ResponseWriter, r *http.Request) {
func (api *API) workspaceBuildState(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceBuild := httpmw.WorkspaceBuildParam(r)
// The dbauthz layer enforces policy.ActionUpdate on the template.
row, err := api.Database.GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuild.ID)
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
workspace, err := api.Database.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner state.",
Message: "No workspace exists for this job.",
})
return
}
template, err := api.Database.GetTemplateByID(ctx, workspace.TemplateID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get template",
Detail: err.Error(),
})
return
}
// You must have update permissions on the template to get the state.
// This matches a push!
if !api.Authorize(r, policy.ActionUpdate, template.RBACObject()) {
httpapi.ResourceNotFound(rw)
return
}
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write(row.ProvisionerState)
_, _ = rw.Write(workspaceBuild.ProvisionerState)
}
// @Summary Update workspace build state
+17 -2
View File
@@ -114,6 +114,7 @@ func (api *API) workspace(rw http.ResponseWriter, r *http.Request) {
w, err := convertWorkspace(
ctx,
api.Experiments,
api.Logger,
apiKey.UserID,
workspace,
@@ -239,6 +240,7 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) {
wss, err := convertWorkspaces(
ctx,
api.Experiments,
api.Logger,
apiKey.UserID,
workspaces,
@@ -334,6 +336,7 @@ func (api *API) workspaceByOwnerAndName(rw http.ResponseWriter, r *http.Request)
w, err := convertWorkspace(
ctx,
api.Experiments,
api.Logger,
apiKey.UserID,
workspace,
@@ -865,6 +868,7 @@ func createWorkspace(
w, err := convertWorkspace(
ctx,
api.Experiments,
api.Logger,
initiatorID,
workspace,
@@ -1510,6 +1514,7 @@ func (api *API) putWorkspaceDormant(rw http.ResponseWriter, r *http.Request) {
w, err := convertWorkspace(
ctx,
api.Experiments,
api.Logger,
apiKey.UserID,
workspace,
@@ -2089,6 +2094,7 @@ func (api *API) watchWorkspace(
}
w, err := convertWorkspace(
ctx,
api.Experiments,
api.Logger,
apiKey.UserID,
workspace,
@@ -2230,7 +2236,8 @@ func (api *API) workspaceACL(rw http.ResponseWriter, r *http.Request) {
// the case here. This data goes directly to an unauthorized user. We are
// just straight up breaking security promises.
//
// TODO: This needs to be fixed before GA. Currently in beta.
// Fine for now while behind the shared-workspaces experiment, but needs to
// be fixed before GA.
// Fetch all of the users and their organization memberships
userIDs := make([]uuid.UUID, 0, len(workspaceACL.Users))
@@ -2588,6 +2595,7 @@ func (api *API) workspaceData(ctx context.Context, workspaces []database.Workspa
func convertWorkspaces(
ctx context.Context,
experiments codersdk.Experiments,
logger slog.Logger,
requesterID uuid.UUID,
workspaces []database.Workspace,
@@ -2625,6 +2633,7 @@ func convertWorkspaces(
w, err := convertWorkspace(
ctx,
experiments,
logger,
requesterID,
workspace,
@@ -2644,6 +2653,7 @@ func convertWorkspaces(
func convertWorkspace(
ctx context.Context,
experiments codersdk.Experiments,
logger slog.Logger,
requesterID uuid.UUID,
workspace database.Workspace,
@@ -2742,15 +2752,20 @@ func convertWorkspace(
NextStartAt: nextStartAt,
IsPrebuild: workspace.IsPrebuild(),
TaskID: workspace.TaskID,
SharedWith: sharedWorkspaceActors(ctx, logger, workspace),
SharedWith: sharedWorkspaceActors(ctx, experiments, logger, workspace),
}, nil
}
func sharedWorkspaceActors(
ctx context.Context,
experiments codersdk.Experiments,
logger slog.Logger,
workspace database.Workspace,
) []codersdk.SharedWorkspaceActor {
if !experiments.Enabled(codersdk.ExperimentWorkspaceSharing) {
return nil
}
out := make([]codersdk.SharedWorkspaceActor, 0, len(workspace.UserACL)+len(workspace.GroupACL))
// Users
+25 -7
View File
@@ -1899,6 +1899,7 @@ func TestWorkspaceFilter(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
var (
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
@@ -1936,6 +1937,7 @@ func TestWorkspaceFilter(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
var (
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
@@ -1973,6 +1975,7 @@ func TestWorkspaceFilter(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
var (
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
@@ -2010,6 +2013,7 @@ func TestWorkspaceFilter(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
var (
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
@@ -5245,7 +5249,7 @@ func TestUpdateWorkspaceACL(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
adminClient := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
DeploymentValues: dv,
@@ -5281,7 +5285,7 @@ func TestUpdateWorkspaceACL(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
adminClient := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
DeploymentValues: dv,
@@ -5314,7 +5318,7 @@ func TestUpdateWorkspaceACL(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
adminClient := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
DeploymentValues: dv,
@@ -5354,7 +5358,7 @@ func TestUpdateWorkspaceACL(t *testing.T) {
t.Cleanup(func() { rbac.SetWorkspaceACLDisabled(prevWorkspaceACLDisabled) })
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
adminClient := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
DeploymentValues: dv,
@@ -5422,7 +5426,11 @@ func TestDeleteWorkspaceACL(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
admin = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
_, toShareWithUser = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
@@ -5453,7 +5461,11 @@ func TestDeleteWorkspaceACL(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
admin = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
sharedUseClient, toShareWithUser = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
@@ -5492,7 +5504,11 @@ func TestWorkspaceReadCanListACL(t *testing.T) {
t.Cleanup(func() { rbac.SetWorkspaceACLDisabled(prevWorkspaceACLDisabled) })
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
admin = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
sharedUserClientA, sharedUserA = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
@@ -5542,6 +5558,7 @@ func TestWorkspaceSharingDisabled(t *testing.T) {
var (
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
// DisableWorkspaceSharing is false (default)
}),
})
@@ -5575,6 +5592,7 @@ func TestWorkspaceSharingDisabled(t *testing.T) {
var (
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
dv.DisableWorkspaceSharing = true
}),
})
+4 -9
View File
@@ -452,7 +452,7 @@ func (b *Builder) buildTx(authFunc func(action policy.Action, object rbac.Object
// to read all provisioner daemons. We need to retrieve the eligible
// provisioner daemons for this job to show in the UI if there is no
// matching provisioner daemon.
provisionerDaemons, err := b.store.GetEligibleProvisionerDaemonsByProvisionerJobIDs(dbauthz.AsWorkspaceBuilder(b.ctx), []uuid.UUID{provisionerJob.ID})
provisionerDaemons, err := b.store.GetEligibleProvisionerDaemonsByProvisionerJobIDs(dbauthz.AsSystemReadProvisionerDaemons(b.ctx), []uuid.UUID{provisionerJob.ID})
if err != nil {
// NOTE: we do **not** want to fail a workspace build if we fail to
// retrieve provisioner daemons. This is just to show in the UI if there
@@ -570,8 +570,8 @@ func (b *Builder) buildTx(authFunc func(action policy.Action, object rbac.Object
}
}
if b.state.orphan && !hasActiveEligibleProvisioner {
// nolint: gocritic // User won't necessarily have the permission to do this so we act as a system user.
if err := store.UpdateProvisionerJobWithCompleteWithStartedAtByID(dbauthz.AsWorkspaceBuilder(b.ctx), database.UpdateProvisionerJobWithCompleteWithStartedAtByIDParams{
// nolint: gocritic // At this moment, we are pretending to be provisionerd.
if err := store.UpdateProvisionerJobWithCompleteWithStartedAtByID(dbauthz.AsProvisionerd(b.ctx), database.UpdateProvisionerJobWithCompleteWithStartedAtByIDParams{
CompletedAt: sql.NullTime{Valid: true, Time: now},
Error: sql.NullString{Valid: false},
ErrorCode: sql.NullString{Valid: false},
@@ -815,12 +815,7 @@ func (b *Builder) getState() ([]byte, error) {
if err != nil {
return nil, xerrors.Errorf("get last build to get state: %w", err)
}
// nolint: gocritic // Workspace builder needs to read provisioner state for the new build.
state, err := b.store.GetWorkspaceBuildProvisionerStateByID(dbauthz.AsWorkspaceBuilder(b.ctx), bld.ID)
if err != nil {
return nil, xerrors.Errorf("get workspace build provisioner state: %w", err)
}
return state.ProvisionerState, nil
return bld.ProvisionerState, nil
}
func (b *Builder) getParameters() (names, values []string, err error) {
+1 -20
View File
@@ -65,7 +65,6 @@ func TestBuilder_NoOptions(t *testing.T) {
withTemplate,
withInactiveVersion(nil),
withLastBuildFound,
withLastBuildState,
withTemplateVersionVariables(inactiveVersionID, nil),
withRichParameters(nil),
withParameterSchemas(inactiveJobID, nil),
@@ -125,7 +124,6 @@ func TestBuilder_Initiator(t *testing.T) {
withTemplate,
withInactiveVersion(nil),
withLastBuildFound,
withLastBuildState,
withTemplateVersionVariables(inactiveVersionID, nil),
withRichParameters(nil),
withParameterSchemas(inactiveJobID, nil),
@@ -176,7 +174,6 @@ func TestBuilder_Baggage(t *testing.T) {
withTemplate,
withInactiveVersion(nil),
withLastBuildFound,
withLastBuildState,
withTemplateVersionVariables(inactiveVersionID, nil),
withRichParameters(nil),
withParameterSchemas(inactiveJobID, nil),
@@ -219,7 +216,6 @@ func TestBuilder_Reason(t *testing.T) {
withTemplate,
withInactiveVersion(nil),
withLastBuildFound,
withLastBuildState,
withTemplateVersionVariables(inactiveVersionID, nil),
withRichParameters(nil),
withParameterSchemas(inactiveJobID, nil),
@@ -369,7 +365,6 @@ func TestWorkspaceBuildWithTags(t *testing.T) {
withTemplate,
withInactiveVersion(richParameters),
withLastBuildFound,
withLastBuildState,
withTemplateVersionVariables(inactiveVersionID, templateVersionVariables),
withRichParameters(nil),
withParameterSchemas(inactiveJobID, nil),
@@ -469,7 +464,6 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) {
withTemplate,
withInactiveVersion(richParameters),
withLastBuildFound,
withLastBuildState,
withTemplateVersionVariables(inactiveVersionID, nil),
withRichParameters(initialBuildParameters),
withParameterSchemas(inactiveJobID, nil),
@@ -521,7 +515,6 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) {
withTemplate,
withInactiveVersion(richParameters),
withLastBuildFound,
withLastBuildState,
withTemplateVersionVariables(inactiveVersionID, nil),
withRichParameters(initialBuildParameters),
withParameterSchemas(inactiveJobID, nil),
@@ -668,7 +661,6 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) {
withTemplate,
withActiveVersion(version2params),
withLastBuildFound,
withLastBuildState,
withTemplateVersionVariables(activeVersionID, nil),
withRichParameters(initialBuildParameters),
withParameterSchemas(activeJobID, nil),
@@ -735,7 +727,6 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) {
withTemplate,
withActiveVersion(version2params),
withLastBuildFound,
withLastBuildState,
withTemplateVersionVariables(activeVersionID, nil),
withRichParameters(initialBuildParameters),
withParameterSchemas(activeJobID, nil),
@@ -800,7 +791,6 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) {
withTemplate,
withActiveVersion(version2params),
withLastBuildFound,
withLastBuildState,
withTemplateVersionVariables(activeVersionID, nil),
withRichParameters(initialBuildParameters),
withParameterSchemas(activeJobID, nil),
@@ -1072,7 +1062,6 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) {
withTemplate,
withInactiveVersion(nil),
withLastBuildFound,
withLastBuildState,
withTemplateVersionVariables(inactiveVersionID, nil),
withRichParameters(nil),
withParameterSchemas(inactiveJobID, nil),
@@ -1186,7 +1175,6 @@ func TestWorkspaceBuildWithTask(t *testing.T) {
withTemplate,
withInactiveVersion(nil),
withLastBuildFound,
withLastBuildState,
withTemplateVersionVariables(inactiveVersionID, nil),
withRichParameters(nil),
withParameterSchemas(inactiveJobID, nil),
@@ -1390,6 +1378,7 @@ func withLastBuildFound(mTx *dbmock.MockStore) {
Transition: database.WorkspaceTransitionStart,
InitiatorID: userID,
JobID: lastBuildJobID,
ProvisionerState: []byte("last build state"),
Reason: database.BuildReasonInitiator,
}, nil)
@@ -1409,14 +1398,6 @@ func withLastBuildFound(mTx *dbmock.MockStore) {
}, nil)
}
func withLastBuildState(mTx *dbmock.MockStore) {
mTx.EXPECT().GetWorkspaceBuildProvisionerStateByID(gomock.Any(), lastBuildID).
Times(1).
Return(database.GetWorkspaceBuildProvisionerStateByIDRow{
ProvisionerState: []byte("last build state"),
}, nil)
}
func withLastBuildNotFound(mTx *dbmock.MockStore) {
mTx.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).
Times(1).
+19 -4
View File
@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"net/url"
"sync"
"time"
@@ -320,15 +321,21 @@ func (c *Client) connectRPCVersion(ctx context.Context, version *apiversion.APIV
}
rpcURL.RawQuery = q.Encode()
jar, err := cookiejar.New(nil)
if err != nil {
return nil, xerrors.Errorf("create cookie jar: %w", err)
}
jar.SetCookies(rpcURL, []*http.Cookie{{
Name: codersdk.SessionTokenCookie,
Value: c.SDK.SessionToken(),
}})
httpClient := &http.Client{
Jar: jar,
Transport: c.SDK.HTTPClient.Transport,
}
// nolint:bodyclose
conn, res, err := websocket.Dial(ctx, rpcURL.String(), &websocket.DialOptions{
HTTPClient: httpClient,
HTTPHeader: http.Header{
codersdk.SessionTokenHeader: []string{c.SDK.SessionToken()},
},
})
if err != nil {
if res == nil {
@@ -702,7 +709,16 @@ func (c *Client) WaitForReinit(ctx context.Context) (*ReinitializationEvent, err
return nil, xerrors.Errorf("parse url: %w", err)
}
jar, err := cookiejar.New(nil)
if err != nil {
return nil, xerrors.Errorf("create cookie jar: %w", err)
}
jar.SetCookies(rpcURL, []*http.Cookie{{
Name: codersdk.SessionTokenCookie,
Value: c.SDK.SessionToken(),
}})
httpClient := &http.Client{
Jar: jar,
Transport: c.SDK.HTTPClient.Transport,
}
@@ -710,7 +726,6 @@ func (c *Client) WaitForReinit(ctx context.Context) (*ReinitializationEvent, err
if err != nil {
return nil, xerrors.Errorf("build request: %w", err)
}
req.Header[codersdk.SessionTokenHeader] = []string{c.SDK.SessionToken()}
res, err := httpClient.Do(req)
if err != nil {
-13
View File
@@ -464,16 +464,3 @@ func ProtoFromDevcontainer(dc codersdk.WorkspaceAgentDevcontainer) *proto.Worksp
SubagentId: subagentID,
}
}
func ProtoFromPatchAppStatus(pas PatchAppStatus) (*proto.UpdateAppStatusRequest, error) {
state, ok := proto.UpdateAppStatusRequest_AppStatusState_value[strings.ToUpper(string(pas.State))]
if !ok {
return nil, xerrors.Errorf("Invalid state: %s", pas.State)
}
return &proto.UpdateAppStatusRequest{
Slug: pas.AppSlug,
State: proto.UpdateAppStatusRequest_AppStatusState(state),
Message: pas.Message,
Uri: pas.URI,
}, nil
}
+1 -3
View File
@@ -94,8 +94,7 @@ func (c *Client) CreateAPIKey(ctx context.Context, user string) (GenerateAPIKeyR
}
type TokensFilter struct {
IncludeAll bool `json:"include_all"`
IncludeExpired bool `json:"include_expired"`
IncludeAll bool `json:"include_all"`
}
type APIKeyWithOwner struct {
@@ -113,7 +112,6 @@ func (f TokensFilter) asRequestOption() RequestOption {
return func(r *http.Request) {
q := r.URL.Query()
q.Set("include_all", fmt.Sprintf("%t", f.IncludeAll))
q.Set("include_expired", fmt.Sprintf("%t", f.IncludeExpired))
r.URL.RawQuery = q.Encode()
}
}
-8
View File
@@ -535,14 +535,6 @@ func NewTestError(statusCode int, method string, u string) *Error {
}
}
// NewError creates a new Error with the response and status code.
func NewError(statusCode int, response Response) *Error {
return &Error{
statusCode: statusCode,
Response: response,
}
}
type closeFunc func() error
func (c closeFunc) Close() error {
+5 -1
View File
@@ -3042,7 +3042,7 @@ func (c *DeploymentValues) Options() serpent.OptionSet {
},
{
Name: "Disable Workspace Sharing",
Description: `Disable workspace sharing. Workspace ACL checking is disabled and only owners can have ssh, apps and terminal access to workspaces. Access based on the 'owner' role is also allowed unless disabled via --disable-owner-workspace-access.`,
Description: `Disable workspace sharing (requires the "workspace-sharing" experiment to be enabled). Workspace ACL checking is disabled and only owners can have ssh, apps and terminal access to workspaces. Access based on the 'owner' role is also allowed unless disabled via --disable-owner-workspace-access.`,
Flag: "disable-workspace-sharing",
Env: "CODER_DISABLE_WORKSPACE_SHARING",
@@ -4265,6 +4265,7 @@ const (
ExperimentWebPush Experiment = "web-push" // Enables web push notifications through the browser.
ExperimentOAuth2 Experiment = "oauth2" // Enables OAuth2 provider functionality.
ExperimentMCPServerHTTP Experiment = "mcp-server-http" // Enables the MCP HTTP server functionality.
ExperimentWorkspaceSharing Experiment = "workspace-sharing" // Enables updating workspace ACLs for sharing with users and groups.
)
func (e Experiment) DisplayName() string {
@@ -4283,6 +4284,8 @@ func (e Experiment) DisplayName() string {
return "OAuth2 Provider Functionality"
case ExperimentMCPServerHTTP:
return "MCP HTTP Server Functionality"
case ExperimentWorkspaceSharing:
return "Workspace Sharing"
default:
// Split on hyphen and convert to title case
// e.g. "web-push" -> "Web Push", "mcp-server-http" -> "Mcp Server Http"
@@ -4300,6 +4303,7 @@ var ExperimentsKnown = Experiments{
ExperimentWebPush,
ExperimentOAuth2,
ExperimentMCPServerHTTP,
ExperimentWorkspaceSharing,
}
// ExperimentsSafe should include all experiments that are safe for
+5 -1
View File
@@ -217,7 +217,11 @@ const (
)
func (e OAuth2ProviderResponseType) Valid() bool {
return e == OAuth2ProviderResponseTypeCode || e == OAuth2ProviderResponseTypeToken
switch e {
case OAuth2ProviderResponseTypeCode, OAuth2ProviderResponseTypeToken:
return true
}
return false
}
type OAuth2TokenEndpointAuthMethod string
+2 -3
View File
@@ -265,7 +265,7 @@ Bad Tasks
Use the "state" field to indicate your progress. Periodically report
progress with state "working" to keep the user updated. It is not possible to send too many updates!
ONLY report a "complete", "idle", or "failure" state if you have FULLY completed the task.
ONLY report an "idle" or "failure" state if you have FULLY completed the task.
`,
Schema: aisdk.Schema{
Properties: map[string]any{
@@ -279,10 +279,9 @@ ONLY report a "complete", "idle", or "failure" state if you have FULLY completed
},
"state": map[string]any{
"type": "string",
"description": "The state of your task. This can be one of the following: working, complete, idle, or failure. Select the state that best represents your current progress.",
"description": "The state of your task. This can be one of the following: working, idle, or failure. Select the state that best represents your current progress.",
"enum": []string{
string(codersdk.WorkspaceAppStatusStateWorking),
string(codersdk.WorkspaceAppStatusStateComplete),
string(codersdk.WorkspaceAppStatusStateIdle),
string(codersdk.WorkspaceAppStatusStateFailure),
},
+4 -21
View File
@@ -128,16 +128,9 @@ If client authentication fails, the token endpoint returns **HTTP 401** with an
"$CODER_URL/api/v2/users/me"
```
> [!NOTE]
> The PKCE flow below is the **required** integration path. The example
> above is shown for reference but omits the mandatory `code_challenge`
> parameter. See [PKCE Flow](#pkce-flow-required) for the complete flow.
### PKCE Flow (Recommended)
### PKCE Flow (Required)
PKCE is **required** for all OAuth2 authorization code flows. Coder enforces
PKCE in compliance with the OAuth 2.1 specification. Both public and
confidential clients must include PKCE parameters:
Use PKCE for enhanced security (recommended for both public and confidential clients):
1. Generate a code verifier and challenge:
@@ -258,8 +251,7 @@ Verify that the `code_verifier` used in the token request matches the one used t
## Security Considerations
- **Use HTTPS**: Always use HTTPS in production to protect tokens in transit
- **Implement PKCE**: PKCE is mandatory for all authorization code clients
(public and confidential)
- **Implement PKCE**: Use PKCE for all public clients (mobile apps, SPAs)
- **Validate redirect URLs**: Only register trusted redirect URIs for your applications
- **Rotate secrets**: Periodically rotate client secrets using the management API
@@ -269,20 +261,11 @@ As an experimental feature, the current implementation has limitations:
- No scope system - all tokens have full API access
- No client credentials grant support
- Implicit grant (`response_type=token`) is not supported; OAuth 2.1
deprecated this flow due to token leakage risks, and requests return
`unsupported_response_type`
- Limited to opaque access tokens (no JWT support)
## Standards Compliance
This implementation follows established OAuth2 standards including
[RFC 6749](https://datatracker.ietf.org/doc/html/rfc6749) (OAuth2 core),
[RFC 7636](https://datatracker.ietf.org/doc/html/rfc7636) (PKCE), and the
[OAuth 2.1 draft](https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12).
Coder enforces OAuth 2.1 requirements including mandatory PKCE for all
authorization code grants, exact redirect URI string matching, rejection
of the implicit grant, and CSRF protections on consent pages.
This implementation follows established OAuth2 standards including [RFC 6749](https://datatracker.ietf.org/doc/html/rfc6749) (OAuth2 core), [RFC 7636](https://datatracker.ietf.org/doc/html/rfc7636) (PKCE), and related specifications for discovery and client registration.
## Next Steps
+3 -3
View File
@@ -175,9 +175,9 @@ deployment. They will always be available from the agent.
| `coderd_dbpurge_iteration_duration_seconds` | histogram | Duration of each dbpurge iteration in seconds. | `success` |
| `coderd_dbpurge_records_purged_total` | counter | Total number of records purged by type. | `record_type` |
| `coderd_experiments` | gauge | Indicates whether each experiment is enabled (1) or not (0) | `experiment` |
| `coderd_insights_applications_usage_seconds` | gauge | The application usage per template. | `application_name` `organization_name` `slug` `template_name` |
| `coderd_insights_parameters` | gauge | The parameter usage per template. | `organization_name` `parameter_name` `parameter_type` `parameter_value` `template_name` |
| `coderd_insights_templates_active_users` | gauge | The number of active users of the template. | `organization_name` `template_name` |
| `coderd_insights_applications_usage_seconds` | gauge | The application usage per template. | `application_name` `slug` `template_name` |
| `coderd_insights_parameters` | gauge | The parameter usage per template. | `parameter_name` `parameter_type` `parameter_value` `template_name` |
| `coderd_insights_templates_active_users` | gauge | The number of active users of the template. | `template_name` |
| `coderd_license_active_users` | gauge | The number of active users. | |
| `coderd_license_errors` | gauge | The number of active license errors. | |
| `coderd_license_limit_users` | gauge | The user seats limit based on the active Coder license. | |

Some files were not shown because too many files have changed in this diff Show More