diff --git a/coderd/inboxnotifications.go b/coderd/inboxnotifications.go index 6c35e33dc7..454aefee79 100644 --- a/coderd/inboxnotifications.go +++ b/coderd/inboxnotifications.go @@ -20,7 +20,6 @@ import ( "github.com/coder/coder/v2/coderd/pubsub" markdown "github.com/coder/coder/v2/coderd/render" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/websocket" ) @@ -126,6 +125,7 @@ func (api *API) watchInboxNotifications(rw http.ResponseWriter, r *http.Request) templates = p.UUIDs(vals, []uuid.UUID{}, "templates") readStatus = p.String(vals, "all", "read_status") format = p.String(vals, notificationFormatMarkdown, "format") + logger = api.Logger.Named("inbox_notifications_watcher") ) p.ErrorExcessParams(vals) if len(p.Errors) > 0 { @@ -213,11 +213,17 @@ func (api *API) watchInboxNotifications(rw http.ResponseWriter, r *http.Request) return } - go httpapi.Heartbeat(ctx, conn) - defer conn.Close(websocket.StatusNormalClosure, "connection closed") + ctx, cancel := context.WithCancel(ctx) + defer cancel() - encoder := wsjson.NewEncoder[codersdk.GetInboxNotificationResponse](conn, websocket.MessageText) - defer encoder.Close(websocket.StatusNormalClosure) + _ = conn.CloseRead(context.Background()) + + ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText) + defer wsNetConn.Close() + + go httpapi.HeartbeatClose(ctx, logger, cancel, conn) + + encoder := json.NewEncoder(wsNetConn) // Log the request immediately instead of after it completes. if rl := loggermw.RequestLoggerFromContext(ctx); rl != nil { @@ -226,8 +232,12 @@ func (api *API) watchInboxNotifications(rw http.ResponseWriter, r *http.Request) for { select { + case <-api.ctx.Done(): + return + case <-ctx.Done(): return + case notif := <-notificationCh: unreadCount, err := api.Database.CountUnreadInboxNotificationsByUserID(ctx, apikey.UserID) if err != nil {