Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 19fb37a8f6 | |||
| 42f9fea0ae | |||
| 7a5bdac25b | |||
| fd3dcb8b97 | |||
| 59938256c3 | |||
| af6cbf89dc | |||
| cb270ca9f7 | |||
| 3c80616b1e | |||
| 4c95cee4cf | |||
| f6a69d5f38 | |||
| e660df8f6e | |||
| a2a0dd920b | |||
| d1c28c964c | |||
| 8f3bed6a99 | |||
| 2bd23fc233 | |||
| 88ed59eca8 | |||
| 8b46de43fe | |||
| cac89b714b | |||
| 0409f8d2a1 | |||
| 017a0caa4b | |||
| 87577b6a6c | |||
| 2dcd9a9388 | |||
| 811c189d1c |
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
|
||||
"github.com/coder/coder/v2/coderd/agentapi/resourcesmonitor"
|
||||
"github.com/coder/coder/v2/coderd/appearance"
|
||||
"github.com/coder/coder/v2/coderd/connectionlog"
|
||||
@@ -80,6 +81,7 @@ type Options struct {
|
||||
DerpMapFn func() *tailcfg.DERPMap
|
||||
TailnetCoordinator *atomic.Pointer[tailnet.Coordinator]
|
||||
StatsReporter *workspacestats.Reporter
|
||||
MetadataBatcher *metadatabatcher.Batcher
|
||||
AppearanceFetcher *atomic.Pointer[appearance.Fetcher]
|
||||
PublishWorkspaceUpdateFn func(ctx context.Context, userID uuid.UUID, event wspubsub.WorkspaceEvent)
|
||||
PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage)
|
||||
@@ -178,8 +180,8 @@ func New(opts Options, workspace database.Workspace) *API {
|
||||
AgentFn: api.agent,
|
||||
Workspace: api.cachedWorkspaceFields,
|
||||
Database: opts.Database,
|
||||
Pubsub: opts.Pubsub,
|
||||
Log: opts.Log,
|
||||
Batcher: opts.MetadataBatcher,
|
||||
}
|
||||
|
||||
api.LogsAPI = &LogsAPI{
|
||||
|
||||
@@ -2,27 +2,25 @@ package agentapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
)
|
||||
|
||||
type MetadataAPI struct {
|
||||
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
||||
Workspace *CachedWorkspaceFields
|
||||
Database database.Store
|
||||
Pubsub pubsub.Pubsub
|
||||
Log slog.Logger
|
||||
Batcher *metadatabatcher.Batcher
|
||||
|
||||
TimeNowFn func() time.Time // defaults to dbtime.Now()
|
||||
}
|
||||
@@ -122,21 +120,10 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
|
||||
)
|
||||
}
|
||||
|
||||
err = a.Database.UpdateWorkspaceAgentMetadata(rbacCtx, dbUpdate)
|
||||
// Use batcher to batch metadata updates.
|
||||
err = a.Batcher.Add(workspaceAgent.ID, dbUpdate.Key, dbUpdate.Value, dbUpdate.Error, dbUpdate.CollectedAt)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("update workspace agent metadata in database: %w", err)
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(WorkspaceAgentMetadataChannelPayload{
|
||||
CollectedAt: collectedAt,
|
||||
Keys: dbUpdate.Key,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal workspace agent metadata channel payload: %w", err)
|
||||
}
|
||||
err = a.Pubsub.Publish(WatchWorkspaceAgentMetadataChannel(workspaceAgent.ID), payload)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("publish workspace agent metadata: %w", err)
|
||||
return nil, xerrors.Errorf("add metadata to batcher: %w", err)
|
||||
}
|
||||
|
||||
// If the metadata keys were too large, we return an error so the agent can
|
||||
@@ -154,12 +141,3 @@ func ellipse(v string, n int) string {
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
type WorkspaceAgentMetadataChannelPayload struct {
|
||||
CollectedAt time.Time `json:"collected_at"`
|
||||
Keys []string `json:"keys"`
|
||||
}
|
||||
|
||||
func WatchWorkspaceAgentMetadataChannel(id uuid.UUID) string {
|
||||
return "workspace_agent_metadata:" + id.String()
|
||||
}
|
||||
|
||||
@@ -2,9 +2,6 @@ package agentapi_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -16,30 +13,14 @@ import (
|
||||
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
type fakePublisher struct {
|
||||
// Nil pointer to pass interface check.
|
||||
pubsub.Pubsub
|
||||
publishes [][]byte
|
||||
}
|
||||
|
||||
var _ pubsub.Pubsub = &fakePublisher{}
|
||||
|
||||
func (f *fakePublisher) Publish(_ string, message []byte) error {
|
||||
f.publishes = append(f.publishes, message)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestBatchUpdateMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -50,8 +31,24 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
pub := &fakePublisher{}
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
ps := pubsub.NewInMemory()
|
||||
reg := prometheus.NewRegistry()
|
||||
|
||||
// Mock the database calls that batcher will make when it flushes.
|
||||
store.EXPECT().
|
||||
BatchUpdateWorkspaceAgentMetadata(gomock.Any(), gomock.Any()).
|
||||
Return(nil).
|
||||
AnyTimes()
|
||||
|
||||
// Create a real batcher for the test
|
||||
batcher, err := metadatabatcher.NewBatcher(ctx, reg, store, ps,
|
||||
metadatabatcher.WithLogger(testutil.Logger(t)),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(batcher.Close)
|
||||
|
||||
now := dbtime.Now()
|
||||
req := &agentproto.BatchUpdateMetadataRequest{
|
||||
@@ -77,23 +74,13 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
Key: []string{req.Metadata[0].Key, req.Metadata[1].Key},
|
||||
Value: []string{req.Metadata[0].Result.Value, req.Metadata[1].Result.Value},
|
||||
Error: []string{req.Metadata[0].Result.Error, req.Metadata[1].Result.Error},
|
||||
// The value from the agent is ignored.
|
||||
CollectedAt: []time.Time{now, now},
|
||||
}).Return(nil)
|
||||
|
||||
api := &agentapi.MetadataAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbM,
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
Batcher: batcher,
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
@@ -102,21 +89,28 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
resp, err := api.BatchUpdateMetadata(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.BatchUpdateMetadataResponse{}, resp)
|
||||
|
||||
require.Equal(t, 1, len(pub.publishes))
|
||||
var gotEvent agentapi.WorkspaceAgentMetadataChannelPayload
|
||||
require.NoError(t, json.Unmarshal(pub.publishes[0], &gotEvent))
|
||||
require.Equal(t, agentapi.WorkspaceAgentMetadataChannelPayload{
|
||||
CollectedAt: now,
|
||||
Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key},
|
||||
}, gotEvent)
|
||||
})
|
||||
|
||||
t.Run("ExceededLength", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
pub := pubsub.NewInMemory()
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
ps := pubsub.NewInMemory()
|
||||
reg := prometheus.NewRegistry()
|
||||
|
||||
// Mock the database calls that batcher will make when it flushes.
|
||||
store.EXPECT().
|
||||
BatchUpdateWorkspaceAgentMetadata(gomock.Any(), gomock.Any()).
|
||||
Return(nil).
|
||||
AnyTimes()
|
||||
|
||||
batcher, err := metadatabatcher.NewBatcher(ctx, reg, store, ps,
|
||||
metadatabatcher.WithLogger(testutil.Logger(t)),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(batcher.Close)
|
||||
|
||||
almostLongValue := ""
|
||||
for i := 0; i < 2048; i++ {
|
||||
@@ -153,33 +147,13 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
Key: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key, req.Metadata[3].Key},
|
||||
Value: []string{
|
||||
almostLongValue,
|
||||
almostLongValue, // truncated
|
||||
"",
|
||||
"",
|
||||
},
|
||||
Error: []string{
|
||||
"",
|
||||
"value of 2049 bytes exceeded 2048 bytes",
|
||||
almostLongValue,
|
||||
"error of 2049 bytes exceeded 2048 bytes", // replaced
|
||||
},
|
||||
// The value from the agent is ignored.
|
||||
CollectedAt: []time.Time{now, now, now, now},
|
||||
}).Return(nil)
|
||||
|
||||
api := &agentapi.MetadataAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbM,
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
Batcher: batcher,
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
@@ -193,8 +167,23 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
t.Run("KeysTooLong", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
pub := pubsub.NewInMemory()
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
ps := pubsub.NewInMemory()
|
||||
reg := prometheus.NewRegistry()
|
||||
|
||||
// Mock the database calls that batcher will make when it flushes.
|
||||
store.EXPECT().
|
||||
BatchUpdateWorkspaceAgentMetadata(gomock.Any(), gomock.Any()).
|
||||
Return(nil).
|
||||
AnyTimes()
|
||||
|
||||
batcher, err := metadatabatcher.NewBatcher(ctx, reg, store, ps,
|
||||
metadatabatcher.WithLogger(testutil.Logger(t)),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(batcher.Close)
|
||||
|
||||
now := dbtime.Now()
|
||||
req := &agentproto.BatchUpdateMetadataRequest{
|
||||
@@ -232,594 +221,21 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
// No key 4.
|
||||
Key: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key},
|
||||
Value: []string{req.Metadata[0].Result.Value, req.Metadata[1].Result.Value, req.Metadata[2].Result.Value},
|
||||
Error: []string{req.Metadata[0].Result.Error, req.Metadata[1].Result.Error, req.Metadata[2].Result.Error},
|
||||
// The value from the agent is ignored.
|
||||
CollectedAt: []time.Time{now, now, now},
|
||||
}).Return(nil)
|
||||
|
||||
api := &agentapi.MetadataAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbM,
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
Batcher: batcher,
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
// Watch the pubsub for events.
|
||||
var (
|
||||
eventCount int64
|
||||
gotEvent agentapi.WorkspaceAgentMetadataChannelPayload
|
||||
)
|
||||
cancel, err := pub.Subscribe(agentapi.WatchWorkspaceAgentMetadataChannel(agent.ID), func(ctx context.Context, message []byte) {
|
||||
if atomic.AddInt64(&eventCount, 1) > 1 {
|
||||
return
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(message, &gotEvent))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
resp, err := api.BatchUpdateMetadata(context.Background(), req)
|
||||
// Should return error because keys are too long.
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "metadata keys of 6145 bytes exceeded 6144 bytes", err.Error())
|
||||
require.Nil(t, resp)
|
||||
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&eventCount))
|
||||
require.Equal(t, agentapi.WorkspaceAgentMetadataChannelPayload{
|
||||
CollectedAt: now,
|
||||
// No key 4.
|
||||
Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key},
|
||||
}, gotEvent)
|
||||
})
|
||||
|
||||
// Test RBAC fast path with valid RBAC object - should NOT call GetWorkspaceByAgentID
|
||||
// This test verifies that when a valid RBAC object is present in context, the dbauthz layer
|
||||
// uses the fast path and skips the GetWorkspaceByAgentID database call.
|
||||
t.Run("WorkspaceCached_SkipsDBCall", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
dbM = dbmock.NewMockStore(ctrl)
|
||||
pub = &fakePublisher{}
|
||||
now = dbtime.Now()
|
||||
// Set up consistent IDs that represent a valid workspace->agent relationship
|
||||
workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012")
|
||||
templateID = uuid.MustParse("aaaabbbb-cccc-dddd-eeee-ffffffff0000")
|
||||
ownerID = uuid.MustParse("87654321-4321-4321-4321-210987654321")
|
||||
orgID = uuid.MustParse("11111111-1111-1111-1111-111111111111")
|
||||
agentID = uuid.MustParse("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
|
||||
)
|
||||
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: agentID,
|
||||
// In a real scenario, this agent would belong to a resource in the workspace above
|
||||
}
|
||||
|
||||
req := &agentproto.BatchUpdateMetadataRequest{
|
||||
Metadata: []*agentproto.Metadata{
|
||||
{
|
||||
Key: "test_key",
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
CollectedAt: timestamppb.New(now.Add(-time.Second)),
|
||||
Age: 1,
|
||||
Value: "test_value",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Expect UpdateWorkspaceAgentMetadata to be called
|
||||
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
Key: []string{"test_key"},
|
||||
Value: []string{"test_value"},
|
||||
Error: []string{""},
|
||||
CollectedAt: []time.Time{now},
|
||||
}).Return(nil)
|
||||
|
||||
// DO NOT expect GetWorkspaceByAgentID - the fast path should skip this call
|
||||
// If GetWorkspaceByAgentID is called, the test will fail with "unexpected call"
|
||||
|
||||
// dbauthz will call Wrappers() to check for wrapped databases
|
||||
dbM.EXPECT().Wrappers().Return([]string{}).AnyTimes()
|
||||
|
||||
// Set up dbauthz to test the actual authorization layer
|
||||
auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||||
accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
|
||||
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
|
||||
accessControlStore.Store(&acs)
|
||||
|
||||
api := &agentapi.MetadataAPI{
|
||||
AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore),
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
api.Workspace.UpdateValues(database.Workspace{
|
||||
ID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
})
|
||||
|
||||
// Create roles with workspace permissions
|
||||
userRoles := rbac.Roles([]rbac.Role{
|
||||
{
|
||||
Identifier: rbac.RoleMember(),
|
||||
User: []rbac.Permission{
|
||||
{
|
||||
Negate: false,
|
||||
ResourceType: rbac.ResourceWorkspace.Type,
|
||||
Action: policy.WildcardSymbol,
|
||||
},
|
||||
},
|
||||
ByOrgID: map[string]rbac.OrgPermissions{
|
||||
orgID.String(): {
|
||||
Member: []rbac.Permission{
|
||||
{
|
||||
Negate: false,
|
||||
ResourceType: rbac.ResourceWorkspace.Type,
|
||||
Action: policy.WildcardSymbol,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
agentScope := rbac.WorkspaceAgentScope(rbac.WorkspaceAgentScopeParams{
|
||||
WorkspaceID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
TemplateID: templateID,
|
||||
VersionID: uuid.New(),
|
||||
})
|
||||
|
||||
ctx := dbauthz.As(context.Background(), rbac.Subject{
|
||||
Type: rbac.SubjectTypeUser,
|
||||
FriendlyName: "testuser",
|
||||
Email: "testuser@example.com",
|
||||
ID: ownerID.String(),
|
||||
Roles: userRoles,
|
||||
Groups: []string{orgID.String()},
|
||||
Scope: agentScope,
|
||||
}.WithCachedASTValue())
|
||||
|
||||
resp, err := api.BatchUpdateMetadata(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
})
|
||||
// Test RBAC slow path - invalid RBAC object should fall back to GetWorkspaceByAgentID
|
||||
// This test verifies that when the RBAC object has invalid IDs (nil UUIDs), the dbauthz layer
|
||||
// falls back to the slow path and calls GetWorkspaceByAgentID.
|
||||
t.Run("InvalidWorkspaceCached_RequiresDBCall", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
dbM = dbmock.NewMockStore(ctrl)
|
||||
pub = &fakePublisher{}
|
||||
now = dbtime.Now()
|
||||
workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012")
|
||||
templateID = uuid.MustParse("aaaabbbb-cccc-dddd-eeee-ffffffff0000")
|
||||
ownerID = uuid.MustParse("87654321-4321-4321-4321-210987654321")
|
||||
orgID = uuid.MustParse("11111111-1111-1111-1111-111111111111")
|
||||
agentID = uuid.MustParse("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")
|
||||
)
|
||||
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: agentID,
|
||||
}
|
||||
|
||||
req := &agentproto.BatchUpdateMetadataRequest{
|
||||
Metadata: []*agentproto.Metadata{
|
||||
{
|
||||
Key: "test_key",
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
CollectedAt: timestamppb.New(now.Add(-time.Second)),
|
||||
Age: 1,
|
||||
Value: "test_value",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// EXPECT GetWorkspaceByAgentID to be called because the RBAC fast path validation fails
|
||||
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agentID).Return(database.Workspace{
|
||||
ID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
}, nil)
|
||||
|
||||
// Expect UpdateWorkspaceAgentMetadata to be called after authorization
|
||||
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
Key: []string{"test_key"},
|
||||
Value: []string{"test_value"},
|
||||
Error: []string{""},
|
||||
CollectedAt: []time.Time{now},
|
||||
}).Return(nil)
|
||||
|
||||
// dbauthz will call Wrappers() to check for wrapped databases
|
||||
dbM.EXPECT().Wrappers().Return([]string{}).AnyTimes()
|
||||
|
||||
// Set up dbauthz to test the actual authorization layer
|
||||
auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||||
accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
|
||||
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
|
||||
accessControlStore.Store(&acs)
|
||||
|
||||
api := &agentapi.MetadataAPI{
|
||||
AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore),
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
// Create an invalid RBAC object with nil UUIDs for owner/org
|
||||
// This will fail dbauthz fast path validation and trigger GetWorkspaceByAgentID
|
||||
api.Workspace.UpdateValues(database.Workspace{
|
||||
ID: uuid.MustParse("cccccccc-cccc-cccc-cccc-cccccccccccc"),
|
||||
OwnerID: uuid.Nil, // Invalid: fails dbauthz fast path validation
|
||||
OrganizationID: uuid.Nil, // Invalid: fails dbauthz fast path validation
|
||||
})
|
||||
|
||||
// Create roles with workspace permissions
|
||||
userRoles := rbac.Roles([]rbac.Role{
|
||||
{
|
||||
Identifier: rbac.RoleMember(),
|
||||
User: []rbac.Permission{
|
||||
{
|
||||
Negate: false,
|
||||
ResourceType: rbac.ResourceWorkspace.Type,
|
||||
Action: policy.WildcardSymbol,
|
||||
},
|
||||
},
|
||||
ByOrgID: map[string]rbac.OrgPermissions{
|
||||
orgID.String(): {
|
||||
Member: []rbac.Permission{
|
||||
{
|
||||
Negate: false,
|
||||
ResourceType: rbac.ResourceWorkspace.Type,
|
||||
Action: policy.WildcardSymbol,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
agentScope := rbac.WorkspaceAgentScope(rbac.WorkspaceAgentScopeParams{
|
||||
WorkspaceID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
TemplateID: templateID,
|
||||
VersionID: uuid.New(),
|
||||
})
|
||||
|
||||
ctx := dbauthz.As(context.Background(), rbac.Subject{
|
||||
Type: rbac.SubjectTypeUser,
|
||||
FriendlyName: "testuser",
|
||||
Email: "testuser@example.com",
|
||||
ID: ownerID.String(),
|
||||
Roles: userRoles,
|
||||
Groups: []string{orgID.String()},
|
||||
Scope: agentScope,
|
||||
}.WithCachedASTValue())
|
||||
|
||||
resp, err := api.BatchUpdateMetadata(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
})
|
||||
|
||||
// Test RBAC slow path - no RBAC object in context
|
||||
// This test verifies that when no RBAC object is present in context, the dbauthz layer
|
||||
// falls back to the slow path and calls GetWorkspaceByAgentID.
|
||||
t.Run("WorkspaceNotCached_RequiresDBCall", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
dbM = dbmock.NewMockStore(ctrl)
|
||||
pub = &fakePublisher{}
|
||||
now = dbtime.Now()
|
||||
workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012")
|
||||
templateID = uuid.MustParse("aaaabbbb-cccc-dddd-eeee-ffffffff0000")
|
||||
ownerID = uuid.MustParse("87654321-4321-4321-4321-210987654321")
|
||||
orgID = uuid.MustParse("11111111-1111-1111-1111-111111111111")
|
||||
agentID = uuid.MustParse("dddddddd-dddd-dddd-dddd-dddddddddddd")
|
||||
)
|
||||
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: agentID,
|
||||
}
|
||||
|
||||
req := &agentproto.BatchUpdateMetadataRequest{
|
||||
Metadata: []*agentproto.Metadata{
|
||||
{
|
||||
Key: "test_key",
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
CollectedAt: timestamppb.New(now.Add(-time.Second)),
|
||||
Age: 1,
|
||||
Value: "test_value",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// EXPECT GetWorkspaceByAgentID to be called because no RBAC object is in context
|
||||
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agentID).Return(database.Workspace{
|
||||
ID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
}, nil)
|
||||
|
||||
// Expect UpdateWorkspaceAgentMetadata to be called after authorization
|
||||
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
Key: []string{"test_key"},
|
||||
Value: []string{"test_value"},
|
||||
Error: []string{""},
|
||||
CollectedAt: []time.Time{now},
|
||||
}).Return(nil)
|
||||
|
||||
// dbauthz will call Wrappers() to check for wrapped databases
|
||||
dbM.EXPECT().Wrappers().Return([]string{}).AnyTimes()
|
||||
|
||||
// Set up dbauthz to test the actual authorization layer
|
||||
auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||||
accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
|
||||
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
|
||||
accessControlStore.Store(&acs)
|
||||
|
||||
api := &agentapi.MetadataAPI{
|
||||
AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore),
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
// Create roles with workspace permissions
|
||||
userRoles := rbac.Roles([]rbac.Role{
|
||||
{
|
||||
Identifier: rbac.RoleMember(),
|
||||
User: []rbac.Permission{
|
||||
{
|
||||
Negate: false,
|
||||
ResourceType: rbac.ResourceWorkspace.Type,
|
||||
Action: policy.WildcardSymbol,
|
||||
},
|
||||
},
|
||||
ByOrgID: map[string]rbac.OrgPermissions{
|
||||
orgID.String(): {
|
||||
Member: []rbac.Permission{
|
||||
{
|
||||
Negate: false,
|
||||
ResourceType: rbac.ResourceWorkspace.Type,
|
||||
Action: policy.WildcardSymbol,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
agentScope := rbac.WorkspaceAgentScope(rbac.WorkspaceAgentScopeParams{
|
||||
WorkspaceID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
TemplateID: templateID,
|
||||
VersionID: uuid.New(),
|
||||
})
|
||||
|
||||
ctx := dbauthz.As(context.Background(), rbac.Subject{
|
||||
Type: rbac.SubjectTypeUser,
|
||||
FriendlyName: "testuser",
|
||||
Email: "testuser@example.com",
|
||||
ID: ownerID.String(),
|
||||
Roles: userRoles,
|
||||
Groups: []string{orgID.String()},
|
||||
Scope: agentScope,
|
||||
}.WithCachedASTValue())
|
||||
|
||||
resp, err := api.BatchUpdateMetadata(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
})
|
||||
|
||||
// Test cache refresh - AutostartSchedule updated
|
||||
// This test verifies that the cache refresh mechanism actually calls GetWorkspaceByID
|
||||
// and updates the cached workspace fields when the workspace is modified (e.g., autostart schedule changes).
|
||||
t.Run("CacheRefreshed_AutostartScheduleUpdated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
dbM = dbmock.NewMockStore(ctrl)
|
||||
pub = &fakePublisher{}
|
||||
now = dbtime.Now()
|
||||
mClock = quartz.NewMock(t)
|
||||
tickerTrap = mClock.Trap().TickerFunc("cache_refresh")
|
||||
|
||||
workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012")
|
||||
ownerID = uuid.MustParse("87654321-4321-4321-4321-210987654321")
|
||||
orgID = uuid.MustParse("11111111-1111-1111-1111-111111111111")
|
||||
templateID = uuid.MustParse("aaaabbbb-cccc-dddd-eeee-ffffffff0000")
|
||||
agentID = uuid.MustParse("ffffffff-ffff-ffff-ffff-ffffffffffff")
|
||||
)
|
||||
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: agentID,
|
||||
}
|
||||
|
||||
// Initial workspace - has Monday-Friday 9am autostart
|
||||
initialWorkspace := database.Workspace{
|
||||
ID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
TemplateID: templateID,
|
||||
Name: "my-workspace",
|
||||
OwnerUsername: "testuser",
|
||||
TemplateName: "test-template",
|
||||
AutostartSchedule: sql.NullString{Valid: true, String: "CRON_TZ=UTC 0 9 * * 1-5"},
|
||||
}
|
||||
|
||||
// Updated workspace - user changed autostart to 5pm and renamed workspace
|
||||
updatedWorkspace := database.Workspace{
|
||||
ID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
TemplateID: templateID,
|
||||
Name: "my-workspace-renamed", // Changed!
|
||||
OwnerUsername: "testuser",
|
||||
TemplateName: "test-template",
|
||||
AutostartSchedule: sql.NullString{Valid: true, String: "CRON_TZ=UTC 0 17 * * 1-5"}, // Changed!
|
||||
DormantAt: sql.NullTime{},
|
||||
}
|
||||
|
||||
req := &agentproto.BatchUpdateMetadataRequest{
|
||||
Metadata: []*agentproto.Metadata{
|
||||
{
|
||||
Key: "test_key",
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
CollectedAt: timestamppb.New(now.Add(-time.Second)),
|
||||
Age: 1,
|
||||
Value: "test_value",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// EXPECT GetWorkspaceByID to be called during cache refresh
|
||||
// This is the key assertion - proves the refresh mechanism is working
|
||||
dbM.EXPECT().GetWorkspaceByID(gomock.Any(), workspaceID).Return(updatedWorkspace, nil)
|
||||
|
||||
// API needs to fetch the agent when calling metadata update
|
||||
dbM.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(agent, nil)
|
||||
|
||||
// After refresh, metadata update should work with updated cache
|
||||
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(ctx context.Context, params database.UpdateWorkspaceAgentMetadataParams) error {
|
||||
require.Equal(t, agent.ID, params.WorkspaceAgentID)
|
||||
require.Equal(t, []string{"test_key"}, params.Key)
|
||||
require.Equal(t, []string{"test_value"}, params.Value)
|
||||
require.Equal(t, []string{""}, params.Error)
|
||||
require.Len(t, params.CollectedAt, 1)
|
||||
return nil
|
||||
},
|
||||
).AnyTimes()
|
||||
|
||||
// May call GetWorkspaceByAgentID if slow path is used before refresh
|
||||
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agentID).Return(updatedWorkspace, nil).AnyTimes()
|
||||
|
||||
// dbauthz will call Wrappers()
|
||||
dbM.EXPECT().Wrappers().Return([]string{}).AnyTimes()
|
||||
|
||||
// Set up dbauthz
|
||||
auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||||
accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
|
||||
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
|
||||
accessControlStore.Store(&acs)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create roles with workspace permissions
|
||||
userRoles := rbac.Roles([]rbac.Role{
|
||||
{
|
||||
Identifier: rbac.RoleMember(),
|
||||
User: []rbac.Permission{
|
||||
{
|
||||
Negate: false,
|
||||
ResourceType: rbac.ResourceWorkspace.Type,
|
||||
Action: policy.WildcardSymbol,
|
||||
},
|
||||
},
|
||||
ByOrgID: map[string]rbac.OrgPermissions{
|
||||
orgID.String(): {
|
||||
Member: []rbac.Permission{
|
||||
{
|
||||
Negate: false,
|
||||
ResourceType: rbac.ResourceWorkspace.Type,
|
||||
Action: policy.WildcardSymbol,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
agentScope := rbac.WorkspaceAgentScope(rbac.WorkspaceAgentScopeParams{
|
||||
WorkspaceID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
TemplateID: templateID,
|
||||
VersionID: uuid.New(),
|
||||
})
|
||||
|
||||
ctxWithActor := dbauthz.As(ctx, rbac.Subject{
|
||||
Type: rbac.SubjectTypeUser,
|
||||
FriendlyName: "testuser",
|
||||
Email: "testuser@example.com",
|
||||
ID: ownerID.String(),
|
||||
Roles: userRoles,
|
||||
Groups: []string{orgID.String()},
|
||||
Scope: agentScope,
|
||||
}.WithCachedASTValue())
|
||||
|
||||
// Create full API with cached workspace fields (initial state)
|
||||
api := agentapi.New(agentapi.Options{
|
||||
AuthenticatedCtx: ctxWithActor,
|
||||
AgentID: agentID,
|
||||
WorkspaceID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore),
|
||||
Log: testutil.Logger(t),
|
||||
Clock: mClock,
|
||||
Pubsub: pub,
|
||||
}, initialWorkspace) // Cache is initialized with 9am schedule and "my-workspace" name
|
||||
|
||||
// Wait for ticker to be set up and release it so it can fire
|
||||
tickerTrap.MustWait(ctx).MustRelease(ctx)
|
||||
tickerTrap.Close()
|
||||
|
||||
// Advance clock to trigger cache refresh and wait for it to complete
|
||||
_, aw := mClock.AdvanceNext()
|
||||
aw.MustWait(ctx)
|
||||
|
||||
// At this point, GetWorkspaceByID should have been called and cache updated
|
||||
// The cache now has the 5pm schedule and "my-workspace-renamed" name
|
||||
|
||||
// Now call metadata update to verify the refreshed cache works
|
||||
resp, err := api.MetadataAPI.BatchUpdateMetadata(ctxWithActor, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,429 @@
|
||||
package metadatabatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultMetadataBatchSize is the maximum number of metadata entries
|
||||
// (key-value pairs across all agents) to batch before forcing a flush.
|
||||
// With typical agents having 5-15 metadata keys, this accommodates
|
||||
// 30-100 agents per batch.
|
||||
defaultMetadataBatchSize = 500
|
||||
|
||||
// defaultChannelBufferMultiplier is the multiplier for the channel buffer size
|
||||
// relative to the batch size. A 5x multiplier provides significant headroom
|
||||
// for bursts while the batch is being flushed.
|
||||
defaultChannelBufferMultiplier = 5
|
||||
|
||||
// defaultMetadataFlushInterval is how frequently to flush batched metadata
|
||||
// updates to the database and pubsub. 5 seconds provides a good balance
|
||||
// between reducing database load and maintaining reasonable UI update
|
||||
// latency.
|
||||
defaultMetadataFlushInterval = 5 * time.Second
|
||||
|
||||
// maxPubsubPayloadSize is the maximum size of a single pubsub message.
|
||||
// PostgreSQL NOTIFY has an 8KB limit for the payload.
|
||||
maxPubsubPayloadSize = 8000 // Leave some headroom below 8192 bytes
|
||||
|
||||
// uuidBase64Size is the size of a base64-encoded UUID without padding.
|
||||
// A UUID is 16 bytes, which encodes to 22 base64 characters (16 * 4 / 3 rounded up).
|
||||
// We use RawStdEncoding (no padding) to maximize space efficiency.
|
||||
uuidBase64Size = 22
|
||||
|
||||
// maxAgentIDsPerChunk is the maximum number of agent IDs that can fit in a
|
||||
// single pubsub message. With 22 bytes per base64-encoded UUID and 8KB limit,
|
||||
// we can fit ~363 agent IDs per chunk (8000 / 22 = 363.6).
|
||||
maxAgentIDsPerChunk = maxPubsubPayloadSize / uuidBase64Size
|
||||
|
||||
// Timeout to use for the context created when flushing the final batch due to the top level context being 'Done'
|
||||
finalFlushTimeout = 15 * time.Second
|
||||
|
||||
// Channel to publish batch metadata updates to, each update contains a list of all Agent IDs that have an update in
|
||||
// the most recent batch
|
||||
MetadataBatchPubsubChannel = "workspace_agent_metadata_batch"
|
||||
|
||||
// flush reasons
|
||||
flushCapacity = "capacity"
|
||||
flushTicker = "scheduled"
|
||||
flushExit = "shutdown"
|
||||
)
|
||||
|
||||
// compositeKey uniquely identifies a metadata entry by agent ID and key name.
|
||||
type compositeKey struct {
|
||||
agentID uuid.UUID
|
||||
key string
|
||||
}
|
||||
|
||||
// value holds a single metadata key-value pair with its error state
|
||||
// and collection timestamp.
|
||||
type value struct {
|
||||
v string
|
||||
error string
|
||||
collectedAt time.Time
|
||||
}
|
||||
|
||||
// update represents a single metadata update to be batched.
|
||||
type update struct {
|
||||
compositeKey
|
||||
value
|
||||
}
|
||||
|
||||
// Batcher holds a buffer of agent metadata updates and periodically
|
||||
// flushes them to the database and pubsub. This reduces database write
|
||||
// frequency and pubsub publish rate.
|
||||
type Batcher struct {
|
||||
store database.Store
|
||||
ps pubsub.Pubsub
|
||||
log slog.Logger
|
||||
|
||||
// updateCh is the buffered channel that receives metadata updates from Add() calls.
|
||||
updateCh chan update
|
||||
|
||||
// batch holds the current batch being accumulated. Updates with the same
|
||||
// composite key are deduplicated, keeping only the most recent value.
|
||||
batch map[compositeKey]value
|
||||
batchSize int
|
||||
|
||||
// clock is used to create tickers and get the current time.
|
||||
clock quartz.Clock
|
||||
ticker *quartz.Ticker
|
||||
interval time.Duration
|
||||
|
||||
// ctx is the context for the batcher. Used to check if shutdown has begun.
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
|
||||
// metrics collects Prometheus metrics for the batcher.
|
||||
metrics *Metrics
|
||||
}
|
||||
|
||||
// Option is a functional option for configuring a Batcher.
|
||||
type Option func(b *Batcher)
|
||||
|
||||
func WithBatchSize(size int) Option {
|
||||
return func(b *Batcher) {
|
||||
b.batchSize = size
|
||||
}
|
||||
}
|
||||
|
||||
func WithInterval(d time.Duration) Option {
|
||||
return func(b *Batcher) {
|
||||
b.interval = d
|
||||
}
|
||||
}
|
||||
|
||||
func WithLogger(log slog.Logger) Option {
|
||||
return func(b *Batcher) {
|
||||
b.log = log
|
||||
}
|
||||
}
|
||||
|
||||
func WithClock(clock quartz.Clock) Option {
|
||||
return func(b *Batcher) {
|
||||
b.clock = clock
|
||||
}
|
||||
}
|
||||
|
||||
// NewBatcher creates a new Batcher and starts it. Here ctx controls the lifetime of the batcher, canceling it will
|
||||
// result in the Batcher exiting it's processing routine (run).
|
||||
func NewBatcher(ctx context.Context, reg prometheus.Registerer, store database.Store, ps pubsub.Pubsub, opts ...Option) (*Batcher, error) {
|
||||
b := &Batcher{
|
||||
store: store,
|
||||
ps: ps,
|
||||
metrics: NewMetrics(),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
b.log = slog.Logger{}
|
||||
b.clock = quartz.NewReal()
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(b)
|
||||
}
|
||||
|
||||
b.metrics.register(reg)
|
||||
|
||||
if b.interval == 0 {
|
||||
b.interval = defaultMetadataFlushInterval
|
||||
}
|
||||
|
||||
if b.batchSize == 0 {
|
||||
b.batchSize = defaultMetadataBatchSize
|
||||
}
|
||||
|
||||
if b.ticker == nil {
|
||||
b.ticker = b.clock.NewTicker(b.interval)
|
||||
}
|
||||
|
||||
// Create buffered channel with 5x batch size capacity
|
||||
channelSize := b.batchSize * defaultChannelBufferMultiplier
|
||||
b.updateCh = make(chan update, channelSize)
|
||||
|
||||
// Initialize batch map
|
||||
b.batch = make(map[compositeKey]value)
|
||||
|
||||
b.ctx, b.cancel = context.WithCancel(ctx)
|
||||
go func() {
|
||||
b.run(b.ctx)
|
||||
close(b.done)
|
||||
}()
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (b *Batcher) Close() {
|
||||
b.cancel()
|
||||
if b.ticker != nil {
|
||||
b.ticker.Stop()
|
||||
}
|
||||
// Wait for the run function to end, it may be sending one last batch.
|
||||
<-b.done
|
||||
}
|
||||
|
||||
// Add adds metadata updates for an agent to the batcher by writing to a
|
||||
// buffered channel. If the channel is full, updates are dropped. Updates
|
||||
// to the same metadata key for the same agent are deduplicated in the batch,
|
||||
// keeping only the value with the most recent collectedAt timestamp.
|
||||
func (b *Batcher) Add(agentID uuid.UUID, keys []string, values []string, errors []string, collectedAt []time.Time) error {
|
||||
if !(len(keys) == len(values) && len(values) == len(errors) && len(errors) == len(collectedAt)) {
|
||||
return xerrors.Errorf("invalid Add call, all inputs must have the same number of items; keys: %d, values: %d, errors: %d, collectedAt: %d", len(keys), len(values), len(errors), len(collectedAt))
|
||||
}
|
||||
|
||||
// Write each update to the channel. If the channel is full, drop the update.
|
||||
var u update
|
||||
droppedCount := 0
|
||||
for i := range keys {
|
||||
u.agentID = agentID
|
||||
u.key = keys[i]
|
||||
u.v = values[i]
|
||||
u.error = errors[i]
|
||||
u.collectedAt = collectedAt[i]
|
||||
|
||||
select {
|
||||
case b.updateCh <- u:
|
||||
// Successfully queued
|
||||
default:
|
||||
// Channel is full, drop this update
|
||||
droppedCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Log dropped keys if any were dropped.
|
||||
if droppedCount > 0 {
|
||||
b.log.Warn(context.Background(), "metadata channel at capacity, dropped updates",
|
||||
slog.F("agent_id", agentID),
|
||||
slog.F("channel_size", cap(b.updateCh)),
|
||||
slog.F("dropped_count", droppedCount),
|
||||
)
|
||||
if b.metrics != nil {
|
||||
b.metrics.droppedKeysTotal.Add(float64(droppedCount))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// processUpdate adds a metadata update to the batch with deduplication based on timestamp.
|
||||
func (b *Batcher) processUpdate(update update) {
|
||||
ck := compositeKey{
|
||||
agentID: update.agentID,
|
||||
key: update.key,
|
||||
}
|
||||
|
||||
// Check if key already exists and only update if new value is newer
|
||||
if existing, exists := b.batch[ck]; exists {
|
||||
if update.collectedAt.After(existing.collectedAt) {
|
||||
b.batch[ck] = value{
|
||||
v: update.v,
|
||||
error: update.error,
|
||||
collectedAt: update.collectedAt,
|
||||
}
|
||||
}
|
||||
// Else: existing value is newer or same, ignore this update
|
||||
return
|
||||
}
|
||||
|
||||
// New key, add to batch
|
||||
b.batch[ck] = value{
|
||||
v: update.v,
|
||||
error: update.error,
|
||||
collectedAt: update.collectedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// run runs the batcher loop, reading from the update channel and flushing
|
||||
// periodically or when the batch reaches capacity.
|
||||
func (b *Batcher) run(ctx context.Context) {
|
||||
flush := func(ctx context.Context, reason string) {
|
||||
if err := b.flush(ctx, reason); err != nil {
|
||||
// Don't error level log here, database errors here are inconvenient but very much possible.
|
||||
//nolint:gocritic
|
||||
b.log.Warn(context.Background(), "metadata flush failed",
|
||||
slog.F("err_msg", err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// nolint:gocritic // This is only ever used for one thing - updating agent metadata.
|
||||
authCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
for {
|
||||
select {
|
||||
case update := <-b.updateCh:
|
||||
b.processUpdate(update)
|
||||
|
||||
// Check if batch has reached capacity
|
||||
if len(b.batch) >= b.batchSize {
|
||||
flush(authCtx, flushCapacity)
|
||||
}
|
||||
|
||||
case <-b.ticker.C:
|
||||
flush(authCtx, flushTicker)
|
||||
|
||||
case <-ctx.Done():
|
||||
b.log.Debug(ctx, "context done, flushing before exit")
|
||||
|
||||
// We must create a new context here as the parent context is done.
|
||||
ctxTimeout, cancel := context.WithTimeout(context.Background(), finalFlushTimeout)
|
||||
defer cancel() //nolint:revive // We're returning, defer is fine.
|
||||
|
||||
// nolint:gocritic // This is only ever used for one thing - updating agent metadata.
|
||||
flush(dbauthz.AsSystemRestricted(ctxTimeout), flushExit)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// flush flushes the current batch to the database and pubsub.
|
||||
func (b *Batcher) flush(ctx context.Context, reason string) error {
|
||||
count := len(b.batch)
|
||||
|
||||
if count == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
b.log.Debug(ctx, "flushing metadata batch",
|
||||
slog.F("reason", reason),
|
||||
slog.F("count", count),
|
||||
)
|
||||
|
||||
// Convert batch map to parallel arrays for the batch query.
|
||||
// Also build map of agent IDs for per-agent metrics and pubsub.
|
||||
var (
|
||||
agentIDs = make([]uuid.UUID, 0, count)
|
||||
keys = make([]string, 0, count)
|
||||
values = make([]string, 0, count)
|
||||
errors = make([]string, 0, count)
|
||||
collectedAt = make([]time.Time, 0, count)
|
||||
agentKeys = make(map[uuid.UUID]int) // Track keys per agent for metrics
|
||||
)
|
||||
|
||||
for ck, mv := range b.batch {
|
||||
agentIDs = append(agentIDs, ck.agentID)
|
||||
keys = append(keys, ck.key)
|
||||
values = append(values, mv.v)
|
||||
errors = append(errors, mv.error)
|
||||
collectedAt = append(collectedAt, mv.collectedAt)
|
||||
agentKeys[ck.agentID]++
|
||||
}
|
||||
|
||||
// Batch has been processed into slices for our DB request, so we can clear it.
|
||||
b.batch = make(map[compositeKey]value)
|
||||
|
||||
// Record per-agent utilization metrics.
|
||||
if b.metrics != nil {
|
||||
for _, keyCount := range agentKeys {
|
||||
b.metrics.batchUtilization.Observe(float64(keyCount))
|
||||
}
|
||||
}
|
||||
|
||||
// Update the database with all metadata updates in a single query.
|
||||
err := b.store.BatchUpdateWorkspaceAgentMetadata(ctx, database.BatchUpdateWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: agentIDs,
|
||||
Key: keys,
|
||||
Value: values,
|
||||
Error: errors,
|
||||
CollectedAt: collectedAt,
|
||||
})
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
if database.IsQueryCanceledError(err) {
|
||||
b.log.Debug(ctx, "query canceled, skipping update of workspace agent metadata", slog.F("elapsed", elapsed))
|
||||
return err
|
||||
}
|
||||
b.log.Error(ctx, "error updating workspace agent metadata", slog.Error(err), slog.F("elapsed", elapsed))
|
||||
return err
|
||||
}
|
||||
|
||||
// Build list of unique agent IDs for pubsub notification.
|
||||
uniqueAgentIDs := make([]uuid.UUID, 0, len(agentKeys))
|
||||
for agentID := range agentKeys {
|
||||
uniqueAgentIDs = append(uniqueAgentIDs, agentID)
|
||||
}
|
||||
|
||||
// Publish agent IDs in chunks that fit within the pubsub size limit.
|
||||
b.publishAgentIDsInChunks(ctx, uniqueAgentIDs)
|
||||
|
||||
// Record successful batch size and flush duration after successful send/publish.
|
||||
if b.metrics != nil {
|
||||
b.metrics.batchSize.Observe(float64(count))
|
||||
b.metrics.metadataTotal.Add(float64(count))
|
||||
b.metrics.batchesTotal.WithLabelValues(reason).Inc()
|
||||
b.metrics.flushDuration.WithLabelValues(reason).Observe(time.Since(start).Seconds())
|
||||
}
|
||||
|
||||
elapsed = time.Since(start)
|
||||
b.log.Debug(ctx, "flush complete",
|
||||
slog.F("count", count),
|
||||
slog.F("elapsed", elapsed),
|
||||
slog.F("reason", reason),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// publishAgentIDsInChunks publishes agent IDs in chunks that fit within the
|
||||
// PostgreSQL NOTIFY 8KB payload size limit. Each UUID is base64-encoded
|
||||
// (without padding) and concatenated into a single string.
|
||||
func (b *Batcher) publishAgentIDsInChunks(ctx context.Context, agentIDs []uuid.UUID) {
|
||||
for i := 0; i < len(agentIDs); i += maxAgentIDsPerChunk {
|
||||
end := i + maxAgentIDsPerChunk
|
||||
if end > len(agentIDs) {
|
||||
end = len(agentIDs)
|
||||
}
|
||||
|
||||
chunk := agentIDs[i:end]
|
||||
|
||||
// Build payload by base64-encoding each UUID (without padding) and
|
||||
// concatenating them. This is UTF-8 safe for PostgreSQL NOTIFY.
|
||||
payload := make([]byte, 0, len(chunk)*uuidBase64Size)
|
||||
for _, agentID := range chunk {
|
||||
// Encode UUID bytes to base64 without padding (RawStdEncoding).
|
||||
// This produces exactly 22 characters per UUID.
|
||||
encoded := base64.RawStdEncoding.AppendEncode(payload, agentID[:])
|
||||
payload = encoded
|
||||
}
|
||||
|
||||
err := b.ps.Publish(MetadataBatchPubsubChannel, payload)
|
||||
if err != nil {
|
||||
b.log.Error(ctx, "failed to publish workspace agent metadata batch",
|
||||
slog.Error(err),
|
||||
slog.F("chunk_size", len(chunk)),
|
||||
slog.F("payload_size", len(payload)),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,94 @@
|
||||
package metadatabatcher
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
type Metrics struct {
|
||||
batchUtilization prometheus.Histogram
|
||||
droppedKeysTotal prometheus.Counter
|
||||
metadataTotal prometheus.Counter
|
||||
batchesTotal *prometheus.CounterVec
|
||||
batchSize prometheus.Histogram
|
||||
flushDuration *prometheus.HistogramVec
|
||||
}
|
||||
|
||||
func NewMetrics() *Metrics {
|
||||
// Native histogram configuration (matching provisionerdserver pattern).
|
||||
nativeHistogramOpts := func(opts prometheus.HistogramOpts) prometheus.HistogramOpts {
|
||||
opts.NativeHistogramBucketFactor = 1.1
|
||||
opts.NativeHistogramMaxBucketNumber = 100
|
||||
opts.NativeHistogramMinResetDuration = time.Hour
|
||||
opts.NativeHistogramZeroThreshold = 0
|
||||
opts.NativeHistogramMaxZeroThreshold = 0
|
||||
return opts
|
||||
}
|
||||
|
||||
return &Metrics{
|
||||
batchUtilization: prometheus.NewHistogram(nativeHistogramOpts(prometheus.HistogramOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "agentapi",
|
||||
Name: "metadata_batch_utilization",
|
||||
Help: "Number of metadata keys per agent in each flushed batch.",
|
||||
})),
|
||||
|
||||
droppedKeysTotal: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "agentapi",
|
||||
Name: "metadata_dropped_keys_total",
|
||||
Help: "Total number of metadata keys dropped due to capacity limits.",
|
||||
}),
|
||||
|
||||
batchesTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "agentapi",
|
||||
Name: "metadata_batches_total",
|
||||
Help: "Total number of metadata batches flushed.",
|
||||
}, []string{"reason"}),
|
||||
|
||||
metadataTotal: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "agentapi",
|
||||
Name: "metadata_flushed_total",
|
||||
Help: "Total number of unique metadatas flushed.",
|
||||
}),
|
||||
|
||||
batchSize: prometheus.NewHistogram(nativeHistogramOpts(prometheus.HistogramOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "agentapi",
|
||||
Name: "metadata_batch_size",
|
||||
Help: "Total number of metadata entries in each flushed batch.",
|
||||
})),
|
||||
|
||||
flushDuration: prometheus.NewHistogramVec(nativeHistogramOpts(prometheus.HistogramOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "agentapi",
|
||||
Name: "metadata_flush_duration_seconds",
|
||||
Help: "Time taken to flush metadata batch to database and pubsub.",
|
||||
}), []string{"reason"}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Metrics) Collectors() []prometheus.Collector {
|
||||
return []prometheus.Collector{
|
||||
m.batchUtilization,
|
||||
m.droppedKeysTotal,
|
||||
m.batchesTotal,
|
||||
m.metadataTotal,
|
||||
m.batchSize,
|
||||
m.flushDuration,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Metrics) register(reg prometheus.Registerer) {
|
||||
if reg != nil {
|
||||
reg.MustRegister(m.batchUtilization)
|
||||
reg.MustRegister(m.droppedKeysTotal)
|
||||
reg.MustRegister(m.batchesTotal)
|
||||
reg.MustRegister(m.metadataTotal)
|
||||
reg.MustRegister(m.batchSize)
|
||||
reg.MustRegister(m.flushDuration)
|
||||
}
|
||||
}
|
||||
+19
-1
@@ -43,6 +43,7 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
|
||||
_ "github.com/coder/coder/v2/coderd/apidoc" // Used for swagger docs.
|
||||
"github.com/coder/coder/v2/coderd/appearance"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
@@ -773,6 +774,19 @@ func New(options *Options) *API {
|
||||
AppStatBatchSize: workspaceapps.DefaultStatsDBReporterBatchSize,
|
||||
DisableDatabaseInserts: !options.DeploymentValues.StatsCollection.UsageStats.Enable.Value(),
|
||||
})
|
||||
|
||||
// Initialize the metadata batcher for batching agent metadata updates.
|
||||
api.metadataBatcher, err = metadatabatcher.NewBatcher(
|
||||
api.ctx,
|
||||
options.PrometheusRegistry,
|
||||
options.Database,
|
||||
options.Pubsub,
|
||||
metadatabatcher.WithLogger(options.Logger.Named("metadata_batcher")),
|
||||
)
|
||||
if err != nil {
|
||||
api.Logger.Fatal(context.Background(), "failed to initialize metadata batcher", slog.Error(err))
|
||||
}
|
||||
|
||||
workspaceAppsLogger := options.Logger.Named("workspaceapps")
|
||||
if options.WorkspaceAppsStatsCollectorOptions.Logger == nil {
|
||||
named := workspaceAppsLogger.Named("stats_collector")
|
||||
@@ -1852,7 +1866,8 @@ type API struct {
|
||||
healthCheckGroup *singleflight.Group[string, *healthsdk.HealthcheckReport]
|
||||
healthCheckCache atomic.Pointer[healthsdk.HealthcheckReport]
|
||||
|
||||
statsReporter *workspacestats.Reporter
|
||||
statsReporter *workspacestats.Reporter
|
||||
metadataBatcher *metadatabatcher.Batcher
|
||||
|
||||
Acquirer *provisionerdserver.Acquirer
|
||||
// dbRolluper rolls up template usage stats from raw agent and app
|
||||
@@ -1904,6 +1919,9 @@ func (api *API) Close() error {
|
||||
_ = (*coordinator).Close()
|
||||
}
|
||||
_ = api.statsReporter.Close()
|
||||
if api.metadataBatcher != nil {
|
||||
api.metadataBatcher.Close()
|
||||
}
|
||||
_ = api.NetworkTelemetryBatcher.Close()
|
||||
_ = api.OIDCConvertKeyCache.Close()
|
||||
_ = api.AppSigningKeyCache.Close()
|
||||
|
||||
@@ -1437,6 +1437,15 @@ func (q *querier) ArchiveUnusedTemplateVersions(ctx context.Context, arg databas
|
||||
return q.db.ArchiveUnusedTemplateVersions(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
|
||||
// Could be any workspace agent and checking auth to each workspace agent is overkill for
|
||||
// the purpose of this function.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceWorkspace.All()); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.BatchUpdateWorkspaceAgentMetadata(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg database.BatchUpdateWorkspaceLastUsedAtParams) error {
|
||||
// Could be any workspace and checking auth to each workspace is overkill for
|
||||
// the purpose of this function.
|
||||
|
||||
@@ -133,6 +133,13 @@ func (m queryMetricsStore) ArchiveUnusedTemplateVersions(ctx context.Context, ar
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.BatchUpdateWorkspaceAgentMetadata(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("BatchUpdateWorkspaceAgentMetadata").Observe(time.Since(start).Seconds())
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg database.BatchUpdateWorkspaceLastUsedAtParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.BatchUpdateWorkspaceLastUsedAt(ctx, arg)
|
||||
|
||||
@@ -132,6 +132,20 @@ func (mr *MockStoreMockRecorder) ArchiveUnusedTemplateVersions(ctx, arg any) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ArchiveUnusedTemplateVersions", reflect.TypeOf((*MockStore)(nil).ArchiveUnusedTemplateVersions), ctx, arg)
|
||||
}
|
||||
|
||||
// BatchUpdateWorkspaceAgentMetadata mocks base method.
|
||||
func (m *MockStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BatchUpdateWorkspaceAgentMetadata", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BatchUpdateWorkspaceAgentMetadata indicates an expected call of BatchUpdateWorkspaceAgentMetadata.
|
||||
func (mr *MockStoreMockRecorder) BatchUpdateWorkspaceAgentMetadata(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdateWorkspaceAgentMetadata", reflect.TypeOf((*MockStore)(nil).BatchUpdateWorkspaceAgentMetadata), ctx, arg)
|
||||
}
|
||||
|
||||
// BatchUpdateWorkspaceLastUsedAt mocks base method.
|
||||
func (m *MockStore) BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg database.BatchUpdateWorkspaceLastUsedAtParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -56,6 +56,7 @@ type sqlcQuerier interface {
|
||||
// Only unused template versions will be archived, which are any versions not
|
||||
// referenced by the latest build of a workspace.
|
||||
ArchiveUnusedTemplateVersions(ctx context.Context, arg ArchiveUnusedTemplateVersionsParams) ([]uuid.UUID, error)
|
||||
BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg BatchUpdateWorkspaceAgentMetadataParams) error
|
||||
BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg BatchUpdateWorkspaceLastUsedAtParams) error
|
||||
BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error
|
||||
BulkMarkNotificationMessagesFailed(ctx context.Context, arg BulkMarkNotificationMessagesFailedParams) (int64, error)
|
||||
|
||||
@@ -17878,6 +17878,47 @@ func (q *sqlQuerier) UpdateVolumeResourceMonitor(ctx context.Context, arg Update
|
||||
return err
|
||||
}
|
||||
|
||||
const batchUpdateWorkspaceAgentMetadata = `-- name: BatchUpdateWorkspaceAgentMetadata :exec
|
||||
WITH metadata AS (
|
||||
SELECT
|
||||
unnest($1::uuid[]) AS workspace_agent_id,
|
||||
unnest($2::text[]) AS key,
|
||||
unnest($3::text[]) AS value,
|
||||
unnest($4::text[]) AS error,
|
||||
unnest($5::timestamptz[]) AS collected_at
|
||||
)
|
||||
UPDATE
|
||||
workspace_agent_metadata wam
|
||||
SET
|
||||
value = m.value,
|
||||
error = m.error,
|
||||
collected_at = m.collected_at
|
||||
FROM
|
||||
metadata m
|
||||
WHERE
|
||||
wam.workspace_agent_id = m.workspace_agent_id
|
||||
AND wam.key = m.key
|
||||
`
|
||||
|
||||
type BatchUpdateWorkspaceAgentMetadataParams struct {
|
||||
WorkspaceAgentID []uuid.UUID `db:"workspace_agent_id" json:"workspace_agent_id"`
|
||||
Key []string `db:"key" json:"key"`
|
||||
Value []string `db:"value" json:"value"`
|
||||
Error []string `db:"error" json:"error"`
|
||||
CollectedAt []time.Time `db:"collected_at" json:"collected_at"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg BatchUpdateWorkspaceAgentMetadataParams) error {
|
||||
_, err := q.db.ExecContext(ctx, batchUpdateWorkspaceAgentMetadata,
|
||||
pq.Array(arg.WorkspaceAgentID),
|
||||
pq.Array(arg.Key),
|
||||
pq.Array(arg.Value),
|
||||
pq.Array(arg.Error),
|
||||
pq.Array(arg.CollectedAt),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteOldWorkspaceAgentLogs = `-- name: DeleteOldWorkspaceAgentLogs :execrows
|
||||
WITH
|
||||
latest_builds AS (
|
||||
|
||||
@@ -142,6 +142,27 @@ WHERE
|
||||
wam.workspace_agent_id = $1
|
||||
AND wam.key = m.key;
|
||||
|
||||
-- name: BatchUpdateWorkspaceAgentMetadata :exec
|
||||
WITH metadata AS (
|
||||
SELECT
|
||||
unnest(sqlc.arg('workspace_agent_id')::uuid[]) AS workspace_agent_id,
|
||||
unnest(sqlc.arg('key')::text[]) AS key,
|
||||
unnest(sqlc.arg('value')::text[]) AS value,
|
||||
unnest(sqlc.arg('error')::text[]) AS error,
|
||||
unnest(sqlc.arg('collected_at')::timestamptz[]) AS collected_at
|
||||
)
|
||||
UPDATE
|
||||
workspace_agent_metadata wam
|
||||
SET
|
||||
value = m.value,
|
||||
error = m.error,
|
||||
collected_at = m.collected_at
|
||||
FROM
|
||||
metadata m
|
||||
WHERE
|
||||
wam.workspace_agent_id = m.workspace_agent_id
|
||||
AND wam.key = m.key;
|
||||
|
||||
-- name: GetWorkspaceAgentMetadata :many
|
||||
SELECT
|
||||
*
|
||||
|
||||
+55
-35
@@ -3,6 +3,7 @@ package coderd
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -24,7 +25,7 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
@@ -1726,34 +1727,69 @@ func (api *API) watchWorkspaceAgentMetadata(
|
||||
|
||||
// Send metadata on updates, we must ensure subscription before sending
|
||||
// initial metadata to guarantee that events in-between are not missed.
|
||||
update := make(chan agentapi.WorkspaceAgentMetadataChannelPayload, 1)
|
||||
cancelSub, err := api.Pubsub.Subscribe(agentapi.WatchWorkspaceAgentMetadataChannel(workspaceAgent.ID), func(_ context.Context, byt []byte) {
|
||||
// The channel carries no data - it's just a signal to fetch all metadata.
|
||||
update := make(chan struct{}, 1)
|
||||
|
||||
// Subscribe to the global batched metadata channel.
|
||||
// The batcher publishes only to this channel to achieve O(1) NOTIFY scaling.
|
||||
cancelBatchSub, err := api.Pubsub.Subscribe(metadatabatcher.MetadataBatchPubsubChannel, func(_ context.Context, byt []byte) {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var payload agentapi.WorkspaceAgentMetadataChannelPayload
|
||||
err := json.Unmarshal(byt, &payload)
|
||||
if err != nil {
|
||||
log.Error(ctx, "failed to unmarshal pubsub message", slog.Error(err))
|
||||
// Parse the payload as concatenated base64-encoded UUIDs (22 chars each, no padding).
|
||||
// Each UUID is base64-encoded without padding (RawStdEncoding) to 22 characters.
|
||||
if len(byt)%22 != 0 {
|
||||
log.Error(ctx, "invalid batched pubsub message size", slog.F("size", len(byt)))
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug(ctx, "received metadata update", slog.F("payload", payload))
|
||||
// Check if this agent is in the batch update.
|
||||
for i := 0; i < len(byt); i += 22 {
|
||||
// Decode the 22-character base64 string back to 16 bytes.
|
||||
var agentIDBytes [16]byte
|
||||
n, err := base64.RawStdEncoding.Decode(agentIDBytes[:], byt[i:i+22])
|
||||
if err != nil || n != 16 {
|
||||
log.Error(ctx, "failed to decode agent ID from batch message",
|
||||
slog.Error(err),
|
||||
slog.F("offset", i),
|
||||
slog.F("decoded_bytes", n),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case prev := <-update:
|
||||
payload.Keys = appendUnique(prev.Keys, payload.Keys)
|
||||
default:
|
||||
agentID, err := uuid.FromBytes(agentIDBytes[:])
|
||||
if err != nil {
|
||||
log.Error(ctx, "invalid agent ID bytes", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
if agentID != workspaceAgent.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug(ctx, "received metadata update from batch channel", slog.F("agent_id", agentID), slog.F("batch_size", len(byt)/22))
|
||||
|
||||
// Signal to re-fetch all metadata for this agent.
|
||||
// Batch notifications don't include which keys changed, so we
|
||||
// always fetch all keys for this agent.
|
||||
|
||||
// Clear any pending signals - batch always means "fetch all".
|
||||
select {
|
||||
case <-update:
|
||||
default:
|
||||
}
|
||||
// This can never block since we drained beforehand.
|
||||
// Send empty struct as signal to fetch all metadata.
|
||||
update <- struct{}{}
|
||||
break
|
||||
}
|
||||
// This can never block since we pop and merge beforehand.
|
||||
update <- payload
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
defer cancelSub()
|
||||
defer cancelBatchSub()
|
||||
|
||||
// We always use the original Request context because it contains
|
||||
// the RBAC actor.
|
||||
@@ -1837,10 +1873,11 @@ func (api *API) watchWorkspaceAgentMetadata(
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case payload := <-update:
|
||||
case <-update:
|
||||
// Batch notification received - fetch all metadata for this agent.
|
||||
md, err := api.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: workspaceAgent.ID,
|
||||
Keys: payload.Keys,
|
||||
Keys: nil, // nil means fetch all keys
|
||||
})
|
||||
if err != nil {
|
||||
if !database.IsQueryCanceledError(err) {
|
||||
@@ -1861,9 +1898,7 @@ func (api *API) watchWorkspaceAgentMetadata(
|
||||
// We want to block here to avoid constantly pinging the
|
||||
// database when the metadata isn't being processed.
|
||||
case fetchedMetadata <- md:
|
||||
log.Debug(ctx, "fetched metadata update for keys",
|
||||
slog.F("keys", payload.Keys),
|
||||
slog.F("num", len(md)))
|
||||
log.Debug(ctx, "fetched all metadata after batch update", slog.F("num", len(md)))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1899,21 +1934,6 @@ func (api *API) watchWorkspaceAgentMetadata(
|
||||
}
|
||||
}
|
||||
|
||||
// appendUnique is like append and adds elements from src to dst,
|
||||
// skipping any elements that already exist in dst.
|
||||
func appendUnique[T comparable](dst, src []T) []T {
|
||||
exists := make(map[T]struct{}, len(dst))
|
||||
for _, key := range dst {
|
||||
exists[key] = struct{}{}
|
||||
}
|
||||
for _, key := range src {
|
||||
if _, ok := exists[key]; !ok {
|
||||
dst = append(dst, key)
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func convertWorkspaceAgentMetadata(db []database.WorkspaceAgentMetadatum) []codersdk.WorkspaceAgentMetadata {
|
||||
// Sort the input database slice by DisplayOrder and then by Key before processing
|
||||
sort.Slice(db, func(i, j int) bool {
|
||||
|
||||
@@ -143,6 +143,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
|
||||
TailnetCoordinator: &api.TailnetCoordinator,
|
||||
AppearanceFetcher: &api.AppearanceFetcher,
|
||||
StatsReporter: api.statsReporter,
|
||||
MetadataBatcher: api.metadataBatcher,
|
||||
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
|
||||
PublishWorkspaceAgentLogsUpdateFn: api.publishWorkspaceAgentLogsUpdate,
|
||||
NetworkTelemetryHandler: api.NetworkTelemetryBatcher.Handler,
|
||||
|
||||
Reference in New Issue
Block a user