perf(coderd): reduce duplicated reads in push and webpush paths (#23115)

## Background

A 5000-chat scaletest (~50k turns, ~2m45s wall time) completed
successfully,
but the main bottleneck was **DB pool starvation from repeated reads**,
not
individually expensive SQL. The push/webpush path showed a few
especially noisy
reads:

- `GetLastChatMessageByRole` for push body generation
- `GetEnabledChatProviders` + `GetChatModelConfigByID` for push summary
model
  resolution
- `GetWebpushSubscriptionsByUserID` for every webpush dispatch

This PR keeps the optimizations that remove those duplicate reads while
leaving
stream behavior unchanged.

## What changes in this PR

### 1. Reuse resolved chat state for push notifications

`maybeSendPushNotification` used to re-read the last assistant message
and
re-resolve the chat model/provider after `runChat` had already done that
work.

Now `runChat` returns the final assistant text plus the already-resolved
model
and provider keys, and the push goroutine uses that state directly.

That removes the extra push-path reads for:

- `GetLastChatMessageByRole`
- the second `resolveChatModel` path
- the provider/model lookups that came with that second resolution

### 2. Cache webpush subscriptions during dispatch

`Dispatch()` previously hit `GetWebpushSubscriptionsByUserID` on every
push. A
small per-user in-memory cache now avoids those repeated reads.

The follow-up fix keeps that optimization correct: `InvalidateUser()`
bumps a
per-user generation so an older in-flight fetch cannot repopulate the
cache with
pre-mutation data after subscribe/unsubscribe.

That preserves the cache win without letting local subscription changes
be
silently overwritten by stale fetch results.

## Why this is safe

- The push change only reuses data already produced during the same chat
run. It
does not change notification semantics; if there is no assistant text to
  summarize, the existing fallback body still applies.
- The webpush change keeps the existing TTL and `410 Gone` cleanup
behavior. The
generation guard only prevents stale in-flight fetches from poisoning
the
  shared cache after invalidation.
- The final PR does **not** change stream setup, pubsub/relay behavior,
or chat
  status snapshot timing.

## Deliberately not included

- No stream-path optimization in `Subscribe`.
- No inline pubsub message payloads.
- No distributed cross-replica webpush cache invalidation.
This commit is contained in:
Ethan
2026-03-17 13:50:47 +11:00
committed by GitHub
parent 7cca2b6176
commit 04fca84872
6 changed files with 465 additions and 49 deletions
+41 -28
View File
@@ -2038,6 +2038,7 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
status := database.ChatStatusWaiting
wasInterrupted := false
lastError := ""
runResult := runChatResult{}
remainingQueuedMessages := []database.ChatQueuedMessage{}
shouldPublishQueueUpdate := false
var promotedMessage *database.ChatMessage
@@ -2144,11 +2145,12 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindStatusChange, nil)
if !wasInterrupted {
p.maybeSendPushNotification(cleanupCtx, chat, status, lastError, logger)
p.maybeSendPushNotification(cleanupCtx, chat, status, lastError, runResult, logger)
}
}()
if err := p.runChat(chatCtx, chat, logger); err != nil {
runResult, err := p.runChat(chatCtx, chat, logger)
if err != nil {
if errors.Is(err, chatloop.ErrInterrupted) || errors.Is(context.Cause(chatCtx), chatloop.ErrInterrupted) {
logger.Info(ctx, "chat interrupted")
status = database.ChatStatusWaiting
@@ -2205,11 +2207,18 @@ func isShutdownCancellation(
return errors.Is(context.Cause(chatCtx), context.Canceled)
}
type runChatResult struct {
FinalAssistantText string
PushSummaryModel fantasy.LanguageModel
ProviderKeys chatprovider.ProviderAPIKeys
}
func (p *Server) runChat(
ctx context.Context,
chat database.Chat,
logger slog.Logger,
) error {
) (runChatResult, error) {
result := runChatResult{}
var (
model fantasy.LanguageModel
modelConfig database.ChatModelConfig
@@ -2241,14 +2250,16 @@ func (p *Server) runChat(
return nil
})
if err := g.Wait(); err != nil {
return err
return result, err
}
result.PushSummaryModel = model
result.ProviderKeys = providerKeys
// Fire title generation asynchronously so it doesn't block the
// chat response. It uses a detached context so it can finish
// even after the chat processing context is canceled.
// Snapshot model so the goroutine doesn't race with the
// model = cuModel reassignment below.
titleModel := model
// Snapshot the original chat model so the goroutine doesn't
// race with the model = cuModel reassignment below.
titleModel := result.PushSummaryModel
p.inflight.Add(1)
go func() {
defer p.inflight.Done()
@@ -2257,7 +2268,7 @@ func (p *Server) runChat(
prompt, err := chatprompt.ConvertMessagesWithFiles(ctx, messages, p.chatFileResolver(), logger)
if err != nil {
return xerrors.Errorf("build chat prompt: %w", err)
return result, xerrors.Errorf("build chat prompt: %w", err)
}
if chat.ParentChatID.Valid {
prompt = chatprompt.InsertSystem(prompt, defaultSubagentInstruction)
@@ -2389,9 +2400,11 @@ func (p *Server) runChat(
prompt = chatprompt.InsertSystem(prompt, resolvedUserPrompt)
}
// Use the model config's context_limit as a fallback when the LLM // provider doesn't include context_limit in its response metadata
// Use the model config's context_limit as a fallback when the LLM
// provider doesn't include context_limit in its response metadata
// (which is the common case).
modelConfigContextLimit := modelConfig.ContextLimit
var finalAssistantText string
persistStep := func(persistCtx context.Context, step chatloop.PersistedStep) error {
// If the chat context has been canceled, bail out before
@@ -2455,6 +2468,7 @@ func (p *Server) runChat(
for _, block := range assistantBlocks {
sdkParts = append(sdkParts, chatprompt.PartFromContent(block))
}
finalAssistantText = strings.TrimSpace(contentBlocksToText(sdkParts))
assistantContent, marshalErr := chatprompt.MarshalParts(sdkParts)
if marshalErr != nil {
return marshalErr
@@ -2630,7 +2644,7 @@ func (p *Server) runChat(
chatprovider.UserAgent(),
)
if cuErr != nil {
return xerrors.Errorf("resolve computer use model: %w", cuErr)
return result, xerrors.Errorf("resolve computer use model: %w", cuErr)
}
model = cuModel
}
@@ -2796,7 +2810,11 @@ func (p *Server) runChat(
p.logger.Warn(ctx, "failed to persist interrupted chat step", slog.Error(err))
},
})
return err
if err != nil {
return result, err
}
result.FinalAssistantText = finalAssistantText
return result, nil
}
// buildProviderTools creates provider-native tool definitions
@@ -3301,6 +3319,7 @@ func (p *Server) maybeSendPushNotification(
chat database.Chat,
status database.ChatStatus,
lastError string,
runResult runChatResult,
logger slog.Logger,
) {
if p.webpushDispatcher == nil || p.webpushDispatcher.PublicKey() == "" {
@@ -3328,23 +3347,17 @@ func (p *Server) maybeSendPushNotification(
defer p.inflight.Done()
pushCtx := context.WithoutCancel(ctx)
pushBody := "Agent has finished running."
msg, err := p.db.GetLastChatMessageByRole(pushCtx, database.GetLastChatMessageByRoleParams{
ChatID: chat.ID,
Role: database.ChatMessageRoleAssistant,
})
if err == nil {
content, parseErr := chatprompt.ParseContent(msg)
if parseErr == nil {
assistantText := strings.TrimSpace(contentBlocksToText(content))
if assistantText != "" {
model, _, keys, resolveErr := p.resolveChatModel(pushCtx, chat)
if resolveErr == nil {
if summary := generatePushSummary(pushCtx, chat.Title, assistantText, model, keys, logger); summary != "" {
pushBody = summary
}
}
}
assistantText := strings.TrimSpace(runResult.FinalAssistantText)
if assistantText != "" && runResult.PushSummaryModel != nil {
if summary := generatePushSummary(
pushCtx,
chat.Title,
assistantText,
runResult.PushSummaryModel,
runResult.ProviderKeys,
logger,
); summary != "" {
pushBody = summary
}
}
+59 -4
View File
@@ -2234,12 +2234,10 @@ func TestSuccessfulChatSendsWebPushWithSummary(t *testing.T) {
const assistantText = "I have completed the task successfully and all tests are passing now."
const summaryText = "Completed task and verified all tests pass."
var nonStreamingRequests atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
// Non-streaming calls are used for title
// generation and push summary generation.
// Return the summary text for both — the title
// result is irrelevant to this test.
nonStreamingRequests.Add(1)
return chattest.OpenAINonStreamingResponse(summaryText)
}
return chattest.OpenAIStreamingResponse(
@@ -2286,6 +2284,63 @@ func TestSuccessfulChatSendsWebPushWithSummary(t *testing.T) {
"push body should be the LLM-generated summary")
require.NotEqual(t, "Agent has finished running.", msg.Body,
"push body should not use the default fallback text")
require.Equal(t, int32(1), nonStreamingRequests.Load(),
"expected exactly one non-streaming request for push summary generation")
}
func TestSuccessfulChatSendsWebPushFallbackWithoutSummaryForEmptyAssistantText(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
var nonStreamingRequests atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
nonStreamingRequests.Add(1)
return chattest.OpenAINonStreamingResponse("unexpected summary request")
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks(" ")...,
)
})
mockPush := &mockWebpushDispatcher{}
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: 10 * time.Millisecond,
InFlightChatStaleAfter: testutil.WaitSuperLong,
WebpushDispatcher: mockPush,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
user, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
_, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "empty-summary-push-test",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("do the thing")},
})
require.NoError(t, err)
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return mockPush.dispatchCount.Load() >= 1
}, testutil.IntervalFast)
msg := mockPush.getLastMessage()
require.Equal(t, "Agent has finished running.", msg.Body,
"push body should fall back when the final assistant text is empty")
require.Equal(t, int32(0), nonStreamingRequests.Load(),
"push summary should not be requested when final assistant text has no usable text")
}
func TestComputerUseSubagentToolsAndModel(t *testing.T) {
+7
View File
@@ -12,6 +12,7 @@ import (
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/webpush"
"github.com/coder/coder/v2/codersdk"
)
@@ -54,6 +55,9 @@ func (api *API) postUserWebpushSubscription(rw http.ResponseWriter, r *http.Requ
})
return
}
if invalidator, ok := api.WebpushDispatcher.(webpush.SubscriptionCacheInvalidator); ok {
invalidator.InvalidateUser(user.ID)
}
rw.WriteHeader(http.StatusNoContent)
}
@@ -111,6 +115,9 @@ func (api *API) deleteUserWebpushSubscription(rw http.ResponseWriter, r *http.Re
})
return
}
if invalidator, ok := api.WebpushDispatcher.(webpush.SubscriptionCacheInvalidator); ok {
invalidator.InvalidateUser(user.ID)
}
rw.WriteHeader(http.StatusNoContent)
}
+190 -7
View File
@@ -9,18 +9,23 @@ import (
"net/http"
"slices"
"sync"
"time"
"github.com/SherClockHolmes/webpush-go"
"github.com/google/uuid"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"tailscale.com/util/singleflight"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
const defaultSubscriptionCacheTTL = 3 * time.Minute
// Dispatcher is an interface that can be used to dispatch
// web push notifications to clients such as browsers.
type Dispatcher interface {
@@ -33,6 +38,36 @@ type Dispatcher interface {
PublicKey() string
}
// SubscriptionCacheInvalidator is an optional interface that lets local
// subscription mutation handlers invalidate cached subscriptions.
type SubscriptionCacheInvalidator interface {
InvalidateUser(userID uuid.UUID)
}
type options struct {
clock quartz.Clock
subscriptionCacheTTL time.Duration
}
// Option configures optional behavior for a Webpusher.
type Option func(*options)
// WithClock sets the clock used by the subscription cache. Defaults to a real
// clock when not provided.
func WithClock(clock quartz.Clock) Option {
return func(o *options) {
o.clock = clock
}
}
// WithSubscriptionCacheTTL sets the in-memory subscription cache TTL. Defaults
// to three minutes when not provided or when given a non-positive duration.
func WithSubscriptionCacheTTL(ttl time.Duration) Option {
return func(o *options) {
o.subscriptionCacheTTL = ttl
}
}
// New creates a new Dispatcher to dispatch web push notifications.
//
// This is *not* integrated into the enqueue system unfortunately.
@@ -41,7 +76,21 @@ type Dispatcher interface {
// for updates inside of a workspace, which we want to be immediate.
//
// See: https://github.com/coder/internal/issues/528
func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub string) (Dispatcher, error) {
func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub string, opts ...Option) (Dispatcher, error) {
cfg := options{
clock: quartz.NewReal(),
subscriptionCacheTTL: defaultSubscriptionCacheTTL,
}
for _, opt := range opts {
opt(&cfg)
}
if cfg.clock == nil {
cfg.clock = quartz.NewReal()
}
if cfg.subscriptionCacheTTL <= 0 {
cfg.subscriptionCacheTTL = defaultSubscriptionCacheTTL
}
keys, err := db.GetWebpushVAPIDKeys(ctx)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
@@ -63,14 +112,23 @@ func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub stri
}
return &Webpusher{
vapidSub: vapidSub,
store: db,
log: log,
VAPIDPublicKey: keys.VapidPublicKey,
VAPIDPrivateKey: keys.VapidPrivateKey,
vapidSub: vapidSub,
store: db,
log: log,
VAPIDPublicKey: keys.VapidPublicKey,
VAPIDPrivateKey: keys.VapidPrivateKey,
clock: cfg.clock,
subscriptionCacheTTL: cfg.subscriptionCacheTTL,
subscriptionCache: make(map[uuid.UUID]cachedSubscriptions),
subscriptionGenerations: make(map[uuid.UUID]uint64),
}, nil
}
type cachedSubscriptions struct {
subscriptions []database.WebpushSubscription
expiresAt time.Time
}
type Webpusher struct {
store database.Store
log *slog.Logger
@@ -83,10 +141,18 @@ type Webpusher struct {
// the message payload.
VAPIDPublicKey string
VAPIDPrivateKey string
clock quartz.Clock
cacheMu sync.RWMutex
subscriptionCache map[uuid.UUID]cachedSubscriptions
subscriptionGenerations map[uuid.UUID]uint64
subscriptionCacheTTL time.Duration
subscriptionFetches singleflight.Group[string, []database.WebpushSubscription]
}
func (n *Webpusher) Dispatch(ctx context.Context, userID uuid.UUID, msg codersdk.WebpushMessage) error {
subscriptions, err := n.store.GetWebpushSubscriptionsByUserID(ctx, userID)
subscriptions, err := n.subscriptionsForUser(ctx, userID)
if err != nil {
return xerrors.Errorf("get web push subscriptions by user ID: %w", err)
}
@@ -142,12 +208,129 @@ func (n *Webpusher) Dispatch(ctx context.Context, userID uuid.UUID, msg codersdk
err = n.store.DeleteWebpushSubscriptions(dbauthz.AsNotifier(ctx), cleanupSubscriptions)
if err != nil {
n.log.Error(ctx, "failed to delete stale push subscriptions", slog.Error(err))
} else {
n.pruneSubscriptions(userID, cleanupSubscriptions)
}
}
return nil
}
func (n *Webpusher) subscriptionsForUser(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
if subscriptions, ok := n.cachedSubscriptions(userID); ok {
return subscriptions, nil
}
subscriptions, err, _ := n.subscriptionFetches.Do(userID.String(), func() ([]database.WebpushSubscription, error) {
if cached, ok := n.cachedSubscriptions(userID); ok {
return cached, nil
}
generation := n.subscriptionGeneration(userID)
fetched, err := n.store.GetWebpushSubscriptionsByUserID(ctx, userID)
if err != nil {
return nil, err
}
n.storeSubscriptions(userID, generation, fetched)
return slices.Clone(fetched), nil
})
if err != nil {
return nil, err
}
return slices.Clone(subscriptions), nil
}
func (n *Webpusher) cachedSubscriptions(userID uuid.UUID) ([]database.WebpushSubscription, bool) {
n.cacheMu.RLock()
entry, ok := n.subscriptionCache[userID]
n.cacheMu.RUnlock()
if !ok {
return nil, false
}
if n.clock.Now().Before(entry.expiresAt) {
return slices.Clone(entry.subscriptions), true
}
n.cacheMu.Lock()
if current, ok := n.subscriptionCache[userID]; ok && !n.clock.Now().Before(current.expiresAt) {
delete(n.subscriptionCache, userID)
}
n.cacheMu.Unlock()
return nil, false
}
func (n *Webpusher) subscriptionGeneration(userID uuid.UUID) uint64 {
n.cacheMu.RLock()
generation := n.subscriptionGenerations[userID]
n.cacheMu.RUnlock()
return generation
}
func (n *Webpusher) storeSubscriptions(userID uuid.UUID, generation uint64, subscriptions []database.WebpushSubscription) {
n.cacheMu.Lock()
defer n.cacheMu.Unlock()
if n.subscriptionGenerations[userID] != generation {
return
}
n.subscriptionCache[userID] = cachedSubscriptions{
subscriptions: slices.Clone(subscriptions),
expiresAt: n.clock.Now().Add(n.subscriptionCacheTTL),
}
}
func (n *Webpusher) pruneSubscriptions(userID uuid.UUID, staleIDs []uuid.UUID) {
if len(staleIDs) == 0 {
return
}
stale := make(map[uuid.UUID]struct{}, len(staleIDs))
for _, id := range staleIDs {
stale[id] = struct{}{}
}
n.cacheMu.Lock()
defer n.cacheMu.Unlock()
entry, ok := n.subscriptionCache[userID]
if !ok {
return
}
if !n.clock.Now().Before(entry.expiresAt) {
delete(n.subscriptionCache, userID)
return
}
filtered := make([]database.WebpushSubscription, 0, len(entry.subscriptions))
for _, subscription := range entry.subscriptions {
if _, shouldDelete := stale[subscription.ID]; shouldDelete {
continue
}
filtered = append(filtered, subscription)
}
if len(filtered) == 0 {
delete(n.subscriptionCache, userID)
return
}
entry.subscriptions = filtered
n.subscriptionCache[userID] = entry
}
// InvalidateUser clears the cached subscriptions for a user and advances
// its invalidation generation. Local subscribe and unsubscribe handlers call
// this after mutating subscriptions in the same process.
func (n *Webpusher) InvalidateUser(userID uuid.UUID) {
n.cacheMu.Lock()
delete(n.subscriptionCache, userID)
n.subscriptionGenerations[userID]++
n.cacheMu.Unlock()
n.subscriptionFetches.Forget(userID.String())
}
func (n *Webpusher) webpushSend(ctx context.Context, msg []byte, endpoint string, keys webpush.Keys) (int, []byte, error) {
// Copy the message to avoid modifying the original.
cpy := slices.Clone(msg)
+150 -3
View File
@@ -6,7 +6,9 @@ import (
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -21,6 +23,7 @@ import (
"github.com/coder/coder/v2/coderd/webpush"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
const (
@@ -28,6 +31,20 @@ const (
validEndpointP256dhKey = "BNNL5ZaTfK81qhXOx23+wewhigUeFb632jN6LvRWCFH1ubQr77FE/9qV1FuojuRmHP42zmf34rXgW80OvUVDgTk="
)
type countingWebpushStore struct {
database.Store
getSubscriptionsCalls atomic.Int32
}
func (s *countingWebpushStore) GetWebpushSubscriptionsByUserID(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
s.getSubscriptionsCalls.Add(1)
return s.Store.GetWebpushSubscriptionsByUserID(ctx, userID)
}
func (s *countingWebpushStore) getCallCount() int32 {
return s.getSubscriptionsCalls.Load()
}
func TestPush(t *testing.T) {
t.Parallel()
@@ -216,6 +233,131 @@ func TestPush(t *testing.T) {
require.NoError(t, err)
assert.Empty(t, subscriptions, "No subscriptions should be returned")
})
t.Run("CachesSubscriptionsWithinTTL", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
rawStore, _ := dbtestutil.NewDB(t)
store := &countingWebpushStore{Store: rawStore}
var delivered atomic.Int32
manager, _, serverURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
delivered.Add(1)
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusOK)
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
user := dbgen.User(t, rawStore, database.User{})
_, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: serverURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
msg := randomWebpushMessage(t)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
require.Equal(t, int32(1), store.getCallCount(), "subscriptions should be read once within the TTL")
require.Equal(t, int32(2), delivered.Load(), "both dispatches should send a notification")
})
t.Run("RefreshesSubscriptionsAfterTTLExpires", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
rawStore, _ := dbtestutil.NewDB(t)
store := &countingWebpushStore{Store: rawStore}
var delivered atomic.Int32
manager, _, serverURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
delivered.Add(1)
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusOK)
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
user := dbgen.User(t, rawStore, database.User{})
_, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: serverURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
msg := randomWebpushMessage(t)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
clock.Advance(time.Minute)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
require.Equal(t, int32(2), store.getCallCount(), "dispatch should refresh subscriptions after the TTL expires")
require.Equal(t, int32(2), delivered.Load(), "both dispatches should send a notification")
})
t.Run("PrunesStaleSubscriptionsFromCache", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
rawStore, _ := dbtestutil.NewDB(t)
store := &countingWebpushStore{Store: rawStore}
var okCalls atomic.Int32
var goneCalls atomic.Int32
manager, _, okServerURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
okCalls.Add(1)
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusOK)
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
goneServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
goneCalls.Add(1)
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusGone)
}))
defer goneServer.Close()
user := dbgen.User(t, rawStore, database.User{})
okSubscription, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: okServerURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
_, err = rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: goneServer.URL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
msg := randomWebpushMessage(t)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
require.Equal(t, int32(1), store.getCallCount(), "stale subscription cleanup should not force a second DB read within the TTL")
require.Equal(t, int32(2), okCalls.Load(), "the healthy endpoint should receive both dispatches")
require.Equal(t, int32(1), goneCalls.Load(), "the stale endpoint should be pruned from the cache after the first dispatch")
subscriptions, err := rawStore.GetWebpushSubscriptionsByUserID(ctx, user.ID)
require.NoError(t, err)
require.Len(t, subscriptions, 1, "only the healthy subscription should remain")
require.Equal(t, okSubscription.ID, subscriptions[0].ID)
})
}
func randomWebpushMessage(t testing.TB) codersdk.WebpushMessage {
@@ -244,16 +386,21 @@ func assertWebpushPayload(t testing.TB, r *http.Request) {
assert.Error(t, json.NewDecoder(r.Body).Decode(io.Discard))
}
// setupPushTest creates a common test setup for webpush notification tests
// setupPushTest creates a common test setup for webpush notification tests.
func setupPushTest(ctx context.Context, t *testing.T, handlerFunc func(w http.ResponseWriter, r *http.Request)) (webpush.Dispatcher, database.Store, string) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
db, _ := dbtestutil.NewDB(t)
return setupPushTestWithOptions(ctx, t, db, handlerFunc)
}
func setupPushTestWithOptions(ctx context.Context, t *testing.T, db database.Store, handlerFunc func(w http.ResponseWriter, r *http.Request), opts ...webpush.Option) (webpush.Dispatcher, database.Store, string) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
server := httptest.NewServer(http.HandlerFunc(handlerFunc))
t.Cleanup(server.Close)
manager, err := webpush.New(ctx, &logger, db, "http://example.com")
manager, err := webpush.New(ctx, &logger, db, "http://example.com", opts...)
require.NoError(t, err, "Failed to create webpush manager")
return manager, db, server.URL
+18 -7
View File
@@ -35,31 +35,42 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
_, anotherMember := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
handlerCalled := make(chan bool, 1)
var handlerCalls atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusCreated)
handlerCalled <- true
handlerCalls.Add(1)
}))
defer server.Close()
err := memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
// Seed the dispatcher cache with an empty subscription set. Creating the
// subscription should invalidate that entry so the next dispatch sees the new
// subscription immediately.
err := memberClient.PostTestWebpushMessage(ctx)
require.NoError(t, err, "test webpush message without a subscription")
require.Zero(t, handlerCalls.Load(), "a user without subscriptions should not receive a push")
err = memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
Endpoint: server.URL,
AuthKey: validEndpointAuthKey,
P256DHKey: validEndpointP256dhKey,
})
require.NoError(t, err, "create webpush subscription")
require.True(t, <-handlerCalled, "handler should have been called")
require.Equal(t, int32(1), handlerCalls.Load(), "subscription validation should hit the endpoint once")
err = memberClient.PostTestWebpushMessage(ctx)
require.NoError(t, err, "test webpush message")
require.True(t, <-handlerCalled, "handler should have been called again")
require.NoError(t, err, "test webpush message after subscribing")
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate empty cache entries after subscribing")
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
Endpoint: server.URL,
})
require.NoError(t, err, "delete webpush subscription")
// Deleting the subscription for a non-existent endpoint should return a 404
err = memberClient.PostTestWebpushMessage(ctx)
require.NoError(t, err, "test webpush message after unsubscribing")
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate cached subscriptions after unsubscribing")
// Deleting the subscription for a non-existent endpoint should return a 404.
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
Endpoint: server.URL,
})