feat: add resume support to coordinator connections (#14234)

This commit is contained in:
Dean Sheather
2024-08-20 17:16:49 +10:00
committed by GitHub
parent 0b2ba96065
commit cf8be4eac5
32 changed files with 1706 additions and 465 deletions
+13 -4
View File
@@ -56,6 +56,7 @@ import (
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/pretty"
"github.com/coder/quartz"
"github.com/coder/retry"
"github.com/coder/serpent"
"github.com/coder/wgtunnel/tunnelsdk"
@@ -791,18 +792,26 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
}
}
keyBytes, err := hex.DecodeString(oauthSigningKeyStr)
oauthKeyBytes, err := hex.DecodeString(oauthSigningKeyStr)
if err != nil {
return xerrors.Errorf("decode oauth signing key from database: %w", err)
}
if len(keyBytes) != len(options.OAuthSigningKey) {
return xerrors.Errorf("oauth signing key in database is not the correct length, expect %d got %d", len(options.OAuthSigningKey), len(keyBytes))
if len(oauthKeyBytes) != len(options.OAuthSigningKey) {
return xerrors.Errorf("oauth signing key in database is not the correct length, expect %d got %d", len(options.OAuthSigningKey), len(oauthKeyBytes))
}
copy(options.OAuthSigningKey[:], keyBytes)
copy(options.OAuthSigningKey[:], oauthKeyBytes)
if options.OAuthSigningKey == [32]byte{} {
return xerrors.Errorf("oauth signing key in database is empty")
}
// Read the coordinator resume token signing key from the
// database.
resumeTokenKey, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, tx)
if err != nil {
return xerrors.Errorf("get coordinator resume token key from database: %w", err)
}
options.CoordinatorResumeTokenProvider = tailnet.NewResumeTokenKeyProvider(resumeTokenKey, quartz.NewReal(), tailnet.DefaultResumeTokenExpiry)
return nil
}, nil)
if err != nil {
+10
View File
@@ -182,6 +182,9 @@ type Options struct {
// AppSecurityKey is the crypto key used to sign and encrypt tokens related to
// workspace applications. It consists of both a signing and encryption key.
AppSecurityKey workspaceapps.SecurityKey
// CoordinatorResumeTokenProvider is used to provide and validate resume
// tokens issued by and passed to the coordinator DRPC API.
CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider
HealthcheckFunc func(ctx context.Context, apiKey string) *healthsdk.HealthcheckReport
HealthcheckTimeout time.Duration
@@ -584,12 +587,16 @@ func New(options *Options) *API {
api.Options.NetworkTelemetryBatchMaxSize,
api.handleNetworkTelemetry,
)
if options.CoordinatorResumeTokenProvider == nil {
panic("CoordinatorResumeTokenProvider is nil")
}
api.TailnetClientService, err = tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: api.Logger.Named("tailnetclient"),
CoordPtr: &api.TailnetCoordinator,
DERPMapUpdateFrequency: api.Options.DERPMapUpdateFrequency,
DERPMapFn: api.DERPMap,
NetworkTelemetryHandler: api.NetworkTelemetryBatcher.Handler,
ResumeTokenProvider: api.Options.CoordinatorResumeTokenProvider,
})
if err != nil {
api.Logger.Fatal(api.ctx, "failed to initialize tailnet client service", slog.Error(err))
@@ -614,6 +621,9 @@ func New(options *Options) *API {
options.WorkspaceAppsStatsCollectorOptions.Reporter = api.statsReporter
}
if options.AppSecurityKey.IsZero() {
api.Logger.Fatal(api.ctx, "app security key cannot be zero")
}
api.workspaceAppServer = &workspaceapps.Server{
Logger: workspaceAppsLogger,
+24 -19
View File
@@ -96,25 +96,26 @@ type Options struct {
// AccessURL denotes a custom access URL. By default we use the httptest
// server's URL. Setting this may result in unexpected behavior (especially
// with running agents).
AccessURL *url.URL
AppHostname string
AWSCertificates awsidentity.Certificates
Authorizer rbac.Authorizer
AzureCertificates x509.VerifyOptions
GithubOAuth2Config *coderd.GithubOAuth2Config
RealIPConfig *httpmw.RealIPConfig
OIDCConfig *coderd.OIDCConfig
GoogleTokenValidator *idtoken.Validator
SSHKeygenAlgorithm gitsshkey.Algorithm
AutobuildTicker <-chan time.Time
AutobuildStats chan<- autobuild.Stats
Auditor audit.Auditor
TLSCertificates []tls.Certificate
ExternalAuthConfigs []*externalauth.Config
TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error
RefreshEntitlements func(ctx context.Context) error
TemplateScheduleStore schedule.TemplateScheduleStore
Coordinator tailnet.Coordinator
AccessURL *url.URL
AppHostname string
AWSCertificates awsidentity.Certificates
Authorizer rbac.Authorizer
AzureCertificates x509.VerifyOptions
GithubOAuth2Config *coderd.GithubOAuth2Config
RealIPConfig *httpmw.RealIPConfig
OIDCConfig *coderd.OIDCConfig
GoogleTokenValidator *idtoken.Validator
SSHKeygenAlgorithm gitsshkey.Algorithm
AutobuildTicker <-chan time.Time
AutobuildStats chan<- autobuild.Stats
Auditor audit.Auditor
TLSCertificates []tls.Certificate
ExternalAuthConfigs []*externalauth.Config
TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error
RefreshEntitlements func(ctx context.Context) error
TemplateScheduleStore schedule.TemplateScheduleStore
Coordinator tailnet.Coordinator
CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider
HealthcheckFunc func(ctx context.Context, apiKey string) *healthsdk.HealthcheckReport
HealthcheckTimeout time.Duration
@@ -240,6 +241,9 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
if options.Database == nil {
options.Database, options.Pubsub = dbtestutil.NewDB(t)
}
if options.CoordinatorResumeTokenProvider == nil {
options.CoordinatorResumeTokenProvider = tailnet.NewInsecureTestResumeTokenProvider()
}
if options.NotificationsEnqueuer == nil {
options.NotificationsEnqueuer = new(testutil.FakeNotificationsEnqueuer)
@@ -492,6 +496,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
TailnetCoordinator: options.Coordinator,
BaseDERPMap: derpMap,
DERPMapUpdateFrequency: 150 * time.Millisecond,
CoordinatorResumeTokenProvider: options.CoordinatorResumeTokenProvider,
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
DeploymentValues: options.DeploymentValues,
+20 -2
View File
@@ -1332,7 +1332,9 @@ func (q *querier) GetAnnouncementBanners(ctx context.Context) (string, error) {
}
func (q *querier) GetAppSecurityKey(ctx context.Context) (string, error) {
// No authz checks
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return "", err
}
return q.db.GetAppSecurityKey(ctx)
}
@@ -1364,6 +1366,13 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI
return q.db.GetAuthorizationUserRoles(ctx, userID)
}
func (q *querier) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return "", err
}
return q.db.GetCoordinatorResumeTokenSigningKey(ctx)
}
func (q *querier) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
@@ -3792,7 +3801,9 @@ func (q *querier) UpsertAnnouncementBanners(ctx context.Context, value string) e
}
func (q *querier) UpsertAppSecurityKey(ctx context.Context, data string) error {
// No authz checks as this is done during startup
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.UpsertAppSecurityKey(ctx, data)
}
@@ -3803,6 +3814,13 @@ func (q *querier) UpsertApplicationName(ctx context.Context, value string) error
return q.db.UpsertApplicationName(ctx, value)
}
func (q *querier) UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.UpsertCoordinatorResumeTokenSigningKey(ctx, value)
}
func (q *querier) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
+9 -2
View File
@@ -2531,10 +2531,10 @@ func (s *MethodTestSuite) TestSystemFunctions() {
check.Args(int32(0)).Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetAppSecurityKey", s.Subtest(func(db database.Store, check *expects) {
check.Args().Asserts()
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("UpsertAppSecurityKey", s.Subtest(func(db database.Store, check *expects) {
check.Args("").Asserts()
check.Args("foo").Asserts(rbac.ResourceSystem, policy.ActionUpdate)
}))
s.Run("GetApplicationName", s.Subtest(func(db database.Store, check *expects) {
db.UpsertApplicationName(context.Background(), "foo")
@@ -2574,6 +2574,13 @@ func (s *MethodTestSuite) TestSystemFunctions() {
db.UpsertOAuthSigningKey(context.Background(), "foo")
check.Args().Asserts(rbac.ResourceSystem, policy.ActionUpdate)
}))
s.Run("UpsertCoordinatorResumeTokenSigningKey", s.Subtest(func(db database.Store, check *expects) {
check.Args("foo").Asserts(rbac.ResourceSystem, policy.ActionUpdate)
}))
s.Run("GetCoordinatorResumeTokenSigningKey", s.Subtest(func(db database.Store, check *expects) {
db.UpsertCoordinatorResumeTokenSigningKey(context.Background(), "foo")
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("InsertMissingGroups", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertMissingGroupsParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate).Errors(errMatchAny)
}))
+32 -14
View File
@@ -196,20 +196,21 @@ type data struct {
customRoles []database.CustomRole
// Locks is a map of lock names. Any keys within the map are currently
// locked.
locks map[int64]struct{}
deploymentID string
derpMeshKey string
lastUpdateCheck []byte
announcementBanners []byte
healthSettings []byte
notificationsSettings []byte
applicationName string
logoURL string
appSecurityKey string
oauthSigningKey string
lastLicenseID int32
defaultProxyDisplayName string
defaultProxyIconURL string
locks map[int64]struct{}
deploymentID string
derpMeshKey string
lastUpdateCheck []byte
announcementBanners []byte
healthSettings []byte
notificationsSettings []byte
applicationName string
logoURL string
appSecurityKey string
oauthSigningKey string
coordinatorResumeTokenSigningKey string
lastLicenseID int32
defaultProxyDisplayName string
defaultProxyIconURL string
}
func validateDatabaseTypeWithValid(v reflect.Value) (handled bool, err error) {
@@ -2222,6 +2223,15 @@ func (q *FakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
}, nil
}
func (q *FakeQuerier) GetCoordinatorResumeTokenSigningKey(_ context.Context) (string, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
if q.coordinatorResumeTokenSigningKey == "" {
return "", sql.ErrNoRows
}
return q.coordinatorResumeTokenSigningKey, nil
}
func (q *FakeQuerier) GetDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@@ -8942,6 +8952,14 @@ func (q *FakeQuerier) UpsertApplicationName(_ context.Context, data string) erro
return nil
}
func (q *FakeQuerier) UpsertCoordinatorResumeTokenSigningKey(_ context.Context, value string) error {
q.mutex.Lock()
defer q.mutex.Unlock()
q.coordinatorResumeTokenSigningKey = value
return nil
}
func (q *FakeQuerier) UpsertDefaultProxy(_ context.Context, arg database.UpsertDefaultProxyParams) error {
q.defaultProxyDisplayName = arg.DisplayName
q.defaultProxyIconURL = arg.IconUrl
+14
View File
@@ -529,6 +529,13 @@ func (m metricsStore) GetAuthorizationUserRoles(ctx context.Context, userID uuid
return row, err
}
func (m metricsStore) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) {
start := time.Now()
r0, r1 := m.s.GetCoordinatorResumeTokenSigningKey(ctx)
m.queryLatencies.WithLabelValues("GetCoordinatorResumeTokenSigningKey").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) {
start := time.Now()
r0, r1 := m.s.GetDBCryptKeys(ctx)
@@ -2377,6 +2384,13 @@ func (m metricsStore) UpsertApplicationName(ctx context.Context, value string) e
return r0
}
func (m metricsStore) UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error {
start := time.Now()
r0 := m.s.UpsertCoordinatorResumeTokenSigningKey(ctx, value)
m.queryLatencies.WithLabelValues("UpsertCoordinatorResumeTokenSigningKey").Observe(time.Since(start).Seconds())
return r0
}
func (m metricsStore) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
start := time.Now()
r0 := m.s.UpsertDefaultProxy(ctx, arg)
+29
View File
@@ -1029,6 +1029,21 @@ func (mr *MockStoreMockRecorder) GetAuthorizedWorkspaces(arg0, arg1, arg2 any) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedWorkspaces", reflect.TypeOf((*MockStore)(nil).GetAuthorizedWorkspaces), arg0, arg1, arg2)
}
// GetCoordinatorResumeTokenSigningKey mocks base method.
func (m *MockStore) GetCoordinatorResumeTokenSigningKey(arg0 context.Context) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetCoordinatorResumeTokenSigningKey", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetCoordinatorResumeTokenSigningKey indicates an expected call of GetCoordinatorResumeTokenSigningKey.
func (mr *MockStoreMockRecorder) GetCoordinatorResumeTokenSigningKey(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCoordinatorResumeTokenSigningKey", reflect.TypeOf((*MockStore)(nil).GetCoordinatorResumeTokenSigningKey), arg0)
}
// GetDBCryptKeys mocks base method.
func (m *MockStore) GetDBCryptKeys(arg0 context.Context) ([]database.DBCryptKey, error) {
m.ctrl.T.Helper()
@@ -4994,6 +5009,20 @@ func (mr *MockStoreMockRecorder) UpsertApplicationName(arg0, arg1 any) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertApplicationName", reflect.TypeOf((*MockStore)(nil).UpsertApplicationName), arg0, arg1)
}
// UpsertCoordinatorResumeTokenSigningKey mocks base method.
func (m *MockStore) UpsertCoordinatorResumeTokenSigningKey(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertCoordinatorResumeTokenSigningKey", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// UpsertCoordinatorResumeTokenSigningKey indicates an expected call of UpsertCoordinatorResumeTokenSigningKey.
func (mr *MockStoreMockRecorder) UpsertCoordinatorResumeTokenSigningKey(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertCoordinatorResumeTokenSigningKey", reflect.TypeOf((*MockStore)(nil).UpsertCoordinatorResumeTokenSigningKey), arg0, arg1)
}
// UpsertDefaultProxy mocks base method.
func (m *MockStore) UpsertDefaultProxy(arg0 context.Context, arg1 database.UpsertDefaultProxyParams) error {
m.ctrl.T.Helper()
+2
View File
@@ -128,6 +128,7 @@ type sqlcQuerier interface {
// This function returns roles for authorization purposes. Implied member roles
// are included.
GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error)
GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error)
GetDBCryptKeys(ctx context.Context) ([]DBCryptKey, error)
GetDERPMeshKey(ctx context.Context) (string, error)
GetDefaultOrganization(ctx context.Context) (Organization, error)
@@ -463,6 +464,7 @@ type sqlcQuerier interface {
UpsertAnnouncementBanners(ctx context.Context, value string) error
UpsertAppSecurityKey(ctx context.Context, value string) error
UpsertApplicationName(ctx context.Context, value string) error
UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error
// The default proxy is implied and not actually stored in the database.
// So we need to store it's configuration here for display purposes.
// The functional values are immutable and controlled implicitly.
+21
View File
@@ -6624,6 +6624,17 @@ func (q *sqlQuerier) GetApplicationName(ctx context.Context) (string, error) {
return value, err
}
const getCoordinatorResumeTokenSigningKey = `-- name: GetCoordinatorResumeTokenSigningKey :one
SELECT value FROM site_configs WHERE key = 'coordinator_resume_token_signing_key'
`
func (q *sqlQuerier) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) {
row := q.db.QueryRowContext(ctx, getCoordinatorResumeTokenSigningKey)
var value string
err := row.Scan(&value)
return value, err
}
const getDERPMeshKey = `-- name: GetDERPMeshKey :one
SELECT value FROM site_configs WHERE key = 'derp_mesh_key'
`
@@ -6769,6 +6780,16 @@ func (q *sqlQuerier) UpsertApplicationName(ctx context.Context, value string) er
return err
}
const upsertCoordinatorResumeTokenSigningKey = `-- name: UpsertCoordinatorResumeTokenSigningKey :exec
INSERT INTO site_configs (key, value) VALUES ('coordinator_resume_token_signing_key', $1)
ON CONFLICT (key) DO UPDATE set value = $1 WHERE site_configs.key = 'coordinator_resume_token_signing_key'
`
func (q *sqlQuerier) UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error {
_, err := q.db.ExecContext(ctx, upsertCoordinatorResumeTokenSigningKey, value)
return err
}
const upsertDefaultProxy = `-- name: UpsertDefaultProxy :exec
INSERT INTO site_configs (key, value)
VALUES
+7
View File
@@ -71,6 +71,13 @@ SELECT value FROM site_configs WHERE key = 'oauth_signing_key';
INSERT INTO site_configs (key, value) VALUES ('oauth_signing_key', $1)
ON CONFLICT (key) DO UPDATE set value = $1 WHERE site_configs.key = 'oauth_signing_key';
-- name: GetCoordinatorResumeTokenSigningKey :one
SELECT value FROM site_configs WHERE key = 'coordinator_resume_token_signing_key';
-- name: UpsertCoordinatorResumeTokenSigningKey :exec
INSERT INTO site_configs (key, value) VALUES ('coordinator_resume_token_signing_key', $1)
ON CONFLICT (key) DO UPDATE set value = $1 WHERE site_configs.key = 'coordinator_resume_token_signing_key';
-- name: GetHealthSettings :one
SELECT
COALESCE((SELECT value FROM site_configs WHERE key = 'health_settings'), '{}') :: text AS health_settings
+21 -1
View File
@@ -846,6 +846,26 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
return
}
// Accept a resume_token query parameter to use the same peer ID.
var (
peerID = uuid.New()
resumeToken = r.URL.Query().Get("resume_token")
)
if resumeToken != "" {
var err error
peerID, err = api.Options.CoordinatorResumeTokenProvider.VerifyResumeToken(resumeToken)
if err != nil {
httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{
Message: workspacesdk.CoordinateAPIInvalidResumeToken,
Detail: err.Error(),
Validations: []codersdk.ValidationError{
{Field: "resume_token", Detail: workspacesdk.CoordinateAPIInvalidResumeToken},
},
})
return
}
}
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock()
@@ -866,7 +886,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
go httpapi.Heartbeat(ctx, conn)
defer conn.Close(websocket.StatusNormalClosure, "")
err = api.TailnetClientService.ServeClient(ctx, version, wsNetConn, uuid.New(), workspaceAgent.ID)
err = api.TailnetClientService.ServeClient(ctx, version, wsNetConn, peerID, workspaceAgent.ID)
if err != nil && !xerrors.Is(err, io.EOF) && !xerrors.Is(err, context.Canceled) {
_ = conn.Close(websocket.StatusInternalError, err.Error())
return
+145
View File
@@ -18,6 +18,7 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/timestamppb"
"nhooyr.io/websocket"
"tailscale.com/tailcfg"
"cdr.dev/slog"
@@ -40,8 +41,11 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/tailnet"
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestWorkspaceAgent(t *testing.T) {
@@ -509,6 +513,147 @@ func TestWorkspaceAgentClientCoordinate_BadVersion(t *testing.T) {
require.Equal(t, "version", sdkErr.Validations[0].Field)
}
type resumeTokenTestFakeCoordinator struct {
tailnet.Coordinator
lastPeerID uuid.UUID
}
var _ tailnet.Coordinator = &resumeTokenTestFakeCoordinator{}
func (c *resumeTokenTestFakeCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agentID uuid.UUID) error {
c.lastPeerID = id
return c.Coordinator.ServeClient(conn, id, agentID)
}
func (c *resumeTokenTestFakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.CoordinateeAuth) (chan<- *tailnetproto.CoordinateRequest, <-chan *tailnetproto.CoordinateResponse) {
c.lastPeerID = id
return c.Coordinator.Coordinate(ctx, id, name, a)
}
func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clock := quartz.NewMock(t)
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
require.NoError(t, err)
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
coordinator := &resumeTokenTestFakeCoordinator{
Coordinator: tailnet.NewCoordinator(logger),
}
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
Coordinator: coordinator,
CoordinatorResumeTokenProvider: resumeTokenProvider,
})
defer closer.Close()
user := coderdtest.CreateFirstUser(t, client)
// Create a workspace with an agent. No need to connect it since clients can
// still connect to the coordinator while the agent isn't connected.
r := dbfake.WorkspaceBuild(t, api.Database, database.Workspace{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
agentTokenUUID, err := uuid.Parse(r.AgentToken)
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitLong)
agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint
require.NoError(t, err)
// Connect with no resume token, and ensure that the peer ID is set to a
// random value.
coordinator.lastPeerID = uuid.Nil
originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "")
require.NoError(t, err)
originalPeerID := coordinator.lastPeerID
require.NotEqual(t, originalPeerID, uuid.Nil)
// Connect with a valid resume token, and ensure that the peer ID is set to
// the stored value.
clock.Advance(time.Second)
coordinator.lastPeerID = uuid.Nil
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken)
require.NoError(t, err)
require.Equal(t, originalPeerID, coordinator.lastPeerID)
require.NotEqual(t, originalResumeToken, newResumeToken)
// Connect with an invalid resume token, and ensure that the request is
// rejected.
clock.Advance(time.Second)
coordinator.lastPeerID = uuid.Nil
_, err = connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "invalid")
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
require.Len(t, sdkErr.Validations, 1)
require.Equal(t, "resume_token", sdkErr.Validations[0].Field)
require.Equal(t, uuid.Nil, coordinator.lastPeerID)
}
// connectToCoordinatorAndFetchResumeToken connects to the tailnet coordinator
// with a given resume token. It returns an error if the connection is rejected.
// If the connection is accepted, it is immediately closed and no error is
// returned.
func connectToCoordinatorAndFetchResumeToken(ctx context.Context, logger slog.Logger, sdkClient *codersdk.Client, agentID uuid.UUID, resumeToken string) (string, error) {
u, err := sdkClient.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", agentID))
if err != nil {
return "", xerrors.Errorf("parse URL: %w", err)
}
q := u.Query()
q.Set("version", "2.0")
if resumeToken != "" {
q.Set("resume_token", resumeToken)
}
u.RawQuery = q.Encode()
//nolint:bodyclose
wsConn, resp, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{
HTTPHeader: http.Header{
"Coder-Session-Token": []string{sdkClient.SessionToken()},
},
})
if err != nil {
if resp.StatusCode != http.StatusSwitchingProtocols {
err = codersdk.ReadBodyAsError(resp)
}
return "", xerrors.Errorf("websocket dial: %w", err)
}
defer wsConn.Close(websocket.StatusNormalClosure, "done")
// Send a request to the server to ensure that we're plumbed all the way
// through.
rpcClient, err := tailnet.NewDRPCClient(
websocket.NetConn(ctx, wsConn, websocket.MessageBinary),
logger,
)
if err != nil {
return "", xerrors.Errorf("new dRPC client: %w", err)
}
// Send an empty coordination request. This will do nothing on the server,
// but ensures our wrapped coordinator can record the peer ID.
coordinateClient, err := rpcClient.Coordinate(ctx)
if err != nil {
return "", xerrors.Errorf("coordinate: %w", err)
}
err = coordinateClient.Send(&tailnetproto.CoordinateRequest{})
if err != nil {
return "", xerrors.Errorf("send empty coordination request: %w", err)
}
err = coordinateClient.Close()
if err != nil {
return "", xerrors.Errorf("close coordination request: %w", err)
}
// Fetch a resume token.
newResumeToken, err := rpcClient.RefreshResumeToken(ctx, &tailnetproto.RefreshResumeTokenRequest{})
if err != nil {
return "", xerrors.Errorf("fetch resume token: %w", err)
}
return newResumeToken.Token, nil
}
func TestWorkspaceAgentTailnetDirectDisabled(t *testing.T) {
t.Parallel()
+4
View File
@@ -65,6 +65,10 @@ func (t SignedToken) MatchesRequest(req Request) bool {
// two keys.
type SecurityKey [96]byte
func (k SecurityKey) IsZero() bool {
return k == SecurityKey{}
}
func (k SecurityKey) String() string {
return hex.EncodeToString(k[:])
}
+87 -11
View File
@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"slices"
"strings"
"sync"
@@ -24,6 +25,7 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/quartz"
"github.com/coder/retry"
)
@@ -61,6 +63,7 @@ type tailnetAPIConnector struct {
agentID uuid.UUID
coordinateURL string
clock quartz.Clock
dialOptions *websocket.DialOptions
conn tailnetConn
customDialFn func() (proto.DRPCTailnetClient, error)
@@ -68,9 +71,10 @@ type tailnetAPIConnector struct {
clientMu sync.RWMutex
client proto.DRPCTailnetClient
connected chan error
isFirst bool
closed chan struct{}
connected chan error
resumeToken *proto.RefreshResumeTokenResponse
isFirst bool
closed chan struct{}
// Only set to true if we get a response from the server that it doesn't support
// network telemetry.
@@ -78,12 +82,13 @@ type tailnetAPIConnector struct {
}
// Create a new tailnetAPIConnector without running it
func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uuid.UUID, coordinateURL string, dialOptions *websocket.DialOptions) *tailnetAPIConnector {
func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uuid.UUID, coordinateURL string, clock quartz.Clock, dialOptions *websocket.DialOptions) *tailnetAPIConnector {
return &tailnetAPIConnector{
ctx: ctx,
logger: logger,
agentID: agentID,
coordinateURL: coordinateURL,
clock: clock,
dialOptions: dialOptions,
conn: nil,
connected: make(chan error, 1),
@@ -96,7 +101,7 @@ func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uui
func (tac *tailnetAPIConnector) manageGracefulTimeout() {
defer tac.cancelGracefulCtx()
<-tac.ctx.Done()
timer := time.NewTimer(tailnetConnectorGracefulTimeout)
timer := tac.clock.NewTimer(tailnetConnectorGracefulTimeout, "tailnetAPIClient", "gracefulTimeout")
defer timer.Stop()
select {
case <-tac.closed:
@@ -112,6 +117,8 @@ func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) {
go func() {
tac.isFirst = true
defer close(tac.closed)
// Sadly retry doesn't support quartz.Clock yet so this is not
// influenced by the configured clock.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(tac.ctx); {
tailnetClient, err := tac.dial()
if err != nil {
@@ -121,7 +128,7 @@ func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) {
tac.client = tailnetClient
tac.clientMu.Unlock()
tac.logger.Debug(tac.ctx, "obtained tailnet API v2+ client")
tac.coordinateAndDERPMap(tailnetClient)
tac.runConnectorOnce(tailnetClient)
tac.logger.Debug(tac.ctx, "tailnet API v2+ connection lost")
}
}()
@@ -138,8 +145,23 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
return tac.customDialFn()
}
tac.logger.Debug(tac.ctx, "dialing Coder tailnet v2+ API")
u, err := url.Parse(tac.coordinateURL)
if err != nil {
return nil, xerrors.Errorf("parse URL %q: %w", tac.coordinateURL, err)
}
if tac.resumeToken != nil {
q := u.Query()
q.Set("resume_token", tac.resumeToken.Token)
u.RawQuery = q.Encode()
tac.logger.Debug(tac.ctx, "using resume token", slog.F("resume_token", tac.resumeToken))
}
coordinateURL := u.String()
tac.logger.Debug(tac.ctx, "using coordinate URL", slog.F("url", coordinateURL))
// nolint:bodyclose
ws, res, err := websocket.Dial(tac.ctx, tac.coordinateURL, tac.dialOptions)
ws, res, err := websocket.Dial(tac.ctx, coordinateURL, tac.dialOptions)
if tac.isFirst {
if res != nil && slices.Contains(permanentErrorStatuses, res.StatusCode) {
err = codersdk.ReadBodyAsError(res)
@@ -160,8 +182,20 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
close(tac.connected)
}
if err != nil {
bodyErr := codersdk.ReadBodyAsError(res)
var sdkErr *codersdk.Error
if xerrors.As(bodyErr, &sdkErr) {
for _, v := range sdkErr.Validations {
if v.Field == "resume_token" {
// Unset the resume token for the next attempt
tac.logger.Warn(tac.ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt")
tac.resumeToken = nil
return nil, err
}
}
}
if !errors.Is(err, context.Canceled) {
tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", slog.Error(err))
tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", slog.Error(err), slog.F("sdk_err", sdkErr))
}
return nil, err
}
@@ -177,11 +211,11 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
return client, err
}
// coordinateAndDERPMap uses the provided client to coordinate and stream DERP Maps. It is combined
// runConnectorOnce uses the provided client to coordinate and stream DERP Maps. It is combined
// into one function so that a problem with one tears down the other and triggers a retry (if
// appropriate). We multiplex both RPCs over the same websocket, so we want them to share the same
// fate.
func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetClient) {
func (tac *tailnetAPIConnector) runConnectorOnce(client proto.DRPCTailnetClient) {
defer func() {
conn := client.DRPCConn()
closeErr := conn.Close()
@@ -193,14 +227,17 @@ func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetCli
<-conn.Closed()
}
}()
refreshTokenCtx, refreshTokenCancel := context.WithCancel(tac.ctx)
wg := sync.WaitGroup{}
wg.Add(2)
wg.Add(3)
go func() {
defer wg.Done()
tac.coordinate(client)
}()
go func() {
defer wg.Done()
defer refreshTokenCancel()
dErr := tac.derpMap(client)
if dErr != nil && tac.ctx.Err() == nil {
// The main context is still active, meaning that we want the tailnet data plane to stay
@@ -215,6 +252,10 @@ func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetCli
// Note that derpMap() logs it own errors, we don't bother here.
}
}()
go func() {
defer wg.Done()
tac.refreshToken(refreshTokenCtx, client)
}()
wg.Wait()
}
@@ -278,6 +319,41 @@ func (tac *tailnetAPIConnector) derpMap(client proto.DRPCTailnetClient) error {
}
}
func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client proto.DRPCTailnetClient) {
ticker := tac.clock.NewTicker(15*time.Second, "tailnetAPIConnector", "refreshToken")
defer ticker.Stop()
initialCh := make(chan struct{}, 1)
initialCh <- struct{}{}
defer close(initialCh)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
case <-initialCh:
}
attemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
res, err := client.RefreshResumeToken(attemptCtx, &proto.RefreshResumeTokenRequest{})
cancel()
if err != nil {
if ctx.Err() == nil {
tac.logger.Error(tac.ctx, "error refreshing coordinator resume token", slog.Error(err))
}
return
}
tac.logger.Debug(tac.ctx, "refreshed coordinator resume token", slog.F("resume_token", res))
tac.resumeToken = res
dur := res.RefreshIn.AsDuration()
if dur <= 0 {
// A sensible delay to refresh again.
dur = 30 * time.Minute
}
ticker.Reset(dur, "tailnetAPIConnector", "refreshToken", "reset")
}
}
func (tac *tailnetAPIConnector) SendTelemetryEvent(event *proto.TelemetryEvent) {
tac.clientMu.RLock()
// We hold the lock for the entire telemetry request, but this would only block
@@ -14,6 +14,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"nhooyr.io/websocket"
"storj.io/drpc"
"storj.io/drpc/drpcerr"
@@ -28,6 +30,7 @@ import (
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func init() {
@@ -59,6 +62,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
DERPMapUpdateFrequency: time.Millisecond,
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
@@ -78,7 +82,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
fConn := newFakeTailnetConn()
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{})
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{})
uut.runConnector(fConn)
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
@@ -131,7 +135,7 @@ func TestTailnetAPIConnector_UplevelVersion(t *testing.T) {
fConn := newFakeTailnetConn()
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{})
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{})
uut.runConnector(fConn)
err := testutil.RequireRecvCtx(ctx, t, uut.connected)
@@ -142,6 +146,215 @@ func TestTailnetAPIConnector_UplevelVersion(t *testing.T) {
require.NotEmpty(t, sdkErr.Helper)
}
func TestTailnetAPIConnector_ResumeToken(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
agentID := uuid.UUID{0x55}
fCoord := tailnettest.NewFakeCoordinator()
var coord tailnet.Coordinator = fCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
derpMapCh := make(chan *tailcfg.DERPMap)
defer close(derpMapCh)
clock := quartz.NewMock(t)
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
require.NoError(t, err)
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger,
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Millisecond,
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
ResumeTokenProvider: resumeTokenProvider,
})
require.NoError(t, err)
var (
websocketConnCh = make(chan *websocket.Conn, 64)
expectResumeToken = ""
)
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Accept a resume_token query parameter to use the same peer ID. This
// behavior matches the actual client coordinate route.
var (
peerID = uuid.New()
resumeToken = r.URL.Query().Get("resume_token")
)
t.Logf("received resume token: %s", resumeToken)
assert.Equal(t, expectResumeToken, resumeToken)
if resumeToken != "" {
peerID, err = resumeTokenProvider.VerifyResumeToken(resumeToken)
assert.NoError(t, err, "failed to parse resume token")
if err != nil {
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{
Message: CoordinateAPIInvalidResumeToken,
Detail: err.Error(),
Validations: []codersdk.ValidationError{
{Field: "resume_token", Detail: CoordinateAPIInvalidResumeToken},
},
})
return
}
}
sws, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
return
}
testutil.RequireSendCtx(ctx, t, websocketConnCh, sws)
ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
Name: "client",
ID: peerID,
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
})
assert.NoError(t, err)
}))
fConn := newFakeTailnetConn()
newTickerTrap := clock.Trap().NewTicker("tailnetAPIConnector", "refreshToken")
tickerResetTrap := clock.Trap().TickerReset("tailnetAPIConnector", "refreshToken", "reset")
defer newTickerTrap.Close()
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, clock, &websocket.DialOptions{})
uut.runConnector(fConn)
// Fetch first token. We don't need to advance the clock since we use a
// channel with a single item to immediately fetch.
newTickerTrap.MustWait(ctx).Release()
// We call ticker.Reset after each token fetch to apply the refresh duration
// requested by the server.
trappedReset := tickerResetTrap.MustWait(ctx)
trappedReset.Release()
require.NotNil(t, uut.resumeToken)
originalResumeToken := uut.resumeToken.Token
// Fetch second token.
waiter := clock.Advance(trappedReset.Duration)
waiter.MustWait(ctx)
trappedReset = tickerResetTrap.MustWait(ctx)
trappedReset.Release()
require.NotNil(t, uut.resumeToken)
require.NotEqual(t, originalResumeToken, uut.resumeToken.Token)
expectResumeToken = uut.resumeToken.Token
t.Logf("expecting resume token: %s", expectResumeToken)
// Sever the connection and expect it to reconnect with the resume token.
wsConn := testutil.RequireRecvCtx(ctx, t, websocketConnCh)
_ = wsConn.Close(websocket.StatusGoingAway, "test")
// Wait for the resume token to be refreshed.
trappedTicker := newTickerTrap.MustWait(ctx)
// Advance the clock slightly to ensure the new JWT is different.
clock.Advance(time.Second).MustWait(ctx)
trappedTicker.Release()
trappedReset = tickerResetTrap.MustWait(ctx)
trappedReset.Release()
// The resume token should have changed again.
require.NotNil(t, uut.resumeToken)
require.NotEqual(t, expectResumeToken, uut.resumeToken.Token)
}
func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
agentID := uuid.UUID{0x55}
fCoord := tailnettest.NewFakeCoordinator()
var coord tailnet.Coordinator = fCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
derpMapCh := make(chan *tailcfg.DERPMap)
defer close(derpMapCh)
clock := quartz.NewMock(t)
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
require.NoError(t, err)
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger,
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Millisecond,
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
ResumeTokenProvider: resumeTokenProvider,
})
require.NoError(t, err)
var (
websocketConnCh = make(chan *websocket.Conn, 64)
didFail int64
)
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("resume_token") != "" {
atomic.AddInt64(&didFail, 1)
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{
Message: CoordinateAPIInvalidResumeToken,
Validations: []codersdk.ValidationError{
{Field: "resume_token", Detail: CoordinateAPIInvalidResumeToken},
},
})
return
}
sws, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
return
}
testutil.RequireSendCtx(ctx, t, websocketConnCh, sws)
ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
Name: "client",
ID: uuid.New(),
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
})
assert.NoError(t, err)
}))
fConn := newFakeTailnetConn()
newTickerTrap := clock.Trap().NewTicker("tailnetAPIConnector", "refreshToken")
tickerResetTrap := clock.Trap().TickerReset("tailnetAPIConnector", "refreshToken", "reset")
defer newTickerTrap.Close()
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, clock, &websocket.DialOptions{})
uut.runConnector(fConn)
// Wait for the resume token to be fetched for the first time.
newTickerTrap.MustWait(ctx).Release()
trappedReset := tickerResetTrap.MustWait(ctx)
trappedReset.Release()
originalResumeToken := uut.resumeToken.Token
// Sever the connection and expect it to reconnect with the resume token,
// which should fail and cause the client to be disconnected. The client
// should then reconnect with no resume token.
wsConn := testutil.RequireRecvCtx(ctx, t, websocketConnCh)
_ = wsConn.Close(websocket.StatusGoingAway, "test")
// Wait for the resume token to be refreshed, which indicates a successful
// reconnect.
trappedTicker := newTickerTrap.MustWait(ctx)
// Since we failed the initial reconnect and we're definitely reconnected
// now, the stored resume token should now be nil.
require.Nil(t, uut.resumeToken)
trappedTicker.Release()
trappedReset = tickerResetTrap.MustWait(ctx)
trappedReset.Release()
require.NotNil(t, uut.resumeToken)
require.NotEqual(t, originalResumeToken, uut.resumeToken.Token)
// The resume token should have been rejected by the server.
require.EqualValues(t, 1, atomic.LoadInt64(&didFail))
}
func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
@@ -161,8 +374,9 @@ func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) {
DERPMapUpdateFrequency: time.Millisecond,
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {
eventCh <- batch
testutil.RequireSendCtx(ctx, t, eventCh, batch)
},
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
@@ -182,7 +396,7 @@ func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) {
fConn := newFakeTailnetConn()
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{})
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{})
uut.runConnector(fConn)
require.Eventually(t, func() bool {
uut.clientMu.Lock()
@@ -213,6 +427,7 @@ func TestTailnetAPIConnector_TelemetryUnimplemented(t *testing.T) {
logger: logger,
agentID: agentID,
coordinateURL: "",
clock: quartz.NewReal(),
dialOptions: &websocket.DialOptions{},
conn: nil,
connected: make(chan error, 1),
@@ -253,6 +468,7 @@ func TestTailnetAPIConnector_TelemetryNotRecognised(t *testing.T) {
logger: logger,
agentID: agentID,
coordinateURL: "",
clock: quartz.NewReal(),
dialOptions: &websocket.DialOptions{},
conn: nil,
connected: make(chan error, 1),
@@ -301,6 +517,7 @@ func newFakeTailnetConn() *fakeTailnetConn {
type fakeDRPCClient struct {
postTelemetryCalls int64
refreshTokenFn func(context.Context, *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error)
telemetryError error
fakeDRPPCMapStream
}
@@ -339,6 +556,19 @@ func (f *fakeDRPCClient) StreamDERPMaps(_ context.Context, _ *proto.StreamDERPMa
return &f.fakeDRPPCMapStream, nil
}
// RefreshResumeToken implements proto.DRPCTailnetClient.
func (f *fakeDRPCClient) RefreshResumeToken(_ context.Context, _ *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) {
if f.refreshTokenFn != nil {
return f.refreshTokenFn(context.Background(), nil)
}
return &proto.RefreshResumeTokenResponse{
Token: "test",
RefreshIn: durationpb.New(30 * time.Minute),
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
}, nil
}
type fakeDRPCConn struct{}
var _ drpc.Conn = &fakeDRPCConn{}
+7 -2
View File
@@ -22,6 +22,7 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/quartz"
)
// AgentIP is a static IPv6 address with the Tailscale prefix that is used to route
@@ -55,7 +56,11 @@ const (
AgentMinimumListeningPort = 9
)
const AgentAPIMismatchMessage = "Unknown or unsupported API version"
const (
AgentAPIMismatchMessage = "Unknown or unsupported API version"
CoordinateAPIInvalidResumeToken = "Invalid resume token"
)
// AgentIgnoredListeningPorts contains a list of ports to ignore when looking for
// running applications inside a workspace. We want to ignore non-HTTP servers,
@@ -232,7 +237,7 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
q.Add("version", "2.0")
coordinateURL.RawQuery = q.Encode()
connector := newTailnetAPIConnector(ctx, options.Logger, agentID, coordinateURL.String(),
connector := newTailnetAPIConnector(ctx, options.Logger, agentID, coordinateURL.String(), quartz.NewReal(),
&websocket.DialOptions{
HTTPClient: c.client.HTTPClient,
HTTPHeader: headers,
+1
View File
@@ -147,6 +147,7 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
DERPMapUpdateFrequency: api.Options.DERPMapUpdateFrequency,
DERPMapFn: api.AGPL.DERPMap,
NetworkTelemetryHandler: api.AGPL.NetworkTelemetryBatcher.Handler,
ResumeTokenProvider: api.AGPL.CoordinatorResumeTokenProvider,
})
if err != nil {
api.Logger.Fatal(api.ctx, "failed to initialize tailnet client service", slog.Error(err))
+1 -7
View File
@@ -24,13 +24,7 @@ type ClientService struct {
// NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is
// loaded on each processed connection.
func NewClientService(options agpl.ClientServiceOptions) (*ClientService, error) {
s, err := agpl.NewClientService(agpl.ClientServiceOptions{
Logger: options.Logger,
CoordPtr: options.CoordPtr,
DERPMapUpdateFrequency: options.DERPMapUpdateFrequency,
DERPMapFn: options.DERPMapFn,
NetworkTelemetryHandler: options.NetworkTelemetryHandler,
})
s, err := agpl.NewClientService(options)
if err != nil {
return nil, err
}
@@ -177,6 +177,7 @@ func TestDialCoordinator(t *testing.T) {
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
ResumeTokenProvider: agpl.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
+10
View File
@@ -608,6 +608,16 @@ func (c *configMaps) fillPeerDiagnostics(d *PeerDiagnostics, peerID uuid.UUID) {
d.LastWireguardHandshake = ps.LastHandshake
}
func (c *configMaps) knownPeerIDs() []uuid.UUID {
c.L.Lock()
defer c.L.Unlock()
out := make([]uuid.UUID, 0, len(c.peers))
for id := range c.peers {
out = append(out, id)
}
return out
}
func (c *configMaps) peerReadyForHandshakeTimeout(peerID uuid.UUID) {
logger := c.logger.With(slog.F("peer_id", peerID))
logger.Debug(context.Background(), "peer ready for handshake timeout")
+4
View File
@@ -847,6 +847,10 @@ func (c *Conn) GetPeerDiagnostics(peerID uuid.UUID) PeerDiagnostics {
return d
}
func (c *Conn) GetKnownPeerIDs() []uuid.UUID {
return c.configMaps.knownPeerIDs()
}
type listenKey struct {
network string
host string
+2
View File
@@ -630,6 +630,7 @@ func TestRemoteCoordination(t *testing.T) {
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
sC, cC := net.Pipe()
@@ -681,6 +682,7 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
sC, cC := net.Pipe()
File diff suppressed because it is too large Load Diff
+9
View File
@@ -57,6 +57,14 @@ message Node {
repeated string endpoints = 10;
}
message RefreshResumeTokenRequest {}
message RefreshResumeTokenResponse {
string token = 1;
google.protobuf.Duration refresh_in = 2;
google.protobuf.Timestamp expires_at = 3;
}
message CoordinateRequest {
message UpdateSelf {
Node node = 1;
@@ -191,5 +199,6 @@ message TelemetryResponse {}
service Tailnet {
rpc PostTelemetry(TelemetryRequest) returns (TelemetryResponse);
rpc StreamDERPMaps(StreamDERPMapsRequest) returns (stream DERPMap);
rpc RefreshResumeToken(RefreshResumeTokenRequest) returns (RefreshResumeTokenResponse);
rpc Coordinate(stream CoordinateRequest) returns (stream CoordinateResponse);
}
+41 -1
View File
@@ -40,6 +40,7 @@ type DRPCTailnetClient interface {
PostTelemetry(ctx context.Context, in *TelemetryRequest) (*TelemetryResponse, error)
StreamDERPMaps(ctx context.Context, in *StreamDERPMapsRequest) (DRPCTailnet_StreamDERPMapsClient, error)
RefreshResumeToken(ctx context.Context, in *RefreshResumeTokenRequest) (*RefreshResumeTokenResponse, error)
Coordinate(ctx context.Context) (DRPCTailnet_CoordinateClient, error)
}
@@ -102,6 +103,15 @@ func (x *drpcTailnet_StreamDERPMapsClient) RecvMsg(m *DERPMap) error {
return x.MsgRecv(m, drpcEncoding_File_tailnet_proto_tailnet_proto{})
}
func (c *drpcTailnetClient) RefreshResumeToken(ctx context.Context, in *RefreshResumeTokenRequest) (*RefreshResumeTokenResponse, error) {
out := new(RefreshResumeTokenResponse)
err := c.cc.Invoke(ctx, "/coder.tailnet.v2.Tailnet/RefreshResumeToken", drpcEncoding_File_tailnet_proto_tailnet_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
func (c *drpcTailnetClient) Coordinate(ctx context.Context) (DRPCTailnet_CoordinateClient, error) {
stream, err := c.cc.NewStream(ctx, "/coder.tailnet.v2.Tailnet/Coordinate", drpcEncoding_File_tailnet_proto_tailnet_proto{})
if err != nil {
@@ -144,6 +154,7 @@ func (x *drpcTailnet_CoordinateClient) RecvMsg(m *CoordinateResponse) error {
type DRPCTailnetServer interface {
PostTelemetry(context.Context, *TelemetryRequest) (*TelemetryResponse, error)
StreamDERPMaps(*StreamDERPMapsRequest, DRPCTailnet_StreamDERPMapsStream) error
RefreshResumeToken(context.Context, *RefreshResumeTokenRequest) (*RefreshResumeTokenResponse, error)
Coordinate(DRPCTailnet_CoordinateStream) error
}
@@ -157,13 +168,17 @@ func (s *DRPCTailnetUnimplementedServer) StreamDERPMaps(*StreamDERPMapsRequest,
return drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCTailnetUnimplementedServer) RefreshResumeToken(context.Context, *RefreshResumeTokenRequest) (*RefreshResumeTokenResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCTailnetUnimplementedServer) Coordinate(DRPCTailnet_CoordinateStream) error {
return drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
type DRPCTailnetDescription struct{}
func (DRPCTailnetDescription) NumMethods() int { return 3 }
func (DRPCTailnetDescription) NumMethods() int { return 4 }
func (DRPCTailnetDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
switch n {
@@ -186,6 +201,15 @@ func (DRPCTailnetDescription) Method(n int) (string, drpc.Encoding, drpc.Receive
)
}, DRPCTailnetServer.StreamDERPMaps, true
case 2:
return "/coder.tailnet.v2.Tailnet/RefreshResumeToken", drpcEncoding_File_tailnet_proto_tailnet_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCTailnetServer).
RefreshResumeToken(
ctx,
in1.(*RefreshResumeTokenRequest),
)
}, DRPCTailnetServer.RefreshResumeToken, true
case 3:
return "/coder.tailnet.v2.Tailnet/Coordinate", drpcEncoding_File_tailnet_proto_tailnet_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return nil, srv.(DRPCTailnetServer).
@@ -231,6 +255,22 @@ func (x *drpcTailnet_StreamDERPMapsStream) Send(m *DERPMap) error {
return x.MsgSend(m, drpcEncoding_File_tailnet_proto_tailnet_proto{})
}
type DRPCTailnet_RefreshResumeTokenStream interface {
drpc.Stream
SendAndClose(*RefreshResumeTokenResponse) error
}
type drpcTailnet_RefreshResumeTokenStream struct {
drpc.Stream
}
func (x *drpcTailnet_RefreshResumeTokenStream) SendAndClose(m *RefreshResumeTokenResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_tailnet_proto_tailnet_proto{}); err != nil {
return err
}
return x.CloseSend()
}
type DRPCTailnet_CoordinateStream interface {
drpc.Stream
Send(*CoordinateResponse) error
+196
View File
@@ -0,0 +1,196 @@
package tailnet
import (
"context"
"crypto/rand"
"database/sql"
"encoding/hex"
"encoding/json"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/google/uuid"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/quartz"
)
const (
DefaultResumeTokenExpiry = 24 * time.Hour
resumeTokenSigningAlgorithm = jose.HS512
)
// resumeTokenSigningKeyID is a fixed key ID for the resume token signing key.
// If/when we add support for multiple keys (e.g. key rotation), this will move
// to the database instead.
var resumeTokenSigningKeyID = uuid.MustParse("97166747-9309-4d7f-9071-a230e257c2a4")
// NewInsecureTestResumeTokenProvider returns a ResumeTokenProvider that uses a
// random key with short expiry for testing purposes. If any errors occur while
// generating the key, the function panics.
func NewInsecureTestResumeTokenProvider() ResumeTokenProvider {
key, err := GenerateResumeTokenSigningKey()
if err != nil {
panic(err)
}
return NewResumeTokenKeyProvider(key, quartz.NewReal(), time.Hour)
}
type ResumeTokenProvider interface {
GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error)
VerifyResumeToken(token string) (uuid.UUID, error)
}
type ResumeTokenSigningKey [64]byte
func GenerateResumeTokenSigningKey() (ResumeTokenSigningKey, error) {
var key ResumeTokenSigningKey
_, err := rand.Read(key[:])
if err != nil {
return key, xerrors.Errorf("generate random key: %w", err)
}
return key, nil
}
type ResumeTokenSigningKeyDatabaseStore interface {
GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error)
UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, key string) error
}
// ResumeTokenSigningKeyFromDatabase retrieves the coordinator resume token
// signing key from the database. If the key is not found, a new key is
// generated and inserted into the database.
func ResumeTokenSigningKeyFromDatabase(ctx context.Context, db ResumeTokenSigningKeyDatabaseStore) (ResumeTokenSigningKey, error) {
var resumeTokenKey ResumeTokenSigningKey
resumeTokenKeyStr, err := db.GetCoordinatorResumeTokenSigningKey(ctx)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return resumeTokenKey, xerrors.Errorf("get coordinator resume token key: %w", err)
}
if decoded, err := hex.DecodeString(resumeTokenKeyStr); err != nil || len(decoded) != len(resumeTokenKey) {
newKey, err := GenerateResumeTokenSigningKey()
if err != nil {
return resumeTokenKey, xerrors.Errorf("generate fresh coordinator resume token key: %w", err)
}
resumeTokenKeyStr = hex.EncodeToString(newKey[:])
err = db.UpsertCoordinatorResumeTokenSigningKey(ctx, resumeTokenKeyStr)
if err != nil {
return resumeTokenKey, xerrors.Errorf("insert freshly generated coordinator resume token key to database: %w", err)
}
}
resumeTokenKeyBytes, err := hex.DecodeString(resumeTokenKeyStr)
if err != nil {
return resumeTokenKey, xerrors.Errorf("decode coordinator resume token key from database: %w", err)
}
if len(resumeTokenKeyBytes) != len(resumeTokenKey) {
return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is not the correct length, expect %d got %d", len(resumeTokenKey), len(resumeTokenKeyBytes))
}
copy(resumeTokenKey[:], resumeTokenKeyBytes)
if resumeTokenKey == [64]byte{} {
return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is empty")
}
return resumeTokenKey, nil
}
type ResumeTokenKeyProvider struct {
key ResumeTokenSigningKey
clock quartz.Clock
expiry time.Duration
}
func NewResumeTokenKeyProvider(key ResumeTokenSigningKey, clock quartz.Clock, expiry time.Duration) ResumeTokenProvider {
if expiry <= 0 {
expiry = DefaultResumeTokenExpiry
}
return ResumeTokenKeyProvider{
key: key,
clock: clock,
expiry: DefaultResumeTokenExpiry,
}
}
type resumeTokenPayload struct {
PeerID uuid.UUID `json:"sub"`
Expiry int64 `json:"exp"`
}
func (p ResumeTokenKeyProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) {
exp := p.clock.Now().Add(p.expiry)
payload := resumeTokenPayload{
PeerID: peerID,
Expiry: exp.Unix(),
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, xerrors.Errorf("marshal payload to JSON: %w", err)
}
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: resumeTokenSigningAlgorithm,
Key: p.key[:],
}, &jose.SignerOptions{
ExtraHeaders: map[jose.HeaderKey]interface{}{
"kid": resumeTokenSigningKeyID.String(),
},
})
if err != nil {
return nil, xerrors.Errorf("create signer: %w", err)
}
signedObject, err := signer.Sign(payloadBytes)
if err != nil {
return nil, xerrors.Errorf("sign payload: %w", err)
}
serialized, err := signedObject.CompactSerialize()
if err != nil {
return nil, xerrors.Errorf("serialize JWS: %w", err)
}
return &proto.RefreshResumeTokenResponse{
Token: serialized,
RefreshIn: durationpb.New(p.expiry / 2),
ExpiresAt: timestamppb.New(exp),
}, nil
}
// VerifyResumeToken parses a signed tailnet resume token with the given key and
// returns the payload. If the token is invalid or expired, an error is
// returned.
func (p ResumeTokenKeyProvider) VerifyResumeToken(str string) (uuid.UUID, error) {
object, err := jose.ParseSigned(str)
if err != nil {
return uuid.Nil, xerrors.Errorf("parse JWS: %w", err)
}
if len(object.Signatures) != 1 {
return uuid.Nil, xerrors.New("expected 1 signature")
}
if object.Signatures[0].Header.Algorithm != string(resumeTokenSigningAlgorithm) {
return uuid.Nil, xerrors.Errorf("expected token signing algorithm to be %q, got %q", resumeTokenSigningAlgorithm, object.Signatures[0].Header.Algorithm)
}
if object.Signatures[0].Header.KeyID != resumeTokenSigningKeyID.String() {
return uuid.Nil, xerrors.Errorf("expected token key ID to be %q, got %q", resumeTokenSigningKeyID, object.Signatures[0].Header.KeyID)
}
output, err := object.Verify(p.key[:])
if err != nil {
return uuid.Nil, xerrors.Errorf("verify JWS: %w", err)
}
var tok resumeTokenPayload
err = json.Unmarshal(output, &tok)
if err != nil {
return uuid.Nil, xerrors.Errorf("unmarshal payload: %w", err)
}
exp := time.Unix(tok.Expiry, 0)
if exp.Before(p.clock.Now()) {
return uuid.Nil, xerrors.New("signed resume token expired")
}
return tok.PeerID, nil
}
+181
View File
@@ -0,0 +1,181 @@
package tailnet_test
import (
"context"
"encoding/hex"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {
t.Parallel()
assertRandomKey := func(t *testing.T, key tailnet.ResumeTokenSigningKey) {
t.Helper()
assert.NotEqual(t, tailnet.ResumeTokenSigningKey{}, key, "key should not be empty")
assert.NotEqualValues(t, [64]byte{1}, key, "key should not be all 1s")
}
t.Run("GenerateRetrieve", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
key1, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.NoError(t, err)
assertRandomKey(t, key1)
key2, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.NoError(t, err)
require.Equal(t, key1, key2, "keys should not be different")
})
t.Run("GetError", func(t *testing.T) {
t.Parallel()
db := dbmock.NewMockStore(gomock.NewController(t))
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("", assert.AnError)
ctx := testutil.Context(t, testutil.WaitShort)
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.ErrorIs(t, err, assert.AnError)
})
t.Run("UpsertError", func(t *testing.T) {
t.Parallel()
db := dbmock.NewMockStore(gomock.NewController(t))
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("", nil)
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(assert.AnError)
ctx := testutil.Context(t, testutil.WaitShort)
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.ErrorIs(t, err, assert.AnError)
})
t.Run("DecodeErrorShouldRegenerate", func(t *testing.T) {
t.Parallel()
db := dbmock.NewMockStore(gomock.NewController(t))
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("invalid", nil)
var storedKey tailnet.ResumeTokenSigningKey
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Do(func(_ context.Context, value string) error {
keyBytes, err := hex.DecodeString(value)
require.NoError(t, err)
require.Len(t, keyBytes, len(storedKey))
copy(storedKey[:], keyBytes)
return nil
})
ctx := testutil.Context(t, testutil.WaitShort)
key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.NoError(t, err)
assertRandomKey(t, key)
require.Equal(t, storedKey, key, "key should match stored value")
})
t.Run("LengthErrorShouldRegenerate", func(t *testing.T) {
t.Parallel()
db := dbmock.NewMockStore(gomock.NewController(t))
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("deadbeef", nil)
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(nil)
ctx := testutil.Context(t, testutil.WaitShort)
key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.NoError(t, err)
assertRandomKey(t, key)
})
t.Run("EmptyError", func(t *testing.T) {
t.Parallel()
db := dbmock.NewMockStore(gomock.NewController(t))
emptyKey := hex.EncodeToString(make([]byte, 64))
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return(emptyKey, nil)
ctx := testutil.Context(t, testutil.WaitShort)
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.ErrorContains(t, err, "is empty")
})
}
func TestResumeTokenKeyProvider(t *testing.T) {
t.Parallel()
key, err := tailnet.GenerateResumeTokenSigningKey()
require.NoError(t, err)
t.Run("OK", func(t *testing.T) {
t.Parallel()
id := uuid.New()
clock := quartz.NewMock(t)
provider := tailnet.NewResumeTokenKeyProvider(key, clock, tailnet.DefaultResumeTokenExpiry)
token, err := provider.GenerateResumeToken(id)
require.NoError(t, err)
require.NotNil(t, token)
require.NotEmpty(t, token.Token)
require.Equal(t, tailnet.DefaultResumeTokenExpiry/2, token.RefreshIn.AsDuration())
require.WithinDuration(t, clock.Now().Add(tailnet.DefaultResumeTokenExpiry), token.ExpiresAt.AsTime(), time.Second)
gotID, err := provider.VerifyResumeToken(token.Token)
require.NoError(t, err)
require.Equal(t, id, gotID)
})
t.Run("Expired", func(t *testing.T) {
t.Parallel()
id := uuid.New()
clock := quartz.NewMock(t)
provider := tailnet.NewResumeTokenKeyProvider(key, clock, tailnet.DefaultResumeTokenExpiry)
token, err := provider.GenerateResumeToken(id)
require.NoError(t, err)
require.NotNil(t, token)
require.NotEmpty(t, token.Token)
require.Equal(t, tailnet.DefaultResumeTokenExpiry/2, token.RefreshIn.AsDuration())
require.WithinDuration(t, clock.Now().Add(tailnet.DefaultResumeTokenExpiry), token.ExpiresAt.AsTime(), time.Second)
// Advance time past expiry
_ = clock.Advance(tailnet.DefaultResumeTokenExpiry + time.Second)
_, err = provider.VerifyResumeToken(token.Token)
require.ErrorContains(t, err, "expired")
})
t.Run("InvalidToken", func(t *testing.T) {
t.Parallel()
provider := tailnet.NewResumeTokenKeyProvider(key, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
_, err := provider.VerifyResumeToken("invalid")
require.ErrorContains(t, err, "parse JWS")
})
t.Run("VerifyError", func(t *testing.T) {
t.Parallel()
// Generate a resume token with a different key
otherKey, err := tailnet.GenerateResumeTokenSigningKey()
require.NoError(t, err)
otherProvider := tailnet.NewResumeTokenKeyProvider(otherKey, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
token, err := otherProvider.GenerateResumeToken(uuid.New())
require.NoError(t, err)
provider := tailnet.NewResumeTokenKeyProvider(key, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
_, err = provider.VerifyResumeToken(token.Token)
require.ErrorContains(t, err, "verify JWS")
})
}
+16
View File
@@ -43,6 +43,7 @@ type ClientServiceOptions struct {
DERPMapUpdateFrequency time.Duration
DERPMapFn func() *tailcfg.DERPMap
NetworkTelemetryHandler func(batch []*proto.TelemetryEvent)
ResumeTokenProvider ResumeTokenProvider
}
// ClientService is a tailnet coordination service that accepts a connection and version from a
@@ -66,6 +67,7 @@ func NewClientService(options ClientServiceOptions) (
DerpMapUpdateFrequency: options.DERPMapUpdateFrequency,
DerpMapFn: options.DERPMapFn,
NetworkTelemetryHandler: options.NetworkTelemetryHandler,
ResumeTokenProvider: options.ResumeTokenProvider,
}
err := proto.DRPCRegisterTailnet(mux, drpcService)
if err != nil {
@@ -127,6 +129,7 @@ type DRPCService struct {
DerpMapUpdateFrequency time.Duration
DerpMapFn func() *tailcfg.DERPMap
NetworkTelemetryHandler func(batch []*proto.TelemetryEvent)
ResumeTokenProvider ResumeTokenProvider
}
func (s *DRPCService) PostTelemetry(_ context.Context, req *proto.TelemetryRequest) (*proto.TelemetryResponse, error) {
@@ -167,6 +170,19 @@ func (s *DRPCService) StreamDERPMaps(_ *proto.StreamDERPMapsRequest, stream prot
}
}
func (s *DRPCService) RefreshResumeToken(ctx context.Context, _ *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) {
streamID, ok := ctx.Value(streamIDContextKey{}).(StreamID)
if !ok {
return nil, xerrors.New("no Stream ID")
}
res, err := s.ResumeTokenProvider.GenerateResumeToken(streamID.ID)
if err != nil {
return nil, xerrors.Errorf("generate resume token: %w", err)
}
return res, nil
}
func (s *DRPCService) Coordinate(stream proto.DRPCTailnet_CoordinateStream) error {
ctx := stream.Context()
streamID, ok := ctx.Value(streamIDContextKey{}).(StreamID)
+2
View File
@@ -40,6 +40,7 @@ func TestClientService_ServeClient_V2(t *testing.T) {
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {
telemetryEvents <- batch
},
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
@@ -144,6 +145,7 @@ func TestClientService_ServeClient_V1(t *testing.T) {
DERPMapUpdateFrequency: 0,
DERPMapFn: nil,
NetworkTelemetryHandler: nil,
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
+1
View File
@@ -181,6 +181,7 @@ func (o SimpleServerOptions) Router(t *testing.T, logger slog.Logger) *chi.Mux {
}
},
NetworkTelemetryHandler: func(batch []*tailnetproto.TelemetryEvent) {},
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
+13
View File
@@ -23,6 +23,9 @@ func RequireRecvCtx[A any](ctx context.Context, t testing.TB, c <-chan A) (a A)
}
}
// NOTE: no AssertRecvCtx because it'd be bad if we returned a default value on
// the cases it times out.
func RequireSendCtx[A any](ctx context.Context, t testing.TB, c chan<- A, a A) {
t.Helper()
select {
@@ -32,3 +35,13 @@ func RequireSendCtx[A any](ctx context.Context, t testing.TB, c chan<- A, a A) {
// OK!
}
}
func AssertSendCtx[A any](ctx context.Context, t testing.TB, c chan<- A, a A) {
t.Helper()
select {
case <-ctx.Done():
t.Error("timeout")
case c <- a:
// OK!
}
}