Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 27a1608782 | |||
| e89f8e1134 | |||
| 5b13118a18 |
@@ -177,7 +177,6 @@ func New(opts Options, workspace database.Workspace) *API {
|
||||
AgentFn: api.agent,
|
||||
Workspace: api.cachedWorkspaceFields,
|
||||
Database: opts.Database,
|
||||
Pubsub: opts.Pubsub,
|
||||
Log: opts.Log,
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package agentapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -14,14 +13,12 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
)
|
||||
|
||||
type MetadataAPI struct {
|
||||
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
||||
Workspace *CachedWorkspaceFields
|
||||
Database database.Store
|
||||
Pubsub pubsub.Pubsub
|
||||
Log slog.Logger
|
||||
|
||||
TimeNowFn func() time.Time // defaults to dbtime.Now()
|
||||
@@ -127,18 +124,6 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
|
||||
return nil, xerrors.Errorf("update workspace agent metadata in database: %w", err)
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(WorkspaceAgentMetadataChannelPayload{
|
||||
CollectedAt: collectedAt,
|
||||
Keys: dbUpdate.Key,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal workspace agent metadata channel payload: %w", err)
|
||||
}
|
||||
err = a.Pubsub.Publish(WatchWorkspaceAgentMetadataChannel(workspaceAgent.ID), payload)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("publish workspace agent metadata: %w", err)
|
||||
}
|
||||
|
||||
// If the metadata keys were too large, we return an error so the agent can
|
||||
// log it.
|
||||
if allKeysLen > maxAllKeysLen {
|
||||
|
||||
@@ -3,7 +3,6 @@ package agentapi_test
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -20,26 +19,12 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
type fakePublisher struct {
|
||||
// Nil pointer to pass interface check.
|
||||
pubsub.Pubsub
|
||||
publishes [][]byte
|
||||
}
|
||||
|
||||
var _ pubsub.Pubsub = &fakePublisher{}
|
||||
|
||||
func (f *fakePublisher) Publish(_ string, message []byte) error {
|
||||
f.publishes = append(f.publishes, message)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestBatchUpdateMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -51,7 +36,6 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
pub := &fakePublisher{}
|
||||
|
||||
now := dbtime.Now()
|
||||
req := &agentproto.BatchUpdateMetadataRequest{
|
||||
@@ -92,7 +76,6 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
},
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbM,
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
@@ -102,21 +85,12 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
resp, err := api.BatchUpdateMetadata(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.BatchUpdateMetadataResponse{}, resp)
|
||||
|
||||
require.Equal(t, 1, len(pub.publishes))
|
||||
var gotEvent agentapi.WorkspaceAgentMetadataChannelPayload
|
||||
require.NoError(t, json.Unmarshal(pub.publishes[0], &gotEvent))
|
||||
require.Equal(t, agentapi.WorkspaceAgentMetadataChannelPayload{
|
||||
CollectedAt: now,
|
||||
Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key},
|
||||
}, gotEvent)
|
||||
})
|
||||
|
||||
t.Run("ExceededLength", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
pub := pubsub.NewInMemory()
|
||||
|
||||
almostLongValue := ""
|
||||
for i := 0; i < 2048; i++ {
|
||||
@@ -178,7 +152,6 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
},
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbM,
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
@@ -194,7 +167,6 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
pub := pubsub.NewInMemory()
|
||||
|
||||
now := dbtime.Now()
|
||||
req := &agentproto.BatchUpdateMetadataRequest{
|
||||
@@ -248,38 +220,16 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
},
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbM,
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
// Watch the pubsub for events.
|
||||
var (
|
||||
eventCount int64
|
||||
gotEvent agentapi.WorkspaceAgentMetadataChannelPayload
|
||||
)
|
||||
cancel, err := pub.Subscribe(agentapi.WatchWorkspaceAgentMetadataChannel(agent.ID), func(ctx context.Context, message []byte) {
|
||||
if atomic.AddInt64(&eventCount, 1) > 1 {
|
||||
return
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(message, &gotEvent))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
resp, err := api.BatchUpdateMetadata(context.Background(), req)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "metadata keys of 6145 bytes exceeded 6144 bytes", err.Error())
|
||||
require.Nil(t, resp)
|
||||
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&eventCount))
|
||||
require.Equal(t, agentapi.WorkspaceAgentMetadataChannelPayload{
|
||||
CollectedAt: now,
|
||||
// No key 4.
|
||||
Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key},
|
||||
}, gotEvent)
|
||||
})
|
||||
|
||||
// Test RBAC fast path with valid RBAC object - should NOT call GetWorkspaceByAgentID
|
||||
@@ -291,7 +241,6 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
dbM = dbmock.NewMockStore(ctrl)
|
||||
pub = &fakePublisher{}
|
||||
now = dbtime.Now()
|
||||
// Set up consistent IDs that represent a valid workspace->agent relationship
|
||||
workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012")
|
||||
@@ -346,7 +295,6 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
},
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore),
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
@@ -414,7 +362,6 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
dbM = dbmock.NewMockStore(ctrl)
|
||||
pub = &fakePublisher{}
|
||||
now = dbtime.Now()
|
||||
workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012")
|
||||
templateID = uuid.MustParse("aaaabbbb-cccc-dddd-eeee-ffffffff0000")
|
||||
@@ -469,10 +416,8 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore),
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
@@ -543,7 +488,6 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
dbM = dbmock.NewMockStore(ctrl)
|
||||
pub = &fakePublisher{}
|
||||
now = dbtime.Now()
|
||||
workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012")
|
||||
templateID = uuid.MustParse("aaaabbbb-cccc-dddd-eeee-ffffffff0000")
|
||||
@@ -600,7 +544,6 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
},
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore),
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
@@ -663,7 +606,6 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
dbM = dbmock.NewMockStore(ctrl)
|
||||
pub = &fakePublisher{}
|
||||
now = dbtime.Now()
|
||||
mClock = quartz.NewMock(t)
|
||||
tickerTrap = mClock.Trap().TickerFunc("cache_refresh")
|
||||
@@ -803,7 +745,6 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore),
|
||||
Log: testutil.Logger(t),
|
||||
Clock: mClock,
|
||||
Pubsub: pub,
|
||||
}, initialWorkspace) // Cache is initialized with 9am schedule and "my-workspace" name
|
||||
|
||||
// Wait for ticker to be set up and release it so it can fire
|
||||
|
||||
+48
-131
@@ -18,7 +18,6 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -26,7 +25,6 @@ import (
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
@@ -1635,60 +1633,18 @@ func (api *API) watchWorkspaceAgentMetadata(
|
||||
slog.F("workspace_agent_id", workspaceAgent.ID),
|
||||
)
|
||||
|
||||
// Send metadata on updates, we must ensure subscription before sending
|
||||
// initial metadata to guarantee that events in-between are not missed.
|
||||
update := make(chan agentapi.WorkspaceAgentMetadataChannelPayload, 1)
|
||||
cancelSub, err := api.Pubsub.Subscribe(agentapi.WatchWorkspaceAgentMetadataChannel(workspaceAgent.ID), func(_ context.Context, byt []byte) {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var payload agentapi.WorkspaceAgentMetadataChannelPayload
|
||||
err := json.Unmarshal(byt, &payload)
|
||||
if err != nil {
|
||||
log.Error(ctx, "failed to unmarshal pubsub message", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug(ctx, "received metadata update", "payload", payload)
|
||||
|
||||
select {
|
||||
case prev := <-update:
|
||||
payload.Keys = appendUnique(prev.Keys, payload.Keys)
|
||||
default:
|
||||
}
|
||||
// This can never block since we pop and merge beforehand.
|
||||
update <- payload
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
defer cancelSub()
|
||||
|
||||
// We always use the original Request context because it contains
|
||||
// the RBAC actor.
|
||||
// Fetch initial metadata.
|
||||
initialMD, err := api.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: workspaceAgent.ID,
|
||||
Keys: nil,
|
||||
})
|
||||
if err != nil {
|
||||
// If we can't successfully pull the initial metadata, pubsub
|
||||
// updates will be no-op so we may as well terminate the
|
||||
// connection early.
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug(ctx, "got initial metadata", "num", len(initialMD))
|
||||
|
||||
metadataMap := make(map[string]database.WorkspaceAgentMetadatum, len(initialMD))
|
||||
for _, datum := range initialMD {
|
||||
metadataMap[datum.Key] = datum
|
||||
}
|
||||
//nolint:ineffassign // Release memory.
|
||||
initialMD = nil
|
||||
|
||||
sendEvent, senderClosed, err := connect(rw, r)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
@@ -1712,115 +1668,76 @@ func (api *API) watchWorkspaceAgentMetadata(
|
||||
}
|
||||
}()
|
||||
|
||||
var lastSend time.Time
|
||||
sendMetadata := func() {
|
||||
lastSend = time.Now()
|
||||
values := maps.Values(metadataMap)
|
||||
|
||||
log.Debug(ctx, "sending metadata", "num", len(values))
|
||||
|
||||
sendMetadata := func(md []database.WorkspaceAgentMetadatum) {
|
||||
log.Debug(ctx, "sending metadata", "num", len(md))
|
||||
_ = sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: convertWorkspaceAgentMetadata(values),
|
||||
Data: convertWorkspaceAgentMetadata(md),
|
||||
})
|
||||
}
|
||||
|
||||
// We send updates exactly every second.
|
||||
const sendInterval = time.Second * 1
|
||||
sendTicker := time.NewTicker(sendInterval)
|
||||
defer sendTicker.Stop()
|
||||
|
||||
// Log the request immediately instead of after it completes.
|
||||
if rl := loggermw.RequestLoggerFromContext(ctx); rl != nil {
|
||||
rl.WriteLog(ctx, http.StatusAccepted)
|
||||
}
|
||||
|
||||
// Send initial metadata.
|
||||
sendMetadata()
|
||||
sendMetadata(initialMD)
|
||||
|
||||
// Fetch updated metadata keys as they come in.
|
||||
fetchedMetadata := make(chan []database.WorkspaceAgentMetadatum)
|
||||
go func() {
|
||||
defer close(fetchedMetadata)
|
||||
defer cancel()
|
||||
// If no metadata exists, don't start the poll loop.
|
||||
if len(initialMD) == 0 {
|
||||
log.Debug(ctx, "no metadata to poll, skipping poll loop")
|
||||
<-ctx.Done()
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case payload := <-update:
|
||||
md, err := api.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: workspaceAgent.ID,
|
||||
Keys: payload.Keys,
|
||||
})
|
||||
if err != nil {
|
||||
if !database.IsQueryCanceledError(err) {
|
||||
log.Error(ctx, "failed to get metadata", slog.Error(err))
|
||||
_ = sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Failed to get metadata.",
|
||||
Detail: err.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
// We want to block here to avoid constantly pinging the
|
||||
// database when the metadata isn't being processed.
|
||||
case fetchedMetadata <- md:
|
||||
log.Debug(ctx, "fetched metadata update for keys", "keys", payload.Keys, "num", len(md))
|
||||
}
|
||||
}
|
||||
// Calculate poll interval as the minimum interval from all metadata items.
|
||||
var pollInterval time.Duration
|
||||
for _, md := range initialMD {
|
||||
interval := time.Duration(md.Interval)
|
||||
if interval > 0 && (pollInterval == 0 || interval < pollInterval) {
|
||||
pollInterval = interval
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
<-fetchedMetadata
|
||||
}()
|
||||
}
|
||||
|
||||
// If all metadata items have zero intervals, log an error and don't start the loop.
|
||||
if pollInterval == 0 {
|
||||
log.Error(ctx, "all metadata items have zero intervals, skipping poll loop")
|
||||
<-ctx.Done()
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug(ctx, "starting metadata poll loop", slog.F("poll_interval", pollInterval))
|
||||
|
||||
pollTicker := time.NewTicker(pollInterval)
|
||||
defer pollTicker.Stop()
|
||||
|
||||
pendingChanges := true
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case md, ok := <-fetchedMetadata:
|
||||
if !ok {
|
||||
case <-pollTicker.C:
|
||||
md, err := api.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: workspaceAgent.ID,
|
||||
Keys: nil,
|
||||
})
|
||||
if err != nil {
|
||||
if database.IsQueryCanceledError(err) {
|
||||
return
|
||||
}
|
||||
log.Error(ctx, "failed to get metadata", slog.Error(err))
|
||||
_ = sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Failed to get metadata.",
|
||||
Detail: err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
for _, datum := range md {
|
||||
metadataMap[datum.Key] = datum
|
||||
}
|
||||
pendingChanges = true
|
||||
continue
|
||||
case <-sendTicker.C:
|
||||
// We send an update even if there's no change every 5 seconds
|
||||
// to ensure that the frontend always has an accurate "Result.Age".
|
||||
if !pendingChanges && time.Since(lastSend) < 5*time.Second {
|
||||
continue
|
||||
}
|
||||
pendingChanges = false
|
||||
}
|
||||
|
||||
sendMetadata()
|
||||
}
|
||||
}
|
||||
|
||||
// appendUnique is like append and adds elements from src to dst,
|
||||
// skipping any elements that already exist in dst.
|
||||
func appendUnique[T comparable](dst, src []T) []T {
|
||||
exists := make(map[T]struct{}, len(dst))
|
||||
for _, key := range dst {
|
||||
exists[key] = struct{}{}
|
||||
}
|
||||
for _, key := range src {
|
||||
if _, ok := exists[key]; !ok {
|
||||
dst = append(dst, key)
|
||||
sendMetadata(md)
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func convertWorkspaceAgentMetadata(db []database.WorkspaceAgentMetadatum) []codersdk.WorkspaceAgentMetadata {
|
||||
|
||||
Reference in New Issue
Block a user