Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| aefc75133a | |||
| b9181c3934 | |||
| a90471db53 | |||
| cb71f5e789 | |||
| f50707bc3e |
@@ -23,7 +23,7 @@ jobs:
|
||||
egress-policy: audit
|
||||
|
||||
- name: stale
|
||||
uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
|
||||
uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
|
||||
with:
|
||||
stale-issue-label: "stale"
|
||||
stale-pr-label: "stale"
|
||||
|
||||
@@ -235,10 +235,6 @@ type FakeAgentAPI struct {
|
||||
pushResourcesMonitoringUsageFunc func(*agentproto.PushResourcesMonitoringUsageRequest) (*agentproto.PushResourcesMonitoringUsageResponse, error)
|
||||
}
|
||||
|
||||
func (*FakeAgentAPI) UpdateAppStatus(context.Context, *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (f *FakeAgentAPI) GetManifest(context.Context, *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
|
||||
return f.manifest, nil
|
||||
}
|
||||
|
||||
+330
-544
File diff suppressed because it is too large
Load Diff
+1
-20
@@ -436,7 +436,7 @@ message CreateSubAgentRequest {
|
||||
}
|
||||
|
||||
repeated DisplayApp display_apps = 6;
|
||||
|
||||
|
||||
optional bytes id = 7;
|
||||
}
|
||||
|
||||
@@ -494,24 +494,6 @@ message ReportBoundaryLogsRequest {
|
||||
|
||||
message ReportBoundaryLogsResponse {}
|
||||
|
||||
// UpdateAppStatusRequest updates the given Workspace App's status. c.f. agentsdk.PatchAppStatus
|
||||
message UpdateAppStatusRequest {
|
||||
string slug = 1;
|
||||
|
||||
enum AppStatusState {
|
||||
WORKING = 0;
|
||||
IDLE = 1;
|
||||
COMPLETE = 2;
|
||||
FAILURE = 3;
|
||||
}
|
||||
AppStatusState state = 2;
|
||||
|
||||
string message = 3;
|
||||
string uri = 4;
|
||||
}
|
||||
|
||||
message UpdateAppStatusResponse {}
|
||||
|
||||
service Agent {
|
||||
rpc GetManifest(GetManifestRequest) returns (Manifest);
|
||||
rpc GetServiceBanner(GetServiceBannerRequest) returns (ServiceBanner);
|
||||
@@ -530,5 +512,4 @@ service Agent {
|
||||
rpc DeleteSubAgent(DeleteSubAgentRequest) returns (DeleteSubAgentResponse);
|
||||
rpc ListSubAgents(ListSubAgentsRequest) returns (ListSubAgentsResponse);
|
||||
rpc ReportBoundaryLogs(ReportBoundaryLogsRequest) returns (ReportBoundaryLogsResponse);
|
||||
rpc UpdateAppStatus(UpdateAppStatusRequest) returns (UpdateAppStatusResponse);
|
||||
}
|
||||
|
||||
@@ -56,7 +56,6 @@ type DRPCAgentClient interface {
|
||||
DeleteSubAgent(ctx context.Context, in *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
|
||||
ListSubAgents(ctx context.Context, in *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
|
||||
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
|
||||
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
|
||||
}
|
||||
|
||||
type drpcAgentClient struct {
|
||||
@@ -222,15 +221,6 @@ func (c *drpcAgentClient) ReportBoundaryLogs(ctx context.Context, in *ReportBoun
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentClient) UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
|
||||
out := new(UpdateAppStatusResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
type DRPCAgentServer interface {
|
||||
GetManifest(context.Context, *GetManifestRequest) (*Manifest, error)
|
||||
GetServiceBanner(context.Context, *GetServiceBannerRequest) (*ServiceBanner, error)
|
||||
@@ -249,7 +239,6 @@ type DRPCAgentServer interface {
|
||||
DeleteSubAgent(context.Context, *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
|
||||
ListSubAgents(context.Context, *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
|
||||
ReportBoundaryLogs(context.Context, *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
|
||||
UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
|
||||
}
|
||||
|
||||
type DRPCAgentUnimplementedServer struct{}
|
||||
@@ -322,13 +311,9 @@ func (s *DRPCAgentUnimplementedServer) ReportBoundaryLogs(context.Context, *Repo
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentUnimplementedServer) UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
type DRPCAgentDescription struct{}
|
||||
|
||||
func (DRPCAgentDescription) NumMethods() int { return 18 }
|
||||
func (DRPCAgentDescription) NumMethods() int { return 17 }
|
||||
|
||||
func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
|
||||
switch n {
|
||||
@@ -485,15 +470,6 @@ func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver,
|
||||
in1.(*ReportBoundaryLogsRequest),
|
||||
)
|
||||
}, DRPCAgentServer.ReportBoundaryLogs, true
|
||||
case 17:
|
||||
return "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentServer).
|
||||
UpdateAppStatus(
|
||||
ctx,
|
||||
in1.(*UpdateAppStatusRequest),
|
||||
)
|
||||
}, DRPCAgentServer.UpdateAppStatus, true
|
||||
default:
|
||||
return "", nil, nil, nil, false
|
||||
}
|
||||
@@ -774,19 +750,3 @@ func (x *drpcAgent_ReportBoundaryLogsStream) SendAndClose(m *ReportBoundaryLogsR
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgent_UpdateAppStatusStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*UpdateAppStatusResponse) error
|
||||
}
|
||||
|
||||
type drpcAgent_UpdateAppStatusStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgent_UpdateAppStatusStream) SendAndClose(m *UpdateAppStatusResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_proto_agent_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
@@ -73,13 +73,9 @@ type DRPCAgentClient27 interface {
|
||||
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
|
||||
}
|
||||
|
||||
// DRPCAgentClient28 is the Agent API at v2.8. It adds
|
||||
// - a SubagentId field to the WorkspaceAgentDevcontainer message
|
||||
// - an Id field to the CreateSubAgentRequest message.
|
||||
// - UpdateAppStatus RPC.
|
||||
//
|
||||
// Compatible with Coder v2.31+
|
||||
// DRPCAgentClient28 is the Agent API at v2.8. It adds a SubagentId field to the
|
||||
// WorkspaceAgentDevcontainer message, and a Id field to the CreateSubAgentRequest
|
||||
// message. Compatible with Coder v2.31+
|
||||
type DRPCAgentClient28 interface {
|
||||
DRPCAgentClient27
|
||||
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
|
||||
}
|
||||
|
||||
+45
-50
@@ -10,7 +10,6 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
@@ -24,7 +23,6 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/codersdk/toolsdk"
|
||||
"github.com/coder/retry"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
@@ -541,6 +539,7 @@ func (r *RootCmd) mcpServer() *serpent.Command {
|
||||
defer cancel()
|
||||
defer srv.queue.Close()
|
||||
|
||||
cliui.Infof(inv.Stderr, "Failed to watch screen events")
|
||||
// Start the reporter, watcher, and server. These are all tied to the
|
||||
// lifetime of the MCP server, which is itself tied to the lifetime of the
|
||||
// AI agent.
|
||||
@@ -614,51 +613,48 @@ func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation)
|
||||
}
|
||||
|
||||
func (s *mcpServer) startWatcher(ctx context.Context, inv *serpent.Invocation) {
|
||||
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
for retrier := retry.New(time.Second, 30*time.Second); retrier.Wait(ctx); {
|
||||
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
|
||||
if err == nil {
|
||||
retrier.Reset()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case event := <-eventsCh:
|
||||
switch ev := event.(type) {
|
||||
case agentapi.EventStatusChange:
|
||||
// If the screen is stable, report idle.
|
||||
state := codersdk.WorkspaceAppStatusStateWorking
|
||||
if ev.Status == agentapi.StatusStable {
|
||||
state = codersdk.WorkspaceAppStatusStateIdle
|
||||
}
|
||||
err := s.queue.Push(taskReport{
|
||||
state: state,
|
||||
})
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
|
||||
return
|
||||
case event := <-eventsCh:
|
||||
switch ev := event.(type) {
|
||||
case agentapi.EventStatusChange:
|
||||
state := codersdk.WorkspaceAppStatusStateWorking
|
||||
if ev.Status == agentapi.StatusStable {
|
||||
state = codersdk.WorkspaceAppStatusStateIdle
|
||||
}
|
||||
err := s.queue.Push(taskReport{
|
||||
state: state,
|
||||
})
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
|
||||
return
|
||||
}
|
||||
case agentapi.EventMessageUpdate:
|
||||
if ev.Role == agentapi.RoleUser {
|
||||
err := s.queue.Push(taskReport{
|
||||
messageID: &ev.Id,
|
||||
state: codersdk.WorkspaceAppStatusStateWorking,
|
||||
})
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
case agentapi.EventMessageUpdate:
|
||||
if ev.Role == agentapi.RoleUser {
|
||||
err := s.queue.Push(taskReport{
|
||||
messageID: &ev.Id,
|
||||
state: codersdk.WorkspaceAppStatusStateWorking,
|
||||
})
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
|
||||
return
|
||||
}
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
|
||||
}
|
||||
break loop
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -696,14 +692,13 @@ func (s *mcpServer) startServer(ctx context.Context, inv *serpent.Invocation, in
|
||||
// Add tool dependencies.
|
||||
toolOpts := []func(*toolsdk.Deps){
|
||||
toolsdk.WithTaskReporter(func(args toolsdk.ReportTaskArgs) error {
|
||||
state := codersdk.WorkspaceAppStatusState(args.State)
|
||||
// The agent does not reliably report idle, so when AgentAPI is
|
||||
// enabled we override idle to working and let the screen watcher
|
||||
// detect the real idle via StatusStable. Final states (failure,
|
||||
// complete) are trusted from the agent since the screen watcher
|
||||
// cannot produce them.
|
||||
if s.aiAgentAPIClient != nil && state == codersdk.WorkspaceAppStatusStateIdle {
|
||||
state = codersdk.WorkspaceAppStatusStateWorking
|
||||
// The agent does not reliably report its status correctly. If AgentAPI
|
||||
// is enabled, we will always set the status to "working" when we get an
|
||||
// MCP message, and rely on the screen watcher to eventually catch the
|
||||
// idle state.
|
||||
state := codersdk.WorkspaceAppStatusStateWorking
|
||||
if s.aiAgentAPIClient == nil {
|
||||
state = codersdk.WorkspaceAppStatusState(args.State)
|
||||
}
|
||||
return s.queue.Push(taskReport{
|
||||
link: args.Link,
|
||||
|
||||
+1
-185
@@ -921,7 +921,7 @@ func TestExpMcpReporter(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
// We override idle from the agent to working, but trust final states.
|
||||
// We ignore the state from the agent and assume "working".
|
||||
{
|
||||
name: "IgnoreAgentState",
|
||||
// AI agent reports that it is finished but the summary says it is doing
|
||||
@@ -953,46 +953,6 @@ func TestExpMcpReporter(t *testing.T) {
|
||||
Message: "finished",
|
||||
},
|
||||
},
|
||||
// Agent reports failure; trusted even with AgentAPI enabled.
|
||||
{
|
||||
state: codersdk.WorkspaceAppStatusStateFailure,
|
||||
summary: "something broke",
|
||||
expected: &codersdk.WorkspaceAppStatus{
|
||||
State: codersdk.WorkspaceAppStatusStateFailure,
|
||||
Message: "something broke",
|
||||
},
|
||||
},
|
||||
// After failure, watcher reports stable -> idle.
|
||||
{
|
||||
event: makeStatusEvent(agentapi.StatusStable),
|
||||
expected: &codersdk.WorkspaceAppStatus{
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "something broke",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Final states pass through with AgentAPI enabled.
|
||||
{
|
||||
name: "AllowFinalStates",
|
||||
tests: []test{
|
||||
{
|
||||
state: codersdk.WorkspaceAppStatusStateWorking,
|
||||
summary: "doing work",
|
||||
expected: &codersdk.WorkspaceAppStatus{
|
||||
State: codersdk.WorkspaceAppStatusStateWorking,
|
||||
Message: "doing work",
|
||||
},
|
||||
},
|
||||
// Agent reports complete; not overridden.
|
||||
{
|
||||
state: codersdk.WorkspaceAppStatusStateComplete,
|
||||
summary: "all done",
|
||||
expected: &codersdk.WorkspaceAppStatus{
|
||||
State: codersdk.WorkspaceAppStatusStateComplete,
|
||||
Message: "all done",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// When AgentAPI is not being used, we accept agent state updates as-is.
|
||||
@@ -1150,148 +1110,4 @@ func TestExpMcpReporter(t *testing.T) {
|
||||
<-cmdDone
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Reconnect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a test deployment and workspace.
|
||||
client, db := coderdtest.NewWithDatabase(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
|
||||
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user2.ID,
|
||||
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
|
||||
a[0].Apps = []*proto.App{
|
||||
{
|
||||
Slug: "vscode",
|
||||
},
|
||||
}
|
||||
return a
|
||||
}).Do()
|
||||
|
||||
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitLong))
|
||||
|
||||
// Watch the workspace for changes.
|
||||
watcher, err := client.WatchWorkspace(ctx, r.Workspace.ID)
|
||||
require.NoError(t, err)
|
||||
var lastAppStatus codersdk.WorkspaceAppStatus
|
||||
nextUpdate := func() codersdk.WorkspaceAppStatus {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
require.FailNow(t, "timed out waiting for status update")
|
||||
case w, ok := <-watcher:
|
||||
require.True(t, ok, "watch channel closed")
|
||||
if w.LatestAppStatus != nil && w.LatestAppStatus.ID != lastAppStatus.ID {
|
||||
t.Logf("Got status update: %s > %s", lastAppStatus.State, w.LatestAppStatus.State)
|
||||
lastAppStatus = *w.LatestAppStatus
|
||||
return lastAppStatus
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mock AI AgentAPI server that supports disconnect/reconnect.
|
||||
disconnect := make(chan struct{})
|
||||
listening := make(chan func(sse codersdk.ServerSentEvent) error)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Create a cancelable context so we can stop the SSE sender
|
||||
// goroutine on disconnect without waiting for the HTTP
|
||||
// serve loop to cancel r.Context().
|
||||
sseCtx, sseCancel := context.WithCancel(r.Context())
|
||||
defer sseCancel()
|
||||
r = r.WithContext(sseCtx)
|
||||
|
||||
send, closed, err := httpapi.ServerSentEventSender(w, r)
|
||||
if err != nil {
|
||||
httpapi.Write(sseCtx, w, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error setting up server-sent events.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// Send initial message so the watcher knows the agent is active.
|
||||
send(*makeMessageEvent(0, agentapi.RoleAgent))
|
||||
select {
|
||||
case listening <- send:
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-closed:
|
||||
case <-disconnect:
|
||||
sseCancel()
|
||||
<-closed
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
inv, _ := clitest.New(t,
|
||||
"exp", "mcp", "server",
|
||||
"--agent-url", client.URL.String(),
|
||||
"--agent-token", r.AgentToken,
|
||||
"--app-status-slug", "vscode",
|
||||
"--allowed-tools=coder_report_task",
|
||||
"--ai-agentapi-url", srv.URL,
|
||||
)
|
||||
inv = inv.WithContext(ctx)
|
||||
|
||||
pty := ptytest.New(t)
|
||||
inv.Stdin = pty.Input()
|
||||
inv.Stdout = pty.Output()
|
||||
stderr := ptytest.New(t)
|
||||
inv.Stderr = stderr.Output()
|
||||
|
||||
// Run the MCP server.
|
||||
clitest.Start(t, inv)
|
||||
|
||||
// Initialize.
|
||||
payload := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
|
||||
pty.WriteLine(payload)
|
||||
_ = pty.ReadLine(ctx) // ignore echo
|
||||
_ = pty.ReadLine(ctx) // ignore init response
|
||||
|
||||
// Get first sender from the initial SSE connection.
|
||||
sender := testutil.RequireReceive(ctx, t, listening)
|
||||
|
||||
// Self-report a working status via tool call.
|
||||
toolPayload := `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"doing work","link":""}}}`
|
||||
pty.WriteLine(toolPayload)
|
||||
_ = pty.ReadLine(ctx) // ignore echo
|
||||
_ = pty.ReadLine(ctx) // ignore response
|
||||
got := nextUpdate()
|
||||
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
|
||||
require.Equal(t, "doing work", got.Message)
|
||||
|
||||
// Watcher sends stable, verify idle is reported.
|
||||
err = sender(*makeStatusEvent(agentapi.StatusStable))
|
||||
require.NoError(t, err)
|
||||
got = nextUpdate()
|
||||
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
|
||||
|
||||
// Disconnect the SSE connection by signaling the handler to return.
|
||||
testutil.RequireSend(ctx, t, disconnect, struct{}{})
|
||||
|
||||
// Wait for the watcher to reconnect and get the new sender.
|
||||
sender = testutil.RequireReceive(ctx, t, listening)
|
||||
|
||||
// After reconnect, self-report a working status again.
|
||||
toolPayload = `{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"reconnected","link":""}}}`
|
||||
pty.WriteLine(toolPayload)
|
||||
_ = pty.ReadLine(ctx) // ignore echo
|
||||
_ = pty.ReadLine(ctx) // ignore response
|
||||
got = nextUpdate()
|
||||
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
|
||||
require.Equal(t, "reconnected", got.Message)
|
||||
|
||||
// Verify the watcher still processes events after reconnect.
|
||||
err = sender(*makeStatusEvent(agentapi.StatusStable))
|
||||
require.NoError(t, err)
|
||||
got = nextUpdate()
|
||||
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
|
||||
|
||||
cancel()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -29,7 +29,6 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
|
||||
templateVersionJobTimeout time.Duration
|
||||
prebuildWorkspaceTimeout time.Duration
|
||||
noCleanup bool
|
||||
provisionerTags []string
|
||||
|
||||
tracingFlags = &scaletestTracingFlags{}
|
||||
timeoutStrategy = &timeoutFlags{}
|
||||
@@ -112,16 +111,10 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
|
||||
|
||||
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
|
||||
|
||||
tags, err := ParseProvisionerTags(provisionerTags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range numTemplates {
|
||||
id := strconv.Itoa(int(i))
|
||||
cfg := prebuilds.Config{
|
||||
OrganizationID: me.OrganizationIDs[0],
|
||||
ProvisionerTags: tags,
|
||||
NumPresets: int(numPresets),
|
||||
NumPresetPrebuilds: int(numPresetPrebuilds),
|
||||
TemplateVersionJobTimeout: templateVersionJobTimeout,
|
||||
@@ -290,11 +283,6 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
|
||||
Description: "Skip cleanup (deletion test) and leave resources intact.",
|
||||
Value: serpent.BoolOf(&noCleanup),
|
||||
},
|
||||
{
|
||||
Flag: "provisioner-tag",
|
||||
Description: "Specify a set of tags to target provisioner daemons.",
|
||||
Value: serpent.StringArrayOf(&provisionerTags),
|
||||
},
|
||||
}
|
||||
|
||||
tracingFlags.attach(&cmd.Options)
|
||||
|
||||
+5
-1
@@ -106,7 +106,11 @@ func TestList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
memberClient, member = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
sharedWorkspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
|
||||
@@ -297,7 +297,7 @@ func (pr *ParameterResolver) verifyConstraints(resolved []codersdk.WorkspaceBuil
|
||||
return xerrors.Errorf("ephemeral parameter %q can be used only with --prompt-ephemeral-parameters or --ephemeral-parameter flag", r.Name)
|
||||
}
|
||||
|
||||
if !tvp.Mutable && action != WorkspaceCreate && !pr.isFirstTimeUse(r.Name) {
|
||||
if !tvp.Mutable && action != WorkspaceCreate {
|
||||
return xerrors.Errorf("parameter %q is immutable and cannot be updated", r.Name)
|
||||
}
|
||||
}
|
||||
|
||||
+31
-7
@@ -25,7 +25,11 @@ func TestSharingShare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -64,8 +68,12 @@ func TestSharingShare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -119,7 +127,11 @@ func TestSharingShare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -170,7 +182,11 @@ func TestSharingStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -214,7 +230,11 @@ func TestSharingRemove(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -271,7 +291,11 @@ func TestSharingRemove(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
|
||||
+1
-1
@@ -120,7 +120,7 @@ func (r *RootCmd) start() *serpent.Command {
|
||||
func buildWorkspaceStartRequest(inv *serpent.Invocation, client *codersdk.Client, workspace codersdk.Workspace, parameterFlags workspaceParameterFlags, buildFlags buildFlags, action WorkspaceCLIAction) (codersdk.CreateWorkspaceBuildRequest, error) {
|
||||
version := workspace.LatestBuild.TemplateVersionID
|
||||
|
||||
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || workspace.TemplateRequireActiveVersion || action == WorkspaceUpdate {
|
||||
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || action == WorkspaceUpdate {
|
||||
version = workspace.TemplateActiveVersionID
|
||||
if version != workspace.LatestBuild.TemplateVersionID {
|
||||
action = WorkspaceUpdate
|
||||
|
||||
+4
-4
@@ -33,7 +33,7 @@ func TestStatePull(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taUser.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
|
||||
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
|
||||
Do()
|
||||
statefilePath := filepath.Join(t.TempDir(), "state")
|
||||
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name, statefilePath)
|
||||
@@ -54,7 +54,7 @@ func TestStatePull(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taUser.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
|
||||
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
|
||||
Do()
|
||||
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name)
|
||||
var gotState bytes.Buffer
|
||||
@@ -74,7 +74,7 @@ func TestStatePull(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taUser.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
|
||||
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
|
||||
Do()
|
||||
inv, root := clitest.New(t, "state", "pull", taUser.Username+"/"+r.Workspace.Name,
|
||||
"--build", fmt.Sprintf("%d", r.Build.BuildNumber))
|
||||
@@ -170,7 +170,7 @@ func TestStatePush(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taUser.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{}).ProvisionerState(initialState).
|
||||
Seed(database.WorkspaceBuild{ProvisionerState: initialState}).
|
||||
Do()
|
||||
wantState := []byte("updated state")
|
||||
stateFile, err := os.CreateTemp(t.TempDir(), "")
|
||||
|
||||
+4
-3
@@ -49,9 +49,10 @@ OPTIONS:
|
||||
security purposes if a --wildcard-access-url is configured.
|
||||
|
||||
--disable-workspace-sharing bool, $CODER_DISABLE_WORKSPACE_SHARING
|
||||
Disable workspace sharing. Workspace ACL checking is disabled and only
|
||||
owners can have ssh, apps and terminal access to workspaces. Access
|
||||
based on the 'owner' role is also allowed unless disabled via
|
||||
Disable workspace sharing (requires the "workspace-sharing" experiment
|
||||
to be enabled). Workspace ACL checking is disabled and only owners can
|
||||
have ssh, apps and terminal access to workspaces. Access based on the
|
||||
'owner' role is also allowed unless disabled via
|
||||
--disable-owner-workspace-access.
|
||||
|
||||
--swagger-enable bool, $CODER_SWAGGER_ENABLE
|
||||
|
||||
+4
-4
@@ -523,10 +523,10 @@ disablePathApps: false
|
||||
# workspaces.
|
||||
# (default: <unset>, type: bool)
|
||||
disableOwnerWorkspaceAccess: false
|
||||
# Disable workspace sharing. Workspace ACL checking is disabled and only owners
|
||||
# can have ssh, apps and terminal access to workspaces. Access based on the
|
||||
# 'owner' role is also allowed unless disabled via
|
||||
# --disable-owner-workspace-access.
|
||||
# Disable workspace sharing (requires the "workspace-sharing" experiment to be
|
||||
# enabled). Workspace ACL checking is disabled and only owners can have ssh, apps
|
||||
# and terminal access to workspaces. Access based on the 'owner' role is also
|
||||
# allowed unless disabled via --disable-owner-workspace-access.
|
||||
# (default: <unset>, type: bool)
|
||||
disableWorkspaceSharing: false
|
||||
# These options change the behavior of how clients interact with the Coder.
|
||||
|
||||
+15
-2
@@ -241,13 +241,26 @@ func (r *RootCmd) listTokens() *serpent.Command {
|
||||
}
|
||||
|
||||
tokens, err := client.Tokens(inv.Context(), codersdk.Me, codersdk.TokensFilter{
|
||||
IncludeAll: all,
|
||||
IncludeExpired: includeExpired,
|
||||
IncludeAll: all,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("list tokens: %w", err)
|
||||
}
|
||||
|
||||
// Filter out expired tokens unless --include-expired is set
|
||||
// TODO(Cian): This _could_ get too big for client-side filtering.
|
||||
// If it causes issues, we can filter server-side.
|
||||
if !includeExpired {
|
||||
now := time.Now()
|
||||
filtered := make([]codersdk.APIKeyWithOwner, 0, len(tokens))
|
||||
for _, token := range tokens {
|
||||
if token.ExpiresAt.After(now) {
|
||||
filtered = append(filtered, token)
|
||||
}
|
||||
}
|
||||
tokens = filtered
|
||||
}
|
||||
|
||||
displayTokens = make([]tokenListRow, len(tokens))
|
||||
|
||||
for i, token := range tokens {
|
||||
|
||||
@@ -990,74 +990,4 @@ func TestUpdateValidateRichParameters(t *testing.T) {
|
||||
|
||||
_ = testutil.TryReceive(ctx, t, doneChan)
|
||||
})
|
||||
|
||||
t.Run("NewImmutableParameterViaFlag", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create template and workspace with only a mutable parameter.
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
templateParameters := []*proto.RichParameter{
|
||||
{Name: stringParameterName, Type: "string", Mutable: true, Required: true, Options: []*proto.RichParameterOption{
|
||||
{Name: "First option", Description: "This is first option", Value: "1st"},
|
||||
{Name: "Second option", Description: "This is second option", Value: "2nd"},
|
||||
}},
|
||||
}
|
||||
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses(templateParameters))
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
|
||||
|
||||
inv, root := clitest.New(t, "create", "my-workspace", "--yes", "--template", template.Name, "--parameter", fmt.Sprintf("%s=%s", stringParameterName, "1st"))
|
||||
clitest.SetupConfig(t, member, root)
|
||||
err := inv.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update template: add a new immutable parameter.
|
||||
updatedTemplateParameters := []*proto.RichParameter{
|
||||
templateParameters[0],
|
||||
{Name: immutableParameterName, Type: "string", Mutable: false, Required: true, Options: []*proto.RichParameterOption{
|
||||
{Name: "fir", Description: "First option for immutable parameter", Value: "I"},
|
||||
{Name: "sec", Description: "Second option for immutable parameter", Value: "II"},
|
||||
}},
|
||||
}
|
||||
|
||||
updatedVersion := coderdtest.UpdateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses(updatedTemplateParameters), template.ID)
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, updatedVersion.ID)
|
||||
err = client.UpdateActiveTemplateVersion(context.Background(), template.ID, codersdk.UpdateActiveTemplateVersion{
|
||||
ID: updatedVersion.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update workspace, supplying the new immutable parameter via
|
||||
// the --parameter flag. This should succeed because it's the
|
||||
// first time this parameter is being set.
|
||||
inv, root = clitest.New(t, "update", "my-workspace",
|
||||
"--parameter", fmt.Sprintf("%s=%s", immutableParameterName, "II"))
|
||||
clitest.SetupConfig(t, member, root)
|
||||
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
doneChan := make(chan struct{})
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
err := inv.Run()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
pty.ExpectMatch("Planning workspace")
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
_ = testutil.TryReceive(ctx, t, doneChan)
|
||||
|
||||
// Verify the immutable parameter was set correctly.
|
||||
workspace, err := client.WorkspaceByOwnerAndName(ctx, memberUser.ID.String(), "my-workspace", codersdk.WorkspaceOptions{})
|
||||
require.NoError(t, err)
|
||||
actualParameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, actualParameters, codersdk.WorkspaceBuildParameter{
|
||||
Name: immutableParameterName,
|
||||
Value: "II",
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -179,8 +179,6 @@ func New(opts Options, workspace database.Workspace) *API {
|
||||
Database: opts.Database,
|
||||
Log: opts.Log,
|
||||
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
|
||||
Clock: opts.Clock,
|
||||
NotificationsEnqueuer: opts.NotificationsEnqueuer,
|
||||
}
|
||||
|
||||
api.MetadataAPI = &MetadataAPI{
|
||||
|
||||
@@ -2,10 +2,6 @@ package agentapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
@@ -13,14 +9,7 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
strutil "github.com/coder/coder/v2/coderd/util/strings"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
type AppsAPI struct {
|
||||
@@ -28,8 +17,6 @@ type AppsAPI struct {
|
||||
Database database.Store
|
||||
Log slog.Logger
|
||||
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
|
||||
NotificationsEnqueuer notifications.Enqueuer
|
||||
Clock quartz.Clock
|
||||
}
|
||||
|
||||
func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) {
|
||||
@@ -117,230 +104,3 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
|
||||
}
|
||||
return &agentproto.BatchUpdateAppHealthResponse{}, nil
|
||||
}
|
||||
|
||||
func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
|
||||
if len(req.Message) > 160 {
|
||||
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Message is too long.",
|
||||
Detail: "Message must be less than 160 characters.",
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "message", Detail: "Message must be less than 160 characters."},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
var dbState database.WorkspaceAppStatusState
|
||||
switch req.State {
|
||||
case agentproto.UpdateAppStatusRequest_COMPLETE:
|
||||
dbState = database.WorkspaceAppStatusStateComplete
|
||||
case agentproto.UpdateAppStatusRequest_FAILURE:
|
||||
dbState = database.WorkspaceAppStatusStateFailure
|
||||
case agentproto.UpdateAppStatusRequest_WORKING:
|
||||
dbState = database.WorkspaceAppStatusStateWorking
|
||||
case agentproto.UpdateAppStatusRequest_IDLE:
|
||||
dbState = database.WorkspaceAppStatusStateIdle
|
||||
default:
|
||||
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid state provided.",
|
||||
Detail: fmt.Sprintf("invalid state: %q", req.State),
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "state", Detail: "State must be one of: complete, failure, working, idle."},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
workspaceAgent, err := a.AgentFn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
app, err := a.Database.GetWorkspaceAppByAgentIDAndSlug(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{
|
||||
AgentID: workspaceAgent.ID,
|
||||
Slug: req.Slug,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to get workspace app.",
|
||||
Detail: fmt.Sprintf("No app found with slug %q", req.Slug),
|
||||
})
|
||||
}
|
||||
|
||||
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to get workspace.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Treat the message as untrusted input.
|
||||
cleaned := strutil.UISanitize(req.Message)
|
||||
|
||||
// Get the latest status for the workspace app to detect no-op updates
|
||||
// nolint:gocritic // This is a system restricted operation.
|
||||
latestAppStatus, err := a.Database.GetLatestWorkspaceAppStatusByAppID(dbauthz.AsSystemRestricted(ctx), app.ID)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get latest workspace app status.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
// If no rows found, latestAppStatus will be a zero-value struct (ID == uuid.Nil)
|
||||
|
||||
// nolint:gocritic // This is a system restricted operation.
|
||||
_, err = a.Database.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
|
||||
ID: uuid.New(),
|
||||
CreatedAt: dbtime.Now(),
|
||||
WorkspaceID: workspace.ID,
|
||||
AgentID: workspaceAgent.ID,
|
||||
AppID: app.ID,
|
||||
State: dbState,
|
||||
Message: cleaned,
|
||||
Uri: sql.NullString{
|
||||
String: req.Uri,
|
||||
Valid: req.Uri != "",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to insert workspace app status.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
if a.PublishWorkspaceUpdateFn != nil {
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentAppStatusUpdate)
|
||||
if err != nil {
|
||||
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to publish workspace update.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Notify on state change to Working/Idle for AI tasks
|
||||
a.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, dbState, workspace, workspaceAgent)
|
||||
|
||||
if shouldBump(dbState, latestAppStatus) {
|
||||
// We pass time.Time{} for nextAutostart since we don't have access to
|
||||
// TemplateScheduleStore here. The activity bump logic handles this by
|
||||
// defaulting to the template's activity_bump duration (typically 1 hour).
|
||||
workspacestats.ActivityBumpWorkspace(ctx, a.Log, a.Database, workspace.ID, time.Time{})
|
||||
}
|
||||
// just return a blank response because it doesn't contain any settable fields at present.
|
||||
return new(agentproto.UpdateAppStatusResponse), nil
|
||||
}
|
||||
|
||||
func shouldBump(dbState database.WorkspaceAppStatusState, latestAppStatus database.WorkspaceAppStatus) bool {
|
||||
// Bump deadline when agent reports working or transitions away from working.
|
||||
// This prevents auto-pause during active work and gives users time to interact
|
||||
// after work completes.
|
||||
|
||||
// Bump if reporting working state.
|
||||
if dbState == database.WorkspaceAppStatusStateWorking {
|
||||
return true
|
||||
}
|
||||
|
||||
// Bump if transitioning away from working state.
|
||||
if latestAppStatus.ID != uuid.Nil {
|
||||
prevState := latestAppStatus.State
|
||||
if prevState == database.WorkspaceAppStatusStateWorking {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// enqueueAITaskStateNotification enqueues a notification when an AI task's app
|
||||
// transitions to Working or Idle.
|
||||
// No-op if:
|
||||
// - the workspace agent app isn't configured as an AI task,
|
||||
// - the new state equals the latest persisted state,
|
||||
// - the workspace agent is not ready (still starting up).
|
||||
func (a *AppsAPI) enqueueAITaskStateNotification(
|
||||
ctx context.Context,
|
||||
appID uuid.UUID,
|
||||
latestAppStatus database.WorkspaceAppStatus,
|
||||
newAppStatus database.WorkspaceAppStatusState,
|
||||
workspace database.Workspace,
|
||||
agent database.WorkspaceAgent,
|
||||
) {
|
||||
var notificationTemplate uuid.UUID
|
||||
switch newAppStatus {
|
||||
case database.WorkspaceAppStatusStateWorking:
|
||||
notificationTemplate = notifications.TemplateTaskWorking
|
||||
case database.WorkspaceAppStatusStateIdle:
|
||||
notificationTemplate = notifications.TemplateTaskIdle
|
||||
case database.WorkspaceAppStatusStateComplete:
|
||||
notificationTemplate = notifications.TemplateTaskCompleted
|
||||
case database.WorkspaceAppStatusStateFailure:
|
||||
notificationTemplate = notifications.TemplateTaskFailed
|
||||
default:
|
||||
// Not a notifiable state, do nothing
|
||||
return
|
||||
}
|
||||
|
||||
if !workspace.TaskID.Valid {
|
||||
// Workspace has no task ID, do nothing.
|
||||
return
|
||||
}
|
||||
|
||||
// Only send notifications when the agent is ready. We want to skip
|
||||
// any state transitions that occur whilst the workspace is starting
|
||||
// up as it doesn't make sense to receive them.
|
||||
if agent.LifecycleState != database.WorkspaceAgentLifecycleStateReady {
|
||||
a.Log.Debug(ctx, "skipping AI task notification because agent is not ready",
|
||||
slog.F("agent_id", agent.ID),
|
||||
slog.F("lifecycle_state", agent.LifecycleState),
|
||||
slog.F("new_app_status", newAppStatus),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
task, err := a.Database.GetTaskByID(ctx, workspace.TaskID.UUID)
|
||||
if err != nil {
|
||||
a.Log.Warn(ctx, "failed to get task", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
if !task.WorkspaceAppID.Valid || task.WorkspaceAppID.UUID != appID {
|
||||
// Non-task app, do nothing.
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if the latest persisted state equals the new state (no new transition)
|
||||
// Note: uuid.Nil check is valid here. If no previous status exists,
|
||||
// GetLatestWorkspaceAppStatusByAppID returns sql.ErrNoRows and we get a zero-value struct.
|
||||
if latestAppStatus.ID != uuid.Nil && latestAppStatus.State == newAppStatus {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip the initial "Working" notification when the task first starts.
|
||||
// This is obvious to the user since they just created the task.
|
||||
// We still notify on the first "Idle" status and all subsequent transitions.
|
||||
if latestAppStatus.ID == uuid.Nil && newAppStatus == database.WorkspaceAppStatusStateWorking {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := a.NotificationsEnqueuer.EnqueueWithData(
|
||||
// nolint:gocritic // Need notifier actor to enqueue notifications
|
||||
dbauthz.AsNotifier(ctx),
|
||||
workspace.OwnerID,
|
||||
notificationTemplate,
|
||||
map[string]string{
|
||||
"task": task.Name,
|
||||
"workspace": workspace.Name,
|
||||
},
|
||||
map[string]any{
|
||||
// Use a 1-minute bucketed timestamp to bypass per-day dedupe,
|
||||
// allowing identical content to resend within the same day
|
||||
// (but not more than once every 10s).
|
||||
"dedupe_bypass_ts": a.Clock.Now().UTC().Truncate(time.Minute),
|
||||
},
|
||||
"api-workspace-agent-app-status",
|
||||
// Associate this notification with related entities
|
||||
workspace.ID, workspace.OwnerID, workspace.OrganizationID, appID,
|
||||
); err != nil {
|
||||
a.Log.Warn(ctx, "failed to notify of task state", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
package agentapi
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
)
|
||||
|
||||
func TestShouldBump(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prevState *database.WorkspaceAppStatusState // nil means no previous state
|
||||
newState database.WorkspaceAppStatusState
|
||||
shouldBump bool
|
||||
}{
|
||||
{
|
||||
name: "FirstStatusBumps",
|
||||
prevState: nil,
|
||||
newState: database.WorkspaceAppStatusStateWorking,
|
||||
shouldBump: true,
|
||||
},
|
||||
{
|
||||
name: "WorkingToIdleBumps",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
|
||||
newState: database.WorkspaceAppStatusStateIdle,
|
||||
shouldBump: true,
|
||||
},
|
||||
{
|
||||
name: "WorkingToCompleteBumps",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
|
||||
newState: database.WorkspaceAppStatusStateComplete,
|
||||
shouldBump: true,
|
||||
},
|
||||
{
|
||||
name: "CompleteToIdleNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
|
||||
newState: database.WorkspaceAppStatusStateIdle,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "CompleteToCompleteNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
|
||||
newState: database.WorkspaceAppStatusStateComplete,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "FailureToIdleNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
|
||||
newState: database.WorkspaceAppStatusStateIdle,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "FailureToFailureNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
|
||||
newState: database.WorkspaceAppStatusStateFailure,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "CompleteToWorkingBumps",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
|
||||
newState: database.WorkspaceAppStatusStateWorking,
|
||||
shouldBump: true,
|
||||
},
|
||||
{
|
||||
name: "FailureToCompleteNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
|
||||
newState: database.WorkspaceAppStatusStateComplete,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "WorkingToFailureBumps",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
|
||||
newState: database.WorkspaceAppStatusStateFailure,
|
||||
shouldBump: true,
|
||||
},
|
||||
{
|
||||
name: "IdleToIdleNoBump",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateIdle),
|
||||
newState: database.WorkspaceAppStatusStateIdle,
|
||||
shouldBump: false,
|
||||
},
|
||||
{
|
||||
name: "IdleToWorkingBumps",
|
||||
prevState: ptr.Ref(database.WorkspaceAppStatusStateIdle),
|
||||
newState: database.WorkspaceAppStatusStateWorking,
|
||||
shouldBump: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var prevAppStatus database.WorkspaceAppStatus
|
||||
// If there's a previous state, report it first.
|
||||
if tt.prevState != nil {
|
||||
prevAppStatus.ID = uuid.UUID{1}
|
||||
prevAppStatus.State = *tt.prevState
|
||||
}
|
||||
|
||||
didBump := shouldBump(tt.newState, prevAppStatus)
|
||||
if tt.shouldBump {
|
||||
require.True(t, didBump, "wanted deadline to bump but it didn't")
|
||||
} else {
|
||||
require.False(t, didBump, "wanted deadline not to bump but it did")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,13 +2,9 @@ package agentapi_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
@@ -16,12 +12,8 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestBatchUpdateAppHealths(t *testing.T) {
|
||||
@@ -261,183 +253,3 @@ func TestBatchUpdateAppHealths(t *testing.T) {
|
||||
require.Nil(t, resp)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkspaceAgentAppStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
fEnq := ¬ificationstest.FakeEnqueuer{}
|
||||
mClock := quartz.NewMock(t)
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: uuid.UUID{2},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
}
|
||||
workspaceUpdates := make(chan wspubsub.WorkspaceEventKind, 100)
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: mDB,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: func(_ context.Context, agnt *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
assert.Equal(t, *agnt, agent)
|
||||
testutil.AssertSend(ctx, t, workspaceUpdates, kind)
|
||||
return nil
|
||||
},
|
||||
NotificationsEnqueuer: fEnq,
|
||||
Clock: mClock,
|
||||
}
|
||||
|
||||
app := database.WorkspaceApp{
|
||||
ID: uuid.UUID{8},
|
||||
}
|
||||
mDB.EXPECT().GetWorkspaceAppByAgentIDAndSlug(gomock.Any(), database.GetWorkspaceAppByAgentIDAndSlugParams{
|
||||
AgentID: agent.ID,
|
||||
Slug: "vscode",
|
||||
}).Times(1).Return(app, nil)
|
||||
task := database.Task{
|
||||
ID: uuid.UUID{7},
|
||||
WorkspaceAppID: uuid.NullUUID{
|
||||
Valid: true,
|
||||
UUID: app.ID,
|
||||
},
|
||||
}
|
||||
mDB.EXPECT().GetTaskByID(gomock.Any(), task.ID).Times(1).Return(task, nil)
|
||||
workspace := database.Workspace{
|
||||
ID: uuid.UUID{9},
|
||||
TaskID: uuid.NullUUID{
|
||||
Valid: true,
|
||||
UUID: task.ID,
|
||||
},
|
||||
}
|
||||
mDB.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Times(1).Return(workspace, nil)
|
||||
appStatus := database.WorkspaceAppStatus{
|
||||
ID: uuid.UUID{6},
|
||||
}
|
||||
mDB.EXPECT().GetLatestWorkspaceAppStatusByAppID(gomock.Any(), app.ID).Times(1).Return(appStatus, nil)
|
||||
mDB.EXPECT().InsertWorkspaceAppStatus(
|
||||
gomock.Any(),
|
||||
gomock.Cond(func(params database.InsertWorkspaceAppStatusParams) bool {
|
||||
if params.AgentID == agent.ID && params.AppID == app.ID {
|
||||
assert.Equal(t, "testing", params.Message)
|
||||
assert.Equal(t, database.WorkspaceAppStatusStateComplete, params.State)
|
||||
assert.True(t, params.Uri.Valid)
|
||||
assert.Equal(t, "https://example.com", params.Uri.String)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})).Times(1).Return(database.WorkspaceAppStatus{}, nil)
|
||||
|
||||
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
|
||||
Slug: "vscode",
|
||||
Message: "testing",
|
||||
Uri: "https://example.com",
|
||||
State: agentproto.UpdateAppStatusRequest_COMPLETE,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
kind := testutil.RequireReceive(ctx, t, workspaceUpdates)
|
||||
require.Equal(t, wspubsub.WorkspaceEventKindAgentAppStatusUpdate, kind)
|
||||
sent := fEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTaskCompleted))
|
||||
require.Len(t, sent, 1)
|
||||
})
|
||||
|
||||
t.Run("FailUnknownApp", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: uuid.UUID{2},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
}
|
||||
|
||||
mDB.EXPECT().GetWorkspaceAppByAgentIDAndSlug(gomock.Any(), gomock.Any()).
|
||||
Times(1).
|
||||
Return(database.WorkspaceApp{}, sql.ErrNoRows)
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: mDB,
|
||||
Log: testutil.Logger(t),
|
||||
}
|
||||
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
|
||||
Slug: "unknown",
|
||||
Message: "testing",
|
||||
Uri: "https://example.com",
|
||||
State: agentproto.UpdateAppStatusRequest_COMPLETE,
|
||||
})
|
||||
require.ErrorContains(t, err, "No app found with slug")
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("FailUnknownState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: uuid.UUID{2},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
}
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: mDB,
|
||||
Log: testutil.Logger(t),
|
||||
}
|
||||
|
||||
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
|
||||
Slug: "vscode",
|
||||
Message: "testing",
|
||||
Uri: "https://example.com",
|
||||
State: 77,
|
||||
})
|
||||
require.ErrorContains(t, err, "Invalid state")
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("FailTooLong", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: uuid.UUID{2},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
}
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: mDB,
|
||||
Log: testutil.Logger(t),
|
||||
}
|
||||
|
||||
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
|
||||
Slug: "vscode",
|
||||
Message: strings.Repeat("a", 161),
|
||||
Uri: "https://example.com",
|
||||
State: agentproto.UpdateAppStatusRequest_COMPLETE,
|
||||
})
|
||||
require.ErrorContains(t, err, "Message is too long")
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -466,6 +466,7 @@ func (api *API) convertTasks(ctx context.Context, requesterID uuid.UUID, dbTasks
|
||||
|
||||
apiWorkspaces, err := convertWorkspaces(
|
||||
ctx,
|
||||
api.Experiments,
|
||||
api.Logger,
|
||||
requesterID,
|
||||
workspaces,
|
||||
@@ -545,6 +546,7 @@ func (api *API) taskGet(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ws, err := convertWorkspace(
|
||||
ctx,
|
||||
api.Experiments,
|
||||
api.Logger,
|
||||
apiKey.UserID,
|
||||
workspace,
|
||||
|
||||
@@ -832,7 +832,7 @@ func TestTasks(t *testing.T) {
|
||||
t.Run("SendToNonActiveStates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{})
|
||||
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
|
||||
Generated
+7
-38
@@ -135,34 +135,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/aibridge/models": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"AI Bridge"
|
||||
],
|
||||
"summary": "List AI Bridge models",
|
||||
"operationId": "list-ai-bridge-models",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/appearance": {
|
||||
"get": {
|
||||
"security": [
|
||||
@@ -8266,12 +8238,6 @@ const docTemplate = `{
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "boolean",
|
||||
"description": "Include expired tokens in the list",
|
||||
"name": "include_expired",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
@@ -9579,7 +9545,6 @@ const docTemplate = `{
|
||||
],
|
||||
"summary": "Patch workspace agent app status",
|
||||
"operationId": "patch-workspace-agent-app-status",
|
||||
"deprecated": true,
|
||||
"parameters": [
|
||||
{
|
||||
"description": "app status",
|
||||
@@ -15132,7 +15097,8 @@ const docTemplate = `{
|
||||
"workspace-usage",
|
||||
"web-push",
|
||||
"oauth2",
|
||||
"mcp-server-http"
|
||||
"mcp-server-http",
|
||||
"workspace-sharing"
|
||||
],
|
||||
"x-enum-comments": {
|
||||
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
|
||||
@@ -15141,6 +15107,7 @@ const docTemplate = `{
|
||||
"ExperimentNotifications": "Sends notifications via SMTP and webhooks following certain events.",
|
||||
"ExperimentOAuth2": "Enables OAuth2 provider functionality.",
|
||||
"ExperimentWebPush": "Enables web push notifications through the browser.",
|
||||
"ExperimentWorkspaceSharing": "Enables updating workspace ACLs for sharing with users and groups.",
|
||||
"ExperimentWorkspaceUsage": "Enables the new workspace usage tracking."
|
||||
},
|
||||
"x-enum-descriptions": [
|
||||
@@ -15150,7 +15117,8 @@ const docTemplate = `{
|
||||
"Enables the new workspace usage tracking.",
|
||||
"Enables web push notifications through the browser.",
|
||||
"Enables OAuth2 provider functionality.",
|
||||
"Enables the MCP HTTP server functionality."
|
||||
"Enables the MCP HTTP server functionality.",
|
||||
"Enables updating workspace ACLs for sharing with users and groups."
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"ExperimentExample",
|
||||
@@ -15159,7 +15127,8 @@ const docTemplate = `{
|
||||
"ExperimentWorkspaceUsage",
|
||||
"ExperimentWebPush",
|
||||
"ExperimentOAuth2",
|
||||
"ExperimentMCPServerHTTP"
|
||||
"ExperimentMCPServerHTTP",
|
||||
"ExperimentWorkspaceSharing"
|
||||
]
|
||||
},
|
||||
"codersdk.ExternalAPIKeyScopes": {
|
||||
|
||||
Generated
+7
-34
@@ -112,30 +112,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/aibridge/models": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["AI Bridge"],
|
||||
"summary": "List AI Bridge models",
|
||||
"operationId": "list-ai-bridge-models",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/appearance": {
|
||||
"get": {
|
||||
"security": [
|
||||
@@ -7309,12 +7285,6 @@
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "boolean",
|
||||
"description": "Include expired tokens in the list",
|
||||
"name": "include_expired",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
@@ -8474,7 +8444,6 @@
|
||||
"tags": ["Agents"],
|
||||
"summary": "Patch workspace agent app status",
|
||||
"operationId": "patch-workspace-agent-app-status",
|
||||
"deprecated": true,
|
||||
"parameters": [
|
||||
{
|
||||
"description": "app status",
|
||||
@@ -13655,7 +13624,8 @@
|
||||
"workspace-usage",
|
||||
"web-push",
|
||||
"oauth2",
|
||||
"mcp-server-http"
|
||||
"mcp-server-http",
|
||||
"workspace-sharing"
|
||||
],
|
||||
"x-enum-comments": {
|
||||
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
|
||||
@@ -13664,6 +13634,7 @@
|
||||
"ExperimentNotifications": "Sends notifications via SMTP and webhooks following certain events.",
|
||||
"ExperimentOAuth2": "Enables OAuth2 provider functionality.",
|
||||
"ExperimentWebPush": "Enables web push notifications through the browser.",
|
||||
"ExperimentWorkspaceSharing": "Enables updating workspace ACLs for sharing with users and groups.",
|
||||
"ExperimentWorkspaceUsage": "Enables the new workspace usage tracking."
|
||||
},
|
||||
"x-enum-descriptions": [
|
||||
@@ -13673,7 +13644,8 @@
|
||||
"Enables the new workspace usage tracking.",
|
||||
"Enables web push notifications through the browser.",
|
||||
"Enables OAuth2 provider functionality.",
|
||||
"Enables the MCP HTTP server functionality."
|
||||
"Enables the MCP HTTP server functionality.",
|
||||
"Enables updating workspace ACLs for sharing with users and groups."
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"ExperimentExample",
|
||||
@@ -13682,7 +13654,8 @@
|
||||
"ExperimentWorkspaceUsage",
|
||||
"ExperimentWebPush",
|
||||
"ExperimentOAuth2",
|
||||
"ExperimentMCPServerHTTP"
|
||||
"ExperimentMCPServerHTTP",
|
||||
"ExperimentWorkspaceSharing"
|
||||
]
|
||||
},
|
||||
"codersdk.ExternalAPIKeyScopes": {
|
||||
|
||||
+8
-14
@@ -307,26 +307,20 @@ func (api *API) apiKeyByName(rw http.ResponseWriter, r *http.Request) {
|
||||
// @Tags Users
|
||||
// @Param user path string true "User ID, name, or me"
|
||||
// @Success 200 {array} codersdk.APIKey
|
||||
// @Param include_expired query bool false "Include expired tokens in the list"
|
||||
// @Router /users/{user}/keys/tokens [get]
|
||||
func (api *API) tokens(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
user = httpmw.UserParam(r)
|
||||
keys []database.APIKey
|
||||
err error
|
||||
queryStr = r.URL.Query().Get("include_all")
|
||||
includeAll, _ = strconv.ParseBool(queryStr)
|
||||
expiredStr = r.URL.Query().Get("include_expired")
|
||||
includeExpired, _ = strconv.ParseBool(expiredStr)
|
||||
ctx = r.Context()
|
||||
user = httpmw.UserParam(r)
|
||||
keys []database.APIKey
|
||||
err error
|
||||
queryStr = r.URL.Query().Get("include_all")
|
||||
includeAll, _ = strconv.ParseBool(queryStr)
|
||||
)
|
||||
|
||||
if includeAll {
|
||||
// get tokens for all users
|
||||
keys, err = api.Database.GetAPIKeysByLoginType(ctx, database.GetAPIKeysByLoginTypeParams{
|
||||
LoginType: database.LoginTypeToken,
|
||||
IncludeExpired: includeExpired,
|
||||
})
|
||||
keys, err = api.Database.GetAPIKeysByLoginType(ctx, database.LoginTypeToken)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching API keys.",
|
||||
@@ -336,7 +330,7 @@ func (api *API) tokens(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
} else {
|
||||
// get user's tokens only
|
||||
keys, err = api.Database.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: user.ID, IncludeExpired: includeExpired})
|
||||
keys, err = api.Database.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: user.ID})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching API keys.",
|
||||
|
||||
@@ -69,44 +69,6 @@ func TestTokenCRUD(t *testing.T) {
|
||||
require.Equal(t, database.AuditActionDelete, auditor.AuditLogs()[numLogs-1].Action)
|
||||
}
|
||||
|
||||
func TestTokensFilterExpired(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
adminClient := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, adminClient)
|
||||
|
||||
// Create a token.
|
||||
res, err := adminClient.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
|
||||
Lifetime: time.Hour * 24 * 7,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
keyID := strings.Split(res.Key, "-")[0]
|
||||
|
||||
// List tokens without including expired - should see the token.
|
||||
keys, err := adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
|
||||
// Expire the token.
|
||||
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// List tokens without including expired - should NOT see expired token.
|
||||
keys, err = adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, keys)
|
||||
|
||||
// List tokens WITH including expired - should see expired token.
|
||||
keys, err = adminClient.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{
|
||||
IncludeExpired: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, keyID, keys[0].ID)
|
||||
}
|
||||
|
||||
func TestTokenScoped(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -500,7 +500,7 @@ func (e *Executor) runOnce(t time.Time) Stats {
|
||||
"task": task.Name,
|
||||
"task_id": task.ID.String(),
|
||||
"workspace": ws.Name,
|
||||
"pause_reason": "idle timeout",
|
||||
"pause_reason": "inactivity exceeded the dormancy threshold",
|
||||
},
|
||||
"lifecycle_executor",
|
||||
ws.ID, ws.OwnerID, ws.OrganizationID,
|
||||
|
||||
@@ -2082,6 +2082,6 @@ func TestExecutorTaskWorkspace(t *testing.T) {
|
||||
require.Equal(t, task.Name, sent[0].Labels["task"])
|
||||
require.Equal(t, task.ID.String(), sent[0].Labels["task_id"])
|
||||
require.Equal(t, workspace.Name, sent[0].Labels["workspace"])
|
||||
require.Equal(t, "idle timeout", sent[0].Labels["pause_reason"])
|
||||
require.Equal(t, "inactivity exceeded the dormancy threshold", sent[0].Labels["pause_reason"])
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1526,6 +1526,10 @@ func New(options *Options) *API {
|
||||
})
|
||||
r.Get("/timings", api.workspaceTimings)
|
||||
r.Route("/acl", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.RequireExperiment(api.Experiments, codersdk.ExperimentWorkspaceSharing),
|
||||
)
|
||||
|
||||
r.Get("/", api.workspaceACL)
|
||||
r.Patch("/", api.patchWorkspaceACL)
|
||||
r.Delete("/", api.deleteWorkspaceACL)
|
||||
|
||||
@@ -668,31 +668,6 @@ var (
|
||||
}),
|
||||
Scope: rbac.ScopeAll,
|
||||
}.WithCachedASTValue()
|
||||
|
||||
subjectWorkspaceBuilder = rbac.Subject{
|
||||
Type: rbac.SubjectTypeWorkspaceBuilder,
|
||||
FriendlyName: "Workspace Builder",
|
||||
ID: uuid.Nil.String(),
|
||||
Roles: rbac.Roles([]rbac.Role{
|
||||
{
|
||||
Identifier: rbac.RoleIdentifier{Name: "workspace-builder"},
|
||||
DisplayName: "Workspace Builder",
|
||||
Site: rbac.Permissions(map[string][]policy.Action{
|
||||
// Reading provisioner daemons to check eligibility.
|
||||
rbac.ResourceProvisionerDaemon.Type: {policy.ActionRead},
|
||||
// Updating provisioner jobs (e.g. marking prebuild
|
||||
// jobs complete).
|
||||
rbac.ResourceProvisionerJobs.Type: {policy.ActionUpdate},
|
||||
// Reading provisioner state requires template update
|
||||
// permission.
|
||||
rbac.ResourceTemplate.Type: {policy.ActionUpdate},
|
||||
}),
|
||||
User: []rbac.Permission{},
|
||||
ByOrgID: map[string]rbac.OrgPermissions{},
|
||||
},
|
||||
}),
|
||||
Scope: rbac.ScopeAll,
|
||||
}.WithCachedASTValue()
|
||||
)
|
||||
|
||||
// AsProvisionerd returns a context with an actor that has permissions required
|
||||
@@ -799,14 +774,6 @@ func AsBoundaryUsageTracker(ctx context.Context) context.Context {
|
||||
return As(ctx, subjectBoundaryUsageTracker)
|
||||
}
|
||||
|
||||
// AsWorkspaceBuilder returns a context with an actor that has permissions
|
||||
// required for the workspace builder to prepare workspace builds. This
|
||||
// includes reading provisioner daemons, updating provisioner jobs, and
|
||||
// reading provisioner state (which requires template update permission).
|
||||
func AsWorkspaceBuilder(ctx context.Context) context.Context {
|
||||
return As(ctx, subjectWorkspaceBuilder)
|
||||
}
|
||||
|
||||
var AsRemoveActor = rbac.Subject{
|
||||
ID: "remove-actor",
|
||||
}
|
||||
@@ -2155,13 +2122,6 @@ func (q *querier) GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UUID)
|
||||
return fetch(q.log, q.auth, q.db.GetAIBridgeInterceptionByID)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetAIBridgeInterceptionLineageByToolCallID(ctx context.Context, toolCallID string) (database.GetAIBridgeInterceptionLineageByToolCallIDRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
|
||||
return database.GetAIBridgeInterceptionLineageByToolCallIDRow{}, err
|
||||
}
|
||||
return q.db.GetAIBridgeInterceptionLineageByToolCallID(ctx, toolCallID)
|
||||
}
|
||||
|
||||
func (q *querier) GetAIBridgeInterceptions(ctx context.Context) ([]database.AIBridgeInterception, error) {
|
||||
fetch := func(ctx context.Context, _ any) ([]database.AIBridgeInterception, error) {
|
||||
return q.db.GetAIBridgeInterceptions(ctx)
|
||||
@@ -2201,12 +2161,12 @@ func (q *querier) GetAPIKeyByName(ctx context.Context, arg database.GetAPIKeyByN
|
||||
return fetch(q.log, q.auth, q.db.GetAPIKeyByName)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetAPIKeysByLoginType(ctx context.Context, loginType database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
|
||||
func (q *querier) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByLoginType)(ctx, loginType)
|
||||
}
|
||||
|
||||
func (q *querier) GetAPIKeysByUserID(ctx context.Context, params database.GetAPIKeysByUserIDParams) ([]database.APIKey, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByUserID)(ctx, params)
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysByUserID)(ctx, database.GetAPIKeysByUserIDParams{LoginType: params.LoginType, UserID: params.UserID})
|
||||
}
|
||||
|
||||
func (q *querier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]database.APIKey, error) {
|
||||
@@ -2297,7 +2257,7 @@ func (q *querier) GetAuditLogsOffset(ctx context.Context, arg database.GetAuditL
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx context.Context, authToken uuid.UUID) (database.GetAuthenticatedWorkspaceAgentAndBuildByAuthTokenRow, error) {
|
||||
// This is a system function.
|
||||
// This is a system function
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return database.GetAuthenticatedWorkspaceAgentAndBuildByAuthTokenRow{}, err
|
||||
}
|
||||
@@ -3173,13 +3133,6 @@ func (q *querier) GetTelemetryItems(ctx context.Context) ([]database.TelemetryIt
|
||||
return q.db.GetTelemetryItems(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetTelemetryTaskEvents(ctx context.Context, arg database.GetTelemetryTaskEventsParams) ([]database.GetTelemetryTaskEventsRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTask.All()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetTelemetryTaskEvents(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) {
|
||||
if err := q.authorizeTemplateInsights(ctx, arg.TemplateIDs); err != nil {
|
||||
return nil, err
|
||||
@@ -3961,11 +3914,6 @@ func (q *querier) GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, wor
|
||||
return q.db.GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx, workspaceBuildIDs, prep)
|
||||
}
|
||||
|
||||
func (q *querier) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, buildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
|
||||
// Fetching the provisioner state requires Update permission on the template.
|
||||
return fetchWithAction(q.log, q.auth, policy.ActionUpdate, q.db.GetWorkspaceBuildProvisionerStateByID)(ctx, buildID)
|
||||
}
|
||||
|
||||
func (q *querier) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
@@ -4798,14 +4746,6 @@ func (q *querier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Contex
|
||||
return q.db.ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams) ([]string, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
}
|
||||
return q.db.ListAuthorizedAIBridgeModels(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
// This function is a system function until we implement a join for aibridge interceptions.
|
||||
// Matches the behavior of the workspaces listing endpoint.
|
||||
@@ -6367,10 +6307,3 @@ func (q *querier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg
|
||||
// database.Store interface, so dbauthz needs to implement it.
|
||||
return q.CountAIBridgeInterceptions(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams, _ rbac.PreparedAuthorized) ([]string, error) {
|
||||
// TODO: Delete this function, all ListAIBridgeModels should be authorized. For now just call ListAIBridgeModels on the authz querier.
|
||||
// This cannot be deleted for now because it's included in the
|
||||
// database.Store interface, so dbauthz needs to implement it.
|
||||
return q.ListAIBridgeModels(ctx, arg)
|
||||
}
|
||||
|
||||
@@ -237,8 +237,8 @@ func (s *MethodTestSuite) TestAPIKey() {
|
||||
s.Run("GetAPIKeysByLoginType", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
a := testutil.Fake(s.T(), faker, database.APIKey{LoginType: database.LoginTypePassword})
|
||||
b := testutil.Fake(s.T(), faker, database.APIKey{LoginType: database.LoginTypePassword})
|
||||
dbm.EXPECT().GetAPIKeysByLoginType(gomock.Any(), database.GetAPIKeysByLoginTypeParams{LoginType: database.LoginTypePassword}).Return([]database.APIKey{a, b}, nil).AnyTimes()
|
||||
check.Args(database.GetAPIKeysByLoginTypeParams{LoginType: database.LoginTypePassword}).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
|
||||
dbm.EXPECT().GetAPIKeysByLoginType(gomock.Any(), database.LoginTypePassword).Return([]database.APIKey{a, b}, nil).AnyTimes()
|
||||
check.Args(database.LoginTypePassword).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
|
||||
}))
|
||||
s.Run("GetAPIKeysByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u1 := testutil.Fake(s.T(), faker, database.User{})
|
||||
@@ -1326,11 +1326,6 @@ func (s *MethodTestSuite) TestTemplate() {
|
||||
dbm.EXPECT().GetTemplateInsightsByTemplate(gomock.Any(), arg).Return([]database.GetTemplateInsightsByTemplateRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceTemplate, policy.ActionViewInsights)
|
||||
}))
|
||||
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetTelemetryTaskEventsParams{}
|
||||
dbm.EXPECT().GetTelemetryTaskEvents(gomock.Any(), arg).Return([]database.GetTelemetryTaskEventsRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceTask.All(), policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetTemplateAppInsights", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetTemplateAppInsightsParams{}
|
||||
dbm.EXPECT().GetTemplateAppInsights(gomock.Any(), arg).Return([]database.GetTemplateAppInsightsRow{}, nil).AnyTimes()
|
||||
@@ -1974,15 +1969,6 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
|
||||
check.Args(build.ID).Asserts(ws, policy.ActionRead).Returns(build)
|
||||
}))
|
||||
s.Run("GetWorkspaceBuildProvisionerStateByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
row := database.GetWorkspaceBuildProvisionerStateByIDRow{
|
||||
ProvisionerState: []byte("state"),
|
||||
TemplateID: uuid.New(),
|
||||
TemplateOrganizationID: uuid.New(),
|
||||
}
|
||||
dbm.EXPECT().GetWorkspaceBuildProvisionerStateByID(gomock.Any(), gomock.Any()).Return(row, nil).AnyTimes()
|
||||
check.Args(uuid.New()).Asserts(row, policy.ActionUpdate).Returns(row)
|
||||
}))
|
||||
s.Run("GetWorkspaceBuildByJobID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ws := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
build := testutil.Fake(s.T(), faker, database.WorkspaceBuild{WorkspaceID: ws.ID})
|
||||
@@ -4695,16 +4681,6 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
check.Args(intID).Asserts(intc, policy.ActionRead).Returns(intc)
|
||||
}))
|
||||
|
||||
s.Run("GetAIBridgeInterceptionLineageByToolCallID", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
toolCallID := "call_123"
|
||||
row := database.GetAIBridgeInterceptionLineageByToolCallIDRow{
|
||||
ThreadParentID: uuid.UUID{1},
|
||||
ThreadRootID: uuid.UUID{2},
|
||||
}
|
||||
db.EXPECT().GetAIBridgeInterceptionLineageByToolCallID(gomock.Any(), toolCallID).Return(row, nil).AnyTimes()
|
||||
check.Args(toolCallID).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead).Returns(row)
|
||||
}))
|
||||
|
||||
s.Run("GetAIBridgeInterceptions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
a := testutil.Fake(s.T(), faker, database.AIBridgeInterception{})
|
||||
b := testutil.Fake(s.T(), faker, database.AIBridgeInterception{})
|
||||
@@ -4770,20 +4746,6 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeModels", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.ListAIBridgeModelsParams{}
|
||||
db.EXPECT().ListAuthorizedAIBridgeModels(gomock.Any(), params, gomock.Any()).Return([]string{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAuthorizedAIBridgeModels", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.ListAIBridgeModelsParams{}
|
||||
db.EXPECT().ListAuthorizedAIBridgeModels(gomock.Any(), params, gomock.Any()).Return([]string{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeTokenUsagesByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ids := []uuid.UUID{{1}}
|
||||
db.EXPECT().ListAIBridgeTokenUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeTokenUsage{}, nil).AnyTimes()
|
||||
|
||||
@@ -67,8 +67,6 @@ type WorkspaceBuildBuilder struct {
|
||||
|
||||
jobError string // Error message for failed jobs
|
||||
jobErrorCode string // Error code for failed jobs
|
||||
|
||||
provisionerState []byte
|
||||
}
|
||||
|
||||
// BuilderOption is a functional option for customizing job timestamps
|
||||
@@ -140,15 +138,6 @@ func (b WorkspaceBuildBuilder) Seed(seed database.WorkspaceBuild) WorkspaceBuild
|
||||
return b
|
||||
}
|
||||
|
||||
// ProvisionerState sets the provisioner state for the workspace build.
|
||||
// This is stored separately from the seed because ProvisionerState is
|
||||
// not part of the WorkspaceBuild view struct.
|
||||
func (b WorkspaceBuildBuilder) ProvisionerState(state []byte) WorkspaceBuildBuilder {
|
||||
//nolint: revive // returns modified struct
|
||||
b.provisionerState = state
|
||||
return b
|
||||
}
|
||||
|
||||
func (b WorkspaceBuildBuilder) Resource(resource ...*sdkproto.Resource) WorkspaceBuildBuilder {
|
||||
//nolint: revive // returns modified struct
|
||||
b.resources = append(b.resources, resource...)
|
||||
@@ -475,14 +464,6 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
|
||||
}
|
||||
|
||||
resp.Build = dbgen.WorkspaceBuild(b.t, b.db, b.seed)
|
||||
if len(b.provisionerState) > 0 {
|
||||
err = b.db.UpdateWorkspaceBuildProvisionerStateByID(ownerCtx, database.UpdateWorkspaceBuildProvisionerStateByIDParams{
|
||||
ID: resp.Build.ID,
|
||||
UpdatedAt: dbtime.Now(),
|
||||
ProvisionerState: b.provisionerState,
|
||||
})
|
||||
require.NoError(b.t, err, "update provisioner state")
|
||||
}
|
||||
b.logger.Debug(context.Background(), "created workspace build",
|
||||
slog.F("build_id", resp.Build.ID),
|
||||
slog.F("workspace_id", resp.Workspace.ID),
|
||||
|
||||
@@ -504,7 +504,7 @@ func WorkspaceBuild(t testing.TB, db database.Store, orig database.WorkspaceBuil
|
||||
Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart),
|
||||
InitiatorID: takeFirst(orig.InitiatorID, uuid.New()),
|
||||
JobID: jobID,
|
||||
ProvisionerState: []byte{},
|
||||
ProvisionerState: takeFirstSlice(orig.ProvisionerState, []byte{}),
|
||||
Deadline: takeFirst(orig.Deadline, dbtime.Now().Add(time.Hour)),
|
||||
MaxDeadline: takeFirst(orig.MaxDeadline, time.Time{}),
|
||||
Reason: takeFirst(orig.Reason, database.BuildReasonInitiator),
|
||||
@@ -1373,8 +1373,6 @@ func OAuth2ProviderAppCode(t testing.TB, db database.Store, seed database.OAuth2
|
||||
ResourceUri: seed.ResourceUri,
|
||||
CodeChallenge: seed.CodeChallenge,
|
||||
CodeChallengeMethod: seed.CodeChallengeMethod,
|
||||
StateHash: seed.StateHash,
|
||||
RedirectUri: seed.RedirectUri,
|
||||
})
|
||||
require.NoError(t, err, "insert oauth2 app code")
|
||||
return code
|
||||
@@ -1585,16 +1583,14 @@ func ClaimPrebuild(
|
||||
|
||||
func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertAIBridgeInterceptionParams, endedAt *time.Time) database.AIBridgeInterception {
|
||||
interception, err := db.InsertAIBridgeInterception(genCtx, database.InsertAIBridgeInterceptionParams{
|
||||
ID: takeFirst(seed.ID, uuid.New()),
|
||||
APIKeyID: seed.APIKeyID,
|
||||
InitiatorID: takeFirst(seed.InitiatorID, uuid.New()),
|
||||
Provider: takeFirst(seed.Provider, "provider"),
|
||||
Model: takeFirst(seed.Model, "model"),
|
||||
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
|
||||
StartedAt: takeFirst(seed.StartedAt, dbtime.Now()),
|
||||
Client: seed.Client,
|
||||
ThreadParentInterceptionID: seed.ThreadParentInterceptionID,
|
||||
ThreadRootInterceptionID: seed.ThreadRootInterceptionID,
|
||||
ID: takeFirst(seed.ID, uuid.New()),
|
||||
APIKeyID: seed.APIKeyID,
|
||||
InitiatorID: takeFirst(seed.InitiatorID, uuid.New()),
|
||||
Provider: takeFirst(seed.Provider, "provider"),
|
||||
Model: takeFirst(seed.Model, "model"),
|
||||
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
|
||||
StartedAt: takeFirst(seed.StartedAt, dbtime.Now()),
|
||||
Client: seed.Client,
|
||||
})
|
||||
if endedAt != nil {
|
||||
interception, err = db.UpdateAIBridgeInterceptionEnded(genCtx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
@@ -1647,7 +1643,6 @@ func AIBridgeToolUsage(t testing.TB, db database.Store, seed database.InsertAIBr
|
||||
ID: takeFirst(seed.ID, uuid.New()),
|
||||
InterceptionID: takeFirst(seed.InterceptionID, uuid.New()),
|
||||
ProviderResponseID: takeFirst(seed.ProviderResponseID, "provider_response_id"),
|
||||
ProviderToolCallID: sql.NullString{String: takeFirst(seed.ProviderResponseID, testutil.GetRandomName(t)), Valid: true},
|
||||
Tool: takeFirst(seed.Tool, "tool"),
|
||||
ServerUrl: serverURL,
|
||||
Input: takeFirst(seed.Input, "input"),
|
||||
|
||||
@@ -726,14 +726,6 @@ func (m queryMetricsStore) GetAIBridgeInterceptionByID(ctx context.Context, id u
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAIBridgeInterceptionLineageByToolCallID(ctx context.Context, toolCallID string) (database.GetAIBridgeInterceptionLineageByToolCallIDRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAIBridgeInterceptionLineageByToolCallID(ctx, toolCallID)
|
||||
m.queryLatencies.WithLabelValues("GetAIBridgeInterceptionLineageByToolCallID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAIBridgeInterceptionLineageByToolCallID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAIBridgeInterceptions(ctx context.Context) ([]database.AIBridgeInterception, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAIBridgeInterceptions(ctx)
|
||||
@@ -782,7 +774,7 @@ func (m queryMetricsStore) GetAPIKeyByName(ctx context.Context, arg database.Get
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
|
||||
func (m queryMetricsStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAPIKeysByLoginType(ctx, loginType)
|
||||
m.queryLatencies.WithLabelValues("GetAPIKeysByLoginType").Observe(time.Since(start).Seconds())
|
||||
@@ -1798,14 +1790,6 @@ func (m queryMetricsStore) GetTelemetryItems(ctx context.Context) ([]database.Te
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetTelemetryTaskEvents(ctx context.Context, createdAfter database.GetTelemetryTaskEventsParams) ([]database.GetTelemetryTaskEventsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetTelemetryTaskEvents(ctx, createdAfter)
|
||||
m.queryLatencies.WithLabelValues("GetTelemetryTaskEvents").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTelemetryTaskEvents").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetTemplateAppInsights(ctx, arg)
|
||||
@@ -2446,14 +2430,6 @@ func (m queryMetricsStore) GetWorkspaceBuildParametersByBuildIDs(ctx context.Con
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuildID)
|
||||
m.queryLatencies.WithLabelValues("GetWorkspaceBuildProvisionerStateByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetWorkspaceBuildProvisionerStateByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetWorkspaceBuildStatsByTemplates(ctx, since)
|
||||
@@ -3222,14 +3198,6 @@ func (m queryMetricsStore) ListAIBridgeInterceptionsTelemetrySummaries(ctx conte
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams) ([]string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeModels(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("ListAIBridgeModels").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAIBridgeModels").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIds)
|
||||
@@ -4444,11 +4412,3 @@ func (m queryMetricsStore) CountAuthorizedAIBridgeInterceptions(ctx context.Cont
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountAuthorizedAIBridgeInterceptions").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAuthorizedAIBridgeModels(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("ListAuthorizedAIBridgeModels").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeModels").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -1214,21 +1214,6 @@ func (mr *MockStoreMockRecorder) GetAIBridgeInterceptionByID(ctx, id any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIBridgeInterceptionByID", reflect.TypeOf((*MockStore)(nil).GetAIBridgeInterceptionByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetAIBridgeInterceptionLineageByToolCallID mocks base method.
|
||||
func (m *MockStore) GetAIBridgeInterceptionLineageByToolCallID(ctx context.Context, toolCallID string) (database.GetAIBridgeInterceptionLineageByToolCallIDRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAIBridgeInterceptionLineageByToolCallID", ctx, toolCallID)
|
||||
ret0, _ := ret[0].(database.GetAIBridgeInterceptionLineageByToolCallIDRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAIBridgeInterceptionLineageByToolCallID indicates an expected call of GetAIBridgeInterceptionLineageByToolCallID.
|
||||
func (mr *MockStoreMockRecorder) GetAIBridgeInterceptionLineageByToolCallID(ctx, toolCallID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIBridgeInterceptionLineageByToolCallID", reflect.TypeOf((*MockStore)(nil).GetAIBridgeInterceptionLineageByToolCallID), ctx, toolCallID)
|
||||
}
|
||||
|
||||
// GetAIBridgeInterceptions mocks base method.
|
||||
func (m *MockStore) GetAIBridgeInterceptions(ctx context.Context) ([]database.AIBridgeInterception, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1320,18 +1305,18 @@ func (mr *MockStoreMockRecorder) GetAPIKeyByName(ctx, arg any) *gomock.Call {
|
||||
}
|
||||
|
||||
// GetAPIKeysByLoginType mocks base method.
|
||||
func (m *MockStore) GetAPIKeysByLoginType(ctx context.Context, arg database.GetAPIKeysByLoginTypeParams) ([]database.APIKey, error) {
|
||||
func (m *MockStore) GetAPIKeysByLoginType(ctx context.Context, loginType database.LoginType) ([]database.APIKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAPIKeysByLoginType", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "GetAPIKeysByLoginType", ctx, loginType)
|
||||
ret0, _ := ret[0].([]database.APIKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAPIKeysByLoginType indicates an expected call of GetAPIKeysByLoginType.
|
||||
func (mr *MockStoreMockRecorder) GetAPIKeysByLoginType(ctx, arg any) *gomock.Call {
|
||||
func (mr *MockStoreMockRecorder) GetAPIKeysByLoginType(ctx, loginType any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeysByLoginType", reflect.TypeOf((*MockStore)(nil).GetAPIKeysByLoginType), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeysByLoginType", reflect.TypeOf((*MockStore)(nil).GetAPIKeysByLoginType), ctx, loginType)
|
||||
}
|
||||
|
||||
// GetAPIKeysByUserID mocks base method.
|
||||
@@ -3329,21 +3314,6 @@ func (mr *MockStoreMockRecorder) GetTelemetryItems(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTelemetryItems", reflect.TypeOf((*MockStore)(nil).GetTelemetryItems), ctx)
|
||||
}
|
||||
|
||||
// GetTelemetryTaskEvents mocks base method.
|
||||
func (m *MockStore) GetTelemetryTaskEvents(ctx context.Context, arg database.GetTelemetryTaskEventsParams) ([]database.GetTelemetryTaskEventsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTelemetryTaskEvents", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetTelemetryTaskEventsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetTelemetryTaskEvents indicates an expected call of GetTelemetryTaskEvents.
|
||||
func (mr *MockStoreMockRecorder) GetTelemetryTaskEvents(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTelemetryTaskEvents", reflect.TypeOf((*MockStore)(nil).GetTelemetryTaskEvents), ctx, arg)
|
||||
}
|
||||
|
||||
// GetTemplateAppInsights mocks base method.
|
||||
func (m *MockStore) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4574,21 +4544,6 @@ func (mr *MockStoreMockRecorder) GetWorkspaceBuildParametersByBuildIDs(ctx, work
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceBuildParametersByBuildIDs", reflect.TypeOf((*MockStore)(nil).GetWorkspaceBuildParametersByBuildIDs), ctx, workspaceBuildIds)
|
||||
}
|
||||
|
||||
// GetWorkspaceBuildProvisionerStateByID mocks base method.
|
||||
func (m *MockStore) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetWorkspaceBuildProvisionerStateByID", ctx, workspaceBuildID)
|
||||
ret0, _ := ret[0].(database.GetWorkspaceBuildProvisionerStateByIDRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetWorkspaceBuildProvisionerStateByID indicates an expected call of GetWorkspaceBuildProvisionerStateByID.
|
||||
func (mr *MockStoreMockRecorder) GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuildID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceBuildProvisionerStateByID", reflect.TypeOf((*MockStore)(nil).GetWorkspaceBuildProvisionerStateByID), ctx, workspaceBuildID)
|
||||
}
|
||||
|
||||
// GetWorkspaceBuildStatsByTemplates mocks base method.
|
||||
func (m *MockStore) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6027,21 +5982,6 @@ func (mr *MockStoreMockRecorder) ListAIBridgeInterceptionsTelemetrySummaries(ctx
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeInterceptionsTelemetrySummaries", reflect.TypeOf((*MockStore)(nil).ListAIBridgeInterceptionsTelemetrySummaries), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeModels mocks base method.
|
||||
func (m *MockStore) ListAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAIBridgeModels", ctx, arg)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAIBridgeModels indicates an expected call of ListAIBridgeModels.
|
||||
func (mr *MockStoreMockRecorder) ListAIBridgeModels(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAIBridgeModels), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeTokenUsagesByInterceptionIDs mocks base method.
|
||||
func (m *MockStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6102,21 +6042,6 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeInterceptions(ctx, arg, p
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeInterceptions), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// ListAuthorizedAIBridgeModels mocks base method.
|
||||
func (m *MockStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAuthorizedAIBridgeModels", ctx, arg, prepared)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAuthorizedAIBridgeModels indicates an expected call of ListAuthorizedAIBridgeModels.
|
||||
func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// ListProvisionerKeysByOrganization mocks base method.
|
||||
func (m *MockStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+4
-22
@@ -1024,19 +1024,13 @@ CREATE TABLE aibridge_interceptions (
|
||||
metadata jsonb,
|
||||
ended_at timestamp with time zone,
|
||||
api_key_id text,
|
||||
client character varying(64) DEFAULT 'Unknown'::character varying,
|
||||
thread_parent_id uuid,
|
||||
thread_root_id uuid
|
||||
client character varying(64) DEFAULT 'Unknown'::character varying
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.initiator_id IS 'Relates to a users record, but FK is elided for performance.';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.thread_parent_id IS 'The interception which directly caused this interception to occur, usually through an agentic loop or threaded conversation.';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interception of the thread that this interception belongs to.';
|
||||
|
||||
CREATE TABLE aibridge_token_usages (
|
||||
id uuid NOT NULL,
|
||||
interception_id uuid NOT NULL,
|
||||
@@ -1061,8 +1055,7 @@ CREATE TABLE aibridge_tool_usages (
|
||||
injected boolean DEFAULT false NOT NULL,
|
||||
invocation_error text,
|
||||
metadata jsonb,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
provider_tool_call_id text
|
||||
created_at timestamp with time zone NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_tool_usages IS 'Audit log of tool calls in intercepted requests in AI Bridge';
|
||||
@@ -1478,9 +1471,7 @@ CREATE TABLE oauth2_provider_app_codes (
|
||||
app_id uuid NOT NULL,
|
||||
resource_uri text,
|
||||
code_challenge text,
|
||||
code_challenge_method text,
|
||||
state_hash text,
|
||||
redirect_uri text
|
||||
code_challenge_method text
|
||||
);
|
||||
|
||||
COMMENT ON TABLE oauth2_provider_app_codes IS 'Codes are meant to be exchanged for access tokens.';
|
||||
@@ -1491,10 +1482,6 @@ COMMENT ON COLUMN oauth2_provider_app_codes.code_challenge IS 'PKCE code challen
|
||||
|
||||
COMMENT ON COLUMN oauth2_provider_app_codes.code_challenge_method IS 'PKCE challenge method (S256)';
|
||||
|
||||
COMMENT ON COLUMN oauth2_provider_app_codes.state_hash IS 'SHA-256 hash of the OAuth2 state parameter, stored to prevent state reflection attacks.';
|
||||
|
||||
COMMENT ON COLUMN oauth2_provider_app_codes.redirect_uri IS 'The redirect_uri provided during authorization, to be verified during token exchange (RFC 6749 §4.1.3).';
|
||||
|
||||
CREATE TABLE oauth2_provider_app_secrets (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
@@ -2715,6 +2702,7 @@ CREATE VIEW workspace_build_with_user AS
|
||||
workspace_builds.build_number,
|
||||
workspace_builds.transition,
|
||||
workspace_builds.initiator_id,
|
||||
workspace_builds.provisioner_state,
|
||||
workspace_builds.job_id,
|
||||
workspace_builds.deadline,
|
||||
workspace_builds.reason,
|
||||
@@ -3297,18 +3285,12 @@ CREATE INDEX idx_aibridge_interceptions_provider ON aibridge_interceptions USING
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_started_id_desc ON aibridge_interceptions USING btree (started_at DESC, id DESC);
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_thread_parent_id ON aibridge_interceptions USING btree (thread_parent_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_thread_root_id ON aibridge_interceptions USING btree (thread_root_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_token_usages_interception_id ON aibridge_token_usages USING btree (interception_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_token_usages_provider_response_id ON aibridge_token_usages USING btree (provider_response_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_tool_usages_interception_id ON aibridge_tool_usages USING btree (interception_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_tool_usages_provider_tool_call_id ON aibridge_tool_usages USING btree (provider_tool_call_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_tool_usagesprovider_response_id ON aibridge_tool_usages USING btree (provider_response_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_user_prompts_interception_id ON aibridge_user_prompts USING btree (interception_id);
|
||||
|
||||
@@ -51,34 +51,15 @@ func TestViewSubsetTemplateVersion(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestViewSubsetWorkspaceBuild ensures WorkspaceBuildTable is a subset of
|
||||
// WorkspaceBuild, with the exception of ProvisionerState which is
|
||||
// intentionally excluded from the workspace_build_with_user view to avoid
|
||||
// loading the large Terraform state blob on hot paths.
|
||||
// TestViewSubsetWorkspaceBuild ensures WorkspaceBuildTable is a subset of WorkspaceBuild
|
||||
func TestViewSubsetWorkspaceBuild(t *testing.T) {
|
||||
t.Parallel()
|
||||
table := reflect.TypeOf(database.WorkspaceBuildTable{})
|
||||
joined := reflect.TypeOf(database.WorkspaceBuild{})
|
||||
|
||||
tableFields := fieldNames(allFields(table))
|
||||
joinedFields := fieldNames(allFields(joined))
|
||||
|
||||
// ProvisionerState is intentionally excluded from the
|
||||
// workspace_build_with_user view to avoid loading multi-MB Terraform
|
||||
// state blobs on hot paths. Callers that need it use
|
||||
// GetWorkspaceBuildProvisionerStateByID instead.
|
||||
excludedFields := map[string]bool{
|
||||
"ProvisionerState": true,
|
||||
}
|
||||
|
||||
var filtered []string
|
||||
for _, name := range tableFields {
|
||||
if !excludedFields[name] {
|
||||
filtered = append(filtered, name)
|
||||
}
|
||||
}
|
||||
|
||||
if !assert.Subset(t, joinedFields, filtered, "table is not subset") {
|
||||
tableFields := allFields(table)
|
||||
joinedFields := allFields(joined)
|
||||
if !assert.Subset(t, fieldNames(joinedFields), fieldNames(tableFields), "table is not subset") {
|
||||
t.Log("Some fields were added to the WorkspaceBuild Table without updating the 'workspace_build_with_user' view.")
|
||||
t.Log("See migration 000141_join_users_build_version.up.sql to create the view.")
|
||||
}
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
ALTER TABLE oauth2_provider_app_codes
|
||||
DROP COLUMN state_hash,
|
||||
DROP COLUMN redirect_uri;
|
||||
@@ -1,9 +0,0 @@
|
||||
ALTER TABLE oauth2_provider_app_codes
|
||||
ADD COLUMN state_hash text,
|
||||
ADD COLUMN redirect_uri text;
|
||||
|
||||
COMMENT ON COLUMN oauth2_provider_app_codes.state_hash IS
|
||||
'SHA-256 hash of the OAuth2 state parameter, stored to prevent state reflection attacks.';
|
||||
|
||||
COMMENT ON COLUMN oauth2_provider_app_codes.redirect_uri IS
|
||||
'The redirect_uri provided during authorization, to be verified during token exchange (RFC 6749 §4.1.3).';
|
||||
-31
@@ -1,31 +0,0 @@
|
||||
-- Restore provisioner_state to workspace_build_with_user view.
|
||||
DROP VIEW workspace_build_with_user;
|
||||
|
||||
CREATE VIEW workspace_build_with_user AS
|
||||
SELECT
|
||||
workspace_builds.id,
|
||||
workspace_builds.created_at,
|
||||
workspace_builds.updated_at,
|
||||
workspace_builds.workspace_id,
|
||||
workspace_builds.template_version_id,
|
||||
workspace_builds.build_number,
|
||||
workspace_builds.transition,
|
||||
workspace_builds.initiator_id,
|
||||
workspace_builds.provisioner_state,
|
||||
workspace_builds.job_id,
|
||||
workspace_builds.deadline,
|
||||
workspace_builds.reason,
|
||||
workspace_builds.daily_cost,
|
||||
workspace_builds.max_deadline,
|
||||
workspace_builds.template_version_preset_id,
|
||||
workspace_builds.has_ai_task,
|
||||
workspace_builds.has_external_agent,
|
||||
COALESCE(visible_users.avatar_url, ''::text) AS initiator_by_avatar_url,
|
||||
COALESCE(visible_users.username, ''::text) AS initiator_by_username,
|
||||
COALESCE(visible_users.name, ''::text) AS initiator_by_name
|
||||
FROM
|
||||
workspace_builds
|
||||
LEFT JOIN
|
||||
visible_users ON workspace_builds.initiator_id = visible_users.id;
|
||||
|
||||
COMMENT ON VIEW workspace_build_with_user IS 'Joins in the username + avatar url of the initiated by user.';
|
||||
@@ -1,33 +0,0 @@
|
||||
-- Drop and recreate workspace_build_with_user to exclude provisioner_state.
|
||||
-- This avoids loading the large Terraform state blob (1-5 MB per workspace)
|
||||
-- on every query that uses this view. The callers that need provisioner_state
|
||||
-- now fetch it separately via GetWorkspaceBuildProvisionerStateByID.
|
||||
DROP VIEW workspace_build_with_user;
|
||||
|
||||
CREATE VIEW workspace_build_with_user AS
|
||||
SELECT
|
||||
workspace_builds.id,
|
||||
workspace_builds.created_at,
|
||||
workspace_builds.updated_at,
|
||||
workspace_builds.workspace_id,
|
||||
workspace_builds.template_version_id,
|
||||
workspace_builds.build_number,
|
||||
workspace_builds.transition,
|
||||
workspace_builds.initiator_id,
|
||||
workspace_builds.job_id,
|
||||
workspace_builds.deadline,
|
||||
workspace_builds.reason,
|
||||
workspace_builds.daily_cost,
|
||||
workspace_builds.max_deadline,
|
||||
workspace_builds.template_version_preset_id,
|
||||
workspace_builds.has_ai_task,
|
||||
workspace_builds.has_external_agent,
|
||||
COALESCE(visible_users.avatar_url, ''::text) AS initiator_by_avatar_url,
|
||||
COALESCE(visible_users.username, ''::text) AS initiator_by_username,
|
||||
COALESCE(visible_users.name, ''::text) AS initiator_by_name
|
||||
FROM
|
||||
workspace_builds
|
||||
LEFT JOIN
|
||||
visible_users ON workspace_builds.initiator_id = visible_users.id;
|
||||
|
||||
COMMENT ON VIEW workspace_build_with_user IS 'Joins in the username + avatar url of the initiated by user.';
|
||||
@@ -1,9 +0,0 @@
|
||||
DROP INDEX IF EXISTS idx_aibridge_tool_usages_provider_tool_call_id;
|
||||
|
||||
ALTER TABLE aibridge_tool_usages
|
||||
DROP COLUMN provider_tool_call_id;
|
||||
|
||||
DROP INDEX IF EXISTS idx_aibridge_interceptions_thread_parent_id;
|
||||
|
||||
ALTER TABLE aibridge_interceptions
|
||||
DROP COLUMN thread_parent_id;
|
||||
@@ -1,11 +0,0 @@
|
||||
ALTER TABLE aibridge_tool_usages
|
||||
ADD COLUMN provider_tool_call_id text NULL; -- nullable to allow existing data to be correct
|
||||
|
||||
CREATE INDEX idx_aibridge_tool_usages_provider_tool_call_id ON aibridge_tool_usages (provider_tool_call_id);
|
||||
|
||||
ALTER TABLE aibridge_interceptions
|
||||
ADD COLUMN thread_parent_id UUID NULL;
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.thread_parent_id IS 'The interception which directly caused this interception to occur, usually through an agentic loop or threaded conversation.';
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_thread_parent_id ON aibridge_interceptions (thread_parent_id);
|
||||
@@ -1,4 +0,0 @@
|
||||
DROP INDEX IF EXISTS idx_aibridge_interceptions_thread_root_id;
|
||||
|
||||
ALTER TABLE aibridge_interceptions
|
||||
DROP COLUMN thread_root_id;
|
||||
@@ -1,6 +0,0 @@
|
||||
ALTER TABLE aibridge_interceptions
|
||||
ADD COLUMN thread_root_id UUID NULL;
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interception of the thread that this interception belongs to.';
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_thread_root_id ON aibridge_interceptions (thread_root_id);
|
||||
@@ -316,14 +316,6 @@ func (t GetFileTemplatesRow) RBACObject() rbac.Object {
|
||||
WithGroupACL(t.GroupACL)
|
||||
}
|
||||
|
||||
// RBACObject for a workspace build's provisioner state requires Update access of the template.
|
||||
func (t GetWorkspaceBuildProvisionerStateByIDRow) RBACObject() rbac.Object {
|
||||
return rbac.ResourceTemplate.WithID(t.TemplateID).
|
||||
InOrg(t.TemplateOrganizationID).
|
||||
WithACLUserList(t.UserACL).
|
||||
WithGroupACL(t.GroupACL)
|
||||
}
|
||||
|
||||
func (t Template) DeepCopy() Template {
|
||||
cpy := t
|
||||
cpy.UserACL = maps.Clone(t.UserACL)
|
||||
|
||||
@@ -769,7 +769,6 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
|
||||
type aibridgeQuerier interface {
|
||||
ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error)
|
||||
CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error)
|
||||
ListAuthorizedAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error)
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error) {
|
||||
@@ -813,8 +812,6 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
|
||||
&i.AIBridgeInterception.EndedAt,
|
||||
&i.AIBridgeInterception.APIKeyID,
|
||||
&i.AIBridgeInterception.Client,
|
||||
&i.AIBridgeInterception.ThreadParentID,
|
||||
&i.AIBridgeInterception.ThreadRootID,
|
||||
&i.VisibleUser.ID,
|
||||
&i.VisibleUser.Username,
|
||||
&i.VisibleUser.Name,
|
||||
@@ -873,35 +870,6 @@ func (q *sqlQuerier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, a
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) ListAuthorizedAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
|
||||
VariableConverter: regosql.AIBridgeInterceptionConverter(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
filtered, err := insertAuthorizedFilter(listAIBridgeModels, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("-- name: ListAIBridgeModels :many\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query, arg.Model, arg.Offset, arg.Limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []string
|
||||
for rows.Next() {
|
||||
var model string
|
||||
if err := rows.Scan(&model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, model)
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
|
||||
if !strings.Contains(query, authorizedQueryPlaceholder) {
|
||||
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")
|
||||
|
||||
@@ -3643,10 +3643,6 @@ type AIBridgeInterception struct {
|
||||
EndedAt sql.NullTime `db:"ended_at" json:"ended_at"`
|
||||
APIKeyID sql.NullString `db:"api_key_id" json:"api_key_id"`
|
||||
Client sql.NullString `db:"client" json:"client"`
|
||||
// The interception which directly caused this interception to occur, usually through an agentic loop or threaded conversation.
|
||||
ThreadParentID uuid.NullUUID `db:"thread_parent_id" json:"thread_parent_id"`
|
||||
// The root interception of the thread that this interception belongs to.
|
||||
ThreadRootID uuid.NullUUID `db:"thread_root_id" json:"thread_root_id"`
|
||||
}
|
||||
|
||||
// Audit log of tokens used by intercepted requests in AI Bridge
|
||||
@@ -3674,10 +3670,9 @@ type AIBridgeToolUsage struct {
|
||||
// Whether this tool was injected; i.e. Bridge injected these tools into the request from an MCP server. If false it means a tool was defined by the client and already existed in the request (MCP or built-in).
|
||||
Injected bool `db:"injected" json:"injected"`
|
||||
// Only injected tools are invoked.
|
||||
InvocationError sql.NullString `db:"invocation_error" json:"invocation_error"`
|
||||
Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
ProviderToolCallID sql.NullString `db:"provider_tool_call_id" json:"provider_tool_call_id"`
|
||||
InvocationError sql.NullString `db:"invocation_error" json:"invocation_error"`
|
||||
Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
// Audit log of prompts used by intercepted requests in AI Bridge
|
||||
@@ -4031,10 +4026,6 @@ type OAuth2ProviderAppCode struct {
|
||||
CodeChallenge sql.NullString `db:"code_challenge" json:"code_challenge"`
|
||||
// PKCE challenge method (S256)
|
||||
CodeChallengeMethod sql.NullString `db:"code_challenge_method" json:"code_challenge_method"`
|
||||
// SHA-256 hash of the OAuth2 state parameter, stored to prevent state reflection attacks.
|
||||
StateHash sql.NullString `db:"state_hash" json:"state_hash"`
|
||||
// The redirect_uri provided during authorization, to be verified during token exchange (RFC 6749 §4.1.3).
|
||||
RedirectUri sql.NullString `db:"redirect_uri" json:"redirect_uri"`
|
||||
}
|
||||
|
||||
type OAuth2ProviderAppSecret struct {
|
||||
@@ -4992,6 +4983,7 @@ type WorkspaceBuild struct {
|
||||
BuildNumber int32 `db:"build_number" json:"build_number"`
|
||||
Transition WorkspaceTransition `db:"transition" json:"transition"`
|
||||
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
|
||||
ProvisionerState []byte `db:"provisioner_state" json:"provisioner_state"`
|
||||
JobID uuid.UUID `db:"job_id" json:"job_id"`
|
||||
Deadline time.Time `db:"deadline" json:"deadline"`
|
||||
Reason BuildReason `db:"reason" json:"reason"`
|
||||
|
||||
@@ -162,11 +162,6 @@ type sqlcQuerier interface {
|
||||
// and returns the preset with the most parameters (largest subset).
|
||||
FindMatchingPresetID(ctx context.Context, arg FindMatchingPresetIDParams) (uuid.UUID, error)
|
||||
GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UUID) (AIBridgeInterception, error)
|
||||
// Look up the parent interception and the root of the thread by finding
|
||||
// which interception recorded a tool usage with the given tool call ID.
|
||||
// COALESCE ensures that if the parent has no thread_root_id (i.e. it IS
|
||||
// the root), we return its own ID as the root.
|
||||
GetAIBridgeInterceptionLineageByToolCallID(ctx context.Context, toolCallID string) (GetAIBridgeInterceptionLineageByToolCallIDRow, error)
|
||||
GetAIBridgeInterceptions(ctx context.Context) ([]AIBridgeInterception, error)
|
||||
GetAIBridgeTokenUsagesByInterceptionID(ctx context.Context, interceptionID uuid.UUID) ([]AIBridgeTokenUsage, error)
|
||||
GetAIBridgeToolUsagesByInterceptionID(ctx context.Context, interceptionID uuid.UUID) ([]AIBridgeToolUsage, error)
|
||||
@@ -174,7 +169,7 @@ type sqlcQuerier interface {
|
||||
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
|
||||
// there is no unique constraint on empty token names
|
||||
GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNameParams) (APIKey, error)
|
||||
GetAPIKeysByLoginType(ctx context.Context, arg GetAPIKeysByLoginTypeParams) ([]APIKey, error)
|
||||
GetAPIKeysByLoginType(ctx context.Context, loginType LoginType) ([]APIKey, error)
|
||||
GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error)
|
||||
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
|
||||
GetActivePresetPrebuildSchedules(ctx context.Context) ([]TemplateVersionPresetPrebuildSchedule, error)
|
||||
@@ -366,23 +361,6 @@ type sqlcQuerier interface {
|
||||
GetTaskSnapshot(ctx context.Context, taskID uuid.UUID) (TaskSnapshot, error)
|
||||
GetTelemetryItem(ctx context.Context, key string) (TelemetryItem, error)
|
||||
GetTelemetryItems(ctx context.Context) ([]TelemetryItem, error)
|
||||
// Returns all data needed to build task lifecycle events for telemetry
|
||||
// in a single round-trip. For each task whose workspace is in the
|
||||
// given set, fetches:
|
||||
// - the latest workspace app binding (task_workspace_apps)
|
||||
// - the most recent stop and start builds (workspace_builds)
|
||||
// - the last "working" app status (workspace_app_statuses)
|
||||
// - the first app status after resume, for active workspaces
|
||||
//
|
||||
// Assumptions:
|
||||
// - 1:1 relationship between tasks and workspaces. All builds on the
|
||||
// workspace are considered task-related.
|
||||
// - Idle duration approximation: If the agent reports "working", does
|
||||
// work, then reports "done", we miss that working time.
|
||||
// - lws and active_dur join across all historical app IDs for the task,
|
||||
// because each resume cycle provisions a new app ID. This ensures
|
||||
// pre-pause statuses contribute to idle duration and active duration.
|
||||
GetTelemetryTaskEvents(ctx context.Context, arg GetTelemetryTaskEventsParams) ([]GetTelemetryTaskEventsRow, error)
|
||||
// GetTemplateAppInsights returns the aggregate usage of each app in a given
|
||||
// timeframe. The result can be filtered on template_ids, meaning only user data
|
||||
// from workspaces based on those templates will be included.
|
||||
@@ -528,11 +506,6 @@ type sqlcQuerier interface {
|
||||
GetWorkspaceBuildMetricsByResourceID(ctx context.Context, id uuid.UUID) (GetWorkspaceBuildMetricsByResourceIDRow, error)
|
||||
GetWorkspaceBuildParameters(ctx context.Context, workspaceBuildID uuid.UUID) ([]WorkspaceBuildParameter, error)
|
||||
GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIds []uuid.UUID) ([]WorkspaceBuildParameter, error)
|
||||
// Fetches the provisioner state of a workspace build, joined through to the
|
||||
// template so that dbauthz can enforce policy.ActionUpdate on the template.
|
||||
// Provisioner state contains sensitive Terraform state and should only be
|
||||
// accessible to template administrators.
|
||||
GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (GetWorkspaceBuildProvisionerStateByIDRow, error)
|
||||
GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]GetWorkspaceBuildStatsByTemplatesRow, error)
|
||||
GetWorkspaceBuildsByWorkspaceID(ctx context.Context, arg GetWorkspaceBuildsByWorkspaceIDParams) ([]WorkspaceBuild, error)
|
||||
GetWorkspaceBuildsCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceBuild, error)
|
||||
@@ -658,7 +631,6 @@ type sqlcQuerier interface {
|
||||
// Finds all unique AI Bridge interception telemetry summaries combinations
|
||||
// (provider, model, client) in the given timeframe for telemetry reporting.
|
||||
ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error)
|
||||
ListAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams) ([]string, error)
|
||||
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
|
||||
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
|
||||
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
|
||||
|
||||
@@ -8195,9 +8195,8 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
|
||||
|
||||
// All keys are present before deletion
|
||||
keys, err := db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
IncludeExpired: true,
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, len(expiredTimes)+len(unexpiredTimes))
|
||||
@@ -8213,9 +8212,8 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
|
||||
|
||||
// Ensure it was deleted
|
||||
remaining, err := db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
IncludeExpired: true,
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, remaining, len(expiredTimes)+len(unexpiredTimes)-1)
|
||||
@@ -8230,9 +8228,8 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
|
||||
|
||||
// Ensure only unexpired keys remain
|
||||
remaining, err = db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
IncludeExpired: true,
|
||||
LoginType: user.LoginType,
|
||||
UserID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, remaining, len(unexpiredTimes))
|
||||
|
||||
+48
-403
@@ -378,7 +378,7 @@ func (q *sqlQuerier) DeleteOldAIBridgeRecords(ctx context.Context, beforeTime ti
|
||||
|
||||
const getAIBridgeInterceptionByID = `-- name: GetAIBridgeInterceptionByID :one
|
||||
SELECT
|
||||
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id
|
||||
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
@@ -398,45 +398,13 @@ func (q *sqlQuerier) GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UU
|
||||
&i.EndedAt,
|
||||
&i.APIKeyID,
|
||||
&i.Client,
|
||||
&i.ThreadParentID,
|
||||
&i.ThreadRootID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getAIBridgeInterceptionLineageByToolCallID = `-- name: GetAIBridgeInterceptionLineageByToolCallID :one
|
||||
WITH linked AS (
|
||||
SELECT interception_id FROM aibridge_tool_usages
|
||||
WHERE provider_tool_call_id = $1::text
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
)
|
||||
SELECT linked.interception_id AS thread_parent_id,
|
||||
COALESCE(aibridge_interceptions.thread_root_id, linked.interception_id) AS thread_root_id
|
||||
FROM aibridge_interceptions
|
||||
INNER JOIN linked ON linked.interception_id = aibridge_interceptions.id
|
||||
WHERE aibridge_interceptions.id = linked.interception_id
|
||||
`
|
||||
|
||||
type GetAIBridgeInterceptionLineageByToolCallIDRow struct {
|
||||
ThreadParentID uuid.UUID `db:"thread_parent_id" json:"thread_parent_id"`
|
||||
ThreadRootID uuid.UUID `db:"thread_root_id" json:"thread_root_id"`
|
||||
}
|
||||
|
||||
// Look up the parent interception and the root of the thread by finding
|
||||
// which interception recorded a tool usage with the given tool call ID.
|
||||
// COALESCE ensures that if the parent has no thread_root_id (i.e. it IS
|
||||
// the root), we return its own ID as the root.
|
||||
func (q *sqlQuerier) GetAIBridgeInterceptionLineageByToolCallID(ctx context.Context, toolCallID string) (GetAIBridgeInterceptionLineageByToolCallIDRow, error) {
|
||||
row := q.db.QueryRowContext(ctx, getAIBridgeInterceptionLineageByToolCallID, toolCallID)
|
||||
var i GetAIBridgeInterceptionLineageByToolCallIDRow
|
||||
err := row.Scan(&i.ThreadParentID, &i.ThreadRootID)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getAIBridgeInterceptions = `-- name: GetAIBridgeInterceptions :many
|
||||
SELECT
|
||||
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id
|
||||
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
`
|
||||
@@ -460,8 +428,6 @@ func (q *sqlQuerier) GetAIBridgeInterceptions(ctx context.Context) ([]AIBridgeIn
|
||||
&i.EndedAt,
|
||||
&i.APIKeyID,
|
||||
&i.Client,
|
||||
&i.ThreadParentID,
|
||||
&i.ThreadRootID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -519,7 +485,7 @@ func (q *sqlQuerier) GetAIBridgeTokenUsagesByInterceptionID(ctx context.Context,
|
||||
|
||||
const getAIBridgeToolUsagesByInterceptionID = `-- name: GetAIBridgeToolUsagesByInterceptionID :many
|
||||
SELECT
|
||||
id, interception_id, provider_response_id, server_url, tool, input, injected, invocation_error, metadata, created_at, provider_tool_call_id
|
||||
id, interception_id, provider_response_id, server_url, tool, input, injected, invocation_error, metadata, created_at
|
||||
FROM
|
||||
aibridge_tool_usages
|
||||
WHERE
|
||||
@@ -549,7 +515,6 @@ func (q *sqlQuerier) GetAIBridgeToolUsagesByInterceptionID(ctx context.Context,
|
||||
&i.InvocationError,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
&i.ProviderToolCallID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -608,24 +573,22 @@ func (q *sqlQuerier) GetAIBridgeUserPromptsByInterceptionID(ctx context.Context,
|
||||
|
||||
const insertAIBridgeInterception = `-- name: InsertAIBridgeInterception :one
|
||||
INSERT INTO aibridge_interceptions (
|
||||
id, api_key_id, initiator_id, provider, model, metadata, started_at, client, thread_parent_id, thread_root_id
|
||||
id, api_key_id, initiator_id, provider, model, metadata, started_at, client
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, COALESCE($6::jsonb, '{}'::jsonb), $7, $8, $9::uuid, $10::uuid
|
||||
$1, $2, $3, $4, $5, COALESCE($6::jsonb, '{}'::jsonb), $7, $8
|
||||
)
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client
|
||||
`
|
||||
|
||||
type InsertAIBridgeInterceptionParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
APIKeyID sql.NullString `db:"api_key_id" json:"api_key_id"`
|
||||
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
Model string `db:"model" json:"model"`
|
||||
Metadata json.RawMessage `db:"metadata" json:"metadata"`
|
||||
StartedAt time.Time `db:"started_at" json:"started_at"`
|
||||
Client sql.NullString `db:"client" json:"client"`
|
||||
ThreadParentInterceptionID uuid.NullUUID `db:"thread_parent_interception_id" json:"thread_parent_interception_id"`
|
||||
ThreadRootInterceptionID uuid.NullUUID `db:"thread_root_interception_id" json:"thread_root_interception_id"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
APIKeyID sql.NullString `db:"api_key_id" json:"api_key_id"`
|
||||
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
Model string `db:"model" json:"model"`
|
||||
Metadata json.RawMessage `db:"metadata" json:"metadata"`
|
||||
StartedAt time.Time `db:"started_at" json:"started_at"`
|
||||
Client sql.NullString `db:"client" json:"client"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertAIBridgeInterceptionParams) (AIBridgeInterception, error) {
|
||||
@@ -638,8 +601,6 @@ func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertA
|
||||
arg.Metadata,
|
||||
arg.StartedAt,
|
||||
arg.Client,
|
||||
arg.ThreadParentInterceptionID,
|
||||
arg.ThreadRootInterceptionID,
|
||||
)
|
||||
var i AIBridgeInterception
|
||||
err := row.Scan(
|
||||
@@ -652,8 +613,6 @@ func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertA
|
||||
&i.EndedAt,
|
||||
&i.APIKeyID,
|
||||
&i.Client,
|
||||
&i.ThreadParentID,
|
||||
&i.ThreadRootID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -702,18 +661,17 @@ func (q *sqlQuerier) InsertAIBridgeTokenUsage(ctx context.Context, arg InsertAIB
|
||||
|
||||
const insertAIBridgeToolUsage = `-- name: InsertAIBridgeToolUsage :one
|
||||
INSERT INTO aibridge_tool_usages (
|
||||
id, interception_id, provider_response_id, provider_tool_call_id, tool, server_url, input, injected, invocation_error, metadata, created_at
|
||||
id, interception_id, provider_response_id, tool, server_url, input, injected, invocation_error, metadata, created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, COALESCE($10::jsonb, '{}'::jsonb), $11
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, COALESCE($9::jsonb, '{}'::jsonb), $10
|
||||
)
|
||||
RETURNING id, interception_id, provider_response_id, server_url, tool, input, injected, invocation_error, metadata, created_at, provider_tool_call_id
|
||||
RETURNING id, interception_id, provider_response_id, server_url, tool, input, injected, invocation_error, metadata, created_at
|
||||
`
|
||||
|
||||
type InsertAIBridgeToolUsageParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
InterceptionID uuid.UUID `db:"interception_id" json:"interception_id"`
|
||||
ProviderResponseID string `db:"provider_response_id" json:"provider_response_id"`
|
||||
ProviderToolCallID sql.NullString `db:"provider_tool_call_id" json:"provider_tool_call_id"`
|
||||
Tool string `db:"tool" json:"tool"`
|
||||
ServerUrl sql.NullString `db:"server_url" json:"server_url"`
|
||||
Input string `db:"input" json:"input"`
|
||||
@@ -728,7 +686,6 @@ func (q *sqlQuerier) InsertAIBridgeToolUsage(ctx context.Context, arg InsertAIBr
|
||||
arg.ID,
|
||||
arg.InterceptionID,
|
||||
arg.ProviderResponseID,
|
||||
arg.ProviderToolCallID,
|
||||
arg.Tool,
|
||||
arg.ServerUrl,
|
||||
arg.Input,
|
||||
@@ -749,7 +706,6 @@ func (q *sqlQuerier) InsertAIBridgeToolUsage(ctx context.Context, arg InsertAIBr
|
||||
&i.InvocationError,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
&i.ProviderToolCallID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -795,7 +751,7 @@ func (q *sqlQuerier) InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIB
|
||||
|
||||
const listAIBridgeInterceptions = `-- name: ListAIBridgeInterceptions :many
|
||||
SELECT
|
||||
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id,
|
||||
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client,
|
||||
visible_users.id, visible_users.username, visible_users.name, visible_users.avatar_url
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
@@ -904,8 +860,6 @@ func (q *sqlQuerier) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBr
|
||||
&i.AIBridgeInterception.EndedAt,
|
||||
&i.AIBridgeInterception.APIKeyID,
|
||||
&i.AIBridgeInterception.Client,
|
||||
&i.AIBridgeInterception.ThreadParentID,
|
||||
&i.AIBridgeInterception.ThreadRootID,
|
||||
&i.VisibleUser.ID,
|
||||
&i.VisibleUser.Username,
|
||||
&i.VisibleUser.Name,
|
||||
@@ -974,60 +928,6 @@ func (q *sqlQuerier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Con
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listAIBridgeModels = `-- name: ListAIBridgeModels :many
|
||||
SELECT
|
||||
model
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
-- Remove inflight interceptions (ones which lack an ended_at value).
|
||||
aibridge_interceptions.ended_at IS NOT NULL
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN $1::text != '' THEN aibridge_interceptions.model LIKE $1::text || '%'
|
||||
ELSE true
|
||||
END
|
||||
-- We use an ` + "`" + `@authorize_filter` + "`" + ` as we are attempting to list models that are relevant
|
||||
-- to the user and what they are allowed to see.
|
||||
-- Authorize Filter clause will be injected below in ListAIBridgeModelsAuthorized
|
||||
-- @authorize_filter
|
||||
GROUP BY
|
||||
model
|
||||
ORDER BY
|
||||
model ASC
|
||||
LIMIT COALESCE(NULLIF($3::integer, 0), 100)
|
||||
OFFSET $2
|
||||
`
|
||||
|
||||
type ListAIBridgeModelsParams struct {
|
||||
Model string `db:"model" json:"model"`
|
||||
Offset int32 `db:"offset_" json:"offset_"`
|
||||
Limit int32 `db:"limit_" json:"limit_"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) ListAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams) ([]string, error) {
|
||||
rows, err := q.db.QueryContext(ctx, listAIBridgeModels, arg.Model, arg.Offset, arg.Limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []string
|
||||
for rows.Next() {
|
||||
var model string
|
||||
if err := rows.Scan(&model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, model)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listAIBridgeTokenUsagesByInterceptionIDs = `-- name: ListAIBridgeTokenUsagesByInterceptionIDs :many
|
||||
SELECT
|
||||
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
|
||||
@@ -1073,7 +973,7 @@ func (q *sqlQuerier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Contex
|
||||
|
||||
const listAIBridgeToolUsagesByInterceptionIDs = `-- name: ListAIBridgeToolUsagesByInterceptionIDs :many
|
||||
SELECT
|
||||
id, interception_id, provider_response_id, server_url, tool, input, injected, invocation_error, metadata, created_at, provider_tool_call_id
|
||||
id, interception_id, provider_response_id, server_url, tool, input, injected, invocation_error, metadata, created_at
|
||||
FROM
|
||||
aibridge_tool_usages
|
||||
WHERE
|
||||
@@ -1103,7 +1003,6 @@ func (q *sqlQuerier) ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context
|
||||
&i.InvocationError,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
&i.ProviderToolCallID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1166,7 +1065,7 @@ UPDATE aibridge_interceptions
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
AND ended_at IS NULL
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client
|
||||
`
|
||||
|
||||
type UpdateAIBridgeInterceptionEndedParams struct {
|
||||
@@ -1187,8 +1086,6 @@ func (q *sqlQuerier) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg Up
|
||||
&i.EndedAt,
|
||||
&i.APIKeyID,
|
||||
&i.Client,
|
||||
&i.ThreadParentID,
|
||||
&i.ThreadRootID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -1373,16 +1270,10 @@ func (q *sqlQuerier) GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNamePar
|
||||
|
||||
const getAPIKeysByLoginType = `-- name: GetAPIKeysByLoginType :many
|
||||
SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address, token_name, scopes, allow_list FROM api_keys WHERE login_type = $1
|
||||
AND ($2::bool OR expires_at > now())
|
||||
`
|
||||
|
||||
type GetAPIKeysByLoginTypeParams struct {
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
IncludeExpired bool `db:"include_expired" json:"include_expired"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAPIKeysByLoginType(ctx context.Context, arg GetAPIKeysByLoginTypeParams) ([]APIKey, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getAPIKeysByLoginType, arg.LoginType, arg.IncludeExpired)
|
||||
func (q *sqlQuerier) GetAPIKeysByLoginType(ctx context.Context, loginType LoginType) ([]APIKey, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getAPIKeysByLoginType, loginType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1420,17 +1311,15 @@ func (q *sqlQuerier) GetAPIKeysByLoginType(ctx context.Context, arg GetAPIKeysBy
|
||||
|
||||
const getAPIKeysByUserID = `-- name: GetAPIKeysByUserID :many
|
||||
SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address, token_name, scopes, allow_list FROM api_keys WHERE login_type = $1 AND user_id = $2
|
||||
AND ($3::bool OR expires_at > now())
|
||||
`
|
||||
|
||||
type GetAPIKeysByUserIDParams struct {
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
IncludeExpired bool `db:"include_expired" json:"include_expired"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getAPIKeysByUserID, arg.LoginType, arg.UserID, arg.IncludeExpired)
|
||||
rows, err := q.db.QueryContext(ctx, getAPIKeysByUserID, arg.LoginType, arg.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -6879,7 +6768,7 @@ func (q *sqlQuerier) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context
|
||||
}
|
||||
|
||||
const getOAuth2ProviderAppCodeByID = `-- name: GetOAuth2ProviderAppCodeByID :one
|
||||
SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method, state_hash, redirect_uri FROM oauth2_provider_app_codes WHERE id = $1
|
||||
SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method FROM oauth2_provider_app_codes WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) (OAuth2ProviderAppCode, error) {
|
||||
@@ -6896,14 +6785,12 @@ func (q *sqlQuerier) GetOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.U
|
||||
&i.ResourceUri,
|
||||
&i.CodeChallenge,
|
||||
&i.CodeChallengeMethod,
|
||||
&i.StateHash,
|
||||
&i.RedirectUri,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getOAuth2ProviderAppCodeByPrefix = `-- name: GetOAuth2ProviderAppCodeByPrefix :one
|
||||
SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method, state_hash, redirect_uri FROM oauth2_provider_app_codes WHERE secret_prefix = $1
|
||||
SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method FROM oauth2_provider_app_codes WHERE secret_prefix = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetOAuth2ProviderAppCodeByPrefix(ctx context.Context, secretPrefix []byte) (OAuth2ProviderAppCode, error) {
|
||||
@@ -6920,8 +6807,6 @@ func (q *sqlQuerier) GetOAuth2ProviderAppCodeByPrefix(ctx context.Context, secre
|
||||
&i.ResourceUri,
|
||||
&i.CodeChallenge,
|
||||
&i.CodeChallengeMethod,
|
||||
&i.StateHash,
|
||||
&i.RedirectUri,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -7325,9 +7210,7 @@ INSERT INTO oauth2_provider_app_codes (
|
||||
user_id,
|
||||
resource_uri,
|
||||
code_challenge,
|
||||
code_challenge_method,
|
||||
state_hash,
|
||||
redirect_uri
|
||||
code_challenge_method
|
||||
) VALUES(
|
||||
$1,
|
||||
$2,
|
||||
@@ -7338,10 +7221,8 @@ INSERT INTO oauth2_provider_app_codes (
|
||||
$7,
|
||||
$8,
|
||||
$9,
|
||||
$10,
|
||||
$11,
|
||||
$12
|
||||
) RETURNING id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method, state_hash, redirect_uri
|
||||
$10
|
||||
) RETURNING id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id, resource_uri, code_challenge, code_challenge_method
|
||||
`
|
||||
|
||||
type InsertOAuth2ProviderAppCodeParams struct {
|
||||
@@ -7355,8 +7236,6 @@ type InsertOAuth2ProviderAppCodeParams struct {
|
||||
ResourceUri sql.NullString `db:"resource_uri" json:"resource_uri"`
|
||||
CodeChallenge sql.NullString `db:"code_challenge" json:"code_challenge"`
|
||||
CodeChallengeMethod sql.NullString `db:"code_challenge_method" json:"code_challenge_method"`
|
||||
StateHash sql.NullString `db:"state_hash" json:"state_hash"`
|
||||
RedirectUri sql.NullString `db:"redirect_uri" json:"redirect_uri"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertOAuth2ProviderAppCode(ctx context.Context, arg InsertOAuth2ProviderAppCodeParams) (OAuth2ProviderAppCode, error) {
|
||||
@@ -7371,8 +7250,6 @@ func (q *sqlQuerier) InsertOAuth2ProviderAppCode(ctx context.Context, arg Insert
|
||||
arg.ResourceUri,
|
||||
arg.CodeChallenge,
|
||||
arg.CodeChallengeMethod,
|
||||
arg.StateHash,
|
||||
arg.RedirectUri,
|
||||
)
|
||||
var i OAuth2ProviderAppCode
|
||||
err := row.Scan(
|
||||
@@ -7386,8 +7263,6 @@ func (q *sqlQuerier) InsertOAuth2ProviderAppCode(ctx context.Context, arg Insert
|
||||
&i.ResourceUri,
|
||||
&i.CodeChallenge,
|
||||
&i.CodeChallengeMethod,
|
||||
&i.StateHash,
|
||||
&i.RedirectUri,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -13437,203 +13312,6 @@ func (q *sqlQuerier) GetTaskSnapshot(ctx context.Context, taskID uuid.UUID) (Tas
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTelemetryTaskEvents = `-- name: GetTelemetryTaskEvents :many
|
||||
WITH task_app_ids AS (
|
||||
SELECT task_id, workspace_app_id
|
||||
FROM task_workspace_apps
|
||||
),
|
||||
task_status_timeline AS (
|
||||
-- All app statuses across every historical app for each task,
|
||||
-- plus synthetic "boundary" rows at each stop/start build transition.
|
||||
-- This allows us to correctly take gaps due to pause/resume into account.
|
||||
SELECT tai.task_id, was.created_at, was.state::text AS state
|
||||
FROM workspace_app_statuses was
|
||||
JOIN task_app_ids tai ON tai.workspace_app_id = was.app_id
|
||||
UNION ALL
|
||||
SELECT t.id AS task_id, wb.created_at, '_boundary' AS state
|
||||
FROM tasks t
|
||||
JOIN workspace_builds wb ON wb.workspace_id = t.workspace_id
|
||||
WHERE t.deleted_at IS NULL
|
||||
AND t.workspace_id IS NOT NULL
|
||||
AND wb.build_number > 1
|
||||
),
|
||||
task_event_data AS (
|
||||
SELECT
|
||||
t.id AS task_id,
|
||||
t.workspace_id,
|
||||
twa.workspace_app_id,
|
||||
-- Latest stop build.
|
||||
stop_build.created_at AS stop_build_created_at,
|
||||
stop_build.reason AS stop_build_reason,
|
||||
-- Latest start build (task_resume only).
|
||||
start_build.created_at AS start_build_created_at,
|
||||
start_build.reason AS start_build_reason,
|
||||
start_build.build_number AS start_build_number,
|
||||
-- Last "working" app status (for idle duration).
|
||||
lws.created_at AS last_working_status_at,
|
||||
-- First app status after resume (for resume-to-status duration).
|
||||
-- Only populated for workspaces in an active phase (started more
|
||||
-- recently than stopped).
|
||||
fsar.created_at AS first_status_after_resume_at,
|
||||
-- Cumulative time spent in "working" state.
|
||||
active_dur.total_working_ms AS active_duration_ms
|
||||
FROM tasks t
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT task_app.workspace_app_id
|
||||
FROM task_workspace_apps task_app
|
||||
WHERE task_app.task_id = t.id
|
||||
ORDER BY task_app.workspace_build_number DESC
|
||||
LIMIT 1
|
||||
) twa ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT wb.created_at, wb.reason, wb.build_number
|
||||
FROM workspace_builds wb
|
||||
WHERE wb.workspace_id = t.workspace_id
|
||||
AND wb.transition = 'stop'
|
||||
ORDER BY wb.build_number DESC
|
||||
LIMIT 1
|
||||
) stop_build ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT wb.created_at, wb.reason, wb.build_number
|
||||
FROM workspace_builds wb
|
||||
WHERE wb.workspace_id = t.workspace_id
|
||||
AND wb.transition = 'start'
|
||||
ORDER BY wb.build_number DESC
|
||||
LIMIT 1
|
||||
) start_build ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT tst.created_at
|
||||
FROM task_status_timeline tst
|
||||
WHERE tst.task_id = t.id
|
||||
AND tst.state = 'working'
|
||||
-- Only consider status before the latest pause so that
|
||||
-- post-resume statuses don't mask pre-pause idle time.
|
||||
AND (stop_build.created_at IS NULL
|
||||
OR tst.created_at <= stop_build.created_at)
|
||||
ORDER BY tst.created_at DESC
|
||||
LIMIT 1
|
||||
) lws ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT was.created_at
|
||||
FROM workspace_app_statuses was
|
||||
WHERE was.app_id = twa.workspace_app_id
|
||||
AND was.created_at > start_build.created_at
|
||||
ORDER BY was.created_at ASC
|
||||
LIMIT 1
|
||||
) fsar ON twa.workspace_app_id IS NOT NULL
|
||||
AND start_build.created_at IS NOT NULL
|
||||
AND (stop_build.created_at IS NULL
|
||||
OR start_build.created_at > stop_build.created_at)
|
||||
-- Active duration: cumulative time spent in "working" state across all
|
||||
-- historical app IDs for this task. Uses LEAD() to convert point-in-time
|
||||
-- statuses into intervals, then sums intervals where state='working'. For
|
||||
-- the last status, falls back to stop_build time (if paused) or @now (if
|
||||
-- still running).
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT COALESCE(
|
||||
SUM(EXTRACT(EPOCH FROM (interval_end - interval_start)) * 1000)::bigint,
|
||||
0
|
||||
)::bigint AS total_working_ms
|
||||
FROM (
|
||||
SELECT
|
||||
tst.created_at AS interval_start,
|
||||
COALESCE(
|
||||
LEAD(tst.created_at) OVER (ORDER BY tst.created_at ASC, CASE WHEN tst.state = '_boundary' THEN 1 ELSE 0 END ASC),
|
||||
CASE WHEN stop_build.created_at IS NOT NULL
|
||||
AND (start_build.created_at IS NULL
|
||||
OR stop_build.created_at > start_build.created_at)
|
||||
THEN stop_build.created_at
|
||||
ELSE $1::timestamptz
|
||||
END
|
||||
) AS interval_end,
|
||||
tst.state
|
||||
FROM task_status_timeline tst
|
||||
WHERE tst.task_id = t.id
|
||||
) intervals
|
||||
WHERE intervals.state = 'working'
|
||||
) active_dur ON TRUE
|
||||
WHERE t.deleted_at IS NULL
|
||||
AND t.workspace_id IS NOT NULL
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM workspace_builds wb
|
||||
WHERE wb.workspace_id = t.workspace_id
|
||||
AND wb.created_at > $2
|
||||
)
|
||||
)
|
||||
SELECT task_id, workspace_id, workspace_app_id, stop_build_created_at, stop_build_reason, start_build_created_at, start_build_reason, start_build_number, last_working_status_at, first_status_after_resume_at, active_duration_ms FROM task_event_data
|
||||
ORDER BY task_id
|
||||
`
|
||||
|
||||
type GetTelemetryTaskEventsParams struct {
|
||||
Now time.Time `db:"now" json:"now"`
|
||||
CreatedAfter time.Time `db:"created_after" json:"created_after"`
|
||||
}
|
||||
|
||||
type GetTelemetryTaskEventsRow struct {
|
||||
TaskID uuid.UUID `db:"task_id" json:"task_id"`
|
||||
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
|
||||
WorkspaceAppID uuid.NullUUID `db:"workspace_app_id" json:"workspace_app_id"`
|
||||
StopBuildCreatedAt sql.NullTime `db:"stop_build_created_at" json:"stop_build_created_at"`
|
||||
StopBuildReason NullBuildReason `db:"stop_build_reason" json:"stop_build_reason"`
|
||||
StartBuildCreatedAt sql.NullTime `db:"start_build_created_at" json:"start_build_created_at"`
|
||||
StartBuildReason NullBuildReason `db:"start_build_reason" json:"start_build_reason"`
|
||||
StartBuildNumber sql.NullInt32 `db:"start_build_number" json:"start_build_number"`
|
||||
LastWorkingStatusAt sql.NullTime `db:"last_working_status_at" json:"last_working_status_at"`
|
||||
FirstStatusAfterResumeAt sql.NullTime `db:"first_status_after_resume_at" json:"first_status_after_resume_at"`
|
||||
ActiveDurationMs int64 `db:"active_duration_ms" json:"active_duration_ms"`
|
||||
}
|
||||
|
||||
// Returns all data needed to build task lifecycle events for telemetry
|
||||
// in a single round-trip. For each task whose workspace is in the
|
||||
// given set, fetches:
|
||||
// - the latest workspace app binding (task_workspace_apps)
|
||||
// - the most recent stop and start builds (workspace_builds)
|
||||
// - the last "working" app status (workspace_app_statuses)
|
||||
// - the first app status after resume, for active workspaces
|
||||
//
|
||||
// Assumptions:
|
||||
// - 1:1 relationship between tasks and workspaces. All builds on the
|
||||
// workspace are considered task-related.
|
||||
// - Idle duration approximation: If the agent reports "working", does
|
||||
// work, then reports "done", we miss that working time.
|
||||
// - lws and active_dur join across all historical app IDs for the task,
|
||||
// because each resume cycle provisions a new app ID. This ensures
|
||||
// pre-pause statuses contribute to idle duration and active duration.
|
||||
func (q *sqlQuerier) GetTelemetryTaskEvents(ctx context.Context, arg GetTelemetryTaskEventsParams) ([]GetTelemetryTaskEventsRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getTelemetryTaskEvents, arg.Now, arg.CreatedAfter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetTelemetryTaskEventsRow
|
||||
for rows.Next() {
|
||||
var i GetTelemetryTaskEventsRow
|
||||
if err := rows.Scan(
|
||||
&i.TaskID,
|
||||
&i.WorkspaceID,
|
||||
&i.WorkspaceAppID,
|
||||
&i.StopBuildCreatedAt,
|
||||
&i.StopBuildReason,
|
||||
&i.StartBuildCreatedAt,
|
||||
&i.StartBuildReason,
|
||||
&i.StartBuildNumber,
|
||||
&i.LastWorkingStatusAt,
|
||||
&i.FirstStatusAfterResumeAt,
|
||||
&i.ActiveDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertTask = `-- name: InsertTask :one
|
||||
INSERT INTO tasks
|
||||
(id, organization_id, owner_id, name, display_name, workspace_id, template_version_id, template_parameters, prompt, created_at)
|
||||
@@ -18263,7 +17941,7 @@ const getAuthenticatedWorkspaceAgentAndBuildByAuthToken = `-- name: GetAuthentic
|
||||
SELECT
|
||||
workspaces.id, workspaces.created_at, workspaces.updated_at, workspaces.owner_id, workspaces.organization_id, workspaces.template_id, workspaces.deleted, workspaces.name, workspaces.autostart_schedule, workspaces.ttl, workspaces.last_used_at, workspaces.dormant_at, workspaces.deleting_at, workspaces.automatic_updates, workspaces.favorite, workspaces.next_start_at, workspaces.group_acl, workspaces.user_acl,
|
||||
workspace_agents.id, workspace_agents.created_at, workspace_agents.updated_at, workspace_agents.name, workspace_agents.first_connected_at, workspace_agents.last_connected_at, workspace_agents.disconnected_at, workspace_agents.resource_id, workspace_agents.auth_token, workspace_agents.auth_instance_id, workspace_agents.architecture, workspace_agents.environment_variables, workspace_agents.operating_system, workspace_agents.instance_metadata, workspace_agents.resource_metadata, workspace_agents.directory, workspace_agents.version, workspace_agents.last_connected_replica_id, workspace_agents.connection_timeout_seconds, workspace_agents.troubleshooting_url, workspace_agents.motd_file, workspace_agents.lifecycle_state, workspace_agents.expanded_directory, workspace_agents.logs_length, workspace_agents.logs_overflowed, workspace_agents.started_at, workspace_agents.ready_at, workspace_agents.subsystems, workspace_agents.display_apps, workspace_agents.api_version, workspace_agents.display_order, workspace_agents.parent_id, workspace_agents.api_key_scope, workspace_agents.deleted,
|
||||
workspace_build_with_user.id, workspace_build_with_user.created_at, workspace_build_with_user.updated_at, workspace_build_with_user.workspace_id, workspace_build_with_user.template_version_id, workspace_build_with_user.build_number, workspace_build_with_user.transition, workspace_build_with_user.initiator_id, workspace_build_with_user.job_id, workspace_build_with_user.deadline, workspace_build_with_user.reason, workspace_build_with_user.daily_cost, workspace_build_with_user.max_deadline, workspace_build_with_user.template_version_preset_id, workspace_build_with_user.has_ai_task, workspace_build_with_user.has_external_agent, workspace_build_with_user.initiator_by_avatar_url, workspace_build_with_user.initiator_by_username, workspace_build_with_user.initiator_by_name,
|
||||
workspace_build_with_user.id, workspace_build_with_user.created_at, workspace_build_with_user.updated_at, workspace_build_with_user.workspace_id, workspace_build_with_user.template_version_id, workspace_build_with_user.build_number, workspace_build_with_user.transition, workspace_build_with_user.initiator_id, workspace_build_with_user.provisioner_state, workspace_build_with_user.job_id, workspace_build_with_user.deadline, workspace_build_with_user.reason, workspace_build_with_user.daily_cost, workspace_build_with_user.max_deadline, workspace_build_with_user.template_version_preset_id, workspace_build_with_user.has_ai_task, workspace_build_with_user.has_external_agent, workspace_build_with_user.initiator_by_avatar_url, workspace_build_with_user.initiator_by_username, workspace_build_with_user.initiator_by_name,
|
||||
tasks.id AS task_id
|
||||
FROM
|
||||
workspace_agents
|
||||
@@ -18401,6 +18079,7 @@ func (q *sqlQuerier) GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx conte
|
||||
&i.WorkspaceBuild.BuildNumber,
|
||||
&i.WorkspaceBuild.Transition,
|
||||
&i.WorkspaceBuild.InitiatorID,
|
||||
&i.WorkspaceBuild.ProvisionerState,
|
||||
&i.WorkspaceBuild.JobID,
|
||||
&i.WorkspaceBuild.Deadline,
|
||||
&i.WorkspaceBuild.Reason,
|
||||
@@ -21314,7 +20993,7 @@ func (q *sqlQuerier) InsertWorkspaceBuildParameters(ctx context.Context, arg Ins
|
||||
}
|
||||
|
||||
const getActiveWorkspaceBuildsByTemplateID = `-- name: GetActiveWorkspaceBuildsByTemplateID :many
|
||||
SELECT wb.id, wb.created_at, wb.updated_at, wb.workspace_id, wb.template_version_id, wb.build_number, wb.transition, wb.initiator_id, wb.job_id, wb.deadline, wb.reason, wb.daily_cost, wb.max_deadline, wb.template_version_preset_id, wb.has_ai_task, wb.has_external_agent, wb.initiator_by_avatar_url, wb.initiator_by_username, wb.initiator_by_name
|
||||
SELECT wb.id, wb.created_at, wb.updated_at, wb.workspace_id, wb.template_version_id, wb.build_number, wb.transition, wb.initiator_id, wb.provisioner_state, wb.job_id, wb.deadline, wb.reason, wb.daily_cost, wb.max_deadline, wb.template_version_preset_id, wb.has_ai_task, wb.has_external_agent, wb.initiator_by_avatar_url, wb.initiator_by_username, wb.initiator_by_name
|
||||
FROM (
|
||||
SELECT
|
||||
workspace_id, MAX(build_number) as max_build_number
|
||||
@@ -21362,6 +21041,7 @@ func (q *sqlQuerier) GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, t
|
||||
&i.BuildNumber,
|
||||
&i.Transition,
|
||||
&i.InitiatorID,
|
||||
&i.ProvisionerState,
|
||||
&i.JobID,
|
||||
&i.Deadline,
|
||||
&i.Reason,
|
||||
@@ -21469,7 +21149,7 @@ func (q *sqlQuerier) GetFailedWorkspaceBuildsByTemplateID(ctx context.Context, a
|
||||
|
||||
const getLatestWorkspaceBuildByWorkspaceID = `-- name: GetLatestWorkspaceBuildByWorkspaceID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
FROM
|
||||
workspace_build_with_user AS workspace_builds
|
||||
WHERE
|
||||
@@ -21492,6 +21172,7 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, w
|
||||
&i.BuildNumber,
|
||||
&i.Transition,
|
||||
&i.InitiatorID,
|
||||
&i.ProvisionerState,
|
||||
&i.JobID,
|
||||
&i.Deadline,
|
||||
&i.Reason,
|
||||
@@ -21510,7 +21191,7 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, w
|
||||
const getLatestWorkspaceBuildsByWorkspaceIDs = `-- name: GetLatestWorkspaceBuildsByWorkspaceIDs :many
|
||||
SELECT
|
||||
DISTINCT ON (workspace_id)
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
FROM
|
||||
workspace_build_with_user AS workspace_builds
|
||||
WHERE
|
||||
@@ -21537,6 +21218,7 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context,
|
||||
&i.BuildNumber,
|
||||
&i.Transition,
|
||||
&i.InitiatorID,
|
||||
&i.ProvisionerState,
|
||||
&i.JobID,
|
||||
&i.Deadline,
|
||||
&i.Reason,
|
||||
@@ -21564,7 +21246,7 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context,
|
||||
|
||||
const getWorkspaceBuildByID = `-- name: GetWorkspaceBuildByID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
FROM
|
||||
workspace_build_with_user AS workspace_builds
|
||||
WHERE
|
||||
@@ -21585,6 +21267,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (W
|
||||
&i.BuildNumber,
|
||||
&i.Transition,
|
||||
&i.InitiatorID,
|
||||
&i.ProvisionerState,
|
||||
&i.JobID,
|
||||
&i.Deadline,
|
||||
&i.Reason,
|
||||
@@ -21602,7 +21285,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (W
|
||||
|
||||
const getWorkspaceBuildByJobID = `-- name: GetWorkspaceBuildByJobID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
FROM
|
||||
workspace_build_with_user AS workspace_builds
|
||||
WHERE
|
||||
@@ -21623,6 +21306,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByJobID(ctx context.Context, jobID uuid.UU
|
||||
&i.BuildNumber,
|
||||
&i.Transition,
|
||||
&i.InitiatorID,
|
||||
&i.ProvisionerState,
|
||||
&i.JobID,
|
||||
&i.Deadline,
|
||||
&i.Reason,
|
||||
@@ -21640,7 +21324,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByJobID(ctx context.Context, jobID uuid.UU
|
||||
|
||||
const getWorkspaceBuildByWorkspaceIDAndBuildNumber = `-- name: GetWorkspaceBuildByWorkspaceIDAndBuildNumber :one
|
||||
SELECT
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
FROM
|
||||
workspace_build_with_user AS workspace_builds
|
||||
WHERE
|
||||
@@ -21665,6 +21349,7 @@ func (q *sqlQuerier) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx context.Co
|
||||
&i.BuildNumber,
|
||||
&i.Transition,
|
||||
&i.InitiatorID,
|
||||
&i.ProvisionerState,
|
||||
&i.JobID,
|
||||
&i.Deadline,
|
||||
&i.Reason,
|
||||
@@ -21736,48 +21421,6 @@ func (q *sqlQuerier) GetWorkspaceBuildMetricsByResourceID(ctx context.Context, i
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getWorkspaceBuildProvisionerStateByID = `-- name: GetWorkspaceBuildProvisionerStateByID :one
|
||||
SELECT
|
||||
workspace_builds.provisioner_state,
|
||||
templates.id AS template_id,
|
||||
templates.organization_id AS template_organization_id,
|
||||
templates.user_acl,
|
||||
templates.group_acl
|
||||
FROM
|
||||
workspace_builds
|
||||
INNER JOIN
|
||||
workspaces ON workspaces.id = workspace_builds.workspace_id
|
||||
INNER JOIN
|
||||
templates ON templates.id = workspaces.template_id
|
||||
WHERE
|
||||
workspace_builds.id = $1
|
||||
`
|
||||
|
||||
type GetWorkspaceBuildProvisionerStateByIDRow struct {
|
||||
ProvisionerState []byte `db:"provisioner_state" json:"provisioner_state"`
|
||||
TemplateID uuid.UUID `db:"template_id" json:"template_id"`
|
||||
TemplateOrganizationID uuid.UUID `db:"template_organization_id" json:"template_organization_id"`
|
||||
UserACL TemplateACL `db:"user_acl" json:"user_acl"`
|
||||
GroupACL TemplateACL `db:"group_acl" json:"group_acl"`
|
||||
}
|
||||
|
||||
// Fetches the provisioner state of a workspace build, joined through to the
|
||||
// template so that dbauthz can enforce policy.ActionUpdate on the template.
|
||||
// Provisioner state contains sensitive Terraform state and should only be
|
||||
// accessible to template administrators.
|
||||
func (q *sqlQuerier) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (GetWorkspaceBuildProvisionerStateByIDRow, error) {
|
||||
row := q.db.QueryRowContext(ctx, getWorkspaceBuildProvisionerStateByID, workspaceBuildID)
|
||||
var i GetWorkspaceBuildProvisionerStateByIDRow
|
||||
err := row.Scan(
|
||||
&i.ProvisionerState,
|
||||
&i.TemplateID,
|
||||
&i.TemplateOrganizationID,
|
||||
&i.UserACL,
|
||||
&i.GroupACL,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getWorkspaceBuildStatsByTemplates = `-- name: GetWorkspaceBuildStatsByTemplates :many
|
||||
SELECT
|
||||
w.template_id,
|
||||
@@ -21847,7 +21490,7 @@ func (q *sqlQuerier) GetWorkspaceBuildStatsByTemplates(ctx context.Context, sinc
|
||||
|
||||
const getWorkspaceBuildsByWorkspaceID = `-- name: GetWorkspaceBuildsByWorkspaceID :many
|
||||
SELECT
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name
|
||||
FROM
|
||||
workspace_build_with_user AS workspace_builds
|
||||
WHERE
|
||||
@@ -21911,6 +21554,7 @@ func (q *sqlQuerier) GetWorkspaceBuildsByWorkspaceID(ctx context.Context, arg Ge
|
||||
&i.BuildNumber,
|
||||
&i.Transition,
|
||||
&i.InitiatorID,
|
||||
&i.ProvisionerState,
|
||||
&i.JobID,
|
||||
&i.Deadline,
|
||||
&i.Reason,
|
||||
@@ -21937,7 +21581,7 @@ func (q *sqlQuerier) GetWorkspaceBuildsByWorkspaceID(ctx context.Context, arg Ge
|
||||
}
|
||||
|
||||
const getWorkspaceBuildsCreatedAfter = `-- name: GetWorkspaceBuildsCreatedAfter :many
|
||||
SELECT id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name FROM workspace_build_with_user WHERE created_at > $1
|
||||
SELECT id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, has_external_agent, initiator_by_avatar_url, initiator_by_username, initiator_by_name FROM workspace_build_with_user WHERE created_at > $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetWorkspaceBuildsCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceBuild, error) {
|
||||
@@ -21958,6 +21602,7 @@ func (q *sqlQuerier) GetWorkspaceBuildsCreatedAfter(ctx context.Context, created
|
||||
&i.BuildNumber,
|
||||
&i.Transition,
|
||||
&i.InitiatorID,
|
||||
&i.ProvisionerState,
|
||||
&i.JobID,
|
||||
&i.Deadline,
|
||||
&i.Reason,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
-- name: InsertAIBridgeInterception :one
|
||||
INSERT INTO aibridge_interceptions (
|
||||
id, api_key_id, initiator_id, provider, model, metadata, started_at, client, thread_parent_id, thread_root_id
|
||||
id, api_key_id, initiator_id, provider, model, metadata, started_at, client
|
||||
) VALUES (
|
||||
@id, @api_key_id, @initiator_id, @provider, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid
|
||||
@id, @api_key_id, @initiator_id, @provider, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
@@ -14,23 +14,6 @@ WHERE
|
||||
AND ended_at IS NULL
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetAIBridgeInterceptionLineageByToolCallID :one
|
||||
-- Look up the parent interception and the root of the thread by finding
|
||||
-- which interception recorded a tool usage with the given tool call ID.
|
||||
-- COALESCE ensures that if the parent has no thread_root_id (i.e. it IS
|
||||
-- the root), we return its own ID as the root.
|
||||
WITH linked AS (
|
||||
SELECT interception_id FROM aibridge_tool_usages
|
||||
WHERE provider_tool_call_id = @tool_call_id::text
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
)
|
||||
SELECT linked.interception_id AS thread_parent_id,
|
||||
COALESCE(aibridge_interceptions.thread_root_id, linked.interception_id) AS thread_root_id
|
||||
FROM aibridge_interceptions
|
||||
INNER JOIN linked ON linked.interception_id = aibridge_interceptions.id
|
||||
WHERE aibridge_interceptions.id = linked.interception_id;
|
||||
|
||||
-- name: InsertAIBridgeTokenUsage :one
|
||||
INSERT INTO aibridge_token_usages (
|
||||
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
|
||||
@@ -49,9 +32,9 @@ RETURNING *;
|
||||
|
||||
-- name: InsertAIBridgeToolUsage :one
|
||||
INSERT INTO aibridge_tool_usages (
|
||||
id, interception_id, provider_response_id, provider_tool_call_id, tool, server_url, input, injected, invocation_error, metadata, created_at
|
||||
id, interception_id, provider_response_id, tool, server_url, input, injected, invocation_error, metadata, created_at
|
||||
) VALUES (
|
||||
@id, @interception_id, @provider_response_id, @provider_tool_call_id, @tool, @server_url, @input, @injected, @invocation_error, COALESCE(@metadata::jsonb, '{}'::jsonb), @created_at
|
||||
@id, @interception_id, @provider_response_id, @tool, @server_url, @input, @injected, @invocation_error, COALESCE(@metadata::jsonb, '{}'::jsonb), @created_at
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
@@ -391,28 +374,3 @@ SELECT (
|
||||
(SELECT COUNT(*) FROM user_prompts) +
|
||||
(SELECT COUNT(*) FROM interceptions)
|
||||
)::bigint as total_deleted;
|
||||
|
||||
-- name: ListAIBridgeModels :many
|
||||
SELECT
|
||||
model
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
-- Remove inflight interceptions (ones which lack an ended_at value).
|
||||
aibridge_interceptions.ended_at IS NOT NULL
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN @model::text != '' THEN aibridge_interceptions.model LIKE @model::text || '%'
|
||||
ELSE true
|
||||
END
|
||||
-- We use an `@authorize_filter` as we are attempting to list models that are relevant
|
||||
-- to the user and what they are allowed to see.
|
||||
-- Authorize Filter clause will be injected below in ListAIBridgeModelsAuthorized
|
||||
-- @authorize_filter
|
||||
GROUP BY
|
||||
model
|
||||
ORDER BY
|
||||
model ASC
|
||||
LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100)
|
||||
OFFSET @offset_
|
||||
;
|
||||
|
||||
@@ -25,12 +25,10 @@ LIMIT
|
||||
SELECT * FROM api_keys WHERE last_used > $1;
|
||||
|
||||
-- name: GetAPIKeysByLoginType :many
|
||||
SELECT * FROM api_keys WHERE login_type = $1
|
||||
AND (@include_expired::bool OR expires_at > now());
|
||||
SELECT * FROM api_keys WHERE login_type = $1;
|
||||
|
||||
-- name: GetAPIKeysByUserID :many
|
||||
SELECT * FROM api_keys WHERE login_type = $1 AND user_id = $2
|
||||
AND (@include_expired::bool OR expires_at > now());
|
||||
SELECT * FROM api_keys WHERE login_type = $1 AND user_id = $2;
|
||||
|
||||
-- name: InsertAPIKey :one
|
||||
INSERT INTO
|
||||
|
||||
@@ -140,9 +140,7 @@ INSERT INTO oauth2_provider_app_codes (
|
||||
user_id,
|
||||
resource_uri,
|
||||
code_challenge,
|
||||
code_challenge_method,
|
||||
state_hash,
|
||||
redirect_uri
|
||||
code_challenge_method
|
||||
) VALUES(
|
||||
$1,
|
||||
$2,
|
||||
@@ -153,9 +151,7 @@ INSERT INTO oauth2_provider_app_codes (
|
||||
$7,
|
||||
$8,
|
||||
$9,
|
||||
$10,
|
||||
$11,
|
||||
$12
|
||||
$10
|
||||
) RETURNING *;
|
||||
|
||||
-- name: DeleteOAuth2ProviderAppCodeByID :exec
|
||||
|
||||
@@ -100,146 +100,3 @@ FROM
|
||||
task_snapshots
|
||||
WHERE
|
||||
task_id = $1;
|
||||
|
||||
-- name: GetTelemetryTaskEvents :many
|
||||
-- Returns all data needed to build task lifecycle events for telemetry
|
||||
-- in a single round-trip. For each task whose workspace is in the
|
||||
-- given set, fetches:
|
||||
-- - the latest workspace app binding (task_workspace_apps)
|
||||
-- - the most recent stop and start builds (workspace_builds)
|
||||
-- - the last "working" app status (workspace_app_statuses)
|
||||
-- - the first app status after resume, for active workspaces
|
||||
--
|
||||
-- Assumptions:
|
||||
-- - 1:1 relationship between tasks and workspaces. All builds on the
|
||||
-- workspace are considered task-related.
|
||||
-- - Idle duration approximation: If the agent reports "working", does
|
||||
-- work, then reports "done", we miss that working time.
|
||||
-- - lws and active_dur join across all historical app IDs for the task,
|
||||
-- because each resume cycle provisions a new app ID. This ensures
|
||||
-- pre-pause statuses contribute to idle duration and active duration.
|
||||
WITH task_app_ids AS (
|
||||
SELECT task_id, workspace_app_id
|
||||
FROM task_workspace_apps
|
||||
),
|
||||
task_status_timeline AS (
|
||||
-- All app statuses across every historical app for each task,
|
||||
-- plus synthetic "boundary" rows at each stop/start build transition.
|
||||
-- This allows us to correctly take gaps due to pause/resume into account.
|
||||
SELECT tai.task_id, was.created_at, was.state::text AS state
|
||||
FROM workspace_app_statuses was
|
||||
JOIN task_app_ids tai ON tai.workspace_app_id = was.app_id
|
||||
UNION ALL
|
||||
SELECT t.id AS task_id, wb.created_at, '_boundary' AS state
|
||||
FROM tasks t
|
||||
JOIN workspace_builds wb ON wb.workspace_id = t.workspace_id
|
||||
WHERE t.deleted_at IS NULL
|
||||
AND t.workspace_id IS NOT NULL
|
||||
AND wb.build_number > 1
|
||||
),
|
||||
task_event_data AS (
|
||||
SELECT
|
||||
t.id AS task_id,
|
||||
t.workspace_id,
|
||||
twa.workspace_app_id,
|
||||
-- Latest stop build.
|
||||
stop_build.created_at AS stop_build_created_at,
|
||||
stop_build.reason AS stop_build_reason,
|
||||
-- Latest start build (task_resume only).
|
||||
start_build.created_at AS start_build_created_at,
|
||||
start_build.reason AS start_build_reason,
|
||||
start_build.build_number AS start_build_number,
|
||||
-- Last "working" app status (for idle duration).
|
||||
lws.created_at AS last_working_status_at,
|
||||
-- First app status after resume (for resume-to-status duration).
|
||||
-- Only populated for workspaces in an active phase (started more
|
||||
-- recently than stopped).
|
||||
fsar.created_at AS first_status_after_resume_at,
|
||||
-- Cumulative time spent in "working" state.
|
||||
active_dur.total_working_ms AS active_duration_ms
|
||||
FROM tasks t
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT task_app.workspace_app_id
|
||||
FROM task_workspace_apps task_app
|
||||
WHERE task_app.task_id = t.id
|
||||
ORDER BY task_app.workspace_build_number DESC
|
||||
LIMIT 1
|
||||
) twa ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT wb.created_at, wb.reason, wb.build_number
|
||||
FROM workspace_builds wb
|
||||
WHERE wb.workspace_id = t.workspace_id
|
||||
AND wb.transition = 'stop'
|
||||
ORDER BY wb.build_number DESC
|
||||
LIMIT 1
|
||||
) stop_build ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT wb.created_at, wb.reason, wb.build_number
|
||||
FROM workspace_builds wb
|
||||
WHERE wb.workspace_id = t.workspace_id
|
||||
AND wb.transition = 'start'
|
||||
ORDER BY wb.build_number DESC
|
||||
LIMIT 1
|
||||
) start_build ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT tst.created_at
|
||||
FROM task_status_timeline tst
|
||||
WHERE tst.task_id = t.id
|
||||
AND tst.state = 'working'
|
||||
-- Only consider status before the latest pause so that
|
||||
-- post-resume statuses don't mask pre-pause idle time.
|
||||
AND (stop_build.created_at IS NULL
|
||||
OR tst.created_at <= stop_build.created_at)
|
||||
ORDER BY tst.created_at DESC
|
||||
LIMIT 1
|
||||
) lws ON TRUE
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT was.created_at
|
||||
FROM workspace_app_statuses was
|
||||
WHERE was.app_id = twa.workspace_app_id
|
||||
AND was.created_at > start_build.created_at
|
||||
ORDER BY was.created_at ASC
|
||||
LIMIT 1
|
||||
) fsar ON twa.workspace_app_id IS NOT NULL
|
||||
AND start_build.created_at IS NOT NULL
|
||||
AND (stop_build.created_at IS NULL
|
||||
OR start_build.created_at > stop_build.created_at)
|
||||
-- Active duration: cumulative time spent in "working" state across all
|
||||
-- historical app IDs for this task. Uses LEAD() to convert point-in-time
|
||||
-- statuses into intervals, then sums intervals where state='working'. For
|
||||
-- the last status, falls back to stop_build time (if paused) or @now (if
|
||||
-- still running).
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT COALESCE(
|
||||
SUM(EXTRACT(EPOCH FROM (interval_end - interval_start)) * 1000)::bigint,
|
||||
0
|
||||
)::bigint AS total_working_ms
|
||||
FROM (
|
||||
SELECT
|
||||
tst.created_at AS interval_start,
|
||||
COALESCE(
|
||||
LEAD(tst.created_at) OVER (ORDER BY tst.created_at ASC, CASE WHEN tst.state = '_boundary' THEN 1 ELSE 0 END ASC),
|
||||
CASE WHEN stop_build.created_at IS NOT NULL
|
||||
AND (start_build.created_at IS NULL
|
||||
OR stop_build.created_at > start_build.created_at)
|
||||
THEN stop_build.created_at
|
||||
ELSE @now::timestamptz
|
||||
END
|
||||
) AS interval_end,
|
||||
tst.state
|
||||
FROM task_status_timeline tst
|
||||
WHERE tst.task_id = t.id
|
||||
) intervals
|
||||
WHERE intervals.state = 'working'
|
||||
) active_dur ON TRUE
|
||||
WHERE t.deleted_at IS NULL
|
||||
AND t.workspace_id IS NOT NULL
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM workspace_builds wb
|
||||
WHERE wb.workspace_id = t.workspace_id
|
||||
AND wb.created_at > @created_after
|
||||
)
|
||||
)
|
||||
SELECT * FROM task_event_data
|
||||
ORDER BY task_id;
|
||||
|
||||
|
||||
@@ -87,4 +87,3 @@ SELECT DISTINCT ON (workspace_id)
|
||||
FROM workspace_app_statuses
|
||||
WHERE workspace_id = ANY(@ids :: uuid[])
|
||||
ORDER BY workspace_id, created_at DESC;
|
||||
|
||||
|
||||
@@ -271,23 +271,3 @@ JOIN workspace_resources wr ON wr.job_id = wb.job_id
|
||||
JOIN workspace_agents wa ON wa.resource_id = wr.id
|
||||
WHERE wb.job_id = (SELECT job_id FROM workspace_resources WHERE workspace_resources.id = $1)
|
||||
GROUP BY wb.created_at, wb.transition, t.name, o.name, w.owner_id;
|
||||
|
||||
-- name: GetWorkspaceBuildProvisionerStateByID :one
|
||||
-- Fetches the provisioner state of a workspace build, joined through to the
|
||||
-- template so that dbauthz can enforce policy.ActionUpdate on the template.
|
||||
-- Provisioner state contains sensitive Terraform state and should only be
|
||||
-- accessible to template administrators.
|
||||
SELECT
|
||||
workspace_builds.provisioner_state,
|
||||
templates.id AS template_id,
|
||||
templates.organization_id AS template_organization_id,
|
||||
templates.user_acl,
|
||||
templates.group_acl
|
||||
FROM
|
||||
workspace_builds
|
||||
INNER JOIN
|
||||
workspaces ON workspaces.id = workspace_builds.workspace_id
|
||||
INNER JOIN
|
||||
templates ON templates.id = workspaces.template_id
|
||||
WHERE
|
||||
workspace_builds.id = @workspace_build_id;
|
||||
|
||||
@@ -124,24 +124,6 @@ sql:
|
||||
- column: "tasks_with_status.workspace_app_health"
|
||||
go_type:
|
||||
type: "NullWorkspaceAppHealth"
|
||||
# Workaround for sqlc not interpreting the left join correctly
|
||||
# in the combined telemetry query.
|
||||
- column: "task_event_data.start_build_number"
|
||||
go_type: "database/sql.NullInt32"
|
||||
- column: "task_event_data.stop_build_created_at"
|
||||
go_type: "database/sql.NullTime"
|
||||
- column: "task_event_data.stop_build_reason"
|
||||
go_type:
|
||||
type: "NullBuildReason"
|
||||
- column: "task_event_data.start_build_created_at"
|
||||
go_type: "database/sql.NullTime"
|
||||
- column: "task_event_data.start_build_reason"
|
||||
go_type:
|
||||
type: "NullBuildReason"
|
||||
- column: "task_event_data.last_working_status_at"
|
||||
go_type: "database/sql.NullTime"
|
||||
- column: "task_event_data.first_status_after_resume_at"
|
||||
go_type: "database/sql.NullTime"
|
||||
rename:
|
||||
group_member: GroupMemberTable
|
||||
group_members_expanded: GroupMember
|
||||
|
||||
@@ -228,11 +228,12 @@ func (p *QueryParamParser) RedirectURL(vals url.Values, base *url.URL, queryPara
|
||||
})
|
||||
}
|
||||
|
||||
// OAuth 2.1 requires exact redirect URI matching.
|
||||
if v.String() != base.String() {
|
||||
// It can be a sub-directory but not a sub-domain, as we have apps on
|
||||
// sub-domains and that seems too dangerous.
|
||||
if v.Host != base.Host || !strings.HasPrefix(v.Path, base.Path) {
|
||||
p.Errors = append(p.Errors, codersdk.ValidationError{
|
||||
Field: queryParam,
|
||||
Detail: fmt.Sprintf("Query param %q must exactly match %s", queryParam, base),
|
||||
Detail: fmt.Sprintf("Query param %q must be a subset of %s", queryParam, base),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -62,12 +62,8 @@ func CSRF(cookieCfg codersdk.HTTPCookieConfig) func(next http.Handler) http.Hand
|
||||
mw.ExemptRegexp(regexp.MustCompile("/organizations/[^/]+/provisionerdaemons/*"))
|
||||
|
||||
mw.ExemptFunc(func(r *http.Request) bool {
|
||||
// Enforce CSRF on API routes and the OAuth2 authorize
|
||||
// endpoint. The authorize endpoint serves a browser consent
|
||||
// form whose POST must be CSRF-protected to prevent
|
||||
// cross-site authorization code theft (coder/security#121).
|
||||
if !strings.HasPrefix(r.URL.Path, "/api") &&
|
||||
!strings.HasPrefix(r.URL.Path, "/oauth2/authorize") {
|
||||
// Only enforce CSRF on API routes.
|
||||
if !strings.HasPrefix(r.URL.Path, "/api") {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -51,26 +51,6 @@ func TestCSRFExemptList(t *testing.T) {
|
||||
URL: "https://coder.com/api/v2/me",
|
||||
Exempt: false,
|
||||
},
|
||||
{
|
||||
Name: "OAuth2Authorize",
|
||||
URL: "https://coder.com/oauth2/authorize",
|
||||
Exempt: false,
|
||||
},
|
||||
{
|
||||
Name: "OAuth2AuthorizeQuery",
|
||||
URL: "https://coder.com/oauth2/authorize?client_id=test",
|
||||
Exempt: false,
|
||||
},
|
||||
{
|
||||
Name: "OAuth2Tokens",
|
||||
URL: "https://coder.com/oauth2/tokens",
|
||||
Exempt: true,
|
||||
},
|
||||
{
|
||||
Name: "OAuth2Register",
|
||||
URL: "https://coder.com/oauth2/register",
|
||||
Exempt: true,
|
||||
},
|
||||
}
|
||||
|
||||
mw := httpmw.CSRF(codersdk.HTTPCookieConfig{})
|
||||
|
||||
@@ -348,12 +348,8 @@ func reapJob(ctx context.Context, log slog.Logger, db database.Store, pub pubsub
|
||||
|
||||
// Only copy the provisioner state if there's no state in
|
||||
// the current build.
|
||||
currentStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get workspace build provisioner state: %w", err)
|
||||
}
|
||||
if len(currentStateRow.ProvisionerState) == 0 {
|
||||
// Get the previous build's state if it exists.
|
||||
if len(build.ProvisionerState) == 0 {
|
||||
// Get the previous build if it exists.
|
||||
prevBuild, err := db.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
|
||||
WorkspaceID: build.WorkspaceID,
|
||||
BuildNumber: build.BuildNumber - 1,
|
||||
@@ -362,14 +358,10 @@ func reapJob(ctx context.Context, log slog.Logger, db database.Store, pub pubsub
|
||||
return xerrors.Errorf("get previous workspace build: %w", err)
|
||||
}
|
||||
if err == nil {
|
||||
prevStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, prevBuild.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get previous workspace build provisioner state: %w", err)
|
||||
}
|
||||
err = db.UpdateWorkspaceBuildProvisionerStateByID(ctx, database.UpdateWorkspaceBuildProvisionerStateByIDParams{
|
||||
ID: build.ID,
|
||||
UpdatedAt: dbtime.Now(),
|
||||
ProvisionerState: prevStateRow.ProvisionerState,
|
||||
ProvisionerState: prevBuild.ProvisionerState,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update workspace build by id: %w", err)
|
||||
|
||||
@@ -126,9 +126,9 @@ func TestDetectorHungWorkspaceBuild(t *testing.T) {
|
||||
previousBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
|
||||
ProvisionerState(expectedWorkspaceBuildState).
|
||||
Succeeded(dbfake.WithJobCompletedAt(twentyMinAgo)).
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
|
||||
ProvisionerState: expectedWorkspaceBuildState,
|
||||
}).Succeeded(dbfake.WithJobCompletedAt(twentyMinAgo)).
|
||||
Do()
|
||||
|
||||
// Current build (hung - running job with UpdatedAt > 5 min ago).
|
||||
@@ -163,9 +163,7 @@ func TestDetectorHungWorkspaceBuild(t *testing.T) {
|
||||
// Check that the provisioner state was copied.
|
||||
build, err := db.GetWorkspaceBuildByID(ctx, currentBuild.Build.ID)
|
||||
require.NoError(t, err)
|
||||
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedWorkspaceBuildState, provisionerStateRow.ProvisionerState)
|
||||
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
|
||||
|
||||
detector.Close()
|
||||
detector.Wait()
|
||||
@@ -196,9 +194,9 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
|
||||
previousBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
|
||||
ProvisionerState([]byte(`{"dean":"NOT cool","colin":"also NOT cool"}`)).
|
||||
Succeeded(dbfake.WithJobCompletedAt(twentyMinAgo)).
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
|
||||
ProvisionerState: []byte(`{"dean":"NOT cool","colin":"also NOT cool"}`),
|
||||
}).Succeeded(dbfake.WithJobCompletedAt(twentyMinAgo)).
|
||||
Do()
|
||||
|
||||
// Current build (hung - running job with UpdatedAt > 5 min ago).
|
||||
@@ -206,8 +204,9 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
|
||||
currentBuild := dbfake.WorkspaceBuild(t, db, previousBuild.Workspace).
|
||||
Pubsub(pubsub).
|
||||
Seed(database.WorkspaceBuild{
|
||||
BuildNumber: 2,
|
||||
}).ProvisionerState(expectedWorkspaceBuildState).
|
||||
BuildNumber: 2,
|
||||
ProvisionerState: expectedWorkspaceBuildState,
|
||||
}).
|
||||
Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
|
||||
Do()
|
||||
|
||||
@@ -236,9 +235,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
|
||||
// Check that the provisioner state was NOT copied.
|
||||
build, err := db.GetWorkspaceBuildByID(ctx, currentBuild.Build.ID)
|
||||
require.NoError(t, err)
|
||||
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedWorkspaceBuildState, provisionerStateRow.ProvisionerState)
|
||||
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
|
||||
|
||||
detector.Close()
|
||||
detector.Wait()
|
||||
@@ -269,9 +266,9 @@ func TestDetectorHungWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testing.T
|
||||
currentBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
|
||||
ProvisionerState(expectedWorkspaceBuildState).
|
||||
Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
|
||||
ProvisionerState: expectedWorkspaceBuildState,
|
||||
}).Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
|
||||
Do()
|
||||
|
||||
t.Log("current job ID: ", currentBuild.Build.JobID)
|
||||
@@ -298,9 +295,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testing.T
|
||||
// Check that the provisioner state was NOT updated.
|
||||
build, err := db.GetWorkspaceBuildByID(ctx, currentBuild.Build.ID)
|
||||
require.NoError(t, err)
|
||||
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedWorkspaceBuildState, provisionerStateRow.ProvisionerState)
|
||||
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
|
||||
|
||||
detector.Close()
|
||||
detector.Wait()
|
||||
@@ -330,9 +325,9 @@ func TestDetectorPendingWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testin
|
||||
currentBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
|
||||
ProvisionerState(expectedWorkspaceBuildState).
|
||||
Pending(dbfake.WithJobCreatedAt(thirtyFiveMinAgo), dbfake.WithJobUpdatedAt(thirtyFiveMinAgo)).
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
|
||||
ProvisionerState: expectedWorkspaceBuildState,
|
||||
}).Pending(dbfake.WithJobCreatedAt(thirtyFiveMinAgo), dbfake.WithJobUpdatedAt(thirtyFiveMinAgo)).
|
||||
Do()
|
||||
|
||||
t.Log("current job ID: ", currentBuild.Build.JobID)
|
||||
@@ -361,9 +356,7 @@ func TestDetectorPendingWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testin
|
||||
// Check that the provisioner state was NOT updated.
|
||||
build, err := db.GetWorkspaceBuildByID(ctx, currentBuild.Build.ID)
|
||||
require.NoError(t, err)
|
||||
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedWorkspaceBuildState, provisionerStateRow.ProvisionerState)
|
||||
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
|
||||
|
||||
detector.Close()
|
||||
detector.Wait()
|
||||
@@ -405,9 +398,9 @@ func TestDetectorWorkspaceBuildForDormantWorkspace(t *testing.T) {
|
||||
Time: now.Add(-time.Hour),
|
||||
Valid: true,
|
||||
},
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{}).
|
||||
ProvisionerState(expectedWorkspaceBuildState).
|
||||
Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
|
||||
}).Pubsub(pubsub).Seed(database.WorkspaceBuild{
|
||||
ProvisionerState: expectedWorkspaceBuildState,
|
||||
}).Starting(dbfake.WithJobStartedAt(tenMinAgo), dbfake.WithJobUpdatedAt(sixMinAgo)).
|
||||
Do()
|
||||
|
||||
t.Log("current job ID: ", currentBuild.Build.JobID)
|
||||
|
||||
+22
-40
@@ -2,8 +2,6 @@ package mcp_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -12,7 +10,6 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
mcpclient "github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
@@ -27,15 +24,6 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// mcpGeneratePKCE creates a PKCE verifier and S256 challenge for MCP
|
||||
// e2e tests.
|
||||
func mcpGeneratePKCE() (verifier, challenge string) {
|
||||
verifier = uuid.NewString() + uuid.NewString()
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||
return verifier, challenge
|
||||
}
|
||||
|
||||
func TestMCPHTTP_E2E_ClientIntegration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -565,32 +553,31 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
|
||||
// In a real flow, this would be done through the browser consent page
|
||||
// For testing, we'll create the code directly using the internal API
|
||||
|
||||
// First, we need to authorize the app (simulating user consent).
|
||||
staticVerifier, staticChallenge := mcpGeneratePKCE()
|
||||
authURL := fmt.Sprintf("%s/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&state=test_state&code_challenge=%s&code_challenge_method=S256",
|
||||
api.AccessURL.String(), app.ID, "http://localhost:3000/callback", staticChallenge)
|
||||
// First, we need to authorize the app (simulating user consent)
|
||||
authURL := fmt.Sprintf("%s/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&state=test_state",
|
||||
api.AccessURL.String(), app.ID, "http://localhost:3000/callback")
|
||||
|
||||
// Create an HTTP client that follows redirects but captures the final redirect.
|
||||
// Create an HTTP client that follows redirects but captures the final redirect
|
||||
client := &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse // Stop following redirects
|
||||
},
|
||||
}
|
||||
|
||||
// Make the authorization request (this would normally be done in a browser).
|
||||
// Make the authorization request (this would normally be done in a browser)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", authURL, nil)
|
||||
require.NoError(t, err)
|
||||
// Use RFC 6750 Bearer token for authentication.
|
||||
// Use RFC 6750 Bearer token for authentication
|
||||
req.Header.Set("Authorization", "Bearer "+coderClient.SessionToken())
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// The response should be a redirect to the consent page or directly to callback.
|
||||
// For testing purposes, let's simulate the POST consent approval.
|
||||
// The response should be a redirect to the consent page or directly to callback
|
||||
// For testing purposes, let's simulate the POST consent approval
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
// This means we got the consent page, now we need to POST consent.
|
||||
// This means we got the consent page, now we need to POST consent
|
||||
consentReq, err := http.NewRequestWithContext(ctx, "POST", authURL, nil)
|
||||
require.NoError(t, err)
|
||||
consentReq.Header.Set("Authorization", "Bearer "+coderClient.SessionToken())
|
||||
@@ -601,7 +588,7 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
// Extract authorization code from redirect URL.
|
||||
// Extract authorization code from redirect URL
|
||||
require.True(t, resp.StatusCode >= 300 && resp.StatusCode < 400, "Expected redirect response")
|
||||
location := resp.Header.Get("Location")
|
||||
require.NotEmpty(t, location, "Expected Location header in redirect")
|
||||
@@ -613,14 +600,13 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
|
||||
|
||||
t.Logf("Successfully obtained authorization code: %s", authCode[:10]+"...")
|
||||
|
||||
// Step 2: Exchange authorization code for access token and refresh token.
|
||||
// Step 2: Exchange authorization code for access token and refresh token
|
||||
tokenRequestBody := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {app.ID.String()},
|
||||
"client_secret": {secret.ClientSecretFull},
|
||||
"code": {authCode},
|
||||
"redirect_uri": {"http://localhost:3000/callback"},
|
||||
"code_verifier": {staticVerifier},
|
||||
}
|
||||
|
||||
tokenReq, err := http.NewRequestWithContext(ctx, "POST", api.AccessURL.String()+"/oauth2/tokens",
|
||||
@@ -882,44 +868,41 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
|
||||
|
||||
t.Logf("Successfully registered dynamic client: %s", clientID)
|
||||
|
||||
// Step 3: Perform OAuth2 authorization code flow with dynamically registered client.
|
||||
dynamicVerifier, dynamicChallenge := mcpGeneratePKCE()
|
||||
authURL := fmt.Sprintf("%s/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&state=dynamic_state&code_challenge=%s&code_challenge_method=S256",
|
||||
api.AccessURL.String(), clientID, "http://localhost:3000/callback", dynamicChallenge)
|
||||
// Step 3: Perform OAuth2 authorization code flow with dynamically registered client
|
||||
authURL := fmt.Sprintf("%s/oauth2/authorize?client_id=%s&response_type=code&redirect_uri=%s&state=dynamic_state",
|
||||
api.AccessURL.String(), clientID, "http://localhost:3000/callback")
|
||||
|
||||
// Create an HTTP client that captures redirects.
|
||||
// Create an HTTP client that captures redirects
|
||||
authClient := &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse // Stop following redirects
|
||||
},
|
||||
}
|
||||
|
||||
// Make the authorization request with authentication.
|
||||
// Make the authorization request with authentication
|
||||
authReq, err := http.NewRequestWithContext(ctx, "GET", authURL, nil)
|
||||
require.NoError(t, err)
|
||||
authReq.Header.Set("Cookie", fmt.Sprintf("coder_session_token=%s", coderClient.SessionToken()))
|
||||
authReq.Header.Set("Authorization", "Bearer "+coderClient.SessionToken())
|
||||
|
||||
authResp, err := authClient.Do(authReq)
|
||||
require.NoError(t, err)
|
||||
defer authResp.Body.Close()
|
||||
|
||||
// Handle the response - check for error first.
|
||||
// Handle the response - check for error first
|
||||
if authResp.StatusCode == http.StatusBadRequest {
|
||||
// Read error response for debugging.
|
||||
// Read error response for debugging
|
||||
bodyBytes, err := io.ReadAll(authResp.Body)
|
||||
require.NoError(t, err)
|
||||
t.Logf("OAuth2 authorization error: %s", string(bodyBytes))
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// Handle consent flow if needed.
|
||||
// Handle consent flow if needed
|
||||
if authResp.StatusCode == http.StatusOK {
|
||||
// This means we got the consent page, now we need to POST consent.
|
||||
// This means we got the consent page, now we need to POST consent
|
||||
consentReq, err := http.NewRequestWithContext(ctx, "POST", authURL, nil)
|
||||
require.NoError(t, err)
|
||||
consentReq.Header.Set("Cookie", fmt.Sprintf("coder_session_token=%s", coderClient.SessionToken()))
|
||||
consentReq.Header.Set("Authorization", "Bearer "+coderClient.SessionToken())
|
||||
consentReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
authResp, err = authClient.Do(consentReq)
|
||||
@@ -927,7 +910,7 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
|
||||
defer authResp.Body.Close()
|
||||
}
|
||||
|
||||
// Extract authorization code from redirect.
|
||||
// Extract authorization code from redirect
|
||||
require.True(t, authResp.StatusCode >= 300 && authResp.StatusCode < 400,
|
||||
"Expected redirect response, got %d", authResp.StatusCode)
|
||||
location := authResp.Header.Get("Location")
|
||||
@@ -940,14 +923,13 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
|
||||
|
||||
t.Logf("Successfully obtained authorization code: %s", authCode[:10]+"...")
|
||||
|
||||
// Step 4: Exchange authorization code for access token.
|
||||
// Step 4: Exchange authorization code for access token
|
||||
tokenRequestBody := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {clientID},
|
||||
"client_secret": {clientSecret},
|
||||
"code": {authCode},
|
||||
"redirect_uri": {"http://localhost:3000/callback"},
|
||||
"code_verifier": {dynamicVerifier},
|
||||
}
|
||||
|
||||
tokenReq, err := http.NewRequestWithContext(ctx, "POST", api.AccessURL.String()+"/oauth2/tokens",
|
||||
|
||||
+41
-92
@@ -2,8 +2,6 @@ package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -291,6 +289,7 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
|
||||
authError: "Invalid query params:",
|
||||
},
|
||||
{
|
||||
// TODO: This is valid for now, but should it be?
|
||||
name: "DifferentProtocol",
|
||||
app: apps.Default,
|
||||
preAuth: func(valid *oauth2.Config) {
|
||||
@@ -298,7 +297,6 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
|
||||
newURL.Scheme = "https"
|
||||
valid.RedirectURL = newURL.String()
|
||||
},
|
||||
authError: "Invalid query params:",
|
||||
},
|
||||
{
|
||||
name: "NestedPath",
|
||||
@@ -308,7 +306,6 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
|
||||
newURL.Path = path.Join(newURL.Path, "nested")
|
||||
valid.RedirectURL = newURL.String()
|
||||
},
|
||||
authError: "Invalid query params:",
|
||||
},
|
||||
{
|
||||
// Some oauth implementations allow this, but our users can host
|
||||
@@ -484,12 +481,11 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
|
||||
}
|
||||
|
||||
var code string
|
||||
var verifier string
|
||||
if test.defaultCode != nil {
|
||||
code = *test.defaultCode
|
||||
} else {
|
||||
var err error
|
||||
code, verifier, err = authorizationFlow(ctx, userClient, valid)
|
||||
code, err = authorizationFlow(ctx, userClient, valid)
|
||||
if test.authError != "" {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, test.authError)
|
||||
@@ -504,12 +500,8 @@ func TestOAuth2ProviderTokenExchange(t *testing.T) {
|
||||
test.preToken(valid)
|
||||
}
|
||||
|
||||
// Do the actual exchange. Include PKCE code_verifier when
|
||||
// we obtained a code through the authorization flow.
|
||||
exchangeOpts := append([]oauth2.AuthCodeOption{
|
||||
oauth2.SetAuthURLParam("code_verifier", verifier),
|
||||
}, test.exchangeMutate...)
|
||||
token, err := valid.Exchange(ctx, code, exchangeOpts...)
|
||||
// Do the actual exchange.
|
||||
token, err := valid.Exchange(ctx, code, test.exchangeMutate...)
|
||||
if test.tokenError != "" {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, test.tokenError)
|
||||
@@ -691,11 +683,10 @@ func TestOAuth2ProviderTokenRefresh(t *testing.T) {
|
||||
}
|
||||
|
||||
type exchangeSetup struct {
|
||||
cfg *oauth2.Config
|
||||
app codersdk.OAuth2ProviderApp
|
||||
secret codersdk.OAuth2ProviderAppSecretFull
|
||||
code string
|
||||
verifier string
|
||||
cfg *oauth2.Config
|
||||
app codersdk.OAuth2ProviderApp
|
||||
secret codersdk.OAuth2ProviderAppSecretFull
|
||||
code string
|
||||
}
|
||||
|
||||
func TestOAuth2ProviderRevoke(t *testing.T) {
|
||||
@@ -739,13 +730,11 @@ func TestOAuth2ProviderRevoke(t *testing.T) {
|
||||
name: "OverrideCodeAndToken",
|
||||
fn: func(ctx context.Context, client *codersdk.Client, s exchangeSetup) {
|
||||
// Generating a new code should wipe out the old code.
|
||||
code, verifier, err := authorizationFlow(ctx, client, s.cfg)
|
||||
code, err := authorizationFlow(ctx, client, s.cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generating a new token should wipe out the old token.
|
||||
_, err = s.cfg.Exchange(ctx, code,
|
||||
oauth2.SetAuthURLParam("code_verifier", verifier),
|
||||
)
|
||||
_, err = s.cfg.Exchange(ctx, code)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
replacesToken: true,
|
||||
@@ -781,15 +770,14 @@ func TestOAuth2ProviderRevoke(t *testing.T) {
|
||||
}
|
||||
|
||||
// Go through the auth flow to get a code.
|
||||
code, verifier, err := authorizationFlow(ctx, testClient, cfg)
|
||||
code, err := authorizationFlow(ctx, testClient, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
return exchangeSetup{
|
||||
cfg: cfg,
|
||||
app: app,
|
||||
secret: secret,
|
||||
code: code,
|
||||
verifier: verifier,
|
||||
cfg: cfg,
|
||||
app: app,
|
||||
secret: secret,
|
||||
code: code,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -806,16 +794,12 @@ func TestOAuth2ProviderRevoke(t *testing.T) {
|
||||
test.fn(ctx, testClient, testEntities)
|
||||
|
||||
// Exchange should fail because the code should be gone.
|
||||
_, err := testEntities.cfg.Exchange(ctx, testEntities.code,
|
||||
oauth2.SetAuthURLParam("code_verifier", testEntities.verifier),
|
||||
)
|
||||
_, err := testEntities.cfg.Exchange(ctx, testEntities.code)
|
||||
require.Error(t, err)
|
||||
|
||||
// Try again, this time letting the exchange complete first.
|
||||
testEntities = setup(ctx, testClient, test.name+"-2")
|
||||
token, err := testEntities.cfg.Exchange(ctx, testEntities.code,
|
||||
oauth2.SetAuthURLParam("code_verifier", testEntities.verifier),
|
||||
)
|
||||
token, err := testEntities.cfg.Exchange(ctx, testEntities.code)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate the returned access token and that the app is listed.
|
||||
@@ -888,38 +872,25 @@ func generateApps(ctx context.Context, t *testing.T, client *codersdk.Client, su
|
||||
}
|
||||
}
|
||||
|
||||
// generatePKCE creates a PKCE verifier and S256 challenge for testing.
|
||||
func generatePKCE() (verifier, challenge string) {
|
||||
verifier = uuid.NewString() + uuid.NewString()
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||
return verifier, challenge
|
||||
}
|
||||
|
||||
func authorizationFlow(ctx context.Context, client *codersdk.Client, cfg *oauth2.Config) (code, codeVerifier string, err error) {
|
||||
func authorizationFlow(ctx context.Context, client *codersdk.Client, cfg *oauth2.Config) (string, error) {
|
||||
state := uuid.NewString()
|
||||
codeVerifier, challenge := generatePKCE()
|
||||
authURL := cfg.AuthCodeURL(state,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
)
|
||||
authURL := cfg.AuthCodeURL(state)
|
||||
|
||||
// Make a POST request to simulate clicking "Allow" on the authorization page.
|
||||
// This bypasses the HTML consent page and directly processes the authorization.
|
||||
code, err = oidctest.OAuth2GetCode(
|
||||
// Make a POST request to simulate clicking "Allow" on the authorization page
|
||||
// This bypasses the HTML consent page and directly processes the authorization
|
||||
return oidctest.OAuth2GetCode(
|
||||
authURL,
|
||||
func(req *http.Request) (*http.Response, error) {
|
||||
// Change to POST to simulate the form submission.
|
||||
// Change to POST to simulate the form submission
|
||||
req.Method = http.MethodPost
|
||||
|
||||
// Prevent automatic redirect following.
|
||||
// Prevent automatic redirect following
|
||||
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
return client.Request(ctx, req.Method, req.URL.String(), nil)
|
||||
},
|
||||
)
|
||||
return code, codeVerifier, err
|
||||
}
|
||||
|
||||
func must[T any](value T, err error) T {
|
||||
@@ -1026,15 +997,11 @@ func TestOAuth2ProviderResourceIndicators(t *testing.T) {
|
||||
Scopes: []string{},
|
||||
}
|
||||
|
||||
// Step 1: Authorization with resource parameter and PKCE.
|
||||
// Step 1: Authorization with resource parameter
|
||||
state := uuid.NewString()
|
||||
verifier, challenge := generatePKCE()
|
||||
authURL := cfg.AuthCodeURL(state,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
)
|
||||
authURL := cfg.AuthCodeURL(state)
|
||||
if test.authResource != "" {
|
||||
// Add resource parameter to auth URL.
|
||||
// Add resource parameter to auth URL
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
require.NoError(t, err)
|
||||
query := parsedURL.Query()
|
||||
@@ -1063,7 +1030,7 @@ func TestOAuth2ProviderResourceIndicators(t *testing.T) {
|
||||
|
||||
// Step 2: Token exchange with resource parameter
|
||||
// Use custom token exchange since golang.org/x/oauth2 doesn't support resource parameter in token requests
|
||||
token, err := customTokenExchange(ctx, ownerClient.URL.String(), apps.Default.ID.String(), secret.ClientSecretFull, code, apps.Default.CallbackURL, test.tokenResource, verifier)
|
||||
token, err := customTokenExchange(ctx, ownerClient.URL.String(), apps.Default.ID.String(), secret.ClientSecretFull, code, apps.Default.CallbackURL, test.tokenResource)
|
||||
if test.expectTokenError {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid_target")
|
||||
@@ -1160,13 +1127,9 @@ func TestOAuth2ProviderCrossResourceAudienceValidation(t *testing.T) {
|
||||
Scopes: []string{},
|
||||
}
|
||||
|
||||
// Authorization with resource parameter for server1 and PKCE.
|
||||
// Authorization with resource parameter for server1
|
||||
state := uuid.NewString()
|
||||
verifier, challenge := generatePKCE()
|
||||
authURL := cfg.AuthCodeURL(state,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
)
|
||||
authURL := cfg.AuthCodeURL(state)
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
require.NoError(t, err)
|
||||
query := parsedURL.Query()
|
||||
@@ -1186,11 +1149,8 @@ func TestOAuth2ProviderCrossResourceAudienceValidation(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exchange code for token with resource parameter and PKCE verifier.
|
||||
token, err := cfg.Exchange(ctx, code,
|
||||
oauth2.SetAuthURLParam("resource", resource1),
|
||||
oauth2.SetAuthURLParam("code_verifier", verifier),
|
||||
)
|
||||
// Exchange code for token with resource parameter
|
||||
token, err := cfg.Exchange(ctx, code, oauth2.SetAuthURLParam("resource", resource1))
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token.AccessToken)
|
||||
|
||||
@@ -1266,11 +1226,9 @@ func TestOAuth2RefreshExpiryOutlivesAccess(t *testing.T) {
|
||||
}
|
||||
|
||||
// Authorization and token exchange
|
||||
code, verifier, err := authorizationFlow(ctx, ownerClient, cfg)
|
||||
code, err := authorizationFlow(ctx, ownerClient, cfg)
|
||||
require.NoError(t, err)
|
||||
tok, err := cfg.Exchange(ctx, code,
|
||||
oauth2.SetAuthURLParam("code_verifier", verifier),
|
||||
)
|
||||
tok, err := cfg.Exchange(ctx, code)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, tok.AccessToken)
|
||||
require.NotEmpty(t, tok.RefreshToken)
|
||||
@@ -1295,7 +1253,7 @@ func TestOAuth2RefreshExpiryOutlivesAccess(t *testing.T) {
|
||||
|
||||
// customTokenExchange performs a custom OAuth2 token exchange with support for resource parameter
|
||||
// This is needed because golang.org/x/oauth2 doesn't support custom parameters in token requests
|
||||
func customTokenExchange(ctx context.Context, baseURL, clientID, clientSecret, code, redirectURI, resource, codeVerifier string) (*oauth2.Token, error) {
|
||||
func customTokenExchange(ctx context.Context, baseURL, clientID, clientSecret, code, redirectURI, resource string) (*oauth2.Token, error) {
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
@@ -1305,9 +1263,6 @@ func customTokenExchange(ctx context.Context, baseURL, clientID, clientSecret, c
|
||||
if resource != "" {
|
||||
data.Set("resource", resource)
|
||||
}
|
||||
if codeVerifier != "" {
|
||||
data.Set("code_verifier", codeVerifier)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/oauth2/tokens", strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
@@ -1682,21 +1637,17 @@ func TestOAuth2CoderClient(t *testing.T) {
|
||||
// Make a new user
|
||||
client, user := coderdtest.CreateAnotherUser(t, owner, first.OrganizationID)
|
||||
|
||||
// Do an OAuth2 token exchange and get a new client with an oauth token.
|
||||
// Do an OAuth2 token exchange and get a new client with an oauth token
|
||||
state := uuid.NewString()
|
||||
verifier, challenge := generatePKCE()
|
||||
|
||||
// Get an OAuth2 code for a token exchange.
|
||||
// Get an OAuth2 code for a token exchange
|
||||
code, err := oidctest.OAuth2GetCode(
|
||||
cfg.AuthCodeURL(state,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
),
|
||||
cfg.AuthCodeURL(state),
|
||||
func(req *http.Request) (*http.Response, error) {
|
||||
// Change to POST to simulate the form submission.
|
||||
// Change to POST to simulate the form submission
|
||||
req.Method = http.MethodPost
|
||||
|
||||
// Prevent automatic redirect following.
|
||||
// Prevent automatic redirect following
|
||||
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
@@ -1705,9 +1656,7 @@ func TestOAuth2CoderClient(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := cfg.Exchange(ctx, code,
|
||||
oauth2.SetAuthURLParam("code_verifier", verifier),
|
||||
)
|
||||
token, err := cfg.Exchange(ctx, code)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Use the oauth client's authentication
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package oauth2provider
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -11,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/justinas/nosurf"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -25,7 +22,6 @@ import (
|
||||
type authorizeParams struct {
|
||||
clientID string
|
||||
redirectURL *url.URL
|
||||
redirectURIProvided bool
|
||||
responseType codersdk.OAuth2ProviderResponseType
|
||||
scope []string
|
||||
state string
|
||||
@@ -38,13 +34,11 @@ func extractAuthorizeParams(r *http.Request, callbackURL *url.URL) (authorizePar
|
||||
p := httpapi.NewQueryParamParser()
|
||||
vals := r.URL.Query()
|
||||
|
||||
// response_type and client_id are always required.
|
||||
p.RequiredNotEmpty("response_type", "client_id")
|
||||
|
||||
params := authorizeParams{
|
||||
clientID: p.String(vals, "", "client_id"),
|
||||
redirectURL: p.RedirectURL(vals, callbackURL, "redirect_uri"),
|
||||
redirectURIProvided: vals.Get("redirect_uri") != "",
|
||||
responseType: httpapi.ParseCustom(p, vals, "", "response_type", httpapi.ParseEnum[codersdk.OAuth2ProviderResponseType]),
|
||||
scope: strings.Fields(strings.TrimSpace(p.String(vals, "", "scope"))),
|
||||
state: p.String(vals, "", "state"),
|
||||
@@ -52,15 +46,6 @@ func extractAuthorizeParams(r *http.Request, callbackURL *url.URL) (authorizePar
|
||||
codeChallenge: p.String(vals, "", "code_challenge"),
|
||||
codeChallengeMethod: p.String(vals, "", "code_challenge_method"),
|
||||
}
|
||||
|
||||
// PKCE is required for authorization code flow requests.
|
||||
if params.responseType == codersdk.OAuth2ProviderResponseTypeCode && params.codeChallenge == "" {
|
||||
p.Errors = append(p.Errors, codersdk.ValidationError{
|
||||
Field: "code_challenge",
|
||||
Detail: `Query param "code_challenge" is required and cannot be empty`,
|
||||
})
|
||||
}
|
||||
|
||||
// Validate resource indicator syntax (RFC 8707): must be absolute URI without fragment
|
||||
if err := validateResourceParameter(params.resource); err != nil {
|
||||
p.Errors = append(p.Errors, codersdk.ValidationError{
|
||||
@@ -127,22 +112,6 @@ func ShowAuthorizePage(accessURL *url.URL) http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
if params.responseType != codersdk.OAuth2ProviderResponseTypeCode {
|
||||
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
|
||||
Status: http.StatusBadRequest,
|
||||
HideStatus: false,
|
||||
Title: "Unsupported Response Type",
|
||||
Description: "Only response_type=code is supported.",
|
||||
Actions: []site.Action{
|
||||
{
|
||||
URL: accessURL.String(),
|
||||
Text: "Back to site",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
cancel := params.redirectURL
|
||||
cancelQuery := params.redirectURL.Query()
|
||||
cancelQuery.Add("error", "access_denied")
|
||||
@@ -153,7 +122,6 @@ func ShowAuthorizePage(accessURL *url.URL) http.HandlerFunc {
|
||||
AppName: app.Name,
|
||||
CancelURI: cancel.String(),
|
||||
RedirectURI: r.URL.String(),
|
||||
CSRFToken: nosurf.Token(r),
|
||||
Username: ua.FriendlyName,
|
||||
})
|
||||
}
|
||||
@@ -179,23 +147,16 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// OAuth 2.1 removes the implicit grant. Only
|
||||
// authorization code flow is supported.
|
||||
if params.responseType != codersdk.OAuth2ProviderResponseTypeCode {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest,
|
||||
codersdk.OAuth2ErrorCodeUnsupportedResponseType,
|
||||
"Only response_type=code is supported")
|
||||
return
|
||||
}
|
||||
|
||||
// code_challenge is required (enforced by RequiredNotEmpty above),
|
||||
// but default the method to S256 if omitted.
|
||||
if params.codeChallengeMethod == "" {
|
||||
params.codeChallengeMethod = string(codersdk.OAuth2PKCECodeChallengeMethodS256)
|
||||
}
|
||||
if err := codersdk.ValidatePKCECodeChallengeMethod(params.codeChallengeMethod); err != nil {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, err.Error())
|
||||
return
|
||||
// Validate PKCE for public clients (MCP requirement)
|
||||
if params.codeChallenge != "" {
|
||||
// If code_challenge is provided but method is not, default to S256
|
||||
if params.codeChallengeMethod == "" {
|
||||
params.codeChallengeMethod = string(codersdk.OAuth2PKCECodeChallengeMethodS256)
|
||||
}
|
||||
if err := codersdk.ValidatePKCECodeChallengeMethod(params.codeChallengeMethod); err != nil {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Ignoring scope for now, but should look into implementing.
|
||||
@@ -233,8 +194,6 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
|
||||
ResourceUri: sql.NullString{String: params.resource, Valid: params.resource != ""},
|
||||
CodeChallenge: sql.NullString{String: params.codeChallenge, Valid: params.codeChallenge != ""},
|
||||
CodeChallengeMethod: sql.NullString{String: params.codeChallengeMethod, Valid: params.codeChallengeMethod != ""},
|
||||
StateHash: hashOAuth2State(params.state),
|
||||
RedirectUri: sql.NullString{String: params.redirectURL.String(), Valid: params.redirectURIProvided},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert oauth2 authorization code: %w", err)
|
||||
@@ -259,16 +218,3 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
|
||||
http.Redirect(rw, r, params.redirectURL.String(), http.StatusFound)
|
||||
}
|
||||
}
|
||||
|
||||
// hashOAuth2State returns a SHA-256 hash of the OAuth2 state parameter. If
|
||||
// the state is empty, it returns a null string.
|
||||
func hashOAuth2State(state string) sql.NullString {
|
||||
if state == "" {
|
||||
return sql.NullString{}
|
||||
}
|
||||
hash := sha256.Sum256([]byte(state))
|
||||
return sql.NullString{
|
||||
String: hex.EncodeToString(hash[:]),
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
//nolint:testpackage // Internal test for unexported hashOAuth2State helper.
|
||||
package oauth2provider
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHashOAuth2State(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("EmptyState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := hashOAuth2State("")
|
||||
assert.False(t, result.Valid, "empty state should return invalid NullString")
|
||||
assert.Empty(t, result.String, "empty state should return empty string")
|
||||
})
|
||||
|
||||
t.Run("NonEmptyState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
state := "test-state-value"
|
||||
result := hashOAuth2State(state)
|
||||
require.True(t, result.Valid, "non-empty state should return valid NullString")
|
||||
|
||||
// Verify it's a proper SHA-256 hash.
|
||||
expected := sha256.Sum256([]byte(state))
|
||||
assert.Equal(t, hex.EncodeToString(expected[:]), result.String,
|
||||
"state hash should be SHA-256 hex digest")
|
||||
})
|
||||
|
||||
t.Run("DifferentStatesProduceDifferentHashes", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
hash1 := hashOAuth2State("state-a")
|
||||
hash2 := hashOAuth2State("state-b")
|
||||
require.True(t, hash1.Valid)
|
||||
require.True(t, hash2.Valid)
|
||||
assert.NotEqual(t, hash1.String, hash2.String,
|
||||
"different states should produce different hashes")
|
||||
})
|
||||
|
||||
t.Run("SameStateProducesSameHash", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
hash1 := hashOAuth2State("deterministic")
|
||||
hash2 := hashOAuth2State("deterministic")
|
||||
require.True(t, hash1.Valid)
|
||||
assert.Equal(t, hash1.String, hash2.String,
|
||||
"same state should produce identical hash")
|
||||
})
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package oauth2provider_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/site"
|
||||
)
|
||||
|
||||
func TestOAuthConsentFormIncludesCSRFToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const csrfFieldValue = "csrf-field-value"
|
||||
req := httptest.NewRequest(http.MethodGet, "https://coder.com/oauth2/authorize", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
site.RenderOAuthAllowPage(rec, req, site.RenderOAuthAllowData{
|
||||
AppName: "Test OAuth App",
|
||||
CancelURI: "https://coder.com/cancel",
|
||||
RedirectURI: "https://coder.com/oauth2/authorize?client_id=test",
|
||||
CSRFToken: csrfFieldValue,
|
||||
Username: "test-user",
|
||||
})
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Result().StatusCode)
|
||||
assert.Contains(t, rec.Body.String(), `name="csrf_token"`)
|
||||
assert.Contains(t, rec.Body.String(), `value="`+csrfFieldValue+`"`)
|
||||
}
|
||||
@@ -158,9 +158,7 @@ func TestOAuth2InvalidPKCE(t *testing.T) {
|
||||
)
|
||||
}
|
||||
|
||||
// TestOAuth2WithoutPKCEIsRejected verifies that authorization requests without
|
||||
// a code_challenge are rejected now that PKCE is mandatory.
|
||||
func TestOAuth2WithoutPKCEIsRejected(t *testing.T) {
|
||||
func TestOAuth2WithoutPKCE(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
@@ -168,15 +166,15 @@ func TestOAuth2WithoutPKCEIsRejected(t *testing.T) {
|
||||
})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create OAuth2 app.
|
||||
app, _ := oauth2providertest.CreateTestOAuth2App(t, client)
|
||||
// Create OAuth2 app
|
||||
app, clientSecret := oauth2providertest.CreateTestOAuth2App(t, client)
|
||||
t.Cleanup(func() {
|
||||
oauth2providertest.CleanupOAuth2App(t, client, app.ID)
|
||||
})
|
||||
|
||||
state := oauth2providertest.GenerateState(t)
|
||||
|
||||
// Authorization without code_challenge should be rejected.
|
||||
// Perform authorization without PKCE
|
||||
authParams := oauth2providertest.AuthorizeParams{
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
@@ -184,9 +182,21 @@ func TestOAuth2WithoutPKCEIsRejected(t *testing.T) {
|
||||
State: state,
|
||||
}
|
||||
|
||||
oauth2providertest.AuthorizeOAuth2AppExpectingError(
|
||||
t, client, client.URL.String(), authParams, http.StatusBadRequest,
|
||||
)
|
||||
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
|
||||
require.NotEmpty(t, code, "should receive authorization code")
|
||||
|
||||
// Exchange code for token without PKCE
|
||||
tokenParams := oauth2providertest.TokenExchangeParams{
|
||||
GrantType: "authorization_code",
|
||||
Code: code,
|
||||
ClientID: app.ID.String(),
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
}
|
||||
|
||||
token := oauth2providertest.ExchangeCodeForToken(t, client.URL.String(), tokenParams)
|
||||
require.NotEmpty(t, token.AccessToken, "should receive access token")
|
||||
require.NotEmpty(t, token.RefreshToken, "should receive refresh token")
|
||||
}
|
||||
|
||||
func TestOAuth2TokenExchangeClientSecretBasic(t *testing.T) {
|
||||
@@ -202,16 +212,13 @@ func TestOAuth2TokenExchangeClientSecretBasic(t *testing.T) {
|
||||
oauth2providertest.CleanupOAuth2App(t, client, app.ID)
|
||||
})
|
||||
|
||||
codeVerifier, codeChallenge := oauth2providertest.GeneratePKCE(t)
|
||||
state := oauth2providertest.GenerateState(t)
|
||||
|
||||
authParams := oauth2providertest.AuthorizeParams{
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: "S256",
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
}
|
||||
|
||||
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
|
||||
@@ -222,7 +229,6 @@ func TestOAuth2TokenExchangeClientSecretBasic(t *testing.T) {
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", oauth2providertest.TestRedirectURI)
|
||||
data.Set("code_verifier", codeVerifier)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", client.URL.String()+"/oauth2/tokens", strings.NewReader(data.Encode()))
|
||||
require.NoError(t, err, "failed to create token request")
|
||||
@@ -259,16 +265,13 @@ func TestOAuth2TokenExchangeClientSecretBasicInvalidSecret(t *testing.T) {
|
||||
oauth2providertest.CleanupOAuth2App(t, client, app.ID)
|
||||
})
|
||||
|
||||
codeVerifier, codeChallenge := oauth2providertest.GeneratePKCE(t)
|
||||
state := oauth2providertest.GenerateState(t)
|
||||
|
||||
authParams := oauth2providertest.AuthorizeParams{
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: "S256",
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
}
|
||||
|
||||
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
|
||||
@@ -279,7 +282,6 @@ func TestOAuth2TokenExchangeClientSecretBasicInvalidSecret(t *testing.T) {
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", oauth2providertest.TestRedirectURI)
|
||||
data.Set("code_verifier", codeVerifier)
|
||||
|
||||
wrongSecret := clientSecret + "x"
|
||||
|
||||
@@ -347,30 +349,26 @@ func TestOAuth2ResourceParameter(t *testing.T) {
|
||||
})
|
||||
|
||||
state := oauth2providertest.GenerateState(t)
|
||||
codeVerifier, codeChallenge := oauth2providertest.GeneratePKCE(t)
|
||||
|
||||
// Perform authorization with resource parameter.
|
||||
// Perform authorization with resource parameter
|
||||
authParams := oauth2providertest.AuthorizeParams{
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: "S256",
|
||||
Resource: oauth2providertest.TestResourceURI,
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
Resource: oauth2providertest.TestResourceURI,
|
||||
}
|
||||
|
||||
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
|
||||
require.NotEmpty(t, code, "should receive authorization code")
|
||||
|
||||
// Exchange code for token with resource parameter.
|
||||
// Exchange code for token with resource parameter
|
||||
tokenParams := oauth2providertest.TokenExchangeParams{
|
||||
GrantType: "authorization_code",
|
||||
Code: code,
|
||||
ClientID: app.ID.String(),
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
CodeVerifier: codeVerifier,
|
||||
Resource: oauth2providertest.TestResourceURI,
|
||||
}
|
||||
|
||||
@@ -394,16 +392,13 @@ func TestOAuth2TokenRefresh(t *testing.T) {
|
||||
})
|
||||
|
||||
state := oauth2providertest.GenerateState(t)
|
||||
codeVerifier, codeChallenge := oauth2providertest.GeneratePKCE(t)
|
||||
|
||||
// Get initial token.
|
||||
// Get initial token
|
||||
authParams := oauth2providertest.AuthorizeParams{
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: "S256",
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: "code",
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
}
|
||||
|
||||
code := oauth2providertest.AuthorizeOAuth2App(t, client, client.URL.String(), authParams)
|
||||
@@ -414,7 +409,6 @@ func TestOAuth2TokenRefresh(t *testing.T) {
|
||||
ClientID: app.ID.String(),
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
CodeVerifier: codeVerifier,
|
||||
}
|
||||
|
||||
initialToken := oauth2providertest.ExchangeCodeForToken(t, client.URL.String(), tokenParams)
|
||||
|
||||
@@ -254,27 +254,14 @@ func authorizationCodeGrant(ctx context.Context, db database.Store, app database
|
||||
return codersdk.OAuth2TokenResponse{}, errBadCode
|
||||
}
|
||||
|
||||
// Verify redirect_uri matches the one used during authorization
|
||||
// (RFC 6749 §4.1.3).
|
||||
if dbCode.RedirectUri.Valid && dbCode.RedirectUri.String != "" {
|
||||
if req.RedirectURI != dbCode.RedirectUri.String {
|
||||
return codersdk.OAuth2TokenResponse{}, errBadCode
|
||||
// Verify PKCE challenge if present
|
||||
if dbCode.CodeChallenge.Valid && dbCode.CodeChallenge.String != "" {
|
||||
if req.CodeVerifier == "" {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
|
||||
}
|
||||
if !VerifyPKCE(dbCode.CodeChallenge.String, req.CodeVerifier) {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
|
||||
}
|
||||
}
|
||||
|
||||
// PKCE is mandatory for all authorization code flows
|
||||
// (OAuth 2.1). Verify the code verifier against the stored
|
||||
// challenge.
|
||||
if req.CodeVerifier == "" {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
|
||||
}
|
||||
if !dbCode.CodeChallenge.Valid || dbCode.CodeChallenge.String == "" {
|
||||
// Code was issued without a challenge — should not happen
|
||||
// with authorize endpoint enforcement, but defend in depth.
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
|
||||
}
|
||||
if !VerifyPKCE(dbCode.CodeChallenge.String, req.CodeVerifier) {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
|
||||
}
|
||||
|
||||
// Verify resource parameter consistency (RFC 8707)
|
||||
|
||||
@@ -318,7 +318,6 @@ func TestExtractAuthorizeParams_Scopes(t *testing.T) {
|
||||
query.Set("response_type", "code")
|
||||
query.Set("client_id", "test-client")
|
||||
query.Set("redirect_uri", "http://localhost:3000/callback")
|
||||
query.Set("code_challenge", "test-challenge")
|
||||
if tc.scopeParam != "" {
|
||||
query.Set("scope", tc.scopeParam)
|
||||
}
|
||||
@@ -342,34 +341,6 @@ func TestExtractAuthorizeParams_Scopes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractAuthorizeParams_TokenResponseTypeDoesNotRequirePKCE ensures
|
||||
// response_type=token is parsed without requiring PKCE fields so callers can
|
||||
// return unsupported_response_type instead of invalid_request.
|
||||
func TestExtractAuthorizeParams_TokenResponseTypeDoesNotRequirePKCE(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
callbackURL, err := url.Parse("http://localhost:3000/callback")
|
||||
require.NoError(t, err)
|
||||
|
||||
query := url.Values{}
|
||||
query.Set("response_type", string(codersdk.OAuth2ProviderResponseTypeToken))
|
||||
query.Set("client_id", "test-client")
|
||||
query.Set("redirect_uri", "http://localhost:3000/callback")
|
||||
|
||||
reqURL, err := url.Parse("http://localhost:8080/oauth2/authorize?" + query.Encode())
|
||||
require.NoError(t, err)
|
||||
|
||||
req := &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: reqURL,
|
||||
}
|
||||
|
||||
params, validationErrs, err := extractAuthorizeParams(req, callbackURL)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, validationErrs)
|
||||
require.Equal(t, codersdk.OAuth2ProviderResponseTypeToken, params.responseType)
|
||||
}
|
||||
|
||||
// TestRefreshTokenGrant_Scopes tests that scopes can be requested during refresh
|
||||
func TestRefreshTokenGrant_Scopes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -19,9 +19,9 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
templatesActiveUsersDesc = prometheus.NewDesc("coderd_insights_templates_active_users", "The number of active users of the template.", []string{"template_name", "organization_name"}, nil)
|
||||
applicationsUsageSecondsDesc = prometheus.NewDesc("coderd_insights_applications_usage_seconds", "The application usage per template.", []string{"template_name", "application_name", "slug", "organization_name"}, nil)
|
||||
parametersDesc = prometheus.NewDesc("coderd_insights_parameters", "The parameter usage per template.", []string{"template_name", "parameter_name", "parameter_type", "parameter_value", "organization_name"}, nil)
|
||||
templatesActiveUsersDesc = prometheus.NewDesc("coderd_insights_templates_active_users", "The number of active users of the template.", []string{"template_name"}, nil)
|
||||
applicationsUsageSecondsDesc = prometheus.NewDesc("coderd_insights_applications_usage_seconds", "The application usage per template.", []string{"template_name", "application_name", "slug"}, nil)
|
||||
parametersDesc = prometheus.NewDesc("coderd_insights_parameters", "The parameter usage per template.", []string{"template_name", "parameter_name", "parameter_type", "parameter_value"}, nil)
|
||||
)
|
||||
|
||||
type MetricsCollector struct {
|
||||
@@ -38,8 +38,7 @@ type insightsData struct {
|
||||
apps []database.GetTemplateAppInsightsByTemplateRow
|
||||
params []parameterRow
|
||||
|
||||
templateNames map[uuid.UUID]string
|
||||
organizationNames map[uuid.UUID]string // template ID → org name
|
||||
templateNames map[uuid.UUID]string
|
||||
}
|
||||
|
||||
type parameterRow struct {
|
||||
@@ -138,7 +137,6 @@ func (mc *MetricsCollector) Run(ctx context.Context) (func(), error) {
|
||||
templateIDs := uniqueTemplateIDs(templateInsights, appInsights, paramInsights)
|
||||
|
||||
templateNames := make(map[uuid.UUID]string, len(templateIDs))
|
||||
organizationNames := make(map[uuid.UUID]string, len(templateIDs))
|
||||
if len(templateIDs) > 0 {
|
||||
templates, err := mc.database.GetTemplatesWithFilter(ctx, database.GetTemplatesWithFilterParams{
|
||||
IDs: templateIDs,
|
||||
@@ -148,31 +146,6 @@ func (mc *MetricsCollector) Run(ctx context.Context) (func(), error) {
|
||||
return
|
||||
}
|
||||
templateNames = onlyTemplateNames(templates)
|
||||
|
||||
// Build org name lookup so that metrics can
|
||||
// distinguish templates with the same name across
|
||||
// different organizations.
|
||||
orgIDs := make([]uuid.UUID, 0, len(templates))
|
||||
for _, t := range templates {
|
||||
orgIDs = append(orgIDs, t.OrganizationID)
|
||||
}
|
||||
orgIDs = slice.Unique(orgIDs)
|
||||
|
||||
orgs, err := mc.database.GetOrganizations(ctx, database.GetOrganizationsParams{
|
||||
IDs: orgIDs,
|
||||
})
|
||||
if err != nil {
|
||||
mc.logger.Error(ctx, "unable to fetch organizations from database", slog.Error(err))
|
||||
return
|
||||
}
|
||||
orgNameByID := make(map[uuid.UUID]string, len(orgs))
|
||||
for _, o := range orgs {
|
||||
orgNameByID[o.ID] = o.Name
|
||||
}
|
||||
organizationNames = make(map[uuid.UUID]string, len(templates))
|
||||
for _, t := range templates {
|
||||
organizationNames[t.ID] = orgNameByID[t.OrganizationID]
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh the collector state
|
||||
@@ -181,8 +154,7 @@ func (mc *MetricsCollector) Run(ctx context.Context) (func(), error) {
|
||||
apps: appInsights,
|
||||
params: paramInsights,
|
||||
|
||||
templateNames: templateNames,
|
||||
organizationNames: organizationNames,
|
||||
templateNames: templateNames,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -222,46 +194,44 @@ func (mc *MetricsCollector) Collect(metricsCh chan<- prometheus.Metric) {
|
||||
// Custom apps
|
||||
for _, appRow := range data.apps {
|
||||
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue, float64(appRow.UsageSeconds), data.templateNames[appRow.TemplateID],
|
||||
appRow.DisplayName, appRow.SlugOrPort, data.organizationNames[appRow.TemplateID])
|
||||
appRow.DisplayName, appRow.SlugOrPort)
|
||||
}
|
||||
|
||||
// Built-in apps
|
||||
for _, templateRow := range data.templates {
|
||||
orgName := data.organizationNames[templateRow.TemplateID]
|
||||
|
||||
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue,
|
||||
float64(templateRow.UsageVscodeSeconds),
|
||||
data.templateNames[templateRow.TemplateID],
|
||||
codersdk.TemplateBuiltinAppDisplayNameVSCode,
|
||||
"", orgName)
|
||||
"")
|
||||
|
||||
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue,
|
||||
float64(templateRow.UsageJetbrainsSeconds),
|
||||
data.templateNames[templateRow.TemplateID],
|
||||
codersdk.TemplateBuiltinAppDisplayNameJetBrains,
|
||||
"", orgName)
|
||||
"")
|
||||
|
||||
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue,
|
||||
float64(templateRow.UsageReconnectingPtySeconds),
|
||||
data.templateNames[templateRow.TemplateID],
|
||||
codersdk.TemplateBuiltinAppDisplayNameWebTerminal,
|
||||
"", orgName)
|
||||
"")
|
||||
|
||||
metricsCh <- prometheus.MustNewConstMetric(applicationsUsageSecondsDesc, prometheus.GaugeValue,
|
||||
float64(templateRow.UsageSshSeconds),
|
||||
data.templateNames[templateRow.TemplateID],
|
||||
codersdk.TemplateBuiltinAppDisplayNameSSH,
|
||||
"", orgName)
|
||||
"")
|
||||
}
|
||||
|
||||
// Templates
|
||||
for _, templateRow := range data.templates {
|
||||
metricsCh <- prometheus.MustNewConstMetric(templatesActiveUsersDesc, prometheus.GaugeValue, float64(templateRow.ActiveUsers), data.templateNames[templateRow.TemplateID], data.organizationNames[templateRow.TemplateID])
|
||||
metricsCh <- prometheus.MustNewConstMetric(templatesActiveUsersDesc, prometheus.GaugeValue, float64(templateRow.ActiveUsers), data.templateNames[templateRow.TemplateID])
|
||||
}
|
||||
|
||||
// Parameters
|
||||
for _, parameterRow := range data.params {
|
||||
metricsCh <- prometheus.MustNewConstMetric(parametersDesc, prometheus.GaugeValue, float64(parameterRow.count), data.templateNames[parameterRow.templateID], parameterRow.name, parameterRow.aType, parameterRow.value, data.organizationNames[parameterRow.templateID])
|
||||
metricsCh <- prometheus.MustNewConstMetric(parametersDesc, prometheus.GaugeValue, float64(parameterRow.count), data.templateNames[parameterRow.templateID], parameterRow.name, parameterRow.aType, parameterRow.value)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
{
|
||||
"coderd_insights_applications_usage_seconds[application_name=JetBrains,organization_name=coder,slug=,template_name=golden-template]": 60,
|
||||
"coderd_insights_applications_usage_seconds[application_name=Visual Studio Code,organization_name=coder,slug=,template_name=golden-template]": 60,
|
||||
"coderd_insights_applications_usage_seconds[application_name=Web Terminal,organization_name=coder,slug=,template_name=golden-template]": 0,
|
||||
"coderd_insights_applications_usage_seconds[application_name=SSH,organization_name=coder,slug=,template_name=golden-template]": 60,
|
||||
"coderd_insights_applications_usage_seconds[application_name=Golden Slug,organization_name=coder,slug=golden-slug,template_name=golden-template]": 180,
|
||||
"coderd_insights_parameters[organization_name=coder,parameter_name=first_parameter,parameter_type=string,parameter_value=Foobar,template_name=golden-template]": 1,
|
||||
"coderd_insights_parameters[organization_name=coder,parameter_name=first_parameter,parameter_type=string,parameter_value=Baz,template_name=golden-template]": 1,
|
||||
"coderd_insights_parameters[organization_name=coder,parameter_name=second_parameter,parameter_type=bool,parameter_value=true,template_name=golden-template]": 2,
|
||||
"coderd_insights_parameters[organization_name=coder,parameter_name=third_parameter,parameter_type=number,parameter_value=789,template_name=golden-template]": 1,
|
||||
"coderd_insights_parameters[organization_name=coder,parameter_name=third_parameter,parameter_type=number,parameter_value=999,template_name=golden-template]": 1,
|
||||
"coderd_insights_templates_active_users[organization_name=coder,template_name=golden-template]": 1
|
||||
"coderd_insights_applications_usage_seconds[application_name=JetBrains,slug=,template_name=golden-template]": 60,
|
||||
"coderd_insights_applications_usage_seconds[application_name=Visual Studio Code,slug=,template_name=golden-template]": 60,
|
||||
"coderd_insights_applications_usage_seconds[application_name=Web Terminal,slug=,template_name=golden-template]": 0,
|
||||
"coderd_insights_applications_usage_seconds[application_name=SSH,slug=,template_name=golden-template]": 60,
|
||||
"coderd_insights_applications_usage_seconds[application_name=Golden Slug,slug=golden-slug,template_name=golden-template]": 180,
|
||||
"coderd_insights_parameters[parameter_name=first_parameter,parameter_type=string,parameter_value=Foobar,template_name=golden-template]": 1,
|
||||
"coderd_insights_parameters[parameter_name=first_parameter,parameter_type=string,parameter_value=Baz,template_name=golden-template]": 1,
|
||||
"coderd_insights_parameters[parameter_name=second_parameter,parameter_type=bool,parameter_value=true,template_name=golden-template]": 2,
|
||||
"coderd_insights_parameters[parameter_name=third_parameter,parameter_type=number,parameter_value=789,template_name=golden-template]": 1,
|
||||
"coderd_insights_parameters[parameter_name=third_parameter,parameter_type=number,parameter_value=999,template_name=golden-template]": 1,
|
||||
"coderd_insights_templates_active_users[template_name=golden-template]": 1
|
||||
}
|
||||
|
||||
@@ -725,16 +725,11 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
|
||||
}
|
||||
}
|
||||
|
||||
provisionerStateRow, err := s.Database.GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuild.ID)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("get workspace build provisioner state: %s", err))
|
||||
}
|
||||
|
||||
protoJob.Type = &proto.AcquiredJob_WorkspaceBuild_{
|
||||
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
|
||||
WorkspaceBuildId: workspaceBuild.ID.String(),
|
||||
WorkspaceName: workspace.Name,
|
||||
State: provisionerStateRow.ProvisionerState,
|
||||
State: workspaceBuild.ProvisionerState,
|
||||
RichParameterValues: convertRichParameterValues(workspaceBuildParameters),
|
||||
PreviousParameterValues: convertRichParameterValues(lastWorkspaceBuildParameters),
|
||||
VariableValues: asVariableValues(templateVariables),
|
||||
@@ -845,11 +840,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
|
||||
|
||||
// Record the time the job spent waiting in the queue.
|
||||
if s.metrics != nil && job.StartedAt.Valid && job.Provisioner.Valid() {
|
||||
// These timestamps lose their monotonic clock component after a Postgres
|
||||
// round-trip, so the subtraction is based purely on wall-clock time. Floor at
|
||||
// 1ms as a defensive measure against clock adjustments producing a negative
|
||||
// delta while acknowledging there's a non-zero queue time.
|
||||
queueWaitSeconds := max(job.StartedAt.Time.Sub(job.CreatedAt).Seconds(), 0.001)
|
||||
queueWaitSeconds := job.StartedAt.Time.Sub(job.CreatedAt).Seconds()
|
||||
s.metrics.ObserveJobQueueWait(string(job.Provisioner), string(job.Type), jobTransition, jobBuildReason, queueWaitSeconds)
|
||||
}
|
||||
|
||||
|
||||
@@ -1321,9 +1321,7 @@ func TestFailJob(t *testing.T) {
|
||||
<-publishedLogs
|
||||
build, err := db.GetWorkspaceBuildByID(ctx, buildID)
|
||||
require.NoError(t, err)
|
||||
provisionerStateRow, err := db.GetWorkspaceBuildProvisionerStateByID(ctx, build.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "some state", string(provisionerStateRow.ProvisionerState))
|
||||
require.Equal(t, "some state", string(build.ProvisionerState))
|
||||
require.Len(t, auditor.AuditLogs(), 1)
|
||||
|
||||
// Assert that the workspace_id field get populated
|
||||
|
||||
@@ -81,7 +81,6 @@ const (
|
||||
SubjectAibridged SubjectType = "aibridged"
|
||||
SubjectTypeDBPurge SubjectType = "dbpurge"
|
||||
SubjectTypeBoundaryUsageTracker SubjectType = "boundary_usage_tracker"
|
||||
SubjectTypeWorkspaceBuilder SubjectType = "workspace_builder"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -401,35 +401,6 @@ func AIBridgeInterceptions(ctx context.Context, db database.Store, query string,
|
||||
return filter, parser.Errors
|
||||
}
|
||||
|
||||
func AIBridgeModels(query string, page codersdk.Pagination) (database.ListAIBridgeModelsParams, []codersdk.ValidationError) {
|
||||
// nolint:exhaustruct // Empty values just means "don't filter by that field".
|
||||
filter := database.ListAIBridgeModelsParams{
|
||||
// #nosec G115 - Safe conversion for pagination offset which is expected to be within int32 range
|
||||
Offset: int32(page.Offset),
|
||||
// #nosec G115 - Safe conversion for pagination limit which is expected to be within int32 range
|
||||
Limit: int32(page.Limit),
|
||||
}
|
||||
|
||||
if query == "" {
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
values, errors := searchTerms(query, func(term string, values url.Values) error {
|
||||
// Defaults to the `model` if no `key:value` pair is provided.
|
||||
values.Add("model", term)
|
||||
return nil
|
||||
})
|
||||
if len(errors) > 0 {
|
||||
return filter, errors
|
||||
}
|
||||
|
||||
parser := httpapi.NewQueryParamParser()
|
||||
filter.Model = parser.String(values, "", "model")
|
||||
|
||||
parser.ErrorExcessParams(values)
|
||||
return filter, parser.Errors
|
||||
}
|
||||
|
||||
// Tasks parses a search query for tasks.
|
||||
//
|
||||
// Supported query parameters:
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
defaultModel = anthropic.ModelClaudeHaiku4_5
|
||||
defaultModel = anthropic.ModelClaude3_5HaikuLatest
|
||||
systemPrompt = `Generate a short task display name and name from this AI task prompt.
|
||||
Identify the main task (the core action and subject) and base both names on it.
|
||||
The task display name and name should be as similar as possible so a human can easily associate them.
|
||||
|
||||
+25
-163
@@ -416,10 +416,9 @@ func checkIDPOrgSync(ctx context.Context, db database.Store, values *codersdk.De
|
||||
func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
|
||||
var (
|
||||
ctx = r.ctx
|
||||
now = r.options.Clock.Now()
|
||||
// For resources that grow in size very quickly (like workspace builds),
|
||||
// we only report events that occurred within the past hour.
|
||||
createdAfter = dbtime.Time(now.Add(-1 * time.Hour)).UTC()
|
||||
createdAfter = dbtime.Time(r.options.Clock.Now().Add(-1 * time.Hour)).UTC()
|
||||
eg errgroup.Group
|
||||
snapshot = &Snapshot{
|
||||
DeploymentID: r.options.DeploymentID,
|
||||
@@ -741,19 +740,17 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
tasks, err := CollectTasks(ctx, r.options.Database)
|
||||
dbTasks, err := r.options.Database.ListTasks(ctx, database.ListTasksParams{
|
||||
OwnerID: uuid.Nil,
|
||||
OrganizationID: uuid.Nil,
|
||||
Status: "",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("collect tasks telemetry: %w", err)
|
||||
return err
|
||||
}
|
||||
snapshot.Tasks = tasks
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
events, err := CollectTaskEvents(ctx, r.options.Database, createdAfter, now)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("collect task events telemetry: %w", err)
|
||||
for _, dbTask := range dbTasks {
|
||||
snapshot.Tasks = append(snapshot.Tasks, ConvertTask(dbTask))
|
||||
}
|
||||
snapshot.TaskEvents = events
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
@@ -905,129 +902,6 @@ func (r *remoteReporter) collectBoundaryUsageSummary(ctx context.Context) (*Boun
|
||||
}, nil
|
||||
}
|
||||
|
||||
func CollectTasks(ctx context.Context, db database.Store) ([]Task, error) {
|
||||
dbTasks, err := db.ListTasks(ctx, database.ListTasksParams{
|
||||
OwnerID: uuid.Nil,
|
||||
OrganizationID: uuid.Nil,
|
||||
Status: "",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("list tasks: %w", err)
|
||||
}
|
||||
if len(dbTasks) == 0 {
|
||||
return []Task{}, nil
|
||||
}
|
||||
|
||||
tasks := make([]Task, 0, len(dbTasks))
|
||||
for _, dbTask := range dbTasks {
|
||||
tasks = append(tasks, ConvertTask(dbTask))
|
||||
}
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// buildTaskEvent constructs a TaskEvent from the combined query row.
|
||||
func buildTaskEvent(
|
||||
row database.GetTelemetryTaskEventsRow,
|
||||
createdAfter time.Time,
|
||||
now time.Time,
|
||||
) TaskEvent {
|
||||
event := TaskEvent{
|
||||
TaskID: row.TaskID.String(),
|
||||
}
|
||||
|
||||
var (
|
||||
hasStartBuild = row.StartBuildCreatedAt.Valid
|
||||
isResumed = hasStartBuild && row.StartBuildNumber.Valid && row.StartBuildNumber.Int32 > 1
|
||||
hasStopBuild = row.StopBuildCreatedAt.Valid
|
||||
startedAfterStop = hasStartBuild && hasStopBuild && row.StartBuildCreatedAt.Time.After(row.StopBuildCreatedAt.Time)
|
||||
currentlyPaused = hasStopBuild && !startedAfterStop
|
||||
)
|
||||
|
||||
// Pause-related fields (requires a stop build).
|
||||
if hasStopBuild {
|
||||
event.LastPausedAt = &row.StopBuildCreatedAt.Time
|
||||
switch {
|
||||
case row.StopBuildReason.Valid && row.StopBuildReason.BuildReason == database.BuildReasonTaskAutoPause:
|
||||
event.PauseReason = ptr.Ref("auto")
|
||||
case row.StopBuildReason.Valid && row.StopBuildReason.BuildReason == database.BuildReasonTaskManualPause:
|
||||
event.PauseReason = ptr.Ref("manual")
|
||||
default:
|
||||
event.PauseReason = ptr.Ref("other")
|
||||
}
|
||||
|
||||
// Idle duration: time between last working status and the pause.
|
||||
if row.LastWorkingStatusAt.Valid &&
|
||||
row.StopBuildCreatedAt.Time.After(row.LastWorkingStatusAt.Time) {
|
||||
idle := row.StopBuildCreatedAt.Time.Sub(row.LastWorkingStatusAt.Time)
|
||||
event.IdleDurationMS = ptr.Ref(idle.Milliseconds())
|
||||
}
|
||||
}
|
||||
|
||||
// Resume-related fields (requires task_resume start after stop).
|
||||
if startedAfterStop {
|
||||
// Paused duration: time between pause and resume.
|
||||
if row.StartBuildCreatedAt.Time.After(createdAfter) {
|
||||
paused := row.StartBuildCreatedAt.Time.Sub(row.StopBuildCreatedAt.Time)
|
||||
event.PausedDurationMS = ptr.Ref(paused.Milliseconds())
|
||||
}
|
||||
|
||||
// Below only relevant for "resumed" tasks, not when initially created.
|
||||
if isResumed {
|
||||
event.LastResumedAt = &row.StartBuildCreatedAt.Time
|
||||
switch {
|
||||
// TODO(Cian): will this exist? Future readers may know better than I.
|
||||
// case row.StartBuildReason == database.BuildReasonTaskAutoResume:
|
||||
// event.ResumeReason = ptr.Ref("auto")
|
||||
case row.StartBuildReason.BuildReason == database.BuildReasonTaskResume:
|
||||
event.ResumeReason = ptr.Ref("manual")
|
||||
default: // Task resumed by starting workspace?
|
||||
event.ResumeReason = ptr.Ref("other")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unresolved pause: report current paused duration.
|
||||
if currentlyPaused {
|
||||
paused := now.Sub(row.StopBuildCreatedAt.Time)
|
||||
event.PausedDurationMS = ptr.Ref(paused.Milliseconds())
|
||||
}
|
||||
|
||||
// Resume-to-status duration.
|
||||
if row.FirstStatusAfterResumeAt.Valid && isResumed {
|
||||
delta := row.FirstStatusAfterResumeAt.Time.Sub(row.StartBuildCreatedAt.Time)
|
||||
event.ResumeToStatusMS = ptr.Ref(delta.Milliseconds())
|
||||
}
|
||||
|
||||
// Active duration: from SQL calculation.
|
||||
if row.ActiveDurationMs > 0 {
|
||||
event.ActiveDurationMS = ptr.Ref(row.ActiveDurationMs)
|
||||
}
|
||||
|
||||
return event
|
||||
}
|
||||
|
||||
// CollectTaskEvents collects lifecycle events for tasks with recent activity.
|
||||
func CollectTaskEvents(ctx context.Context, db database.Store, createdAfter, now time.Time) ([]TaskEvent, error) {
|
||||
rows, err := db.GetTelemetryTaskEvents(ctx, database.GetTelemetryTaskEventsParams{
|
||||
CreatedAfter: createdAfter,
|
||||
Now: now,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get telemetry task events: %w", err)
|
||||
}
|
||||
events := make([]TaskEvent, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
events = append(events, buildTaskEvent(row, createdAfter, now))
|
||||
}
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// HashContent returns a SHA256 hash of the content as a hex string.
|
||||
// This is useful for hashing sensitive content like prompts for telemetry.
|
||||
func HashContent(content string) string {
|
||||
return fmt.Sprintf("%x", sha256.Sum256([]byte(content)))
|
||||
}
|
||||
|
||||
// ConvertAPIKey anonymizes an API key.
|
||||
func ConvertAPIKey(apiKey database.APIKey) APIKey {
|
||||
a := APIKey{
|
||||
@@ -1496,7 +1370,6 @@ type Snapshot struct {
|
||||
NetworkEvents []NetworkEvent `json:"network_events"`
|
||||
Organizations []Organization `json:"organizations"`
|
||||
Tasks []Task `json:"tasks"`
|
||||
TaskEvents []TaskEvent `json:"task_events"`
|
||||
TelemetryItems []TelemetryItem `json:"telemetry_items"`
|
||||
UserTailnetConnections []UserTailnetConnection `json:"user_tailnet_connections"`
|
||||
PrebuiltWorkspaces []PrebuiltWorkspace `json:"prebuilt_workspaces"`
|
||||
@@ -2058,36 +1931,25 @@ type Task struct {
|
||||
WorkspaceAppID *string `json:"workspace_app_id"`
|
||||
TemplateVersionID string `json:"template_version_id"`
|
||||
PromptHash string `json:"prompt_hash"` // Prompt is hashed for privacy.
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// TaskEvent represents lifecycle events for a task (pause/resume
|
||||
// cycles). The createdAfter parameter gates PausedDurationMS so
|
||||
// that only recent pause/resume pairs are reported.
|
||||
type TaskEvent struct {
|
||||
TaskID string `json:"task_id"`
|
||||
LastPausedAt *time.Time `json:"last_paused_at"`
|
||||
LastResumedAt *time.Time `json:"last_resumed_at"`
|
||||
PauseReason *string `json:"pause_reason"`
|
||||
ResumeReason *string `json:"resume_reason"`
|
||||
IdleDurationMS *int64 `json:"idle_duration_ms"`
|
||||
PausedDurationMS *int64 `json:"paused_duration_ms"`
|
||||
ResumeToStatusMS *int64 `json:"resume_to_status_ms"`
|
||||
ActiveDurationMS *int64 `json:"active_duration_ms"`
|
||||
}
|
||||
|
||||
// ConvertTask converts a database Task to a telemetry Task.
|
||||
// ConvertTask anonymizes a Task.
|
||||
func ConvertTask(task database.Task) Task {
|
||||
t := Task{
|
||||
ID: task.ID.String(),
|
||||
OrganizationID: task.OrganizationID.String(),
|
||||
OwnerID: task.OwnerID.String(),
|
||||
Name: task.Name,
|
||||
TemplateVersionID: task.TemplateVersionID.String(),
|
||||
PromptHash: HashContent(task.Prompt),
|
||||
Status: string(task.Status),
|
||||
CreatedAt: task.CreatedAt,
|
||||
t := &Task{
|
||||
ID: task.ID.String(),
|
||||
OrganizationID: task.OrganizationID.String(),
|
||||
OwnerID: task.OwnerID.String(),
|
||||
Name: task.Name,
|
||||
WorkspaceID: nil,
|
||||
WorkspaceBuildNumber: nil,
|
||||
WorkspaceAgentID: nil,
|
||||
WorkspaceAppID: nil,
|
||||
TemplateVersionID: task.TemplateVersionID.String(),
|
||||
PromptHash: fmt.Sprintf("%x", sha256.Sum256([]byte(task.Prompt))),
|
||||
CreatedAt: task.CreatedAt,
|
||||
Status: string(task.Status),
|
||||
}
|
||||
if task.WorkspaceID.Valid {
|
||||
t.WorkspaceID = ptr.Ref(task.WorkspaceID.UUID.String())
|
||||
@@ -2101,7 +1963,7 @@ func ConvertTask(task database.Task) Task {
|
||||
if task.WorkspaceAppID.Valid {
|
||||
t.WorkspaceAppID = ptr.Ref(task.WorkspaceAppID.UUID.String())
|
||||
}
|
||||
return t
|
||||
return *t
|
||||
}
|
||||
|
||||
type telemetryItemKey string
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@@ -14,7 +13,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -23,14 +21,12 @@ import (
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/coderd/boundaryusage"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/idpsync"
|
||||
"github.com/coder/coder/v2/coderd/runtimeconfig"
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
@@ -317,17 +313,6 @@ func TestTelemetry(t *testing.T) {
|
||||
require.Equal(t, string(database.WorkspaceAgentSubsystemEnvbox), wsa.Subsystems[0])
|
||||
require.Equal(t, string(database.WorkspaceAgentSubsystemExectrace), wsa.Subsystems[1])
|
||||
require.Len(t, snapshot.Tasks, 1)
|
||||
require.Len(t, snapshot.TaskEvents, 1)
|
||||
taskEvent := snapshot.TaskEvents[0]
|
||||
assert.Equal(t, task.ID.String(), taskEvent.TaskID)
|
||||
assert.Nil(t, taskEvent.LastResumedAt)
|
||||
assert.Nil(t, taskEvent.LastPausedAt)
|
||||
assert.Nil(t, taskEvent.PauseReason)
|
||||
assert.Nil(t, taskEvent.ResumeReason)
|
||||
assert.Nil(t, taskEvent.IdleDurationMS)
|
||||
assert.Nil(t, taskEvent.PausedDurationMS)
|
||||
assert.Nil(t, taskEvent.ResumeToStatusMS)
|
||||
assert.Nil(t, taskEvent.ActiveDurationMS)
|
||||
for _, snapTask := range snapshot.Tasks {
|
||||
assert.Equal(t, task.ID.String(), snapTask.ID)
|
||||
assert.Equal(t, task.OrganizationID.String(), snapTask.OrganizationID)
|
||||
@@ -341,7 +326,6 @@ func TestTelemetry(t *testing.T) {
|
||||
assert.Equal(t, taskWA.WorkspaceAppID.UUID.String(), *snapTask.WorkspaceAppID)
|
||||
assert.Equal(t, task.TemplateVersionID.String(), snapTask.TemplateVersionID)
|
||||
assert.Equal(t, "e196fe22e61cfa32d8c38749e0ce348108bb4cae29e2c36cdcce7e77faa9eb5f", snapTask.PromptHash)
|
||||
assert.Equal(t, string(task.Status), snapTask.Status)
|
||||
assert.Equal(t, task.CreatedAt.UTC(), snapTask.CreatedAt.UTC())
|
||||
}
|
||||
|
||||
@@ -691,573 +675,6 @@ func TestPrebuiltWorkspacesTelemetry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// taskTelemetryHelper is a grab bag of stuff useful in task telemetry test cases
|
||||
type taskTelemetryHelper struct {
|
||||
t *testing.T
|
||||
ctx context.Context
|
||||
db database.Store
|
||||
org database.Organization
|
||||
user database.User
|
||||
}
|
||||
|
||||
// createBuild creates a workspace build with the given parameters,
|
||||
// handling provisioner job creation automatically.
|
||||
func (h *taskTelemetryHelper) createBuild(
|
||||
resp dbfake.WorkspaceResponse,
|
||||
buildNumber int32,
|
||||
createdAt time.Time,
|
||||
transition database.WorkspaceTransition,
|
||||
reason database.BuildReason,
|
||||
) (database.WorkspaceBuild, *database.WorkspaceApp) {
|
||||
job := dbgen.ProvisionerJob(h.t, h.db, nil, database.ProvisionerJob{
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
StorageMethod: database.ProvisionerStorageMethodFile,
|
||||
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
||||
OrganizationID: h.org.ID,
|
||||
})
|
||||
bld := dbgen.WorkspaceBuild(h.t, h.db, database.WorkspaceBuild{
|
||||
WorkspaceID: resp.Workspace.ID,
|
||||
TemplateVersionID: resp.TemplateVersion.ID,
|
||||
JobID: job.ID,
|
||||
Transition: transition,
|
||||
Reason: reason,
|
||||
BuildNumber: buildNumber,
|
||||
CreatedAt: createdAt,
|
||||
HasAITask: sql.NullBool{
|
||||
Bool: true,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if transition == database.WorkspaceTransitionStart {
|
||||
require.NotEmpty(h.t, resp.Agents, "need at least one agent")
|
||||
agt := resp.Agents[0]
|
||||
// App IDs are regenerated by provisionerd each build.
|
||||
app := dbgen.WorkspaceApp(h.t, h.db, database.WorkspaceApp{
|
||||
AgentID: agt.ID,
|
||||
})
|
||||
_, err := h.db.UpsertTaskWorkspaceApp(h.ctx, database.UpsertTaskWorkspaceAppParams{
|
||||
TaskID: resp.Task.ID,
|
||||
WorkspaceBuildNumber: buildNumber,
|
||||
WorkspaceAgentID: uuid.NullUUID{UUID: agt.ID, Valid: true},
|
||||
WorkspaceAppID: uuid.NullUUID{UUID: app.ID, Valid: true},
|
||||
})
|
||||
require.NoError(h.t, err, "failed to upsert task app")
|
||||
return bld, &app
|
||||
}
|
||||
return bld, nil
|
||||
}
|
||||
|
||||
// nolint: dupl // Test code is better WET than DRY.
|
||||
func TestTasksTelemetry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Define a fixed reference time for deterministic testing.
|
||||
now := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
createAppStatus := func(ctx context.Context, db database.Store, wsID uuid.UUID, agentID, appID uuid.UUID, state database.WorkspaceAppStatusState, message string, createdAt time.Time) {
|
||||
_, err := db.InsertWorkspaceAppStatus(ctx, database.InsertWorkspaceAppStatusParams{
|
||||
ID: uuid.New(),
|
||||
CreatedAt: createdAt,
|
||||
WorkspaceID: wsID,
|
||||
AgentID: agentID,
|
||||
AppID: appID,
|
||||
State: state,
|
||||
Message: message,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
getApp := func(ctx context.Context, db database.Store, agentID uuid.UUID) database.WorkspaceApp {
|
||||
apps, err := db.GetWorkspaceAppsByAgentID(ctx, agentID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, apps, "expected at least one app")
|
||||
return apps[0]
|
||||
}
|
||||
|
||||
type statusSpec struct {
|
||||
state database.WorkspaceAppStatusState
|
||||
message string
|
||||
offset time.Duration
|
||||
}
|
||||
|
||||
type buildSpec struct {
|
||||
buildNumber int32
|
||||
offset time.Duration
|
||||
transition database.WorkspaceTransition
|
||||
reason database.BuildReason
|
||||
statuses []statusSpec // created after this build, using this build's app
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
// Input: DB setup.
|
||||
skipWorkspace bool
|
||||
createdOffset time.Duration
|
||||
buildOffset *time.Duration
|
||||
extraBuilds []buildSpec
|
||||
appStatuses []statusSpec
|
||||
|
||||
// Expected output.
|
||||
expectEvent bool
|
||||
lastPausedOffset *time.Duration
|
||||
lastResumedOffset *time.Duration
|
||||
pauseReason *string
|
||||
resumeReason *string
|
||||
idleDurationMS *int64
|
||||
pausedDurationMS *int64
|
||||
resumeToStatusMS *int64
|
||||
activeDurationMS *int64
|
||||
}{
|
||||
{
|
||||
name: "no workspace - all lifecycle fields nil",
|
||||
skipWorkspace: true,
|
||||
createdOffset: -1 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "running workspace - no pause/resume events",
|
||||
createdOffset: -45 * time.Minute,
|
||||
buildOffset: ptr.Ref(-30 * time.Minute),
|
||||
expectEvent: true,
|
||||
},
|
||||
{
|
||||
name: "with app status - no lifecycle events",
|
||||
createdOffset: -90 * time.Minute,
|
||||
buildOffset: ptr.Ref(-45 * time.Minute),
|
||||
appStatuses: []statusSpec{
|
||||
{database.WorkspaceAppStatusStateWorking, "Task started", -40 * time.Minute},
|
||||
},
|
||||
expectEvent: true,
|
||||
// ResumeToStatusMS is nil because initial start (BuildReasonInitiator)
|
||||
// doesn't count - only task_resume starts are considered.
|
||||
activeDurationMS: ptr.Ref(int64(40 * time.Minute / time.Millisecond)),
|
||||
},
|
||||
{
|
||||
name: "auto paused - LastPausedAt and PauseReason=auto",
|
||||
createdOffset: -3 * time.Hour,
|
||||
extraBuilds: []buildSpec{
|
||||
{2, -20 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
|
||||
},
|
||||
expectEvent: true,
|
||||
lastPausedOffset: ptr.Ref(-20 * time.Minute),
|
||||
pauseReason: ptr.Ref("auto"),
|
||||
pausedDurationMS: ptr.Ref(20 * time.Minute.Milliseconds()), // Ongoing pause.
|
||||
},
|
||||
{
|
||||
name: "manual paused - LastPausedAt and PauseReason=manual",
|
||||
createdOffset: -4 * time.Hour,
|
||||
extraBuilds: []buildSpec{
|
||||
{2, -15 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskManualPause, nil},
|
||||
},
|
||||
expectEvent: true,
|
||||
lastPausedOffset: ptr.Ref(-15 * time.Minute),
|
||||
pauseReason: ptr.Ref("manual"),
|
||||
pausedDurationMS: ptr.Ref(15 * time.Minute.Milliseconds()), // Ongoing pause.
|
||||
},
|
||||
{
|
||||
name: "paused with idle time - IdleDurationMS calculated",
|
||||
createdOffset: -5 * time.Hour,
|
||||
appStatuses: []statusSpec{
|
||||
{database.WorkspaceAppStatusStateWorking, "Working on something", -40 * time.Minute},
|
||||
{database.WorkspaceAppStatusStateIdle, "Idle now", -35 * time.Minute},
|
||||
},
|
||||
extraBuilds: []buildSpec{
|
||||
{2, -25 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
|
||||
},
|
||||
expectEvent: true,
|
||||
lastPausedOffset: ptr.Ref(-25 * time.Minute),
|
||||
pauseReason: ptr.Ref("auto"),
|
||||
idleDurationMS: ptr.Ref(15 * time.Minute.Milliseconds()), // Last working (-40) to stop (-25).
|
||||
activeDurationMS: ptr.Ref(5 * time.Minute.Milliseconds()), // -40 min (working) to -35 min (idle).
|
||||
pausedDurationMS: ptr.Ref(25 * time.Minute.Milliseconds()), // Ongoing pause: now - (-25min).
|
||||
},
|
||||
{
|
||||
name: "paused with working status after pause - IdleDurationMS nil",
|
||||
createdOffset: -5 * time.Hour,
|
||||
appStatuses: []statusSpec{
|
||||
{database.WorkspaceAppStatusStateWorking, "Working after pause", -20 * time.Minute},
|
||||
},
|
||||
extraBuilds: []buildSpec{
|
||||
{2, -25 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
|
||||
},
|
||||
expectEvent: true,
|
||||
lastPausedOffset: ptr.Ref(-25 * time.Minute),
|
||||
pauseReason: ptr.Ref("auto"),
|
||||
pausedDurationMS: ptr.Ref(25 * time.Minute.Milliseconds()), // Ongoing pause.
|
||||
// IdleDurationMS is nil because "last working" is after pause.
|
||||
// ActiveDurationMS is nil because working→stop interval is negative.
|
||||
},
|
||||
{
|
||||
name: "recently resumed - PausedDurationMS calculated",
|
||||
createdOffset: -6 * time.Hour,
|
||||
extraBuilds: []buildSpec{
|
||||
{2, -50 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
|
||||
{3, -10 * time.Minute, database.WorkspaceTransitionStart, database.BuildReasonTaskResume, nil},
|
||||
},
|
||||
expectEvent: true,
|
||||
lastPausedOffset: ptr.Ref(-50 * time.Minute),
|
||||
lastResumedOffset: ptr.Ref(-10 * time.Minute),
|
||||
pauseReason: ptr.Ref("auto"),
|
||||
resumeReason: ptr.Ref("manual"),
|
||||
pausedDurationMS: ptr.Ref(40 * time.Minute.Milliseconds()),
|
||||
},
|
||||
{
|
||||
// This test verifies that we do not double-report task events outside of the window.
|
||||
name: "resumed long ago - PausedDurationMS nil",
|
||||
createdOffset: -10 * time.Hour,
|
||||
extraBuilds: []buildSpec{
|
||||
{2, -5 * time.Hour, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
|
||||
{3, -2 * time.Hour, database.WorkspaceTransitionStart, database.BuildReasonTaskResume, nil},
|
||||
},
|
||||
expectEvent: false,
|
||||
},
|
||||
{
|
||||
name: "multiple cycles - captures latest pause/resume",
|
||||
createdOffset: -8 * time.Hour,
|
||||
extraBuilds: []buildSpec{
|
||||
{2, -3 * time.Hour, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
|
||||
{3, -150 * time.Minute, database.WorkspaceTransitionStart, database.BuildReasonTaskResume, nil},
|
||||
{4, -30 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskManualPause, nil},
|
||||
},
|
||||
expectEvent: true,
|
||||
lastPausedOffset: ptr.Ref(-30 * time.Minute),
|
||||
pauseReason: ptr.Ref("manual"),
|
||||
pausedDurationMS: ptr.Ref(30 * time.Minute.Milliseconds()), // Ongoing pause: now - (-30min).
|
||||
},
|
||||
{
|
||||
name: "currently paused after recent resume - reports ongoing pause",
|
||||
createdOffset: -6 * time.Hour,
|
||||
extraBuilds: []buildSpec{
|
||||
{2, -50 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
|
||||
{3, -30 * time.Minute, database.WorkspaceTransitionStart, database.BuildReasonTaskResume, nil},
|
||||
{4, -10 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskManualPause, nil},
|
||||
},
|
||||
expectEvent: true,
|
||||
lastPausedOffset: ptr.Ref(-10 * time.Minute),
|
||||
pauseReason: ptr.Ref("manual"),
|
||||
pausedDurationMS: ptr.Ref(10 * time.Minute.Milliseconds()), // Ongoing pause: now - pause time.
|
||||
},
|
||||
{
|
||||
name: "multiple cycles with recent resume - pairs with preceding pause",
|
||||
createdOffset: -6 * time.Hour,
|
||||
appStatuses: []statusSpec{
|
||||
{database.WorkspaceAppStatusStateWorking, "started work", -6 * time.Hour},
|
||||
},
|
||||
extraBuilds: []buildSpec{
|
||||
{2, -50 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
|
||||
{3, -30 * time.Minute, database.WorkspaceTransitionStart, database.BuildReasonTaskResume, []statusSpec{
|
||||
{database.WorkspaceAppStatusStateWorking, "resumed work", -25 * time.Minute},
|
||||
}},
|
||||
},
|
||||
expectEvent: true,
|
||||
lastPausedOffset: ptr.Ref(-50 * time.Minute),
|
||||
lastResumedOffset: ptr.Ref(-30 * time.Minute),
|
||||
pauseReason: ptr.Ref("auto"),
|
||||
resumeReason: ptr.Ref("manual"),
|
||||
pausedDurationMS: ptr.Ref(20 * time.Minute.Milliseconds()),
|
||||
resumeToStatusMS: ptr.Ref((5 * time.Minute).Milliseconds()),
|
||||
// Build 1 ("started work") -> Build 2 (stop) (5h10m) + Build 3 ("resumed work") -> now (25m)
|
||||
// TODO(cian): We define IdleDurationMS as "the time from the last working status to pause".
|
||||
// We know that the task has reported working since T-6h and got auto-paused at T-50m.
|
||||
// We can reasonably assume that it has been 'idle' from when it was stopped (T-30m) to
|
||||
// its next report at T-25m. This is covered by ResumeToStatusMS.
|
||||
// But do we consider the time since its last report (T-6h) to its being auto-paused
|
||||
// as truly "idle"?
|
||||
idleDurationMS: ptr.Ref(310 * time.Minute.Milliseconds()),
|
||||
activeDurationMS: ptr.Ref((5*time.Hour + 10*time.Minute + 25*time.Minute).Milliseconds()),
|
||||
},
|
||||
{
|
||||
name: "all fields populated - full lifecycle",
|
||||
createdOffset: -7 * time.Hour,
|
||||
appStatuses: []statusSpec{
|
||||
{database.WorkspaceAppStatusStateWorking, "Started working", -390 * time.Minute},
|
||||
{database.WorkspaceAppStatusStateWorking, "Still working", -45 * time.Minute},
|
||||
},
|
||||
extraBuilds: []buildSpec{
|
||||
{2, -35 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
|
||||
{3, -5 * time.Minute, database.WorkspaceTransitionStart, database.BuildReasonTaskResume, []statusSpec{
|
||||
{database.WorkspaceAppStatusStateWorking, "Resumed work", -3 * time.Minute},
|
||||
{database.WorkspaceAppStatusStateIdle, "Finished work", -2 * time.Minute},
|
||||
}},
|
||||
},
|
||||
expectEvent: true,
|
||||
lastPausedOffset: ptr.Ref(-35 * time.Minute),
|
||||
lastResumedOffset: ptr.Ref(-5 * time.Minute),
|
||||
pauseReason: ptr.Ref("auto"),
|
||||
resumeReason: ptr.Ref("manual"),
|
||||
idleDurationMS: ptr.Ref(10 * time.Minute.Milliseconds()),
|
||||
pausedDurationMS: ptr.Ref(30 * time.Minute.Milliseconds()),
|
||||
resumeToStatusMS: ptr.Ref((2 * time.Minute).Milliseconds()),
|
||||
// Active duration: (-390 to -35) + (-3 to -2) = 355 + 1 = 356 min.
|
||||
activeDurationMS: ptr.Ref(356 * time.Minute.Milliseconds()),
|
||||
},
|
||||
{
|
||||
name: "non-task_resume builds are tracked as other",
|
||||
createdOffset: -4 * time.Hour,
|
||||
extraBuilds: []buildSpec{
|
||||
{2, -60 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
|
||||
{3, -30 * time.Minute, database.WorkspaceTransitionStart, database.BuildReasonInitiator, nil},
|
||||
},
|
||||
expectEvent: true,
|
||||
lastPausedOffset: ptr.Ref(-60 * time.Minute),
|
||||
pauseReason: ptr.Ref("auto"),
|
||||
resumeReason: ptr.Ref("other"),
|
||||
// LastResumedAt is set because isResumed is true (build_number > 1)
|
||||
// even though the start reason isn't task_resume.
|
||||
lastResumedOffset: ptr.Ref(-30 * time.Minute),
|
||||
// PausedDurationMS reports ongoing pause: now - (-60min) = 60min.
|
||||
pausedDurationMS: ptr.Ref(30 * time.Minute.Milliseconds()),
|
||||
},
|
||||
{
|
||||
name: "simple ongoing pause reports duration",
|
||||
createdOffset: -3 * time.Hour,
|
||||
extraBuilds: []buildSpec{
|
||||
{2, -45 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
|
||||
},
|
||||
expectEvent: true,
|
||||
lastPausedOffset: ptr.Ref(-45 * time.Minute),
|
||||
pauseReason: ptr.Ref("auto"),
|
||||
// No resume, so ongoing pause: now - (-45min) = 45min.
|
||||
pausedDurationMS: ptr.Ref(45 * time.Minute.Milliseconds()),
|
||||
},
|
||||
{
|
||||
name: "active duration with paused task",
|
||||
createdOffset: -2 * time.Hour,
|
||||
buildOffset: ptr.Ref(-2 * time.Hour),
|
||||
appStatuses: []statusSpec{
|
||||
{database.WorkspaceAppStatusStateWorking, "Started", -90 * time.Minute},
|
||||
{database.WorkspaceAppStatusStateIdle, "Thinking", -60 * time.Minute}, // 30min working
|
||||
{database.WorkspaceAppStatusStateWorking, "Resumed", -45 * time.Minute},
|
||||
{database.WorkspaceAppStatusStateComplete, "Done", -30 * time.Minute}, // 15min working
|
||||
},
|
||||
extraBuilds: []buildSpec{
|
||||
{2, -25 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
|
||||
},
|
||||
expectEvent: true,
|
||||
lastPausedOffset: ptr.Ref(-25 * time.Minute),
|
||||
pauseReason: ptr.Ref("auto"),
|
||||
idleDurationMS: ptr.Ref(20 * time.Minute.Milliseconds()), // Last working (-45) to stop (-25).
|
||||
activeDurationMS: ptr.Ref(45 * time.Minute.Milliseconds()), // 30 + 15 = 45min of "working".
|
||||
pausedDurationMS: ptr.Ref(25 * time.Minute.Milliseconds()), // Ongoing pause.
|
||||
},
|
||||
{
|
||||
// When a workspace_app_status and a workspace_build share
|
||||
// the exact same created_at timestamp, the ordering inside
|
||||
// task_status_timeline is ambiguous. The boundary row must
|
||||
// sort after real statuses so that LEAD() and the lws
|
||||
// lateral join produce deterministic results.
|
||||
name: "status and build at same timestamp - deterministic ordering",
|
||||
createdOffset: -3 * time.Hour,
|
||||
buildOffset: ptr.Ref(-2 * time.Hour),
|
||||
appStatuses: []statusSpec{
|
||||
{database.WorkspaceAppStatusStateWorking, "Started work", -90 * time.Minute},
|
||||
// This status has the exact same timestamp as the
|
||||
// stop build below, exercising the tiebreaker.
|
||||
{database.WorkspaceAppStatusStateWorking, "Last update before pause", -30 * time.Minute},
|
||||
},
|
||||
extraBuilds: []buildSpec{
|
||||
{2, -30 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
|
||||
},
|
||||
expectEvent: true,
|
||||
lastPausedOffset: ptr.Ref(-30 * time.Minute),
|
||||
pauseReason: ptr.Ref("auto"),
|
||||
// IdleDurationMS is nil: the Go code requires
|
||||
// stop.After(lastWorking), which is false when equal.
|
||||
// Active: -90m (working) → -30m (boundary/stop) = 60 min.
|
||||
activeDurationMS: ptr.Ref(60 * time.Minute.Milliseconds()),
|
||||
pausedDurationMS: ptr.Ref(30 * time.Minute.Milliseconds()),
|
||||
},
|
||||
{
|
||||
// SQL filter: EXISTS (workspace_builds.created_at > createdAfter).
|
||||
// This task has only old builds (7 days ago), so it won't match
|
||||
// the 1-hour createdAfter filter and should not return an event.
|
||||
name: "old task with no recent builds - not returned",
|
||||
createdOffset: -7 * 24 * time.Hour,
|
||||
buildOffset: ptr.Ref(-7 * 24 * time.Hour),
|
||||
expectEvent: false,
|
||||
},
|
||||
{
|
||||
// SQL filter: EXISTS (workspace_builds.created_at > createdAfter).
|
||||
// This task was created 7 days ago, but has a recent stop build,
|
||||
// so it should match the filter and return an event.
|
||||
name: "old task with recent build - returned",
|
||||
createdOffset: -7 * 24 * time.Hour,
|
||||
buildOffset: ptr.Ref(-7 * 24 * time.Hour),
|
||||
extraBuilds: []buildSpec{
|
||||
{2, -30 * time.Minute, database.WorkspaceTransitionStop, database.BuildReasonTaskAutoPause, nil},
|
||||
},
|
||||
expectEvent: true,
|
||||
lastPausedOffset: ptr.Ref(-30 * time.Minute),
|
||||
pauseReason: ptr.Ref("auto"),
|
||||
pausedDurationMS: ptr.Ref(30 * time.Minute.Milliseconds()), // Ongoing pause.
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
org, err := db.GetDefaultOrganization(ctx)
|
||||
require.NoError(t, err)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
h := &taskTelemetryHelper{
|
||||
t: t,
|
||||
ctx: ctx,
|
||||
db: db,
|
||||
org: org,
|
||||
user: user,
|
||||
}
|
||||
|
||||
// Create a deleted task. This is a test antagonist that should never show up in results.
|
||||
deletedTaskResp := dbfake.WorkspaceBuild(h.t, h.db, database.WorkspaceTable{
|
||||
OrganizationID: h.org.ID,
|
||||
OwnerID: h.user.ID,
|
||||
}).WithTask(database.TaskTable{
|
||||
Prompt: fmt.Sprintf("deleted-task-%s", t.Name()),
|
||||
CreatedAt: now.Add(-100 * time.Hour),
|
||||
}, nil).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
Reason: database.BuildReasonInitiator,
|
||||
BuildNumber: 1,
|
||||
CreatedAt: now.Add(-100 * time.Hour),
|
||||
}).Succeeded().Do()
|
||||
_, err = db.DeleteTask(h.ctx, database.DeleteTaskParams{
|
||||
DeletedAt: now.Add(-99 * time.Hour),
|
||||
ID: deletedTaskResp.Task.ID,
|
||||
})
|
||||
require.NoError(h.t, err, "creating deleted task antagonist")
|
||||
|
||||
var expectedTask telemetry.Task
|
||||
|
||||
if tt.skipWorkspace {
|
||||
tv := dbgen.TemplateVersion(t, h.db, database.TemplateVersion{
|
||||
OrganizationID: h.org.ID,
|
||||
CreatedBy: h.user.ID,
|
||||
HasAITask: sql.NullBool{Bool: true, Valid: true},
|
||||
})
|
||||
task := dbgen.Task(h.t, h.db, database.TaskTable{
|
||||
OwnerID: h.user.ID,
|
||||
OrganizationID: h.org.ID,
|
||||
WorkspaceID: uuid.NullUUID{},
|
||||
TemplateVersionID: tv.ID,
|
||||
Prompt: fmt.Sprintf("pending-task-%s", t.Name()),
|
||||
CreatedAt: now.Add(tt.createdOffset),
|
||||
})
|
||||
expectedTask = telemetry.Task{
|
||||
ID: task.ID.String(),
|
||||
OrganizationID: h.org.ID.String(),
|
||||
OwnerID: h.user.ID.String(),
|
||||
Name: task.Name,
|
||||
TemplateVersionID: tv.ID.String(),
|
||||
PromptHash: telemetry.HashContent(task.Prompt),
|
||||
Status: "pending",
|
||||
CreatedAt: task.CreatedAt,
|
||||
}
|
||||
} else {
|
||||
buildCreatedAt := now.Add(tt.createdOffset)
|
||||
if tt.buildOffset != nil {
|
||||
buildCreatedAt = now.Add(*tt.buildOffset)
|
||||
}
|
||||
|
||||
resp := dbfake.WorkspaceBuild(h.t, h.db, database.WorkspaceTable{
|
||||
OrganizationID: h.org.ID,
|
||||
OwnerID: h.user.ID,
|
||||
}).WithTask(database.TaskTable{
|
||||
Prompt: fmt.Sprintf("task-%s", t.Name()),
|
||||
CreatedAt: now.Add(tt.createdOffset),
|
||||
}, nil).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
Reason: database.BuildReasonInitiator,
|
||||
BuildNumber: 1,
|
||||
CreatedAt: buildCreatedAt,
|
||||
}).Succeeded().Do()
|
||||
|
||||
app := getApp(h.ctx, h.db, resp.Agents[0].ID)
|
||||
|
||||
for _, s := range tt.appStatuses {
|
||||
createAppStatus(h.ctx, h.db, resp.Workspace.ID, resp.Agents[0].ID, app.ID, s.state, s.message, now.Add(s.offset))
|
||||
}
|
||||
|
||||
for _, b := range tt.extraBuilds {
|
||||
bld, bldApp := h.createBuild(resp, b.buildNumber, now.Add(b.offset), b.transition, b.reason)
|
||||
_ = bld
|
||||
if bldApp != nil {
|
||||
for _, s := range b.statuses {
|
||||
createAppStatus(h.ctx, h.db, resp.Workspace.ID, resp.Agents[0].ID, bldApp.ID, s.state, s.message, now.Add(s.offset))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh the task
|
||||
updated, err := h.db.GetTaskByID(ctx, resp.Task.ID)
|
||||
require.NoError(t, err, "fetching updated task")
|
||||
expectedTask = telemetry.Task{
|
||||
ID: updated.ID.String(),
|
||||
OrganizationID: updated.OrganizationID.String(),
|
||||
OwnerID: updated.OwnerID.String(),
|
||||
Name: updated.Name,
|
||||
WorkspaceID: ptr.Ref(updated.WorkspaceID.UUID.String()),
|
||||
WorkspaceBuildNumber: ptr.Ref(int64(updated.WorkspaceBuildNumber.Int32)),
|
||||
WorkspaceAgentID: ptr.Ref(updated.WorkspaceAgentID.UUID.String()),
|
||||
WorkspaceAppID: ptr.Ref(updated.WorkspaceAppID.UUID.String()),
|
||||
TemplateVersionID: updated.TemplateVersionID.String(),
|
||||
PromptHash: telemetry.HashContent(updated.Prompt),
|
||||
Status: string(updated.Status),
|
||||
CreatedAt: updated.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
actualTasks, err := telemetry.CollectTasks(h.ctx, h.db)
|
||||
require.NoError(t, err, "unexpected error collecting tasks telemetry")
|
||||
// Invariant: deleted tasks should NEVER appear in results.
|
||||
require.Len(t, actualTasks, 1, "expected exactly one task")
|
||||
|
||||
if diff := cmp.Diff(expectedTask, actualTasks[0]); diff != "" {
|
||||
t.Fatalf("test case %q: task diff (-want +got):\n%s", tt.name, diff)
|
||||
}
|
||||
|
||||
actualEvents, err := telemetry.CollectTaskEvents(h.ctx, h.db, now.Add(-1*time.Hour), now)
|
||||
require.NoError(t, err)
|
||||
if !tt.expectEvent {
|
||||
require.Empty(t, actualEvents)
|
||||
} else {
|
||||
expectedEvent := telemetry.TaskEvent{
|
||||
TaskID: expectedTask.ID,
|
||||
}
|
||||
if tt.lastPausedOffset != nil {
|
||||
t := now.Add(*tt.lastPausedOffset)
|
||||
expectedEvent.LastPausedAt = &t
|
||||
}
|
||||
if tt.lastResumedOffset != nil {
|
||||
t := now.Add(*tt.lastResumedOffset)
|
||||
expectedEvent.LastResumedAt = &t
|
||||
}
|
||||
expectedEvent.PauseReason = tt.pauseReason
|
||||
expectedEvent.ResumeReason = tt.resumeReason
|
||||
expectedEvent.IdleDurationMS = tt.idleDurationMS
|
||||
expectedEvent.PausedDurationMS = tt.pausedDurationMS
|
||||
expectedEvent.ResumeToStatusMS = tt.resumeToStatusMS
|
||||
expectedEvent.ActiveDurationMS = tt.activeDurationMS
|
||||
|
||||
// Each test case creates exactly one workspace with lifecycle
|
||||
// activity, so we expect exactly one event.
|
||||
require.Len(t, actualEvents, 1)
|
||||
actual := actualEvents[0]
|
||||
|
||||
if diff := cmp.Diff(expectedEvent, actual); diff != "" {
|
||||
t.Fatalf("test case %q: event diff (-want +got):\n%s", tt.name, diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockDB struct {
|
||||
database.Store
|
||||
}
|
||||
@@ -1350,7 +767,7 @@ func TestRecordTelemetryStatus(t *testing.T) {
|
||||
require.Nil(t, snapshot1)
|
||||
}
|
||||
|
||||
for range 3 {
|
||||
for i := 0; i < 3; i++ {
|
||||
// Whatever happens, subsequent calls should not report if telemetryEnabled didn't change
|
||||
snapshot2, err := telemetry.RecordTelemetryStatus(ctx, logger, db, testCase.telemetryEnabled)
|
||||
require.NoError(t, err)
|
||||
|
||||
+205
-39
@@ -25,7 +25,6 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
@@ -36,11 +35,14 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/coderd/prebuilds"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
maputil "github.com/coder/coder/v2/coderd/util/maps"
|
||||
strutil "github.com/coder/coder/v2/coderd/util/strings"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
@@ -293,7 +295,6 @@ func (api *API) patchWorkspaceAgentLogs(rw http.ResponseWriter, r *http.Request)
|
||||
// @Param request body agentsdk.PatchAppStatus true "app status"
|
||||
// @Success 200 {object} codersdk.Response
|
||||
// @Router /workspaceagents/me/app-status [patch]
|
||||
// @Deprecated Use UpdateAppStatus on the Agent API instead.
|
||||
func (api *API) patchWorkspaceAgentAppStatus(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
@@ -303,6 +304,45 @@ func (api *API) patchWorkspaceAgentAppStatus(rw http.ResponseWriter, r *http.Req
|
||||
return
|
||||
}
|
||||
|
||||
app, err := api.Database.GetWorkspaceAppByAgentIDAndSlug(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{
|
||||
AgentID: workspaceAgent.ID,
|
||||
Slug: req.AppSlug,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to get workspace app.",
|
||||
Detail: fmt.Sprintf("No app found with slug %q", req.AppSlug),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Message) > 160 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Message is too long.",
|
||||
Detail: "Message must be less than 160 characters.",
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "message", Detail: "Message must be less than 160 characters."},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
switch req.State {
|
||||
case codersdk.WorkspaceAppStatusStateComplete,
|
||||
codersdk.WorkspaceAppStatusStateFailure,
|
||||
codersdk.WorkspaceAppStatusStateWorking,
|
||||
codersdk.WorkspaceAppStatusStateIdle: // valid states
|
||||
default:
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid state provided.",
|
||||
Detail: fmt.Sprintf("invalid state: %q", req.State),
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "state", Detail: "State must be one of: complete, failure, working."},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
workspace, err := api.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
@@ -312,50 +352,176 @@ func (api *API) patchWorkspaceAgentAppStatus(rw http.ResponseWriter, r *http.Req
|
||||
return
|
||||
}
|
||||
|
||||
// This functionality has been moved to the AppsAPI in the agentapi. We keep this HTTP handler around for back
|
||||
// compatibility with older agents. We'll translate the request into the protobuf so there is only one primary
|
||||
// implementation.
|
||||
appAPI := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return workspaceAgent, nil
|
||||
},
|
||||
Database: api.Database,
|
||||
Log: api.Logger,
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
api.publishWorkspaceUpdate(ctx, workspace.OwnerID, wspubsub.WorkspaceEvent{
|
||||
Kind: kind,
|
||||
WorkspaceID: workspace.ID,
|
||||
AgentID: &agent.ID,
|
||||
})
|
||||
return nil
|
||||
},
|
||||
NotificationsEnqueuer: api.NotificationsEnqueuer,
|
||||
Clock: api.Clock,
|
||||
}
|
||||
protoReq, err := agentsdk.ProtoFromPatchAppStatus(req)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to parse request.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
_, err = appAPI.UpdateAppStatus(r.Context(), protoReq)
|
||||
if err != nil {
|
||||
sdkErr := new(codersdk.Error)
|
||||
if xerrors.As(err, &sdkErr) {
|
||||
httpapi.Write(ctx, rw, sdkErr.StatusCode(), sdkErr.Response)
|
||||
return
|
||||
}
|
||||
// Treat the message as untrusted input.
|
||||
cleaned := strutil.UISanitize(req.Message)
|
||||
|
||||
// Get the latest status for the workspace app to detect no-op updates
|
||||
// nolint:gocritic // This is a system restricted operation.
|
||||
latestAppStatus, err := api.Database.GetLatestWorkspaceAppStatusByAppID(dbauthz.AsSystemRestricted(ctx), app.ID)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to update app status.",
|
||||
Message: "Failed to get latest workspace app status.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// If no rows found, latestAppStatus will be a zero-value struct (ID == uuid.Nil)
|
||||
|
||||
// nolint:gocritic // This is a system restricted operation.
|
||||
_, err = api.Database.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
|
||||
ID: uuid.New(),
|
||||
CreatedAt: dbtime.Now(),
|
||||
WorkspaceID: workspace.ID,
|
||||
AgentID: workspaceAgent.ID,
|
||||
AppID: app.ID,
|
||||
State: database.WorkspaceAppStatusState(req.State),
|
||||
Message: cleaned,
|
||||
Uri: sql.NullString{
|
||||
String: req.URI,
|
||||
Valid: req.URI != "",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to insert workspace app status.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
api.publishWorkspaceUpdate(ctx, workspace.OwnerID, wspubsub.WorkspaceEvent{
|
||||
Kind: wspubsub.WorkspaceEventKindAgentAppStatusUpdate,
|
||||
WorkspaceID: workspace.ID,
|
||||
AgentID: &workspaceAgent.ID,
|
||||
})
|
||||
|
||||
// Notify on state change to Working/Idle for AI tasks
|
||||
api.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, req.State, workspace, workspaceAgent)
|
||||
|
||||
// Bump deadline when agent reports working or transitions away from working.
|
||||
// This prevents auto-pause during active work and gives users time to interact
|
||||
// after work completes.
|
||||
shouldBump := false
|
||||
newState := database.WorkspaceAppStatusState(req.State)
|
||||
|
||||
// Bump if reporting working state.
|
||||
if newState == database.WorkspaceAppStatusStateWorking {
|
||||
shouldBump = true
|
||||
}
|
||||
|
||||
// Bump if transitioning away from working state.
|
||||
if latestAppStatus.ID != uuid.Nil {
|
||||
prevState := latestAppStatus.State
|
||||
if prevState == database.WorkspaceAppStatusStateWorking {
|
||||
shouldBump = true
|
||||
}
|
||||
}
|
||||
if shouldBump {
|
||||
// We pass time.Time{} for nextAutostart since we don't have access to
|
||||
// TemplateScheduleStore here. The activity bump logic handles this by
|
||||
// defaulting to the template's activity_bump duration (typically 1 hour).
|
||||
workspacestats.ActivityBumpWorkspace(ctx, api.Logger, api.Database, workspace.ID, time.Time{})
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, nil)
|
||||
}
|
||||
|
||||
// enqueueAITaskStateNotification enqueues a notification when an AI task's app
|
||||
// transitions to Working or Idle.
|
||||
// No-op if:
|
||||
// - the workspace agent app isn't configured as an AI task,
|
||||
// - the new state equals the latest persisted state,
|
||||
// - the workspace agent is not ready (still starting up).
|
||||
func (api *API) enqueueAITaskStateNotification(
|
||||
ctx context.Context,
|
||||
appID uuid.UUID,
|
||||
latestAppStatus database.WorkspaceAppStatus,
|
||||
newAppStatus codersdk.WorkspaceAppStatusState,
|
||||
workspace database.Workspace,
|
||||
agent database.WorkspaceAgent,
|
||||
) {
|
||||
// Select notification template based on the new state
|
||||
var notificationTemplate uuid.UUID
|
||||
switch newAppStatus {
|
||||
case codersdk.WorkspaceAppStatusStateWorking:
|
||||
notificationTemplate = notifications.TemplateTaskWorking
|
||||
case codersdk.WorkspaceAppStatusStateIdle:
|
||||
notificationTemplate = notifications.TemplateTaskIdle
|
||||
case codersdk.WorkspaceAppStatusStateComplete:
|
||||
notificationTemplate = notifications.TemplateTaskCompleted
|
||||
case codersdk.WorkspaceAppStatusStateFailure:
|
||||
notificationTemplate = notifications.TemplateTaskFailed
|
||||
default:
|
||||
// Not a notifiable state, do nothing
|
||||
return
|
||||
}
|
||||
|
||||
if !workspace.TaskID.Valid {
|
||||
// Workspace has no task ID, do nothing.
|
||||
return
|
||||
}
|
||||
|
||||
// Only send notifications when the agent is ready. We want to skip
|
||||
// any state transitions that occur whilst the workspace is starting
|
||||
// up as it doesn't make sense to receive them.
|
||||
if agent.LifecycleState != database.WorkspaceAgentLifecycleStateReady {
|
||||
api.Logger.Debug(ctx, "skipping AI task notification because agent is not ready",
|
||||
slog.F("agent_id", agent.ID),
|
||||
slog.F("lifecycle_state", agent.LifecycleState),
|
||||
slog.F("new_app_status", newAppStatus),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
task, err := api.Database.GetTaskByID(ctx, workspace.TaskID.UUID)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to get task", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
if !task.WorkspaceAppID.Valid || task.WorkspaceAppID.UUID != appID {
|
||||
// Non-task app, do nothing.
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if the latest persisted state equals the new state (no new transition)
|
||||
// Note: uuid.Nil check is valid here. If no previous status exists,
|
||||
// GetLatestWorkspaceAppStatusByAppID returns sql.ErrNoRows and we get a zero-value struct.
|
||||
if latestAppStatus.ID != uuid.Nil && latestAppStatus.State == database.WorkspaceAppStatusState(newAppStatus) {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip the initial "Working" notification when task first starts.
|
||||
// This is obvious to the user since they just created the task.
|
||||
// We still notify on first "Idle" status and all subsequent transitions.
|
||||
if latestAppStatus.ID == uuid.Nil && newAppStatus == codersdk.WorkspaceAppStatusStateWorking {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := api.NotificationsEnqueuer.EnqueueWithData(
|
||||
// nolint:gocritic // Need notifier actor to enqueue notifications
|
||||
dbauthz.AsNotifier(ctx),
|
||||
workspace.OwnerID,
|
||||
notificationTemplate,
|
||||
map[string]string{
|
||||
"task": task.Name,
|
||||
"workspace": workspace.Name,
|
||||
},
|
||||
map[string]any{
|
||||
// Use a 1-minute bucketed timestamp to bypass per-day dedupe,
|
||||
// allowing identical content to resend within the same day
|
||||
// (but not more than once every 10s).
|
||||
"dedupe_bypass_ts": api.Clock.Now().UTC().Truncate(time.Minute),
|
||||
},
|
||||
"api-workspace-agent-app-status",
|
||||
// Associate this notification with related entities
|
||||
workspace.ID, workspace.OwnerID, workspace.OrganizationID, appID,
|
||||
); err != nil {
|
||||
api.Logger.Warn(ctx, "failed to notify of task state", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// workspaceAgentLogs returns the logs associated with a workspace agent
|
||||
//
|
||||
// @Summary Get logs by workspace agent
|
||||
@@ -2045,7 +2211,7 @@ func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.R
|
||||
// No point in trying to validate the same token over and over again.
|
||||
if previousToken.OAuthAccessToken == externalAuthLink.OAuthAccessToken &&
|
||||
previousToken.OAuthRefreshToken == externalAuthLink.OAuthRefreshToken &&
|
||||
previousToken.OAuthExpiry.Equal(externalAuthLink.OAuthExpiry) {
|
||||
previousToken.OAuthExpiry == externalAuthLink.OAuthExpiry {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -2784,12 +2784,12 @@ func TestWorkspaceAgentExternalAuthListen(t *testing.T) {
|
||||
const providerID = "fake-idp"
|
||||
|
||||
// Count all the times we call validate
|
||||
var validateCalls atomic.Int32
|
||||
validateCalls := 0
|
||||
fake := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithMiddlewares(func(handler http.Handler) http.Handler {
|
||||
return http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Count all the validate calls
|
||||
if strings.Contains(r.URL.Path, "/external-auth-validate/") {
|
||||
validateCalls.Add(1)
|
||||
validateCalls++
|
||||
}
|
||||
handler.ServeHTTP(w, r)
|
||||
}))
|
||||
@@ -2852,7 +2852,7 @@ func TestWorkspaceAgentExternalAuthListen(t *testing.T) {
|
||||
// other should be skipped.
|
||||
// In a failed test, you will likely see 9, as the last one
|
||||
// gets canceled.
|
||||
require.EqualValues(t, 1, validateCalls.Load(), "validate calls duplicated on same token")
|
||||
require.Equal(t, 1, validateCalls, "validate calls duplicated on same token")
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -19,10 +19,7 @@ var (
|
||||
appURL = regexp.MustCompile(fmt.Sprintf(
|
||||
`^(?P<AppSlug>%[1]s)(?:--(?P<AgentName>%[1]s))?--(?P<WorkspaceName>%[1]s)--(?P<Username>%[1]s)$`,
|
||||
nameRegex))
|
||||
// PortRegex should not be able to be greater than 65535. In usage though, if a
|
||||
// user tries to use a greater port, the proxy will just block it and not cause
|
||||
// any issues. This is a good enough regex check.
|
||||
PortRegex = regexp.MustCompile(`^\d{4,5}s?$`)
|
||||
PortRegex = regexp.MustCompile(`^\d{4}s?$`)
|
||||
|
||||
validHostnameLabelRegex = regexp.MustCompile(`^[a-z0-9]([-a-z0-9]*[a-z0-9])?$`)
|
||||
)
|
||||
|
||||
@@ -193,16 +193,6 @@ func TestParseSubdomainAppURL(t *testing.T) {
|
||||
Username: "user",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Port(5)--Agent--Workspace--User",
|
||||
Subdomain: "12412--agent--workspace--user",
|
||||
Expected: appurl.ApplicationURL{
|
||||
AppSlugOrPort: "12412",
|
||||
AgentName: "agent",
|
||||
WorkspaceName: "workspace",
|
||||
Username: "user",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Port--Agent--Workspace--User",
|
||||
Subdomain: "8080s--agent--workspace--user",
|
||||
@@ -235,11 +225,11 @@ func TestParseSubdomainAppURL(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "5DigitPort--agent--Workspace--User",
|
||||
Subdomain: "30000--agent--workspace--user",
|
||||
Name: "5DigitAppSlug--Workspace--User",
|
||||
Subdomain: "30000--workspace--user",
|
||||
Expected: appurl.ApplicationURL{
|
||||
AppSlugOrPort: "30000",
|
||||
AgentName: "agent",
|
||||
AgentName: "",
|
||||
WorkspaceName: "workspace",
|
||||
Username: "user",
|
||||
},
|
||||
@@ -609,14 +599,6 @@ func TestURLGenerationVsParsing(t *testing.T) {
|
||||
Name: "5DigitAppSlug_AgentOmittedInParsing",
|
||||
AppSlugOrPort: "30000",
|
||||
AgentName: "agent",
|
||||
ExpectedParsed: "agent",
|
||||
},
|
||||
{
|
||||
// 6 digits is not a valid port, so it is treated as an app slug.
|
||||
// App slugs do not require the agent name, so it is dropped
|
||||
Name: "6DigitAppSlug_AgentOmittedInParsing",
|
||||
AppSlugOrPort: "300000",
|
||||
AgentName: "agent",
|
||||
ExpectedParsed: "",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -856,24 +856,32 @@ func (api *API) workspaceBuildLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
func (api *API) workspaceBuildState(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
workspaceBuild := httpmw.WorkspaceBuildParam(r)
|
||||
|
||||
// The dbauthz layer enforces policy.ActionUpdate on the template.
|
||||
row, err := api.Database.GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuild.ID)
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
workspace, err := api.Database.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching provisioner state.",
|
||||
Message: "No workspace exists for this job.",
|
||||
})
|
||||
return
|
||||
}
|
||||
template, err := api.Database.GetTemplateByID(ctx, workspace.TemplateID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get template",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// You must have update permissions on the template to get the state.
|
||||
// This matches a push!
|
||||
if !api.Authorize(r, policy.ActionUpdate, template.RBACObject()) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = rw.Write(row.ProvisionerState)
|
||||
_, _ = rw.Write(workspaceBuild.ProvisionerState)
|
||||
}
|
||||
|
||||
// @Summary Update workspace build state
|
||||
|
||||
+17
-2
@@ -114,6 +114,7 @@ func (api *API) workspace(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w, err := convertWorkspace(
|
||||
ctx,
|
||||
api.Experiments,
|
||||
api.Logger,
|
||||
apiKey.UserID,
|
||||
workspace,
|
||||
@@ -239,6 +240,7 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
wss, err := convertWorkspaces(
|
||||
ctx,
|
||||
api.Experiments,
|
||||
api.Logger,
|
||||
apiKey.UserID,
|
||||
workspaces,
|
||||
@@ -334,6 +336,7 @@ func (api *API) workspaceByOwnerAndName(rw http.ResponseWriter, r *http.Request)
|
||||
|
||||
w, err := convertWorkspace(
|
||||
ctx,
|
||||
api.Experiments,
|
||||
api.Logger,
|
||||
apiKey.UserID,
|
||||
workspace,
|
||||
@@ -865,6 +868,7 @@ func createWorkspace(
|
||||
|
||||
w, err := convertWorkspace(
|
||||
ctx,
|
||||
api.Experiments,
|
||||
api.Logger,
|
||||
initiatorID,
|
||||
workspace,
|
||||
@@ -1510,6 +1514,7 @@ func (api *API) putWorkspaceDormant(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w, err := convertWorkspace(
|
||||
ctx,
|
||||
api.Experiments,
|
||||
api.Logger,
|
||||
apiKey.UserID,
|
||||
workspace,
|
||||
@@ -2089,6 +2094,7 @@ func (api *API) watchWorkspace(
|
||||
}
|
||||
w, err := convertWorkspace(
|
||||
ctx,
|
||||
api.Experiments,
|
||||
api.Logger,
|
||||
apiKey.UserID,
|
||||
workspace,
|
||||
@@ -2230,7 +2236,8 @@ func (api *API) workspaceACL(rw http.ResponseWriter, r *http.Request) {
|
||||
// the case here. This data goes directly to an unauthorized user. We are
|
||||
// just straight up breaking security promises.
|
||||
//
|
||||
// TODO: This needs to be fixed before GA. Currently in beta.
|
||||
// Fine for now while behind the shared-workspaces experiment, but needs to
|
||||
// be fixed before GA.
|
||||
|
||||
// Fetch all of the users and their organization memberships
|
||||
userIDs := make([]uuid.UUID, 0, len(workspaceACL.Users))
|
||||
@@ -2588,6 +2595,7 @@ func (api *API) workspaceData(ctx context.Context, workspaces []database.Workspa
|
||||
|
||||
func convertWorkspaces(
|
||||
ctx context.Context,
|
||||
experiments codersdk.Experiments,
|
||||
logger slog.Logger,
|
||||
requesterID uuid.UUID,
|
||||
workspaces []database.Workspace,
|
||||
@@ -2625,6 +2633,7 @@ func convertWorkspaces(
|
||||
|
||||
w, err := convertWorkspace(
|
||||
ctx,
|
||||
experiments,
|
||||
logger,
|
||||
requesterID,
|
||||
workspace,
|
||||
@@ -2644,6 +2653,7 @@ func convertWorkspaces(
|
||||
|
||||
func convertWorkspace(
|
||||
ctx context.Context,
|
||||
experiments codersdk.Experiments,
|
||||
logger slog.Logger,
|
||||
requesterID uuid.UUID,
|
||||
workspace database.Workspace,
|
||||
@@ -2742,15 +2752,20 @@ func convertWorkspace(
|
||||
NextStartAt: nextStartAt,
|
||||
IsPrebuild: workspace.IsPrebuild(),
|
||||
TaskID: workspace.TaskID,
|
||||
SharedWith: sharedWorkspaceActors(ctx, logger, workspace),
|
||||
SharedWith: sharedWorkspaceActors(ctx, experiments, logger, workspace),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func sharedWorkspaceActors(
|
||||
ctx context.Context,
|
||||
experiments codersdk.Experiments,
|
||||
logger slog.Logger,
|
||||
workspace database.Workspace,
|
||||
) []codersdk.SharedWorkspaceActor {
|
||||
if !experiments.Enabled(codersdk.ExperimentWorkspaceSharing) {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]codersdk.SharedWorkspaceActor, 0, len(workspace.UserACL)+len(workspace.GroupACL))
|
||||
|
||||
// Users
|
||||
|
||||
@@ -1899,6 +1899,7 @@ func TestWorkspaceFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
@@ -1936,6 +1937,7 @@ func TestWorkspaceFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
@@ -1973,6 +1975,7 @@ func TestWorkspaceFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
@@ -2010,6 +2013,7 @@ func TestWorkspaceFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
@@ -5245,7 +5249,7 @@ func TestUpdateWorkspaceACL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
adminClient := coderdtest.New(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
DeploymentValues: dv,
|
||||
@@ -5281,7 +5285,7 @@ func TestUpdateWorkspaceACL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
adminClient := coderdtest.New(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
DeploymentValues: dv,
|
||||
@@ -5314,7 +5318,7 @@ func TestUpdateWorkspaceACL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
adminClient := coderdtest.New(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
DeploymentValues: dv,
|
||||
@@ -5354,7 +5358,7 @@ func TestUpdateWorkspaceACL(t *testing.T) {
|
||||
t.Cleanup(func() { rbac.SetWorkspaceACLDisabled(prevWorkspaceACLDisabled) })
|
||||
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
adminClient := coderdtest.New(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
DeploymentValues: dv,
|
||||
@@ -5422,7 +5426,11 @@ func TestDeleteWorkspaceACL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
admin = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
|
||||
_, toShareWithUser = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
|
||||
@@ -5453,7 +5461,11 @@ func TestDeleteWorkspaceACL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
admin = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
|
||||
sharedUseClient, toShareWithUser = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
|
||||
@@ -5492,7 +5504,11 @@ func TestWorkspaceReadCanListACL(t *testing.T) {
|
||||
t.Cleanup(func() { rbac.SetWorkspaceACLDisabled(prevWorkspaceACLDisabled) })
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
admin = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
|
||||
sharedUserClientA, sharedUserA = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
|
||||
@@ -5542,6 +5558,7 @@ func TestWorkspaceSharingDisabled(t *testing.T) {
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
// DisableWorkspaceSharing is false (default)
|
||||
}),
|
||||
})
|
||||
@@ -5575,6 +5592,7 @@ func TestWorkspaceSharingDisabled(t *testing.T) {
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
dv.DisableWorkspaceSharing = true
|
||||
}),
|
||||
})
|
||||
|
||||
@@ -452,7 +452,7 @@ func (b *Builder) buildTx(authFunc func(action policy.Action, object rbac.Object
|
||||
// to read all provisioner daemons. We need to retrieve the eligible
|
||||
// provisioner daemons for this job to show in the UI if there is no
|
||||
// matching provisioner daemon.
|
||||
provisionerDaemons, err := b.store.GetEligibleProvisionerDaemonsByProvisionerJobIDs(dbauthz.AsWorkspaceBuilder(b.ctx), []uuid.UUID{provisionerJob.ID})
|
||||
provisionerDaemons, err := b.store.GetEligibleProvisionerDaemonsByProvisionerJobIDs(dbauthz.AsSystemReadProvisionerDaemons(b.ctx), []uuid.UUID{provisionerJob.ID})
|
||||
if err != nil {
|
||||
// NOTE: we do **not** want to fail a workspace build if we fail to
|
||||
// retrieve provisioner daemons. This is just to show in the UI if there
|
||||
@@ -570,8 +570,8 @@ func (b *Builder) buildTx(authFunc func(action policy.Action, object rbac.Object
|
||||
}
|
||||
}
|
||||
if b.state.orphan && !hasActiveEligibleProvisioner {
|
||||
// nolint: gocritic // User won't necessarily have the permission to do this so we act as a system user.
|
||||
if err := store.UpdateProvisionerJobWithCompleteWithStartedAtByID(dbauthz.AsWorkspaceBuilder(b.ctx), database.UpdateProvisionerJobWithCompleteWithStartedAtByIDParams{
|
||||
// nolint: gocritic // At this moment, we are pretending to be provisionerd.
|
||||
if err := store.UpdateProvisionerJobWithCompleteWithStartedAtByID(dbauthz.AsProvisionerd(b.ctx), database.UpdateProvisionerJobWithCompleteWithStartedAtByIDParams{
|
||||
CompletedAt: sql.NullTime{Valid: true, Time: now},
|
||||
Error: sql.NullString{Valid: false},
|
||||
ErrorCode: sql.NullString{Valid: false},
|
||||
@@ -815,12 +815,7 @@ func (b *Builder) getState() ([]byte, error) {
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get last build to get state: %w", err)
|
||||
}
|
||||
// nolint: gocritic // Workspace builder needs to read provisioner state for the new build.
|
||||
state, err := b.store.GetWorkspaceBuildProvisionerStateByID(dbauthz.AsWorkspaceBuilder(b.ctx), bld.ID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get workspace build provisioner state: %w", err)
|
||||
}
|
||||
return state.ProvisionerState, nil
|
||||
return bld.ProvisionerState, nil
|
||||
}
|
||||
|
||||
func (b *Builder) getParameters() (names, values []string, err error) {
|
||||
|
||||
@@ -65,7 +65,6 @@ func TestBuilder_NoOptions(t *testing.T) {
|
||||
withTemplate,
|
||||
withInactiveVersion(nil),
|
||||
withLastBuildFound,
|
||||
withLastBuildState,
|
||||
withTemplateVersionVariables(inactiveVersionID, nil),
|
||||
withRichParameters(nil),
|
||||
withParameterSchemas(inactiveJobID, nil),
|
||||
@@ -125,7 +124,6 @@ func TestBuilder_Initiator(t *testing.T) {
|
||||
withTemplate,
|
||||
withInactiveVersion(nil),
|
||||
withLastBuildFound,
|
||||
withLastBuildState,
|
||||
withTemplateVersionVariables(inactiveVersionID, nil),
|
||||
withRichParameters(nil),
|
||||
withParameterSchemas(inactiveJobID, nil),
|
||||
@@ -176,7 +174,6 @@ func TestBuilder_Baggage(t *testing.T) {
|
||||
withTemplate,
|
||||
withInactiveVersion(nil),
|
||||
withLastBuildFound,
|
||||
withLastBuildState,
|
||||
withTemplateVersionVariables(inactiveVersionID, nil),
|
||||
withRichParameters(nil),
|
||||
withParameterSchemas(inactiveJobID, nil),
|
||||
@@ -219,7 +216,6 @@ func TestBuilder_Reason(t *testing.T) {
|
||||
withTemplate,
|
||||
withInactiveVersion(nil),
|
||||
withLastBuildFound,
|
||||
withLastBuildState,
|
||||
withTemplateVersionVariables(inactiveVersionID, nil),
|
||||
withRichParameters(nil),
|
||||
withParameterSchemas(inactiveJobID, nil),
|
||||
@@ -369,7 +365,6 @@ func TestWorkspaceBuildWithTags(t *testing.T) {
|
||||
withTemplate,
|
||||
withInactiveVersion(richParameters),
|
||||
withLastBuildFound,
|
||||
withLastBuildState,
|
||||
withTemplateVersionVariables(inactiveVersionID, templateVersionVariables),
|
||||
withRichParameters(nil),
|
||||
withParameterSchemas(inactiveJobID, nil),
|
||||
@@ -469,7 +464,6 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) {
|
||||
withTemplate,
|
||||
withInactiveVersion(richParameters),
|
||||
withLastBuildFound,
|
||||
withLastBuildState,
|
||||
withTemplateVersionVariables(inactiveVersionID, nil),
|
||||
withRichParameters(initialBuildParameters),
|
||||
withParameterSchemas(inactiveJobID, nil),
|
||||
@@ -521,7 +515,6 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) {
|
||||
withTemplate,
|
||||
withInactiveVersion(richParameters),
|
||||
withLastBuildFound,
|
||||
withLastBuildState,
|
||||
withTemplateVersionVariables(inactiveVersionID, nil),
|
||||
withRichParameters(initialBuildParameters),
|
||||
withParameterSchemas(inactiveJobID, nil),
|
||||
@@ -668,7 +661,6 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) {
|
||||
withTemplate,
|
||||
withActiveVersion(version2params),
|
||||
withLastBuildFound,
|
||||
withLastBuildState,
|
||||
withTemplateVersionVariables(activeVersionID, nil),
|
||||
withRichParameters(initialBuildParameters),
|
||||
withParameterSchemas(activeJobID, nil),
|
||||
@@ -735,7 +727,6 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) {
|
||||
withTemplate,
|
||||
withActiveVersion(version2params),
|
||||
withLastBuildFound,
|
||||
withLastBuildState,
|
||||
withTemplateVersionVariables(activeVersionID, nil),
|
||||
withRichParameters(initialBuildParameters),
|
||||
withParameterSchemas(activeJobID, nil),
|
||||
@@ -800,7 +791,6 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) {
|
||||
withTemplate,
|
||||
withActiveVersion(version2params),
|
||||
withLastBuildFound,
|
||||
withLastBuildState,
|
||||
withTemplateVersionVariables(activeVersionID, nil),
|
||||
withRichParameters(initialBuildParameters),
|
||||
withParameterSchemas(activeJobID, nil),
|
||||
@@ -1072,7 +1062,6 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) {
|
||||
withTemplate,
|
||||
withInactiveVersion(nil),
|
||||
withLastBuildFound,
|
||||
withLastBuildState,
|
||||
withTemplateVersionVariables(inactiveVersionID, nil),
|
||||
withRichParameters(nil),
|
||||
withParameterSchemas(inactiveJobID, nil),
|
||||
@@ -1186,7 +1175,6 @@ func TestWorkspaceBuildWithTask(t *testing.T) {
|
||||
withTemplate,
|
||||
withInactiveVersion(nil),
|
||||
withLastBuildFound,
|
||||
withLastBuildState,
|
||||
withTemplateVersionVariables(inactiveVersionID, nil),
|
||||
withRichParameters(nil),
|
||||
withParameterSchemas(inactiveJobID, nil),
|
||||
@@ -1390,6 +1378,7 @@ func withLastBuildFound(mTx *dbmock.MockStore) {
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
InitiatorID: userID,
|
||||
JobID: lastBuildJobID,
|
||||
ProvisionerState: []byte("last build state"),
|
||||
Reason: database.BuildReasonInitiator,
|
||||
}, nil)
|
||||
|
||||
@@ -1409,14 +1398,6 @@ func withLastBuildFound(mTx *dbmock.MockStore) {
|
||||
}, nil)
|
||||
}
|
||||
|
||||
func withLastBuildState(mTx *dbmock.MockStore) {
|
||||
mTx.EXPECT().GetWorkspaceBuildProvisionerStateByID(gomock.Any(), lastBuildID).
|
||||
Times(1).
|
||||
Return(database.GetWorkspaceBuildProvisionerStateByIDRow{
|
||||
ProvisionerState: []byte("last build state"),
|
||||
}, nil)
|
||||
}
|
||||
|
||||
func withLastBuildNotFound(mTx *dbmock.MockStore) {
|
||||
mTx.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).
|
||||
Times(1).
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -320,15 +321,21 @@ func (c *Client) connectRPCVersion(ctx context.Context, version *apiversion.APIV
|
||||
}
|
||||
rpcURL.RawQuery = q.Encode()
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create cookie jar: %w", err)
|
||||
}
|
||||
jar.SetCookies(rpcURL, []*http.Cookie{{
|
||||
Name: codersdk.SessionTokenCookie,
|
||||
Value: c.SDK.SessionToken(),
|
||||
}})
|
||||
httpClient := &http.Client{
|
||||
Jar: jar,
|
||||
Transport: c.SDK.HTTPClient.Transport,
|
||||
}
|
||||
// nolint:bodyclose
|
||||
conn, res, err := websocket.Dial(ctx, rpcURL.String(), &websocket.DialOptions{
|
||||
HTTPClient: httpClient,
|
||||
HTTPHeader: http.Header{
|
||||
codersdk.SessionTokenHeader: []string{c.SDK.SessionToken()},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
if res == nil {
|
||||
@@ -702,7 +709,16 @@ func (c *Client) WaitForReinit(ctx context.Context) (*ReinitializationEvent, err
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create cookie jar: %w", err)
|
||||
}
|
||||
jar.SetCookies(rpcURL, []*http.Cookie{{
|
||||
Name: codersdk.SessionTokenCookie,
|
||||
Value: c.SDK.SessionToken(),
|
||||
}})
|
||||
httpClient := &http.Client{
|
||||
Jar: jar,
|
||||
Transport: c.SDK.HTTPClient.Transport,
|
||||
}
|
||||
|
||||
@@ -710,7 +726,6 @@ func (c *Client) WaitForReinit(ctx context.Context) (*ReinitializationEvent, err
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("build request: %w", err)
|
||||
}
|
||||
req.Header[codersdk.SessionTokenHeader] = []string{c.SDK.SessionToken()}
|
||||
|
||||
res, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
|
||||
@@ -464,16 +464,3 @@ func ProtoFromDevcontainer(dc codersdk.WorkspaceAgentDevcontainer) *proto.Worksp
|
||||
SubagentId: subagentID,
|
||||
}
|
||||
}
|
||||
|
||||
func ProtoFromPatchAppStatus(pas PatchAppStatus) (*proto.UpdateAppStatusRequest, error) {
|
||||
state, ok := proto.UpdateAppStatusRequest_AppStatusState_value[strings.ToUpper(string(pas.State))]
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("Invalid state: %s", pas.State)
|
||||
}
|
||||
return &proto.UpdateAppStatusRequest{
|
||||
Slug: pas.AppSlug,
|
||||
State: proto.UpdateAppStatusRequest_AppStatusState(state),
|
||||
Message: pas.Message,
|
||||
Uri: pas.URI,
|
||||
}, nil
|
||||
}
|
||||
|
||||
+1
-3
@@ -94,8 +94,7 @@ func (c *Client) CreateAPIKey(ctx context.Context, user string) (GenerateAPIKeyR
|
||||
}
|
||||
|
||||
type TokensFilter struct {
|
||||
IncludeAll bool `json:"include_all"`
|
||||
IncludeExpired bool `json:"include_expired"`
|
||||
IncludeAll bool `json:"include_all"`
|
||||
}
|
||||
|
||||
type APIKeyWithOwner struct {
|
||||
@@ -113,7 +112,6 @@ func (f TokensFilter) asRequestOption() RequestOption {
|
||||
return func(r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
q.Set("include_all", fmt.Sprintf("%t", f.IncludeAll))
|
||||
q.Set("include_expired", fmt.Sprintf("%t", f.IncludeExpired))
|
||||
r.URL.RawQuery = q.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -535,14 +535,6 @@ func NewTestError(statusCode int, method string, u string) *Error {
|
||||
}
|
||||
}
|
||||
|
||||
// NewError creates a new Error with the response and status code.
|
||||
func NewError(statusCode int, response Response) *Error {
|
||||
return &Error{
|
||||
statusCode: statusCode,
|
||||
Response: response,
|
||||
}
|
||||
}
|
||||
|
||||
type closeFunc func() error
|
||||
|
||||
func (c closeFunc) Close() error {
|
||||
|
||||
@@ -3042,7 +3042,7 @@ func (c *DeploymentValues) Options() serpent.OptionSet {
|
||||
},
|
||||
{
|
||||
Name: "Disable Workspace Sharing",
|
||||
Description: `Disable workspace sharing. Workspace ACL checking is disabled and only owners can have ssh, apps and terminal access to workspaces. Access based on the 'owner' role is also allowed unless disabled via --disable-owner-workspace-access.`,
|
||||
Description: `Disable workspace sharing (requires the "workspace-sharing" experiment to be enabled). Workspace ACL checking is disabled and only owners can have ssh, apps and terminal access to workspaces. Access based on the 'owner' role is also allowed unless disabled via --disable-owner-workspace-access.`,
|
||||
Flag: "disable-workspace-sharing",
|
||||
Env: "CODER_DISABLE_WORKSPACE_SHARING",
|
||||
|
||||
@@ -4265,6 +4265,7 @@ const (
|
||||
ExperimentWebPush Experiment = "web-push" // Enables web push notifications through the browser.
|
||||
ExperimentOAuth2 Experiment = "oauth2" // Enables OAuth2 provider functionality.
|
||||
ExperimentMCPServerHTTP Experiment = "mcp-server-http" // Enables the MCP HTTP server functionality.
|
||||
ExperimentWorkspaceSharing Experiment = "workspace-sharing" // Enables updating workspace ACLs for sharing with users and groups.
|
||||
)
|
||||
|
||||
func (e Experiment) DisplayName() string {
|
||||
@@ -4283,6 +4284,8 @@ func (e Experiment) DisplayName() string {
|
||||
return "OAuth2 Provider Functionality"
|
||||
case ExperimentMCPServerHTTP:
|
||||
return "MCP HTTP Server Functionality"
|
||||
case ExperimentWorkspaceSharing:
|
||||
return "Workspace Sharing"
|
||||
default:
|
||||
// Split on hyphen and convert to title case
|
||||
// e.g. "web-push" -> "Web Push", "mcp-server-http" -> "Mcp Server Http"
|
||||
@@ -4300,6 +4303,7 @@ var ExperimentsKnown = Experiments{
|
||||
ExperimentWebPush,
|
||||
ExperimentOAuth2,
|
||||
ExperimentMCPServerHTTP,
|
||||
ExperimentWorkspaceSharing,
|
||||
}
|
||||
|
||||
// ExperimentsSafe should include all experiments that are safe for
|
||||
|
||||
+5
-1
@@ -217,7 +217,11 @@ const (
|
||||
)
|
||||
|
||||
func (e OAuth2ProviderResponseType) Valid() bool {
|
||||
return e == OAuth2ProviderResponseTypeCode || e == OAuth2ProviderResponseTypeToken
|
||||
switch e {
|
||||
case OAuth2ProviderResponseTypeCode, OAuth2ProviderResponseTypeToken:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type OAuth2TokenEndpointAuthMethod string
|
||||
|
||||
@@ -265,7 +265,7 @@ Bad Tasks
|
||||
Use the "state" field to indicate your progress. Periodically report
|
||||
progress with state "working" to keep the user updated. It is not possible to send too many updates!
|
||||
|
||||
ONLY report a "complete", "idle", or "failure" state if you have FULLY completed the task.
|
||||
ONLY report an "idle" or "failure" state if you have FULLY completed the task.
|
||||
`,
|
||||
Schema: aisdk.Schema{
|
||||
Properties: map[string]any{
|
||||
@@ -279,10 +279,9 @@ ONLY report a "complete", "idle", or "failure" state if you have FULLY completed
|
||||
},
|
||||
"state": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The state of your task. This can be one of the following: working, complete, idle, or failure. Select the state that best represents your current progress.",
|
||||
"description": "The state of your task. This can be one of the following: working, idle, or failure. Select the state that best represents your current progress.",
|
||||
"enum": []string{
|
||||
string(codersdk.WorkspaceAppStatusStateWorking),
|
||||
string(codersdk.WorkspaceAppStatusStateComplete),
|
||||
string(codersdk.WorkspaceAppStatusStateIdle),
|
||||
string(codersdk.WorkspaceAppStatusStateFailure),
|
||||
},
|
||||
|
||||
@@ -128,16 +128,9 @@ If client authentication fails, the token endpoint returns **HTTP 401** with an
|
||||
"$CODER_URL/api/v2/users/me"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The PKCE flow below is the **required** integration path. The example
|
||||
> above is shown for reference but omits the mandatory `code_challenge`
|
||||
> parameter. See [PKCE Flow](#pkce-flow-required) for the complete flow.
|
||||
### PKCE Flow (Recommended)
|
||||
|
||||
### PKCE Flow (Required)
|
||||
|
||||
PKCE is **required** for all OAuth2 authorization code flows. Coder enforces
|
||||
PKCE in compliance with the OAuth 2.1 specification. Both public and
|
||||
confidential clients must include PKCE parameters:
|
||||
Use PKCE for enhanced security (recommended for both public and confidential clients):
|
||||
|
||||
1. Generate a code verifier and challenge:
|
||||
|
||||
@@ -258,8 +251,7 @@ Verify that the `code_verifier` used in the token request matches the one used t
|
||||
## Security Considerations
|
||||
|
||||
- **Use HTTPS**: Always use HTTPS in production to protect tokens in transit
|
||||
- **Implement PKCE**: PKCE is mandatory for all authorization code clients
|
||||
(public and confidential)
|
||||
- **Implement PKCE**: Use PKCE for all public clients (mobile apps, SPAs)
|
||||
- **Validate redirect URLs**: Only register trusted redirect URIs for your applications
|
||||
- **Rotate secrets**: Periodically rotate client secrets using the management API
|
||||
|
||||
@@ -269,20 +261,11 @@ As an experimental feature, the current implementation has limitations:
|
||||
|
||||
- No scope system - all tokens have full API access
|
||||
- No client credentials grant support
|
||||
- Implicit grant (`response_type=token`) is not supported; OAuth 2.1
|
||||
deprecated this flow due to token leakage risks, and requests return
|
||||
`unsupported_response_type`
|
||||
- Limited to opaque access tokens (no JWT support)
|
||||
|
||||
## Standards Compliance
|
||||
|
||||
This implementation follows established OAuth2 standards including
|
||||
[RFC 6749](https://datatracker.ietf.org/doc/html/rfc6749) (OAuth2 core),
|
||||
[RFC 7636](https://datatracker.ietf.org/doc/html/rfc7636) (PKCE), and the
|
||||
[OAuth 2.1 draft](https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12).
|
||||
Coder enforces OAuth 2.1 requirements including mandatory PKCE for all
|
||||
authorization code grants, exact redirect URI string matching, rejection
|
||||
of the implicit grant, and CSRF protections on consent pages.
|
||||
This implementation follows established OAuth2 standards including [RFC 6749](https://datatracker.ietf.org/doc/html/rfc6749) (OAuth2 core), [RFC 7636](https://datatracker.ietf.org/doc/html/rfc7636) (PKCE), and related specifications for discovery and client registration.
|
||||
|
||||
## Next Steps
|
||||
|
||||
|
||||
@@ -175,9 +175,9 @@ deployment. They will always be available from the agent.
|
||||
| `coderd_dbpurge_iteration_duration_seconds` | histogram | Duration of each dbpurge iteration in seconds. | `success` |
|
||||
| `coderd_dbpurge_records_purged_total` | counter | Total number of records purged by type. | `record_type` |
|
||||
| `coderd_experiments` | gauge | Indicates whether each experiment is enabled (1) or not (0) | `experiment` |
|
||||
| `coderd_insights_applications_usage_seconds` | gauge | The application usage per template. | `application_name` `organization_name` `slug` `template_name` |
|
||||
| `coderd_insights_parameters` | gauge | The parameter usage per template. | `organization_name` `parameter_name` `parameter_type` `parameter_value` `template_name` |
|
||||
| `coderd_insights_templates_active_users` | gauge | The number of active users of the template. | `organization_name` `template_name` |
|
||||
| `coderd_insights_applications_usage_seconds` | gauge | The application usage per template. | `application_name` `slug` `template_name` |
|
||||
| `coderd_insights_parameters` | gauge | The parameter usage per template. | `parameter_name` `parameter_type` `parameter_value` `template_name` |
|
||||
| `coderd_insights_templates_active_users` | gauge | The number of active users of the template. | `template_name` |
|
||||
| `coderd_license_active_users` | gauge | The number of active users. | |
|
||||
| `coderd_license_errors` | gauge | The number of active license errors. | |
|
||||
| `coderd_license_limit_users` | gauge | The user seats limit based on the active Coder license. | |
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user