perf(coderd): reduce duplicated reads in push and webpush paths (#23115)
## Background A 5000-chat scaletest (~50k turns, ~2m45s wall time) completed successfully, but the main bottleneck was **DB pool starvation from repeated reads**, not individually expensive SQL. The push/webpush path showed a few especially noisy reads: - `GetLastChatMessageByRole` for push body generation - `GetEnabledChatProviders` + `GetChatModelConfigByID` for push summary model resolution - `GetWebpushSubscriptionsByUserID` for every webpush dispatch This PR keeps the optimizations that remove those duplicate reads while leaving stream behavior unchanged. ## What changes in this PR ### 1. Reuse resolved chat state for push notifications `maybeSendPushNotification` used to re-read the last assistant message and re-resolve the chat model/provider after `runChat` had already done that work. Now `runChat` returns the final assistant text plus the already-resolved model and provider keys, and the push goroutine uses that state directly. That removes the extra push-path reads for: - `GetLastChatMessageByRole` - the second `resolveChatModel` path - the provider/model lookups that came with that second resolution ### 2. Cache webpush subscriptions during dispatch `Dispatch()` previously hit `GetWebpushSubscriptionsByUserID` on every push. A small per-user in-memory cache now avoids those repeated reads. The follow-up fix keeps that optimization correct: `InvalidateUser()` bumps a per-user generation so an older in-flight fetch cannot repopulate the cache with pre-mutation data after subscribe/unsubscribe. That preserves the cache win without letting local subscription changes be silently overwritten by stale fetch results. ## Why this is safe - The push change only reuses data already produced during the same chat run. It does not change notification semantics; if there is no assistant text to summarize, the existing fallback body still applies. - The webpush change keeps the existing TTL and `410 Gone` cleanup behavior. The generation guard only prevents stale in-flight fetches from poisoning the shared cache after invalidation. - The final PR does **not** change stream setup, pubsub/relay behavior, or chat status snapshot timing. ## Deliberately not included - No stream-path optimization in `Subscribe`. - No inline pubsub message payloads. - No distributed cross-replica webpush cache invalidation.
This commit is contained in:
+41
-28
@@ -2038,6 +2038,7 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
status := database.ChatStatusWaiting
|
||||
wasInterrupted := false
|
||||
lastError := ""
|
||||
runResult := runChatResult{}
|
||||
remainingQueuedMessages := []database.ChatQueuedMessage{}
|
||||
shouldPublishQueueUpdate := false
|
||||
var promotedMessage *database.ChatMessage
|
||||
@@ -2144,11 +2145,12 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
|
||||
if !wasInterrupted {
|
||||
p.maybeSendPushNotification(cleanupCtx, chat, status, lastError, logger)
|
||||
p.maybeSendPushNotification(cleanupCtx, chat, status, lastError, runResult, logger)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := p.runChat(chatCtx, chat, logger); err != nil {
|
||||
runResult, err := p.runChat(chatCtx, chat, logger)
|
||||
if err != nil {
|
||||
if errors.Is(err, chatloop.ErrInterrupted) || errors.Is(context.Cause(chatCtx), chatloop.ErrInterrupted) {
|
||||
logger.Info(ctx, "chat interrupted")
|
||||
status = database.ChatStatusWaiting
|
||||
@@ -2205,11 +2207,18 @@ func isShutdownCancellation(
|
||||
return errors.Is(context.Cause(chatCtx), context.Canceled)
|
||||
}
|
||||
|
||||
type runChatResult struct {
|
||||
FinalAssistantText string
|
||||
PushSummaryModel fantasy.LanguageModel
|
||||
ProviderKeys chatprovider.ProviderAPIKeys
|
||||
}
|
||||
|
||||
func (p *Server) runChat(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
logger slog.Logger,
|
||||
) error {
|
||||
) (runChatResult, error) {
|
||||
result := runChatResult{}
|
||||
var (
|
||||
model fantasy.LanguageModel
|
||||
modelConfig database.ChatModelConfig
|
||||
@@ -2241,14 +2250,16 @@ func (p *Server) runChat(
|
||||
return nil
|
||||
})
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
return result, err
|
||||
}
|
||||
result.PushSummaryModel = model
|
||||
result.ProviderKeys = providerKeys
|
||||
// Fire title generation asynchronously so it doesn't block the
|
||||
// chat response. It uses a detached context so it can finish
|
||||
// even after the chat processing context is canceled.
|
||||
// Snapshot model so the goroutine doesn't race with the
|
||||
// model = cuModel reassignment below.
|
||||
titleModel := model
|
||||
// Snapshot the original chat model so the goroutine doesn't
|
||||
// race with the model = cuModel reassignment below.
|
||||
titleModel := result.PushSummaryModel
|
||||
p.inflight.Add(1)
|
||||
go func() {
|
||||
defer p.inflight.Done()
|
||||
@@ -2257,7 +2268,7 @@ func (p *Server) runChat(
|
||||
|
||||
prompt, err := chatprompt.ConvertMessagesWithFiles(ctx, messages, p.chatFileResolver(), logger)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("build chat prompt: %w", err)
|
||||
return result, xerrors.Errorf("build chat prompt: %w", err)
|
||||
}
|
||||
if chat.ParentChatID.Valid {
|
||||
prompt = chatprompt.InsertSystem(prompt, defaultSubagentInstruction)
|
||||
@@ -2389,9 +2400,11 @@ func (p *Server) runChat(
|
||||
prompt = chatprompt.InsertSystem(prompt, resolvedUserPrompt)
|
||||
}
|
||||
|
||||
// Use the model config's context_limit as a fallback when the LLM // provider doesn't include context_limit in its response metadata
|
||||
// Use the model config's context_limit as a fallback when the LLM
|
||||
// provider doesn't include context_limit in its response metadata
|
||||
// (which is the common case).
|
||||
modelConfigContextLimit := modelConfig.ContextLimit
|
||||
var finalAssistantText string
|
||||
|
||||
persistStep := func(persistCtx context.Context, step chatloop.PersistedStep) error {
|
||||
// If the chat context has been canceled, bail out before
|
||||
@@ -2455,6 +2468,7 @@ func (p *Server) runChat(
|
||||
for _, block := range assistantBlocks {
|
||||
sdkParts = append(sdkParts, chatprompt.PartFromContent(block))
|
||||
}
|
||||
finalAssistantText = strings.TrimSpace(contentBlocksToText(sdkParts))
|
||||
assistantContent, marshalErr := chatprompt.MarshalParts(sdkParts)
|
||||
if marshalErr != nil {
|
||||
return marshalErr
|
||||
@@ -2630,7 +2644,7 @@ func (p *Server) runChat(
|
||||
chatprovider.UserAgent(),
|
||||
)
|
||||
if cuErr != nil {
|
||||
return xerrors.Errorf("resolve computer use model: %w", cuErr)
|
||||
return result, xerrors.Errorf("resolve computer use model: %w", cuErr)
|
||||
}
|
||||
model = cuModel
|
||||
}
|
||||
@@ -2796,7 +2810,11 @@ func (p *Server) runChat(
|
||||
p.logger.Warn(ctx, "failed to persist interrupted chat step", slog.Error(err))
|
||||
},
|
||||
})
|
||||
return err
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
result.FinalAssistantText = finalAssistantText
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// buildProviderTools creates provider-native tool definitions
|
||||
@@ -3301,6 +3319,7 @@ func (p *Server) maybeSendPushNotification(
|
||||
chat database.Chat,
|
||||
status database.ChatStatus,
|
||||
lastError string,
|
||||
runResult runChatResult,
|
||||
logger slog.Logger,
|
||||
) {
|
||||
if p.webpushDispatcher == nil || p.webpushDispatcher.PublicKey() == "" {
|
||||
@@ -3328,23 +3347,17 @@ func (p *Server) maybeSendPushNotification(
|
||||
defer p.inflight.Done()
|
||||
pushCtx := context.WithoutCancel(ctx)
|
||||
pushBody := "Agent has finished running."
|
||||
|
||||
msg, err := p.db.GetLastChatMessageByRole(pushCtx, database.GetLastChatMessageByRoleParams{
|
||||
ChatID: chat.ID,
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
})
|
||||
if err == nil {
|
||||
content, parseErr := chatprompt.ParseContent(msg)
|
||||
if parseErr == nil {
|
||||
assistantText := strings.TrimSpace(contentBlocksToText(content))
|
||||
if assistantText != "" {
|
||||
model, _, keys, resolveErr := p.resolveChatModel(pushCtx, chat)
|
||||
if resolveErr == nil {
|
||||
if summary := generatePushSummary(pushCtx, chat.Title, assistantText, model, keys, logger); summary != "" {
|
||||
pushBody = summary
|
||||
}
|
||||
}
|
||||
}
|
||||
assistantText := strings.TrimSpace(runResult.FinalAssistantText)
|
||||
if assistantText != "" && runResult.PushSummaryModel != nil {
|
||||
if summary := generatePushSummary(
|
||||
pushCtx,
|
||||
chat.Title,
|
||||
assistantText,
|
||||
runResult.PushSummaryModel,
|
||||
runResult.ProviderKeys,
|
||||
logger,
|
||||
); summary != "" {
|
||||
pushBody = summary
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2234,12 +2234,10 @@ func TestSuccessfulChatSendsWebPushWithSummary(t *testing.T) {
|
||||
const assistantText = "I have completed the task successfully and all tests are passing now."
|
||||
const summaryText = "Completed task and verified all tests pass."
|
||||
|
||||
var nonStreamingRequests atomic.Int32
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
// Non-streaming calls are used for title
|
||||
// generation and push summary generation.
|
||||
// Return the summary text for both — the title
|
||||
// result is irrelevant to this test.
|
||||
nonStreamingRequests.Add(1)
|
||||
return chattest.OpenAINonStreamingResponse(summaryText)
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
@@ -2286,6 +2284,63 @@ func TestSuccessfulChatSendsWebPushWithSummary(t *testing.T) {
|
||||
"push body should be the LLM-generated summary")
|
||||
require.NotEqual(t, "Agent has finished running.", msg.Body,
|
||||
"push body should not use the default fallback text")
|
||||
require.Equal(t, int32(1), nonStreamingRequests.Load(),
|
||||
"expected exactly one non-streaming request for push summary generation")
|
||||
}
|
||||
|
||||
func TestSuccessfulChatSendsWebPushFallbackWithoutSummaryForEmptyAssistantText(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
var nonStreamingRequests atomic.Int32
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
nonStreamingRequests.Add(1)
|
||||
return chattest.OpenAINonStreamingResponse("unexpected summary request")
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks(" ")...,
|
||||
)
|
||||
})
|
||||
|
||||
mockPush := &mockWebpushDispatcher{}
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := chatd.New(chatd.Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: uuid.New(),
|
||||
Pubsub: ps,
|
||||
PendingChatAcquireInterval: 10 * time.Millisecond,
|
||||
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
||||
WebpushDispatcher: mockPush,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
||||
|
||||
_, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "empty-summary-push-test",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("do the thing")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
return mockPush.dispatchCount.Load() >= 1
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
msg := mockPush.getLastMessage()
|
||||
require.Equal(t, "Agent has finished running.", msg.Body,
|
||||
"push body should fall back when the final assistant text is empty")
|
||||
require.Equal(t, int32(0), nonStreamingRequests.Load(),
|
||||
"push summary should not be requested when final assistant text has no usable text")
|
||||
}
|
||||
|
||||
func TestComputerUseSubagentToolsAndModel(t *testing.T) {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/webpush"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -54,6 +55,9 @@ func (api *API) postUserWebpushSubscription(rw http.ResponseWriter, r *http.Requ
|
||||
})
|
||||
return
|
||||
}
|
||||
if invalidator, ok := api.WebpushDispatcher.(webpush.SubscriptionCacheInvalidator); ok {
|
||||
invalidator.InvalidateUser(user.ID)
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@@ -111,6 +115,9 @@ func (api *API) deleteUserWebpushSubscription(rw http.ResponseWriter, r *http.Re
|
||||
})
|
||||
return
|
||||
}
|
||||
if invalidator, ok := api.WebpushDispatcher.(webpush.SubscriptionCacheInvalidator); ok {
|
||||
invalidator.InvalidateUser(user.ID)
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
+190
-7
@@ -9,18 +9,23 @@ import (
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/SherClockHolmes/webpush-go"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/util/singleflight"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const defaultSubscriptionCacheTTL = 3 * time.Minute
|
||||
|
||||
// Dispatcher is an interface that can be used to dispatch
|
||||
// web push notifications to clients such as browsers.
|
||||
type Dispatcher interface {
|
||||
@@ -33,6 +38,36 @@ type Dispatcher interface {
|
||||
PublicKey() string
|
||||
}
|
||||
|
||||
// SubscriptionCacheInvalidator is an optional interface that lets local
|
||||
// subscription mutation handlers invalidate cached subscriptions.
|
||||
type SubscriptionCacheInvalidator interface {
|
||||
InvalidateUser(userID uuid.UUID)
|
||||
}
|
||||
|
||||
type options struct {
|
||||
clock quartz.Clock
|
||||
subscriptionCacheTTL time.Duration
|
||||
}
|
||||
|
||||
// Option configures optional behavior for a Webpusher.
|
||||
type Option func(*options)
|
||||
|
||||
// WithClock sets the clock used by the subscription cache. Defaults to a real
|
||||
// clock when not provided.
|
||||
func WithClock(clock quartz.Clock) Option {
|
||||
return func(o *options) {
|
||||
o.clock = clock
|
||||
}
|
||||
}
|
||||
|
||||
// WithSubscriptionCacheTTL sets the in-memory subscription cache TTL. Defaults
|
||||
// to three minutes when not provided or when given a non-positive duration.
|
||||
func WithSubscriptionCacheTTL(ttl time.Duration) Option {
|
||||
return func(o *options) {
|
||||
o.subscriptionCacheTTL = ttl
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new Dispatcher to dispatch web push notifications.
|
||||
//
|
||||
// This is *not* integrated into the enqueue system unfortunately.
|
||||
@@ -41,7 +76,21 @@ type Dispatcher interface {
|
||||
// for updates inside of a workspace, which we want to be immediate.
|
||||
//
|
||||
// See: https://github.com/coder/internal/issues/528
|
||||
func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub string) (Dispatcher, error) {
|
||||
func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub string, opts ...Option) (Dispatcher, error) {
|
||||
cfg := options{
|
||||
clock: quartz.NewReal(),
|
||||
subscriptionCacheTTL: defaultSubscriptionCacheTTL,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(&cfg)
|
||||
}
|
||||
if cfg.clock == nil {
|
||||
cfg.clock = quartz.NewReal()
|
||||
}
|
||||
if cfg.subscriptionCacheTTL <= 0 {
|
||||
cfg.subscriptionCacheTTL = defaultSubscriptionCacheTTL
|
||||
}
|
||||
|
||||
keys, err := db.GetWebpushVAPIDKeys(ctx)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
@@ -63,14 +112,23 @@ func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub stri
|
||||
}
|
||||
|
||||
return &Webpusher{
|
||||
vapidSub: vapidSub,
|
||||
store: db,
|
||||
log: log,
|
||||
VAPIDPublicKey: keys.VapidPublicKey,
|
||||
VAPIDPrivateKey: keys.VapidPrivateKey,
|
||||
vapidSub: vapidSub,
|
||||
store: db,
|
||||
log: log,
|
||||
VAPIDPublicKey: keys.VapidPublicKey,
|
||||
VAPIDPrivateKey: keys.VapidPrivateKey,
|
||||
clock: cfg.clock,
|
||||
subscriptionCacheTTL: cfg.subscriptionCacheTTL,
|
||||
subscriptionCache: make(map[uuid.UUID]cachedSubscriptions),
|
||||
subscriptionGenerations: make(map[uuid.UUID]uint64),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type cachedSubscriptions struct {
|
||||
subscriptions []database.WebpushSubscription
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type Webpusher struct {
|
||||
store database.Store
|
||||
log *slog.Logger
|
||||
@@ -83,10 +141,18 @@ type Webpusher struct {
|
||||
// the message payload.
|
||||
VAPIDPublicKey string
|
||||
VAPIDPrivateKey string
|
||||
|
||||
clock quartz.Clock
|
||||
|
||||
cacheMu sync.RWMutex
|
||||
subscriptionCache map[uuid.UUID]cachedSubscriptions
|
||||
subscriptionGenerations map[uuid.UUID]uint64
|
||||
subscriptionCacheTTL time.Duration
|
||||
subscriptionFetches singleflight.Group[string, []database.WebpushSubscription]
|
||||
}
|
||||
|
||||
func (n *Webpusher) Dispatch(ctx context.Context, userID uuid.UUID, msg codersdk.WebpushMessage) error {
|
||||
subscriptions, err := n.store.GetWebpushSubscriptionsByUserID(ctx, userID)
|
||||
subscriptions, err := n.subscriptionsForUser(ctx, userID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get web push subscriptions by user ID: %w", err)
|
||||
}
|
||||
@@ -142,12 +208,129 @@ func (n *Webpusher) Dispatch(ctx context.Context, userID uuid.UUID, msg codersdk
|
||||
err = n.store.DeleteWebpushSubscriptions(dbauthz.AsNotifier(ctx), cleanupSubscriptions)
|
||||
if err != nil {
|
||||
n.log.Error(ctx, "failed to delete stale push subscriptions", slog.Error(err))
|
||||
} else {
|
||||
n.pruneSubscriptions(userID, cleanupSubscriptions)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *Webpusher) subscriptionsForUser(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
|
||||
if subscriptions, ok := n.cachedSubscriptions(userID); ok {
|
||||
return subscriptions, nil
|
||||
}
|
||||
|
||||
subscriptions, err, _ := n.subscriptionFetches.Do(userID.String(), func() ([]database.WebpushSubscription, error) {
|
||||
if cached, ok := n.cachedSubscriptions(userID); ok {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
generation := n.subscriptionGeneration(userID)
|
||||
fetched, err := n.store.GetWebpushSubscriptionsByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n.storeSubscriptions(userID, generation, fetched)
|
||||
return slices.Clone(fetched), nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return slices.Clone(subscriptions), nil
|
||||
}
|
||||
|
||||
func (n *Webpusher) cachedSubscriptions(userID uuid.UUID) ([]database.WebpushSubscription, bool) {
|
||||
n.cacheMu.RLock()
|
||||
entry, ok := n.subscriptionCache[userID]
|
||||
n.cacheMu.RUnlock()
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if n.clock.Now().Before(entry.expiresAt) {
|
||||
return slices.Clone(entry.subscriptions), true
|
||||
}
|
||||
|
||||
n.cacheMu.Lock()
|
||||
if current, ok := n.subscriptionCache[userID]; ok && !n.clock.Now().Before(current.expiresAt) {
|
||||
delete(n.subscriptionCache, userID)
|
||||
}
|
||||
n.cacheMu.Unlock()
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (n *Webpusher) subscriptionGeneration(userID uuid.UUID) uint64 {
|
||||
n.cacheMu.RLock()
|
||||
generation := n.subscriptionGenerations[userID]
|
||||
n.cacheMu.RUnlock()
|
||||
return generation
|
||||
}
|
||||
|
||||
func (n *Webpusher) storeSubscriptions(userID uuid.UUID, generation uint64, subscriptions []database.WebpushSubscription) {
|
||||
n.cacheMu.Lock()
|
||||
defer n.cacheMu.Unlock()
|
||||
|
||||
if n.subscriptionGenerations[userID] != generation {
|
||||
return
|
||||
}
|
||||
|
||||
n.subscriptionCache[userID] = cachedSubscriptions{
|
||||
subscriptions: slices.Clone(subscriptions),
|
||||
expiresAt: n.clock.Now().Add(n.subscriptionCacheTTL),
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Webpusher) pruneSubscriptions(userID uuid.UUID, staleIDs []uuid.UUID) {
|
||||
if len(staleIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
stale := make(map[uuid.UUID]struct{}, len(staleIDs))
|
||||
for _, id := range staleIDs {
|
||||
stale[id] = struct{}{}
|
||||
}
|
||||
|
||||
n.cacheMu.Lock()
|
||||
defer n.cacheMu.Unlock()
|
||||
|
||||
entry, ok := n.subscriptionCache[userID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if !n.clock.Now().Before(entry.expiresAt) {
|
||||
delete(n.subscriptionCache, userID)
|
||||
return
|
||||
}
|
||||
|
||||
filtered := make([]database.WebpushSubscription, 0, len(entry.subscriptions))
|
||||
for _, subscription := range entry.subscriptions {
|
||||
if _, shouldDelete := stale[subscription.ID]; shouldDelete {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, subscription)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
delete(n.subscriptionCache, userID)
|
||||
return
|
||||
}
|
||||
|
||||
entry.subscriptions = filtered
|
||||
n.subscriptionCache[userID] = entry
|
||||
}
|
||||
|
||||
// InvalidateUser clears the cached subscriptions for a user and advances
|
||||
// its invalidation generation. Local subscribe and unsubscribe handlers call
|
||||
// this after mutating subscriptions in the same process.
|
||||
func (n *Webpusher) InvalidateUser(userID uuid.UUID) {
|
||||
n.cacheMu.Lock()
|
||||
delete(n.subscriptionCache, userID)
|
||||
n.subscriptionGenerations[userID]++
|
||||
n.cacheMu.Unlock()
|
||||
n.subscriptionFetches.Forget(userID.String())
|
||||
}
|
||||
|
||||
func (n *Webpusher) webpushSend(ctx context.Context, msg []byte, endpoint string, keys webpush.Keys) (int, []byte, error) {
|
||||
// Copy the message to avoid modifying the original.
|
||||
cpy := slices.Clone(msg)
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -21,6 +23,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/webpush"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -28,6 +31,20 @@ const (
|
||||
validEndpointP256dhKey = "BNNL5ZaTfK81qhXOx23+wewhigUeFb632jN6LvRWCFH1ubQr77FE/9qV1FuojuRmHP42zmf34rXgW80OvUVDgTk="
|
||||
)
|
||||
|
||||
type countingWebpushStore struct {
|
||||
database.Store
|
||||
getSubscriptionsCalls atomic.Int32
|
||||
}
|
||||
|
||||
func (s *countingWebpushStore) GetWebpushSubscriptionsByUserID(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
|
||||
s.getSubscriptionsCalls.Add(1)
|
||||
return s.Store.GetWebpushSubscriptionsByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *countingWebpushStore) getCallCount() int32 {
|
||||
return s.getSubscriptionsCalls.Load()
|
||||
}
|
||||
|
||||
func TestPush(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -216,6 +233,131 @@ func TestPush(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, subscriptions, "No subscriptions should be returned")
|
||||
})
|
||||
|
||||
t.Run("CachesSubscriptionsWithinTTL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
rawStore, _ := dbtestutil.NewDB(t)
|
||||
store := &countingWebpushStore{Store: rawStore}
|
||||
var delivered atomic.Int32
|
||||
manager, _, serverURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
|
||||
delivered.Add(1)
|
||||
assertWebpushPayload(t, r)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
|
||||
|
||||
user := dbgen.User(t, rawStore, database.User{})
|
||||
_, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
|
||||
CreatedAt: dbtime.Now(),
|
||||
UserID: user.ID,
|
||||
Endpoint: serverURL,
|
||||
EndpointAuthKey: validEndpointAuthKey,
|
||||
EndpointP256dhKey: validEndpointP256dhKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := randomWebpushMessage(t)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int32(1), store.getCallCount(), "subscriptions should be read once within the TTL")
|
||||
require.Equal(t, int32(2), delivered.Load(), "both dispatches should send a notification")
|
||||
})
|
||||
|
||||
t.Run("RefreshesSubscriptionsAfterTTLExpires", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
rawStore, _ := dbtestutil.NewDB(t)
|
||||
store := &countingWebpushStore{Store: rawStore}
|
||||
var delivered atomic.Int32
|
||||
manager, _, serverURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
|
||||
delivered.Add(1)
|
||||
assertWebpushPayload(t, r)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
|
||||
|
||||
user := dbgen.User(t, rawStore, database.User{})
|
||||
_, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
|
||||
CreatedAt: dbtime.Now(),
|
||||
UserID: user.ID,
|
||||
Endpoint: serverURL,
|
||||
EndpointAuthKey: validEndpointAuthKey,
|
||||
EndpointP256dhKey: validEndpointP256dhKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := randomWebpushMessage(t)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
clock.Advance(time.Minute)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int32(2), store.getCallCount(), "dispatch should refresh subscriptions after the TTL expires")
|
||||
require.Equal(t, int32(2), delivered.Load(), "both dispatches should send a notification")
|
||||
})
|
||||
|
||||
t.Run("PrunesStaleSubscriptionsFromCache", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
rawStore, _ := dbtestutil.NewDB(t)
|
||||
store := &countingWebpushStore{Store: rawStore}
|
||||
var okCalls atomic.Int32
|
||||
var goneCalls atomic.Int32
|
||||
manager, _, okServerURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
|
||||
okCalls.Add(1)
|
||||
assertWebpushPayload(t, r)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
|
||||
|
||||
goneServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
goneCalls.Add(1)
|
||||
assertWebpushPayload(t, r)
|
||||
w.WriteHeader(http.StatusGone)
|
||||
}))
|
||||
defer goneServer.Close()
|
||||
|
||||
user := dbgen.User(t, rawStore, database.User{})
|
||||
okSubscription, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
|
||||
CreatedAt: dbtime.Now(),
|
||||
UserID: user.ID,
|
||||
Endpoint: okServerURL,
|
||||
EndpointAuthKey: validEndpointAuthKey,
|
||||
EndpointP256dhKey: validEndpointP256dhKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
|
||||
CreatedAt: dbtime.Now(),
|
||||
UserID: user.ID,
|
||||
Endpoint: goneServer.URL,
|
||||
EndpointAuthKey: validEndpointAuthKey,
|
||||
EndpointP256dhKey: validEndpointP256dhKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := randomWebpushMessage(t)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int32(1), store.getCallCount(), "stale subscription cleanup should not force a second DB read within the TTL")
|
||||
require.Equal(t, int32(2), okCalls.Load(), "the healthy endpoint should receive both dispatches")
|
||||
require.Equal(t, int32(1), goneCalls.Load(), "the stale endpoint should be pruned from the cache after the first dispatch")
|
||||
|
||||
subscriptions, err := rawStore.GetWebpushSubscriptionsByUserID(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, subscriptions, 1, "only the healthy subscription should remain")
|
||||
require.Equal(t, okSubscription.ID, subscriptions[0].ID)
|
||||
})
|
||||
}
|
||||
|
||||
func randomWebpushMessage(t testing.TB) codersdk.WebpushMessage {
|
||||
@@ -244,16 +386,21 @@ func assertWebpushPayload(t testing.TB, r *http.Request) {
|
||||
assert.Error(t, json.NewDecoder(r.Body).Decode(io.Discard))
|
||||
}
|
||||
|
||||
// setupPushTest creates a common test setup for webpush notification tests
|
||||
// setupPushTest creates a common test setup for webpush notification tests.
|
||||
func setupPushTest(ctx context.Context, t *testing.T, handlerFunc func(w http.ResponseWriter, r *http.Request)) (webpush.Dispatcher, database.Store, string) {
|
||||
t.Helper()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
return setupPushTestWithOptions(ctx, t, db, handlerFunc)
|
||||
}
|
||||
|
||||
func setupPushTestWithOptions(ctx context.Context, t *testing.T, db database.Store, handlerFunc func(w http.ResponseWriter, r *http.Request), opts ...webpush.Option) (webpush.Dispatcher, database.Store, string) {
|
||||
t.Helper()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(handlerFunc))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
manager, err := webpush.New(ctx, &logger, db, "http://example.com")
|
||||
manager, err := webpush.New(ctx, &logger, db, "http://example.com", opts...)
|
||||
require.NoError(t, err, "Failed to create webpush manager")
|
||||
|
||||
return manager, db, server.URL
|
||||
|
||||
+18
-7
@@ -35,31 +35,42 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
_, anotherMember := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
handlerCalled := make(chan bool, 1)
|
||||
var handlerCalls atomic.Int32
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
handlerCalled <- true
|
||||
handlerCalls.Add(1)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
err := memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
|
||||
// Seed the dispatcher cache with an empty subscription set. Creating the
|
||||
// subscription should invalidate that entry so the next dispatch sees the new
|
||||
// subscription immediately.
|
||||
err := memberClient.PostTestWebpushMessage(ctx)
|
||||
require.NoError(t, err, "test webpush message without a subscription")
|
||||
require.Zero(t, handlerCalls.Load(), "a user without subscriptions should not receive a push")
|
||||
|
||||
err = memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
|
||||
Endpoint: server.URL,
|
||||
AuthKey: validEndpointAuthKey,
|
||||
P256DHKey: validEndpointP256dhKey,
|
||||
})
|
||||
require.NoError(t, err, "create webpush subscription")
|
||||
require.True(t, <-handlerCalled, "handler should have been called")
|
||||
require.Equal(t, int32(1), handlerCalls.Load(), "subscription validation should hit the endpoint once")
|
||||
|
||||
err = memberClient.PostTestWebpushMessage(ctx)
|
||||
require.NoError(t, err, "test webpush message")
|
||||
require.True(t, <-handlerCalled, "handler should have been called again")
|
||||
require.NoError(t, err, "test webpush message after subscribing")
|
||||
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate empty cache entries after subscribing")
|
||||
|
||||
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
|
||||
Endpoint: server.URL,
|
||||
})
|
||||
require.NoError(t, err, "delete webpush subscription")
|
||||
|
||||
// Deleting the subscription for a non-existent endpoint should return a 404
|
||||
err = memberClient.PostTestWebpushMessage(ctx)
|
||||
require.NoError(t, err, "test webpush message after unsubscribing")
|
||||
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate cached subscriptions after unsubscribing")
|
||||
|
||||
// Deleting the subscription for a non-existent endpoint should return a 404.
|
||||
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
|
||||
Endpoint: server.URL,
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user