feat: add resume support to coordinator connections (#14234)
This commit is contained in:
+13
-4
@@ -56,6 +56,7 @@ import (
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/retry"
|
||||
"github.com/coder/serpent"
|
||||
"github.com/coder/wgtunnel/tunnelsdk"
|
||||
@@ -791,18 +792,26 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
}
|
||||
}
|
||||
|
||||
keyBytes, err := hex.DecodeString(oauthSigningKeyStr)
|
||||
oauthKeyBytes, err := hex.DecodeString(oauthSigningKeyStr)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("decode oauth signing key from database: %w", err)
|
||||
}
|
||||
if len(keyBytes) != len(options.OAuthSigningKey) {
|
||||
return xerrors.Errorf("oauth signing key in database is not the correct length, expect %d got %d", len(options.OAuthSigningKey), len(keyBytes))
|
||||
if len(oauthKeyBytes) != len(options.OAuthSigningKey) {
|
||||
return xerrors.Errorf("oauth signing key in database is not the correct length, expect %d got %d", len(options.OAuthSigningKey), len(oauthKeyBytes))
|
||||
}
|
||||
copy(options.OAuthSigningKey[:], keyBytes)
|
||||
copy(options.OAuthSigningKey[:], oauthKeyBytes)
|
||||
if options.OAuthSigningKey == [32]byte{} {
|
||||
return xerrors.Errorf("oauth signing key in database is empty")
|
||||
}
|
||||
|
||||
// Read the coordinator resume token signing key from the
|
||||
// database.
|
||||
resumeTokenKey, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, tx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get coordinator resume token key from database: %w", err)
|
||||
}
|
||||
options.CoordinatorResumeTokenProvider = tailnet.NewResumeTokenKeyProvider(resumeTokenKey, quartz.NewReal(), tailnet.DefaultResumeTokenExpiry)
|
||||
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -182,6 +182,9 @@ type Options struct {
|
||||
// AppSecurityKey is the crypto key used to sign and encrypt tokens related to
|
||||
// workspace applications. It consists of both a signing and encryption key.
|
||||
AppSecurityKey workspaceapps.SecurityKey
|
||||
// CoordinatorResumeTokenProvider is used to provide and validate resume
|
||||
// tokens issued by and passed to the coordinator DRPC API.
|
||||
CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider
|
||||
|
||||
HealthcheckFunc func(ctx context.Context, apiKey string) *healthsdk.HealthcheckReport
|
||||
HealthcheckTimeout time.Duration
|
||||
@@ -584,12 +587,16 @@ func New(options *Options) *API {
|
||||
api.Options.NetworkTelemetryBatchMaxSize,
|
||||
api.handleNetworkTelemetry,
|
||||
)
|
||||
if options.CoordinatorResumeTokenProvider == nil {
|
||||
panic("CoordinatorResumeTokenProvider is nil")
|
||||
}
|
||||
api.TailnetClientService, err = tailnet.NewClientService(tailnet.ClientServiceOptions{
|
||||
Logger: api.Logger.Named("tailnetclient"),
|
||||
CoordPtr: &api.TailnetCoordinator,
|
||||
DERPMapUpdateFrequency: api.Options.DERPMapUpdateFrequency,
|
||||
DERPMapFn: api.DERPMap,
|
||||
NetworkTelemetryHandler: api.NetworkTelemetryBatcher.Handler,
|
||||
ResumeTokenProvider: api.Options.CoordinatorResumeTokenProvider,
|
||||
})
|
||||
if err != nil {
|
||||
api.Logger.Fatal(api.ctx, "failed to initialize tailnet client service", slog.Error(err))
|
||||
@@ -614,6 +621,9 @@ func New(options *Options) *API {
|
||||
options.WorkspaceAppsStatsCollectorOptions.Reporter = api.statsReporter
|
||||
}
|
||||
|
||||
if options.AppSecurityKey.IsZero() {
|
||||
api.Logger.Fatal(api.ctx, "app security key cannot be zero")
|
||||
}
|
||||
api.workspaceAppServer = &workspaceapps.Server{
|
||||
Logger: workspaceAppsLogger,
|
||||
|
||||
|
||||
@@ -96,25 +96,26 @@ type Options struct {
|
||||
// AccessURL denotes a custom access URL. By default we use the httptest
|
||||
// server's URL. Setting this may result in unexpected behavior (especially
|
||||
// with running agents).
|
||||
AccessURL *url.URL
|
||||
AppHostname string
|
||||
AWSCertificates awsidentity.Certificates
|
||||
Authorizer rbac.Authorizer
|
||||
AzureCertificates x509.VerifyOptions
|
||||
GithubOAuth2Config *coderd.GithubOAuth2Config
|
||||
RealIPConfig *httpmw.RealIPConfig
|
||||
OIDCConfig *coderd.OIDCConfig
|
||||
GoogleTokenValidator *idtoken.Validator
|
||||
SSHKeygenAlgorithm gitsshkey.Algorithm
|
||||
AutobuildTicker <-chan time.Time
|
||||
AutobuildStats chan<- autobuild.Stats
|
||||
Auditor audit.Auditor
|
||||
TLSCertificates []tls.Certificate
|
||||
ExternalAuthConfigs []*externalauth.Config
|
||||
TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error
|
||||
RefreshEntitlements func(ctx context.Context) error
|
||||
TemplateScheduleStore schedule.TemplateScheduleStore
|
||||
Coordinator tailnet.Coordinator
|
||||
AccessURL *url.URL
|
||||
AppHostname string
|
||||
AWSCertificates awsidentity.Certificates
|
||||
Authorizer rbac.Authorizer
|
||||
AzureCertificates x509.VerifyOptions
|
||||
GithubOAuth2Config *coderd.GithubOAuth2Config
|
||||
RealIPConfig *httpmw.RealIPConfig
|
||||
OIDCConfig *coderd.OIDCConfig
|
||||
GoogleTokenValidator *idtoken.Validator
|
||||
SSHKeygenAlgorithm gitsshkey.Algorithm
|
||||
AutobuildTicker <-chan time.Time
|
||||
AutobuildStats chan<- autobuild.Stats
|
||||
Auditor audit.Auditor
|
||||
TLSCertificates []tls.Certificate
|
||||
ExternalAuthConfigs []*externalauth.Config
|
||||
TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error
|
||||
RefreshEntitlements func(ctx context.Context) error
|
||||
TemplateScheduleStore schedule.TemplateScheduleStore
|
||||
Coordinator tailnet.Coordinator
|
||||
CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider
|
||||
|
||||
HealthcheckFunc func(ctx context.Context, apiKey string) *healthsdk.HealthcheckReport
|
||||
HealthcheckTimeout time.Duration
|
||||
@@ -240,6 +241,9 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
|
||||
if options.Database == nil {
|
||||
options.Database, options.Pubsub = dbtestutil.NewDB(t)
|
||||
}
|
||||
if options.CoordinatorResumeTokenProvider == nil {
|
||||
options.CoordinatorResumeTokenProvider = tailnet.NewInsecureTestResumeTokenProvider()
|
||||
}
|
||||
|
||||
if options.NotificationsEnqueuer == nil {
|
||||
options.NotificationsEnqueuer = new(testutil.FakeNotificationsEnqueuer)
|
||||
@@ -492,6 +496,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
|
||||
TailnetCoordinator: options.Coordinator,
|
||||
BaseDERPMap: derpMap,
|
||||
DERPMapUpdateFrequency: 150 * time.Millisecond,
|
||||
CoordinatorResumeTokenProvider: options.CoordinatorResumeTokenProvider,
|
||||
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
|
||||
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
|
||||
DeploymentValues: options.DeploymentValues,
|
||||
|
||||
@@ -1332,7 +1332,9 @@ func (q *querier) GetAnnouncementBanners(ctx context.Context) (string, error) {
|
||||
}
|
||||
|
||||
func (q *querier) GetAppSecurityKey(ctx context.Context) (string, error) {
|
||||
// No authz checks
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return q.db.GetAppSecurityKey(ctx)
|
||||
}
|
||||
|
||||
@@ -1364,6 +1366,13 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI
|
||||
return q.db.GetAuthorizationUserRoles(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return q.db.GetCoordinatorResumeTokenSigningKey(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
@@ -3792,7 +3801,9 @@ func (q *querier) UpsertAnnouncementBanners(ctx context.Context, value string) e
|
||||
}
|
||||
|
||||
func (q *querier) UpsertAppSecurityKey(ctx context.Context, data string) error {
|
||||
// No authz checks as this is done during startup
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertAppSecurityKey(ctx, data)
|
||||
}
|
||||
|
||||
@@ -3803,6 +3814,13 @@ func (q *querier) UpsertApplicationName(ctx context.Context, value string) error
|
||||
return q.db.UpsertApplicationName(ctx, value)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertCoordinatorResumeTokenSigningKey(ctx, value)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
|
||||
@@ -2531,10 +2531,10 @@ func (s *MethodTestSuite) TestSystemFunctions() {
|
||||
check.Args(int32(0)).Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetAppSecurityKey", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args().Asserts()
|
||||
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("UpsertAppSecurityKey", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args("").Asserts()
|
||||
check.Args("foo").Asserts(rbac.ResourceSystem, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetApplicationName", s.Subtest(func(db database.Store, check *expects) {
|
||||
db.UpsertApplicationName(context.Background(), "foo")
|
||||
@@ -2574,6 +2574,13 @@ func (s *MethodTestSuite) TestSystemFunctions() {
|
||||
db.UpsertOAuthSigningKey(context.Background(), "foo")
|
||||
check.Args().Asserts(rbac.ResourceSystem, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("UpsertCoordinatorResumeTokenSigningKey", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args("foo").Asserts(rbac.ResourceSystem, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetCoordinatorResumeTokenSigningKey", s.Subtest(func(db database.Store, check *expects) {
|
||||
db.UpsertCoordinatorResumeTokenSigningKey(context.Background(), "foo")
|
||||
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("InsertMissingGroups", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args(database.InsertMissingGroupsParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate).Errors(errMatchAny)
|
||||
}))
|
||||
|
||||
@@ -196,20 +196,21 @@ type data struct {
|
||||
customRoles []database.CustomRole
|
||||
// Locks is a map of lock names. Any keys within the map are currently
|
||||
// locked.
|
||||
locks map[int64]struct{}
|
||||
deploymentID string
|
||||
derpMeshKey string
|
||||
lastUpdateCheck []byte
|
||||
announcementBanners []byte
|
||||
healthSettings []byte
|
||||
notificationsSettings []byte
|
||||
applicationName string
|
||||
logoURL string
|
||||
appSecurityKey string
|
||||
oauthSigningKey string
|
||||
lastLicenseID int32
|
||||
defaultProxyDisplayName string
|
||||
defaultProxyIconURL string
|
||||
locks map[int64]struct{}
|
||||
deploymentID string
|
||||
derpMeshKey string
|
||||
lastUpdateCheck []byte
|
||||
announcementBanners []byte
|
||||
healthSettings []byte
|
||||
notificationsSettings []byte
|
||||
applicationName string
|
||||
logoURL string
|
||||
appSecurityKey string
|
||||
oauthSigningKey string
|
||||
coordinatorResumeTokenSigningKey string
|
||||
lastLicenseID int32
|
||||
defaultProxyDisplayName string
|
||||
defaultProxyIconURL string
|
||||
}
|
||||
|
||||
func validateDatabaseTypeWithValid(v reflect.Value) (handled bool, err error) {
|
||||
@@ -2222,6 +2223,15 @@ func (q *FakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetCoordinatorResumeTokenSigningKey(_ context.Context) (string, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
if q.coordinatorResumeTokenSigningKey == "" {
|
||||
return "", sql.ErrNoRows
|
||||
}
|
||||
return q.coordinatorResumeTokenSigningKey, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
@@ -8942,6 +8952,14 @@ func (q *FakeQuerier) UpsertApplicationName(_ context.Context, data string) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) UpsertCoordinatorResumeTokenSigningKey(_ context.Context, value string) error {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
q.coordinatorResumeTokenSigningKey = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) UpsertDefaultProxy(_ context.Context, arg database.UpsertDefaultProxyParams) error {
|
||||
q.defaultProxyDisplayName = arg.DisplayName
|
||||
q.defaultProxyIconURL = arg.IconUrl
|
||||
|
||||
@@ -529,6 +529,13 @@ func (m metricsStore) GetAuthorizationUserRoles(ctx context.Context, userID uuid
|
||||
return row, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetCoordinatorResumeTokenSigningKey(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetCoordinatorResumeTokenSigningKey").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m metricsStore) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetDBCryptKeys(ctx)
|
||||
@@ -2377,6 +2384,13 @@ func (m metricsStore) UpsertApplicationName(ctx context.Context, value string) e
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m metricsStore) UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertCoordinatorResumeTokenSigningKey(ctx, value)
|
||||
m.queryLatencies.WithLabelValues("UpsertCoordinatorResumeTokenSigningKey").Observe(time.Since(start).Seconds())
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m metricsStore) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertDefaultProxy(ctx, arg)
|
||||
|
||||
@@ -1029,6 +1029,21 @@ func (mr *MockStoreMockRecorder) GetAuthorizedWorkspaces(arg0, arg1, arg2 any) *
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedWorkspaces", reflect.TypeOf((*MockStore)(nil).GetAuthorizedWorkspaces), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// GetCoordinatorResumeTokenSigningKey mocks base method.
|
||||
func (m *MockStore) GetCoordinatorResumeTokenSigningKey(arg0 context.Context) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetCoordinatorResumeTokenSigningKey", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetCoordinatorResumeTokenSigningKey indicates an expected call of GetCoordinatorResumeTokenSigningKey.
|
||||
func (mr *MockStoreMockRecorder) GetCoordinatorResumeTokenSigningKey(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCoordinatorResumeTokenSigningKey", reflect.TypeOf((*MockStore)(nil).GetCoordinatorResumeTokenSigningKey), arg0)
|
||||
}
|
||||
|
||||
// GetDBCryptKeys mocks base method.
|
||||
func (m *MockStore) GetDBCryptKeys(arg0 context.Context) ([]database.DBCryptKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4994,6 +5009,20 @@ func (mr *MockStoreMockRecorder) UpsertApplicationName(arg0, arg1 any) *gomock.C
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertApplicationName", reflect.TypeOf((*MockStore)(nil).UpsertApplicationName), arg0, arg1)
|
||||
}
|
||||
|
||||
// UpsertCoordinatorResumeTokenSigningKey mocks base method.
|
||||
func (m *MockStore) UpsertCoordinatorResumeTokenSigningKey(arg0 context.Context, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertCoordinatorResumeTokenSigningKey", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertCoordinatorResumeTokenSigningKey indicates an expected call of UpsertCoordinatorResumeTokenSigningKey.
|
||||
func (mr *MockStoreMockRecorder) UpsertCoordinatorResumeTokenSigningKey(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertCoordinatorResumeTokenSigningKey", reflect.TypeOf((*MockStore)(nil).UpsertCoordinatorResumeTokenSigningKey), arg0, arg1)
|
||||
}
|
||||
|
||||
// UpsertDefaultProxy mocks base method.
|
||||
func (m *MockStore) UpsertDefaultProxy(arg0 context.Context, arg1 database.UpsertDefaultProxyParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -128,6 +128,7 @@ type sqlcQuerier interface {
|
||||
// This function returns roles for authorization purposes. Implied member roles
|
||||
// are included.
|
||||
GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error)
|
||||
GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error)
|
||||
GetDBCryptKeys(ctx context.Context) ([]DBCryptKey, error)
|
||||
GetDERPMeshKey(ctx context.Context) (string, error)
|
||||
GetDefaultOrganization(ctx context.Context) (Organization, error)
|
||||
@@ -463,6 +464,7 @@ type sqlcQuerier interface {
|
||||
UpsertAnnouncementBanners(ctx context.Context, value string) error
|
||||
UpsertAppSecurityKey(ctx context.Context, value string) error
|
||||
UpsertApplicationName(ctx context.Context, value string) error
|
||||
UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error
|
||||
// The default proxy is implied and not actually stored in the database.
|
||||
// So we need to store it's configuration here for display purposes.
|
||||
// The functional values are immutable and controlled implicitly.
|
||||
|
||||
@@ -6624,6 +6624,17 @@ func (q *sqlQuerier) GetApplicationName(ctx context.Context) (string, error) {
|
||||
return value, err
|
||||
}
|
||||
|
||||
const getCoordinatorResumeTokenSigningKey = `-- name: GetCoordinatorResumeTokenSigningKey :one
|
||||
SELECT value FROM site_configs WHERE key = 'coordinator_resume_token_signing_key'
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) {
|
||||
row := q.db.QueryRowContext(ctx, getCoordinatorResumeTokenSigningKey)
|
||||
var value string
|
||||
err := row.Scan(&value)
|
||||
return value, err
|
||||
}
|
||||
|
||||
const getDERPMeshKey = `-- name: GetDERPMeshKey :one
|
||||
SELECT value FROM site_configs WHERE key = 'derp_mesh_key'
|
||||
`
|
||||
@@ -6769,6 +6780,16 @@ func (q *sqlQuerier) UpsertApplicationName(ctx context.Context, value string) er
|
||||
return err
|
||||
}
|
||||
|
||||
const upsertCoordinatorResumeTokenSigningKey = `-- name: UpsertCoordinatorResumeTokenSigningKey :exec
|
||||
INSERT INTO site_configs (key, value) VALUES ('coordinator_resume_token_signing_key', $1)
|
||||
ON CONFLICT (key) DO UPDATE set value = $1 WHERE site_configs.key = 'coordinator_resume_token_signing_key'
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error {
|
||||
_, err := q.db.ExecContext(ctx, upsertCoordinatorResumeTokenSigningKey, value)
|
||||
return err
|
||||
}
|
||||
|
||||
const upsertDefaultProxy = `-- name: UpsertDefaultProxy :exec
|
||||
INSERT INTO site_configs (key, value)
|
||||
VALUES
|
||||
|
||||
@@ -71,6 +71,13 @@ SELECT value FROM site_configs WHERE key = 'oauth_signing_key';
|
||||
INSERT INTO site_configs (key, value) VALUES ('oauth_signing_key', $1)
|
||||
ON CONFLICT (key) DO UPDATE set value = $1 WHERE site_configs.key = 'oauth_signing_key';
|
||||
|
||||
-- name: GetCoordinatorResumeTokenSigningKey :one
|
||||
SELECT value FROM site_configs WHERE key = 'coordinator_resume_token_signing_key';
|
||||
|
||||
-- name: UpsertCoordinatorResumeTokenSigningKey :exec
|
||||
INSERT INTO site_configs (key, value) VALUES ('coordinator_resume_token_signing_key', $1)
|
||||
ON CONFLICT (key) DO UPDATE set value = $1 WHERE site_configs.key = 'coordinator_resume_token_signing_key';
|
||||
|
||||
-- name: GetHealthSettings :one
|
||||
SELECT
|
||||
COALESCE((SELECT value FROM site_configs WHERE key = 'health_settings'), '{}') :: text AS health_settings
|
||||
|
||||
@@ -846,6 +846,26 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
|
||||
return
|
||||
}
|
||||
|
||||
// Accept a resume_token query parameter to use the same peer ID.
|
||||
var (
|
||||
peerID = uuid.New()
|
||||
resumeToken = r.URL.Query().Get("resume_token")
|
||||
)
|
||||
if resumeToken != "" {
|
||||
var err error
|
||||
peerID, err = api.Options.CoordinatorResumeTokenProvider.VerifyResumeToken(resumeToken)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{
|
||||
Message: workspacesdk.CoordinateAPIInvalidResumeToken,
|
||||
Detail: err.Error(),
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "resume_token", Detail: workspacesdk.CoordinateAPIInvalidResumeToken},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
api.WebsocketWaitMutex.Lock()
|
||||
api.WebsocketWaitGroup.Add(1)
|
||||
api.WebsocketWaitMutex.Unlock()
|
||||
@@ -866,7 +886,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
|
||||
go httpapi.Heartbeat(ctx, conn)
|
||||
|
||||
defer conn.Close(websocket.StatusNormalClosure, "")
|
||||
err = api.TailnetClientService.ServeClient(ctx, version, wsNetConn, uuid.New(), workspaceAgent.ID)
|
||||
err = api.TailnetClientService.ServeClient(ctx, version, wsNetConn, peerID, workspaceAgent.ID)
|
||||
if err != nil && !xerrors.Is(err, io.EOF) && !xerrors.Is(err, context.Canceled) {
|
||||
_ = conn.Close(websocket.StatusInternalError, err.Error())
|
||||
return
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"nhooyr.io/websocket"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
@@ -40,8 +41,11 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/provisioner/echo"
|
||||
"github.com/coder/coder/v2/provisionersdk/proto"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestWorkspaceAgent(t *testing.T) {
|
||||
@@ -509,6 +513,147 @@ func TestWorkspaceAgentClientCoordinate_BadVersion(t *testing.T) {
|
||||
require.Equal(t, "version", sdkErr.Validations[0].Field)
|
||||
}
|
||||
|
||||
type resumeTokenTestFakeCoordinator struct {
|
||||
tailnet.Coordinator
|
||||
lastPeerID uuid.UUID
|
||||
}
|
||||
|
||||
var _ tailnet.Coordinator = &resumeTokenTestFakeCoordinator{}
|
||||
|
||||
func (c *resumeTokenTestFakeCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agentID uuid.UUID) error {
|
||||
c.lastPeerID = id
|
||||
return c.Coordinator.ServeClient(conn, id, agentID)
|
||||
}
|
||||
|
||||
func (c *resumeTokenTestFakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.CoordinateeAuth) (chan<- *tailnetproto.CoordinateRequest, <-chan *tailnetproto.CoordinateResponse) {
|
||||
c.lastPeerID = id
|
||||
return c.Coordinator.Coordinate(ctx, id, name, a)
|
||||
}
|
||||
|
||||
func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
clock := quartz.NewMock(t)
|
||||
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
|
||||
require.NoError(t, err)
|
||||
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
|
||||
coordinator := &resumeTokenTestFakeCoordinator{
|
||||
Coordinator: tailnet.NewCoordinator(logger),
|
||||
}
|
||||
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||
Coordinator: coordinator,
|
||||
CoordinatorResumeTokenProvider: resumeTokenProvider,
|
||||
})
|
||||
defer closer.Close()
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create a workspace with an agent. No need to connect it since clients can
|
||||
// still connect to the coordinator while the agent isn't connected.
|
||||
r := dbfake.WorkspaceBuild(t, api.Database, database.Workspace{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent().Do()
|
||||
agentTokenUUID, err := uuid.Parse(r.AgentToken)
|
||||
require.NoError(t, err)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connect with no resume token, and ensure that the peer ID is set to a
|
||||
// random value.
|
||||
coordinator.lastPeerID = uuid.Nil
|
||||
originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "")
|
||||
require.NoError(t, err)
|
||||
originalPeerID := coordinator.lastPeerID
|
||||
require.NotEqual(t, originalPeerID, uuid.Nil)
|
||||
|
||||
// Connect with a valid resume token, and ensure that the peer ID is set to
|
||||
// the stored value.
|
||||
clock.Advance(time.Second)
|
||||
coordinator.lastPeerID = uuid.Nil
|
||||
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalPeerID, coordinator.lastPeerID)
|
||||
require.NotEqual(t, originalResumeToken, newResumeToken)
|
||||
|
||||
// Connect with an invalid resume token, and ensure that the request is
|
||||
// rejected.
|
||||
clock.Advance(time.Second)
|
||||
coordinator.lastPeerID = uuid.Nil
|
||||
_, err = connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "invalid")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
|
||||
require.Len(t, sdkErr.Validations, 1)
|
||||
require.Equal(t, "resume_token", sdkErr.Validations[0].Field)
|
||||
require.Equal(t, uuid.Nil, coordinator.lastPeerID)
|
||||
}
|
||||
|
||||
// connectToCoordinatorAndFetchResumeToken connects to the tailnet coordinator
|
||||
// with a given resume token. It returns an error if the connection is rejected.
|
||||
// If the connection is accepted, it is immediately closed and no error is
|
||||
// returned.
|
||||
func connectToCoordinatorAndFetchResumeToken(ctx context.Context, logger slog.Logger, sdkClient *codersdk.Client, agentID uuid.UUID, resumeToken string) (string, error) {
|
||||
u, err := sdkClient.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", agentID))
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("parse URL: %w", err)
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("version", "2.0")
|
||||
if resumeToken != "" {
|
||||
q.Set("resume_token", resumeToken)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
//nolint:bodyclose
|
||||
wsConn, resp, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{
|
||||
HTTPHeader: http.Header{
|
||||
"Coder-Session-Token": []string{sdkClient.SessionToken()},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
err = codersdk.ReadBodyAsError(resp)
|
||||
}
|
||||
return "", xerrors.Errorf("websocket dial: %w", err)
|
||||
}
|
||||
defer wsConn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
// Send a request to the server to ensure that we're plumbed all the way
|
||||
// through.
|
||||
rpcClient, err := tailnet.NewDRPCClient(
|
||||
websocket.NetConn(ctx, wsConn, websocket.MessageBinary),
|
||||
logger,
|
||||
)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("new dRPC client: %w", err)
|
||||
}
|
||||
|
||||
// Send an empty coordination request. This will do nothing on the server,
|
||||
// but ensures our wrapped coordinator can record the peer ID.
|
||||
coordinateClient, err := rpcClient.Coordinate(ctx)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("coordinate: %w", err)
|
||||
}
|
||||
err = coordinateClient.Send(&tailnetproto.CoordinateRequest{})
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("send empty coordination request: %w", err)
|
||||
}
|
||||
err = coordinateClient.Close()
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("close coordination request: %w", err)
|
||||
}
|
||||
|
||||
// Fetch a resume token.
|
||||
newResumeToken, err := rpcClient.RefreshResumeToken(ctx, &tailnetproto.RefreshResumeTokenRequest{})
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("fetch resume token: %w", err)
|
||||
}
|
||||
return newResumeToken.Token, nil
|
||||
}
|
||||
|
||||
func TestWorkspaceAgentTailnetDirectDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -65,6 +65,10 @@ func (t SignedToken) MatchesRequest(req Request) bool {
|
||||
// two keys.
|
||||
type SecurityKey [96]byte
|
||||
|
||||
func (k SecurityKey) IsZero() bool {
|
||||
return k == SecurityKey{}
|
||||
}
|
||||
|
||||
func (k SecurityKey) String() string {
|
||||
return hex.EncodeToString(k[:])
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/retry"
|
||||
)
|
||||
|
||||
@@ -61,6 +63,7 @@ type tailnetAPIConnector struct {
|
||||
|
||||
agentID uuid.UUID
|
||||
coordinateURL string
|
||||
clock quartz.Clock
|
||||
dialOptions *websocket.DialOptions
|
||||
conn tailnetConn
|
||||
customDialFn func() (proto.DRPCTailnetClient, error)
|
||||
@@ -68,9 +71,10 @@ type tailnetAPIConnector struct {
|
||||
clientMu sync.RWMutex
|
||||
client proto.DRPCTailnetClient
|
||||
|
||||
connected chan error
|
||||
isFirst bool
|
||||
closed chan struct{}
|
||||
connected chan error
|
||||
resumeToken *proto.RefreshResumeTokenResponse
|
||||
isFirst bool
|
||||
closed chan struct{}
|
||||
|
||||
// Only set to true if we get a response from the server that it doesn't support
|
||||
// network telemetry.
|
||||
@@ -78,12 +82,13 @@ type tailnetAPIConnector struct {
|
||||
}
|
||||
|
||||
// Create a new tailnetAPIConnector without running it
|
||||
func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uuid.UUID, coordinateURL string, dialOptions *websocket.DialOptions) *tailnetAPIConnector {
|
||||
func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uuid.UUID, coordinateURL string, clock quartz.Clock, dialOptions *websocket.DialOptions) *tailnetAPIConnector {
|
||||
return &tailnetAPIConnector{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
agentID: agentID,
|
||||
coordinateURL: coordinateURL,
|
||||
clock: clock,
|
||||
dialOptions: dialOptions,
|
||||
conn: nil,
|
||||
connected: make(chan error, 1),
|
||||
@@ -96,7 +101,7 @@ func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uui
|
||||
func (tac *tailnetAPIConnector) manageGracefulTimeout() {
|
||||
defer tac.cancelGracefulCtx()
|
||||
<-tac.ctx.Done()
|
||||
timer := time.NewTimer(tailnetConnectorGracefulTimeout)
|
||||
timer := tac.clock.NewTimer(tailnetConnectorGracefulTimeout, "tailnetAPIClient", "gracefulTimeout")
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-tac.closed:
|
||||
@@ -112,6 +117,8 @@ func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) {
|
||||
go func() {
|
||||
tac.isFirst = true
|
||||
defer close(tac.closed)
|
||||
// Sadly retry doesn't support quartz.Clock yet so this is not
|
||||
// influenced by the configured clock.
|
||||
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(tac.ctx); {
|
||||
tailnetClient, err := tac.dial()
|
||||
if err != nil {
|
||||
@@ -121,7 +128,7 @@ func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) {
|
||||
tac.client = tailnetClient
|
||||
tac.clientMu.Unlock()
|
||||
tac.logger.Debug(tac.ctx, "obtained tailnet API v2+ client")
|
||||
tac.coordinateAndDERPMap(tailnetClient)
|
||||
tac.runConnectorOnce(tailnetClient)
|
||||
tac.logger.Debug(tac.ctx, "tailnet API v2+ connection lost")
|
||||
}
|
||||
}()
|
||||
@@ -138,8 +145,23 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
|
||||
return tac.customDialFn()
|
||||
}
|
||||
tac.logger.Debug(tac.ctx, "dialing Coder tailnet v2+ API")
|
||||
|
||||
u, err := url.Parse(tac.coordinateURL)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse URL %q: %w", tac.coordinateURL, err)
|
||||
}
|
||||
if tac.resumeToken != nil {
|
||||
q := u.Query()
|
||||
q.Set("resume_token", tac.resumeToken.Token)
|
||||
u.RawQuery = q.Encode()
|
||||
tac.logger.Debug(tac.ctx, "using resume token", slog.F("resume_token", tac.resumeToken))
|
||||
}
|
||||
|
||||
coordinateURL := u.String()
|
||||
tac.logger.Debug(tac.ctx, "using coordinate URL", slog.F("url", coordinateURL))
|
||||
|
||||
// nolint:bodyclose
|
||||
ws, res, err := websocket.Dial(tac.ctx, tac.coordinateURL, tac.dialOptions)
|
||||
ws, res, err := websocket.Dial(tac.ctx, coordinateURL, tac.dialOptions)
|
||||
if tac.isFirst {
|
||||
if res != nil && slices.Contains(permanentErrorStatuses, res.StatusCode) {
|
||||
err = codersdk.ReadBodyAsError(res)
|
||||
@@ -160,8 +182,20 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
|
||||
close(tac.connected)
|
||||
}
|
||||
if err != nil {
|
||||
bodyErr := codersdk.ReadBodyAsError(res)
|
||||
var sdkErr *codersdk.Error
|
||||
if xerrors.As(bodyErr, &sdkErr) {
|
||||
for _, v := range sdkErr.Validations {
|
||||
if v.Field == "resume_token" {
|
||||
// Unset the resume token for the next attempt
|
||||
tac.logger.Warn(tac.ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt")
|
||||
tac.resumeToken = nil
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", slog.Error(err))
|
||||
tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", slog.Error(err), slog.F("sdk_err", sdkErr))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -177,11 +211,11 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
|
||||
return client, err
|
||||
}
|
||||
|
||||
// coordinateAndDERPMap uses the provided client to coordinate and stream DERP Maps. It is combined
|
||||
// runConnectorOnce uses the provided client to coordinate and stream DERP Maps. It is combined
|
||||
// into one function so that a problem with one tears down the other and triggers a retry (if
|
||||
// appropriate). We multiplex both RPCs over the same websocket, so we want them to share the same
|
||||
// fate.
|
||||
func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetClient) {
|
||||
func (tac *tailnetAPIConnector) runConnectorOnce(client proto.DRPCTailnetClient) {
|
||||
defer func() {
|
||||
conn := client.DRPCConn()
|
||||
closeErr := conn.Close()
|
||||
@@ -193,14 +227,17 @@ func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetCli
|
||||
<-conn.Closed()
|
||||
}
|
||||
}()
|
||||
|
||||
refreshTokenCtx, refreshTokenCancel := context.WithCancel(tac.ctx)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tac.coordinate(client)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer refreshTokenCancel()
|
||||
dErr := tac.derpMap(client)
|
||||
if dErr != nil && tac.ctx.Err() == nil {
|
||||
// The main context is still active, meaning that we want the tailnet data plane to stay
|
||||
@@ -215,6 +252,10 @@ func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetCli
|
||||
// Note that derpMap() logs it own errors, we don't bother here.
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tac.refreshToken(refreshTokenCtx, client)
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -278,6 +319,41 @@ func (tac *tailnetAPIConnector) derpMap(client proto.DRPCTailnetClient) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client proto.DRPCTailnetClient) {
|
||||
ticker := tac.clock.NewTicker(15*time.Second, "tailnetAPIConnector", "refreshToken")
|
||||
defer ticker.Stop()
|
||||
|
||||
initialCh := make(chan struct{}, 1)
|
||||
initialCh <- struct{}{}
|
||||
defer close(initialCh)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
case <-initialCh:
|
||||
}
|
||||
|
||||
attemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
res, err := client.RefreshResumeToken(attemptCtx, &proto.RefreshResumeTokenRequest{})
|
||||
cancel()
|
||||
if err != nil {
|
||||
if ctx.Err() == nil {
|
||||
tac.logger.Error(tac.ctx, "error refreshing coordinator resume token", slog.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
tac.logger.Debug(tac.ctx, "refreshed coordinator resume token", slog.F("resume_token", res))
|
||||
tac.resumeToken = res
|
||||
dur := res.RefreshIn.AsDuration()
|
||||
if dur <= 0 {
|
||||
// A sensible delay to refresh again.
|
||||
dur = 30 * time.Minute
|
||||
}
|
||||
ticker.Reset(dur, "tailnetAPIConnector", "refreshToken", "reset")
|
||||
}
|
||||
}
|
||||
|
||||
func (tac *tailnetAPIConnector) SendTelemetryEvent(event *proto.TelemetryEvent) {
|
||||
tac.clientMu.RLock()
|
||||
// We hold the lock for the entire telemetry request, but this would only block
|
||||
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"nhooyr.io/websocket"
|
||||
"storj.io/drpc"
|
||||
"storj.io/drpc/drpcerr"
|
||||
@@ -28,6 +30,7 @@ import (
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -59,6 +62,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
||||
DERPMapUpdateFrequency: time.Millisecond,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
||||
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
|
||||
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -78,7 +82,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
||||
|
||||
fConn := newFakeTailnetConn()
|
||||
|
||||
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{})
|
||||
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{})
|
||||
uut.runConnector(fConn)
|
||||
|
||||
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
|
||||
@@ -131,7 +135,7 @@ func TestTailnetAPIConnector_UplevelVersion(t *testing.T) {
|
||||
|
||||
fConn := newFakeTailnetConn()
|
||||
|
||||
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{})
|
||||
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{})
|
||||
uut.runConnector(fConn)
|
||||
|
||||
err := testutil.RequireRecvCtx(ctx, t, uut.connected)
|
||||
@@ -142,6 +146,215 @@ func TestTailnetAPIConnector_UplevelVersion(t *testing.T) {
|
||||
require.NotEmpty(t, sdkErr.Helper)
|
||||
}
|
||||
|
||||
func TestTailnetAPIConnector_ResumeToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{
|
||||
IgnoreErrors: true,
|
||||
}).Leveled(slog.LevelDebug)
|
||||
agentID := uuid.UUID{0x55}
|
||||
fCoord := tailnettest.NewFakeCoordinator()
|
||||
var coord tailnet.Coordinator = fCoord
|
||||
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||
coordPtr.Store(&coord)
|
||||
derpMapCh := make(chan *tailcfg.DERPMap)
|
||||
defer close(derpMapCh)
|
||||
|
||||
clock := quartz.NewMock(t)
|
||||
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
|
||||
require.NoError(t, err)
|
||||
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
|
||||
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
||||
Logger: logger,
|
||||
CoordPtr: &coordPtr,
|
||||
DERPMapUpdateFrequency: time.Millisecond,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
||||
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
|
||||
ResumeTokenProvider: resumeTokenProvider,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var (
|
||||
websocketConnCh = make(chan *websocket.Conn, 64)
|
||||
expectResumeToken = ""
|
||||
)
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Accept a resume_token query parameter to use the same peer ID. This
|
||||
// behavior matches the actual client coordinate route.
|
||||
var (
|
||||
peerID = uuid.New()
|
||||
resumeToken = r.URL.Query().Get("resume_token")
|
||||
)
|
||||
t.Logf("received resume token: %s", resumeToken)
|
||||
assert.Equal(t, expectResumeToken, resumeToken)
|
||||
if resumeToken != "" {
|
||||
peerID, err = resumeTokenProvider.VerifyResumeToken(resumeToken)
|
||||
assert.NoError(t, err, "failed to parse resume token")
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{
|
||||
Message: CoordinateAPIInvalidResumeToken,
|
||||
Detail: err.Error(),
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "resume_token", Detail: CoordinateAPIInvalidResumeToken},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sws, err := websocket.Accept(w, r, nil)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
testutil.RequireSendCtx(ctx, t, websocketConnCh, sws)
|
||||
ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
|
||||
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
|
||||
Name: "client",
|
||||
ID: peerID,
|
||||
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}))
|
||||
|
||||
fConn := newFakeTailnetConn()
|
||||
|
||||
newTickerTrap := clock.Trap().NewTicker("tailnetAPIConnector", "refreshToken")
|
||||
tickerResetTrap := clock.Trap().TickerReset("tailnetAPIConnector", "refreshToken", "reset")
|
||||
defer newTickerTrap.Close()
|
||||
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, clock, &websocket.DialOptions{})
|
||||
uut.runConnector(fConn)
|
||||
|
||||
// Fetch first token. We don't need to advance the clock since we use a
|
||||
// channel with a single item to immediately fetch.
|
||||
newTickerTrap.MustWait(ctx).Release()
|
||||
// We call ticker.Reset after each token fetch to apply the refresh duration
|
||||
// requested by the server.
|
||||
trappedReset := tickerResetTrap.MustWait(ctx)
|
||||
trappedReset.Release()
|
||||
require.NotNil(t, uut.resumeToken)
|
||||
originalResumeToken := uut.resumeToken.Token
|
||||
|
||||
// Fetch second token.
|
||||
waiter := clock.Advance(trappedReset.Duration)
|
||||
waiter.MustWait(ctx)
|
||||
trappedReset = tickerResetTrap.MustWait(ctx)
|
||||
trappedReset.Release()
|
||||
require.NotNil(t, uut.resumeToken)
|
||||
require.NotEqual(t, originalResumeToken, uut.resumeToken.Token)
|
||||
expectResumeToken = uut.resumeToken.Token
|
||||
t.Logf("expecting resume token: %s", expectResumeToken)
|
||||
|
||||
// Sever the connection and expect it to reconnect with the resume token.
|
||||
wsConn := testutil.RequireRecvCtx(ctx, t, websocketConnCh)
|
||||
_ = wsConn.Close(websocket.StatusGoingAway, "test")
|
||||
|
||||
// Wait for the resume token to be refreshed.
|
||||
trappedTicker := newTickerTrap.MustWait(ctx)
|
||||
// Advance the clock slightly to ensure the new JWT is different.
|
||||
clock.Advance(time.Second).MustWait(ctx)
|
||||
trappedTicker.Release()
|
||||
trappedReset = tickerResetTrap.MustWait(ctx)
|
||||
trappedReset.Release()
|
||||
|
||||
// The resume token should have changed again.
|
||||
require.NotNil(t, uut.resumeToken)
|
||||
require.NotEqual(t, expectResumeToken, uut.resumeToken.Token)
|
||||
}
|
||||
|
||||
func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{
|
||||
IgnoreErrors: true,
|
||||
}).Leveled(slog.LevelDebug)
|
||||
agentID := uuid.UUID{0x55}
|
||||
fCoord := tailnettest.NewFakeCoordinator()
|
||||
var coord tailnet.Coordinator = fCoord
|
||||
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||
coordPtr.Store(&coord)
|
||||
derpMapCh := make(chan *tailcfg.DERPMap)
|
||||
defer close(derpMapCh)
|
||||
|
||||
clock := quartz.NewMock(t)
|
||||
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
|
||||
require.NoError(t, err)
|
||||
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
|
||||
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
||||
Logger: logger,
|
||||
CoordPtr: &coordPtr,
|
||||
DERPMapUpdateFrequency: time.Millisecond,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
||||
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
|
||||
ResumeTokenProvider: resumeTokenProvider,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var (
|
||||
websocketConnCh = make(chan *websocket.Conn, 64)
|
||||
didFail int64
|
||||
)
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Get("resume_token") != "" {
|
||||
atomic.AddInt64(&didFail, 1)
|
||||
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{
|
||||
Message: CoordinateAPIInvalidResumeToken,
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "resume_token", Detail: CoordinateAPIInvalidResumeToken},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
sws, err := websocket.Accept(w, r, nil)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
testutil.RequireSendCtx(ctx, t, websocketConnCh, sws)
|
||||
ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
|
||||
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
|
||||
Name: "client",
|
||||
ID: uuid.New(),
|
||||
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}))
|
||||
|
||||
fConn := newFakeTailnetConn()
|
||||
|
||||
newTickerTrap := clock.Trap().NewTicker("tailnetAPIConnector", "refreshToken")
|
||||
tickerResetTrap := clock.Trap().TickerReset("tailnetAPIConnector", "refreshToken", "reset")
|
||||
defer newTickerTrap.Close()
|
||||
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, clock, &websocket.DialOptions{})
|
||||
uut.runConnector(fConn)
|
||||
|
||||
// Wait for the resume token to be fetched for the first time.
|
||||
newTickerTrap.MustWait(ctx).Release()
|
||||
trappedReset := tickerResetTrap.MustWait(ctx)
|
||||
trappedReset.Release()
|
||||
originalResumeToken := uut.resumeToken.Token
|
||||
|
||||
// Sever the connection and expect it to reconnect with the resume token,
|
||||
// which should fail and cause the client to be disconnected. The client
|
||||
// should then reconnect with no resume token.
|
||||
wsConn := testutil.RequireRecvCtx(ctx, t, websocketConnCh)
|
||||
_ = wsConn.Close(websocket.StatusGoingAway, "test")
|
||||
|
||||
// Wait for the resume token to be refreshed, which indicates a successful
|
||||
// reconnect.
|
||||
trappedTicker := newTickerTrap.MustWait(ctx)
|
||||
// Since we failed the initial reconnect and we're definitely reconnected
|
||||
// now, the stored resume token should now be nil.
|
||||
require.Nil(t, uut.resumeToken)
|
||||
trappedTicker.Release()
|
||||
trappedReset = tickerResetTrap.MustWait(ctx)
|
||||
trappedReset.Release()
|
||||
require.NotNil(t, uut.resumeToken)
|
||||
require.NotEqual(t, originalResumeToken, uut.resumeToken.Token)
|
||||
|
||||
// The resume token should have been rejected by the server.
|
||||
require.EqualValues(t, 1, atomic.LoadInt64(&didFail))
|
||||
}
|
||||
|
||||
func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
@@ -161,8 +374,9 @@ func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) {
|
||||
DERPMapUpdateFrequency: time.Millisecond,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
||||
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {
|
||||
eventCh <- batch
|
||||
testutil.RequireSendCtx(ctx, t, eventCh, batch)
|
||||
},
|
||||
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -182,7 +396,7 @@ func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) {
|
||||
|
||||
fConn := newFakeTailnetConn()
|
||||
|
||||
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{})
|
||||
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{})
|
||||
uut.runConnector(fConn)
|
||||
require.Eventually(t, func() bool {
|
||||
uut.clientMu.Lock()
|
||||
@@ -213,6 +427,7 @@ func TestTailnetAPIConnector_TelemetryUnimplemented(t *testing.T) {
|
||||
logger: logger,
|
||||
agentID: agentID,
|
||||
coordinateURL: "",
|
||||
clock: quartz.NewReal(),
|
||||
dialOptions: &websocket.DialOptions{},
|
||||
conn: nil,
|
||||
connected: make(chan error, 1),
|
||||
@@ -253,6 +468,7 @@ func TestTailnetAPIConnector_TelemetryNotRecognised(t *testing.T) {
|
||||
logger: logger,
|
||||
agentID: agentID,
|
||||
coordinateURL: "",
|
||||
clock: quartz.NewReal(),
|
||||
dialOptions: &websocket.DialOptions{},
|
||||
conn: nil,
|
||||
connected: make(chan error, 1),
|
||||
@@ -301,6 +517,7 @@ func newFakeTailnetConn() *fakeTailnetConn {
|
||||
|
||||
type fakeDRPCClient struct {
|
||||
postTelemetryCalls int64
|
||||
refreshTokenFn func(context.Context, *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error)
|
||||
telemetryError error
|
||||
fakeDRPPCMapStream
|
||||
}
|
||||
@@ -339,6 +556,19 @@ func (f *fakeDRPCClient) StreamDERPMaps(_ context.Context, _ *proto.StreamDERPMa
|
||||
return &f.fakeDRPPCMapStream, nil
|
||||
}
|
||||
|
||||
// RefreshResumeToken implements proto.DRPCTailnetClient.
|
||||
func (f *fakeDRPCClient) RefreshResumeToken(_ context.Context, _ *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) {
|
||||
if f.refreshTokenFn != nil {
|
||||
return f.refreshTokenFn(context.Background(), nil)
|
||||
}
|
||||
|
||||
return &proto.RefreshResumeTokenResponse{
|
||||
Token: "test",
|
||||
RefreshIn: durationpb.New(30 * time.Minute),
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type fakeDRPCConn struct{}
|
||||
|
||||
var _ drpc.Conn = &fakeDRPCConn{}
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// AgentIP is a static IPv6 address with the Tailscale prefix that is used to route
|
||||
@@ -55,7 +56,11 @@ const (
|
||||
AgentMinimumListeningPort = 9
|
||||
)
|
||||
|
||||
const AgentAPIMismatchMessage = "Unknown or unsupported API version"
|
||||
const (
|
||||
AgentAPIMismatchMessage = "Unknown or unsupported API version"
|
||||
|
||||
CoordinateAPIInvalidResumeToken = "Invalid resume token"
|
||||
)
|
||||
|
||||
// AgentIgnoredListeningPorts contains a list of ports to ignore when looking for
|
||||
// running applications inside a workspace. We want to ignore non-HTTP servers,
|
||||
@@ -232,7 +237,7 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
|
||||
q.Add("version", "2.0")
|
||||
coordinateURL.RawQuery = q.Encode()
|
||||
|
||||
connector := newTailnetAPIConnector(ctx, options.Logger, agentID, coordinateURL.String(),
|
||||
connector := newTailnetAPIConnector(ctx, options.Logger, agentID, coordinateURL.String(), quartz.NewReal(),
|
||||
&websocket.DialOptions{
|
||||
HTTPClient: c.client.HTTPClient,
|
||||
HTTPHeader: headers,
|
||||
|
||||
@@ -147,6 +147,7 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
||||
DERPMapUpdateFrequency: api.Options.DERPMapUpdateFrequency,
|
||||
DERPMapFn: api.AGPL.DERPMap,
|
||||
NetworkTelemetryHandler: api.AGPL.NetworkTelemetryBatcher.Handler,
|
||||
ResumeTokenProvider: api.AGPL.CoordinatorResumeTokenProvider,
|
||||
})
|
||||
if err != nil {
|
||||
api.Logger.Fatal(api.ctx, "failed to initialize tailnet client service", slog.Error(err))
|
||||
|
||||
@@ -24,13 +24,7 @@ type ClientService struct {
|
||||
// NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is
|
||||
// loaded on each processed connection.
|
||||
func NewClientService(options agpl.ClientServiceOptions) (*ClientService, error) {
|
||||
s, err := agpl.NewClientService(agpl.ClientServiceOptions{
|
||||
Logger: options.Logger,
|
||||
CoordPtr: options.CoordPtr,
|
||||
DERPMapUpdateFrequency: options.DERPMapUpdateFrequency,
|
||||
DERPMapFn: options.DERPMapFn,
|
||||
NetworkTelemetryHandler: options.NetworkTelemetryHandler,
|
||||
})
|
||||
s, err := agpl.NewClientService(options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -177,6 +177,7 @@ func TestDialCoordinator(t *testing.T) {
|
||||
DERPMapUpdateFrequency: time.Hour,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
|
||||
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
|
||||
ResumeTokenProvider: agpl.NewInsecureTestResumeTokenProvider(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -608,6 +608,16 @@ func (c *configMaps) fillPeerDiagnostics(d *PeerDiagnostics, peerID uuid.UUID) {
|
||||
d.LastWireguardHandshake = ps.LastHandshake
|
||||
}
|
||||
|
||||
func (c *configMaps) knownPeerIDs() []uuid.UUID {
|
||||
c.L.Lock()
|
||||
defer c.L.Unlock()
|
||||
out := make([]uuid.UUID, 0, len(c.peers))
|
||||
for id := range c.peers {
|
||||
out = append(out, id)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *configMaps) peerReadyForHandshakeTimeout(peerID uuid.UUID) {
|
||||
logger := c.logger.With(slog.F("peer_id", peerID))
|
||||
logger.Debug(context.Background(), "peer ready for handshake timeout")
|
||||
|
||||
@@ -847,6 +847,10 @@ func (c *Conn) GetPeerDiagnostics(peerID uuid.UUID) PeerDiagnostics {
|
||||
return d
|
||||
}
|
||||
|
||||
func (c *Conn) GetKnownPeerIDs() []uuid.UUID {
|
||||
return c.configMaps.knownPeerIDs()
|
||||
}
|
||||
|
||||
type listenKey struct {
|
||||
network string
|
||||
host string
|
||||
|
||||
@@ -630,6 +630,7 @@ func TestRemoteCoordination(t *testing.T) {
|
||||
DERPMapUpdateFrequency: time.Hour,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
|
||||
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
|
||||
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
sC, cC := net.Pipe()
|
||||
@@ -681,6 +682,7 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
|
||||
DERPMapUpdateFrequency: time.Hour,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
|
||||
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
|
||||
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
sC, cC := net.Pipe()
|
||||
|
||||
+549
-398
File diff suppressed because it is too large
Load Diff
@@ -57,6 +57,14 @@ message Node {
|
||||
repeated string endpoints = 10;
|
||||
}
|
||||
|
||||
message RefreshResumeTokenRequest {}
|
||||
|
||||
message RefreshResumeTokenResponse {
|
||||
string token = 1;
|
||||
google.protobuf.Duration refresh_in = 2;
|
||||
google.protobuf.Timestamp expires_at = 3;
|
||||
}
|
||||
|
||||
message CoordinateRequest {
|
||||
message UpdateSelf {
|
||||
Node node = 1;
|
||||
@@ -191,5 +199,6 @@ message TelemetryResponse {}
|
||||
service Tailnet {
|
||||
rpc PostTelemetry(TelemetryRequest) returns (TelemetryResponse);
|
||||
rpc StreamDERPMaps(StreamDERPMapsRequest) returns (stream DERPMap);
|
||||
rpc RefreshResumeToken(RefreshResumeTokenRequest) returns (RefreshResumeTokenResponse);
|
||||
rpc Coordinate(stream CoordinateRequest) returns (stream CoordinateResponse);
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ type DRPCTailnetClient interface {
|
||||
|
||||
PostTelemetry(ctx context.Context, in *TelemetryRequest) (*TelemetryResponse, error)
|
||||
StreamDERPMaps(ctx context.Context, in *StreamDERPMapsRequest) (DRPCTailnet_StreamDERPMapsClient, error)
|
||||
RefreshResumeToken(ctx context.Context, in *RefreshResumeTokenRequest) (*RefreshResumeTokenResponse, error)
|
||||
Coordinate(ctx context.Context) (DRPCTailnet_CoordinateClient, error)
|
||||
}
|
||||
|
||||
@@ -102,6 +103,15 @@ func (x *drpcTailnet_StreamDERPMapsClient) RecvMsg(m *DERPMap) error {
|
||||
return x.MsgRecv(m, drpcEncoding_File_tailnet_proto_tailnet_proto{})
|
||||
}
|
||||
|
||||
func (c *drpcTailnetClient) RefreshResumeToken(ctx context.Context, in *RefreshResumeTokenRequest) (*RefreshResumeTokenResponse, error) {
|
||||
out := new(RefreshResumeTokenResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.tailnet.v2.Tailnet/RefreshResumeToken", drpcEncoding_File_tailnet_proto_tailnet_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcTailnetClient) Coordinate(ctx context.Context) (DRPCTailnet_CoordinateClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, "/coder.tailnet.v2.Tailnet/Coordinate", drpcEncoding_File_tailnet_proto_tailnet_proto{})
|
||||
if err != nil {
|
||||
@@ -144,6 +154,7 @@ func (x *drpcTailnet_CoordinateClient) RecvMsg(m *CoordinateResponse) error {
|
||||
type DRPCTailnetServer interface {
|
||||
PostTelemetry(context.Context, *TelemetryRequest) (*TelemetryResponse, error)
|
||||
StreamDERPMaps(*StreamDERPMapsRequest, DRPCTailnet_StreamDERPMapsStream) error
|
||||
RefreshResumeToken(context.Context, *RefreshResumeTokenRequest) (*RefreshResumeTokenResponse, error)
|
||||
Coordinate(DRPCTailnet_CoordinateStream) error
|
||||
}
|
||||
|
||||
@@ -157,13 +168,17 @@ func (s *DRPCTailnetUnimplementedServer) StreamDERPMaps(*StreamDERPMapsRequest,
|
||||
return drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCTailnetUnimplementedServer) RefreshResumeToken(context.Context, *RefreshResumeTokenRequest) (*RefreshResumeTokenResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCTailnetUnimplementedServer) Coordinate(DRPCTailnet_CoordinateStream) error {
|
||||
return drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
type DRPCTailnetDescription struct{}
|
||||
|
||||
func (DRPCTailnetDescription) NumMethods() int { return 3 }
|
||||
func (DRPCTailnetDescription) NumMethods() int { return 4 }
|
||||
|
||||
func (DRPCTailnetDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
|
||||
switch n {
|
||||
@@ -186,6 +201,15 @@ func (DRPCTailnetDescription) Method(n int) (string, drpc.Encoding, drpc.Receive
|
||||
)
|
||||
}, DRPCTailnetServer.StreamDERPMaps, true
|
||||
case 2:
|
||||
return "/coder.tailnet.v2.Tailnet/RefreshResumeToken", drpcEncoding_File_tailnet_proto_tailnet_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCTailnetServer).
|
||||
RefreshResumeToken(
|
||||
ctx,
|
||||
in1.(*RefreshResumeTokenRequest),
|
||||
)
|
||||
}, DRPCTailnetServer.RefreshResumeToken, true
|
||||
case 3:
|
||||
return "/coder.tailnet.v2.Tailnet/Coordinate", drpcEncoding_File_tailnet_proto_tailnet_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return nil, srv.(DRPCTailnetServer).
|
||||
@@ -231,6 +255,22 @@ func (x *drpcTailnet_StreamDERPMapsStream) Send(m *DERPMap) error {
|
||||
return x.MsgSend(m, drpcEncoding_File_tailnet_proto_tailnet_proto{})
|
||||
}
|
||||
|
||||
type DRPCTailnet_RefreshResumeTokenStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*RefreshResumeTokenResponse) error
|
||||
}
|
||||
|
||||
type drpcTailnet_RefreshResumeTokenStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcTailnet_RefreshResumeTokenStream) SendAndClose(m *RefreshResumeTokenResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_tailnet_proto_tailnet_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCTailnet_CoordinateStream interface {
|
||||
drpc.Stream
|
||||
Send(*CoordinateResponse) error
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
package tailnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultResumeTokenExpiry = 24 * time.Hour
|
||||
|
||||
resumeTokenSigningAlgorithm = jose.HS512
|
||||
)
|
||||
|
||||
// resumeTokenSigningKeyID is a fixed key ID for the resume token signing key.
|
||||
// If/when we add support for multiple keys (e.g. key rotation), this will move
|
||||
// to the database instead.
|
||||
var resumeTokenSigningKeyID = uuid.MustParse("97166747-9309-4d7f-9071-a230e257c2a4")
|
||||
|
||||
// NewInsecureTestResumeTokenProvider returns a ResumeTokenProvider that uses a
|
||||
// random key with short expiry for testing purposes. If any errors occur while
|
||||
// generating the key, the function panics.
|
||||
func NewInsecureTestResumeTokenProvider() ResumeTokenProvider {
|
||||
key, err := GenerateResumeTokenSigningKey()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return NewResumeTokenKeyProvider(key, quartz.NewReal(), time.Hour)
|
||||
}
|
||||
|
||||
type ResumeTokenProvider interface {
|
||||
GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error)
|
||||
VerifyResumeToken(token string) (uuid.UUID, error)
|
||||
}
|
||||
|
||||
type ResumeTokenSigningKey [64]byte
|
||||
|
||||
func GenerateResumeTokenSigningKey() (ResumeTokenSigningKey, error) {
|
||||
var key ResumeTokenSigningKey
|
||||
_, err := rand.Read(key[:])
|
||||
if err != nil {
|
||||
return key, xerrors.Errorf("generate random key: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
type ResumeTokenSigningKeyDatabaseStore interface {
|
||||
GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error)
|
||||
UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, key string) error
|
||||
}
|
||||
|
||||
// ResumeTokenSigningKeyFromDatabase retrieves the coordinator resume token
|
||||
// signing key from the database. If the key is not found, a new key is
|
||||
// generated and inserted into the database.
|
||||
func ResumeTokenSigningKeyFromDatabase(ctx context.Context, db ResumeTokenSigningKeyDatabaseStore) (ResumeTokenSigningKey, error) {
|
||||
var resumeTokenKey ResumeTokenSigningKey
|
||||
resumeTokenKeyStr, err := db.GetCoordinatorResumeTokenSigningKey(ctx)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return resumeTokenKey, xerrors.Errorf("get coordinator resume token key: %w", err)
|
||||
}
|
||||
if decoded, err := hex.DecodeString(resumeTokenKeyStr); err != nil || len(decoded) != len(resumeTokenKey) {
|
||||
newKey, err := GenerateResumeTokenSigningKey()
|
||||
if err != nil {
|
||||
return resumeTokenKey, xerrors.Errorf("generate fresh coordinator resume token key: %w", err)
|
||||
}
|
||||
|
||||
resumeTokenKeyStr = hex.EncodeToString(newKey[:])
|
||||
err = db.UpsertCoordinatorResumeTokenSigningKey(ctx, resumeTokenKeyStr)
|
||||
if err != nil {
|
||||
return resumeTokenKey, xerrors.Errorf("insert freshly generated coordinator resume token key to database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
resumeTokenKeyBytes, err := hex.DecodeString(resumeTokenKeyStr)
|
||||
if err != nil {
|
||||
return resumeTokenKey, xerrors.Errorf("decode coordinator resume token key from database: %w", err)
|
||||
}
|
||||
if len(resumeTokenKeyBytes) != len(resumeTokenKey) {
|
||||
return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is not the correct length, expect %d got %d", len(resumeTokenKey), len(resumeTokenKeyBytes))
|
||||
}
|
||||
copy(resumeTokenKey[:], resumeTokenKeyBytes)
|
||||
if resumeTokenKey == [64]byte{} {
|
||||
return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is empty")
|
||||
}
|
||||
return resumeTokenKey, nil
|
||||
}
|
||||
|
||||
type ResumeTokenKeyProvider struct {
|
||||
key ResumeTokenSigningKey
|
||||
clock quartz.Clock
|
||||
expiry time.Duration
|
||||
}
|
||||
|
||||
func NewResumeTokenKeyProvider(key ResumeTokenSigningKey, clock quartz.Clock, expiry time.Duration) ResumeTokenProvider {
|
||||
if expiry <= 0 {
|
||||
expiry = DefaultResumeTokenExpiry
|
||||
}
|
||||
return ResumeTokenKeyProvider{
|
||||
key: key,
|
||||
clock: clock,
|
||||
expiry: DefaultResumeTokenExpiry,
|
||||
}
|
||||
}
|
||||
|
||||
type resumeTokenPayload struct {
|
||||
PeerID uuid.UUID `json:"sub"`
|
||||
Expiry int64 `json:"exp"`
|
||||
}
|
||||
|
||||
func (p ResumeTokenKeyProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) {
|
||||
exp := p.clock.Now().Add(p.expiry)
|
||||
payload := resumeTokenPayload{
|
||||
PeerID: peerID,
|
||||
Expiry: exp.Unix(),
|
||||
}
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal payload to JSON: %w", err)
|
||||
}
|
||||
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: resumeTokenSigningAlgorithm,
|
||||
Key: p.key[:],
|
||||
}, &jose.SignerOptions{
|
||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||
"kid": resumeTokenSigningKeyID.String(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create signer: %w", err)
|
||||
}
|
||||
|
||||
signedObject, err := signer.Sign(payloadBytes)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("sign payload: %w", err)
|
||||
}
|
||||
|
||||
serialized, err := signedObject.CompactSerialize()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("serialize JWS: %w", err)
|
||||
}
|
||||
|
||||
return &proto.RefreshResumeTokenResponse{
|
||||
Token: serialized,
|
||||
RefreshIn: durationpb.New(p.expiry / 2),
|
||||
ExpiresAt: timestamppb.New(exp),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VerifyResumeToken parses a signed tailnet resume token with the given key and
|
||||
// returns the payload. If the token is invalid or expired, an error is
|
||||
// returned.
|
||||
func (p ResumeTokenKeyProvider) VerifyResumeToken(str string) (uuid.UUID, error) {
|
||||
object, err := jose.ParseSigned(str)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("parse JWS: %w", err)
|
||||
}
|
||||
if len(object.Signatures) != 1 {
|
||||
return uuid.Nil, xerrors.New("expected 1 signature")
|
||||
}
|
||||
if object.Signatures[0].Header.Algorithm != string(resumeTokenSigningAlgorithm) {
|
||||
return uuid.Nil, xerrors.Errorf("expected token signing algorithm to be %q, got %q", resumeTokenSigningAlgorithm, object.Signatures[0].Header.Algorithm)
|
||||
}
|
||||
if object.Signatures[0].Header.KeyID != resumeTokenSigningKeyID.String() {
|
||||
return uuid.Nil, xerrors.Errorf("expected token key ID to be %q, got %q", resumeTokenSigningKeyID, object.Signatures[0].Header.KeyID)
|
||||
}
|
||||
|
||||
output, err := object.Verify(p.key[:])
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("verify JWS: %w", err)
|
||||
}
|
||||
|
||||
var tok resumeTokenPayload
|
||||
err = json.Unmarshal(output, &tok)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("unmarshal payload: %w", err)
|
||||
}
|
||||
exp := time.Unix(tok.Expiry, 0)
|
||||
if exp.Before(p.clock.Now()) {
|
||||
return uuid.Nil, xerrors.New("signed resume token expired")
|
||||
}
|
||||
|
||||
return tok.PeerID, nil
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
package tailnet_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assertRandomKey := func(t *testing.T, key tailnet.ResumeTokenSigningKey) {
|
||||
t.Helper()
|
||||
assert.NotEqual(t, tailnet.ResumeTokenSigningKey{}, key, "key should not be empty")
|
||||
assert.NotEqualValues(t, [64]byte{1}, key, "key should not be all 1s")
|
||||
}
|
||||
|
||||
t.Run("GenerateRetrieve", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
key1, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
|
||||
require.NoError(t, err)
|
||||
assertRandomKey(t, key1)
|
||||
|
||||
key2, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, key1, key2, "keys should not be different")
|
||||
})
|
||||
|
||||
t.Run("GetError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := dbmock.NewMockStore(gomock.NewController(t))
|
||||
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("", assert.AnError)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
|
||||
require.ErrorIs(t, err, assert.AnError)
|
||||
})
|
||||
|
||||
t.Run("UpsertError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := dbmock.NewMockStore(gomock.NewController(t))
|
||||
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("", nil)
|
||||
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(assert.AnError)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
|
||||
require.ErrorIs(t, err, assert.AnError)
|
||||
})
|
||||
|
||||
t.Run("DecodeErrorShouldRegenerate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := dbmock.NewMockStore(gomock.NewController(t))
|
||||
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("invalid", nil)
|
||||
|
||||
var storedKey tailnet.ResumeTokenSigningKey
|
||||
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Do(func(_ context.Context, value string) error {
|
||||
keyBytes, err := hex.DecodeString(value)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keyBytes, len(storedKey))
|
||||
copy(storedKey[:], keyBytes)
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
|
||||
require.NoError(t, err)
|
||||
assertRandomKey(t, key)
|
||||
require.Equal(t, storedKey, key, "key should match stored value")
|
||||
})
|
||||
|
||||
t.Run("LengthErrorShouldRegenerate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := dbmock.NewMockStore(gomock.NewController(t))
|
||||
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("deadbeef", nil)
|
||||
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(nil)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
|
||||
require.NoError(t, err)
|
||||
assertRandomKey(t, key)
|
||||
})
|
||||
|
||||
t.Run("EmptyError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := dbmock.NewMockStore(gomock.NewController(t))
|
||||
emptyKey := hex.EncodeToString(make([]byte, 64))
|
||||
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return(emptyKey, nil)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
|
||||
require.ErrorContains(t, err, "is empty")
|
||||
})
|
||||
}
|
||||
|
||||
func TestResumeTokenKeyProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
key, err := tailnet.GenerateResumeTokenSigningKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
id := uuid.New()
|
||||
clock := quartz.NewMock(t)
|
||||
provider := tailnet.NewResumeTokenKeyProvider(key, clock, tailnet.DefaultResumeTokenExpiry)
|
||||
token, err := provider.GenerateResumeToken(id)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, token)
|
||||
require.NotEmpty(t, token.Token)
|
||||
require.Equal(t, tailnet.DefaultResumeTokenExpiry/2, token.RefreshIn.AsDuration())
|
||||
require.WithinDuration(t, clock.Now().Add(tailnet.DefaultResumeTokenExpiry), token.ExpiresAt.AsTime(), time.Second)
|
||||
|
||||
gotID, err := provider.VerifyResumeToken(token.Token)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, gotID)
|
||||
})
|
||||
|
||||
t.Run("Expired", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
id := uuid.New()
|
||||
clock := quartz.NewMock(t)
|
||||
provider := tailnet.NewResumeTokenKeyProvider(key, clock, tailnet.DefaultResumeTokenExpiry)
|
||||
token, err := provider.GenerateResumeToken(id)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, token)
|
||||
require.NotEmpty(t, token.Token)
|
||||
require.Equal(t, tailnet.DefaultResumeTokenExpiry/2, token.RefreshIn.AsDuration())
|
||||
require.WithinDuration(t, clock.Now().Add(tailnet.DefaultResumeTokenExpiry), token.ExpiresAt.AsTime(), time.Second)
|
||||
|
||||
// Advance time past expiry
|
||||
_ = clock.Advance(tailnet.DefaultResumeTokenExpiry + time.Second)
|
||||
|
||||
_, err = provider.VerifyResumeToken(token.Token)
|
||||
require.ErrorContains(t, err, "expired")
|
||||
})
|
||||
|
||||
t.Run("InvalidToken", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := tailnet.NewResumeTokenKeyProvider(key, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
|
||||
_, err := provider.VerifyResumeToken("invalid")
|
||||
require.ErrorContains(t, err, "parse JWS")
|
||||
})
|
||||
|
||||
t.Run("VerifyError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Generate a resume token with a different key
|
||||
otherKey, err := tailnet.GenerateResumeTokenSigningKey()
|
||||
require.NoError(t, err)
|
||||
otherProvider := tailnet.NewResumeTokenKeyProvider(otherKey, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
|
||||
token, err := otherProvider.GenerateResumeToken(uuid.New())
|
||||
require.NoError(t, err)
|
||||
|
||||
provider := tailnet.NewResumeTokenKeyProvider(key, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
|
||||
_, err = provider.VerifyResumeToken(token.Token)
|
||||
require.ErrorContains(t, err, "verify JWS")
|
||||
})
|
||||
}
|
||||
@@ -43,6 +43,7 @@ type ClientServiceOptions struct {
|
||||
DERPMapUpdateFrequency time.Duration
|
||||
DERPMapFn func() *tailcfg.DERPMap
|
||||
NetworkTelemetryHandler func(batch []*proto.TelemetryEvent)
|
||||
ResumeTokenProvider ResumeTokenProvider
|
||||
}
|
||||
|
||||
// ClientService is a tailnet coordination service that accepts a connection and version from a
|
||||
@@ -66,6 +67,7 @@ func NewClientService(options ClientServiceOptions) (
|
||||
DerpMapUpdateFrequency: options.DERPMapUpdateFrequency,
|
||||
DerpMapFn: options.DERPMapFn,
|
||||
NetworkTelemetryHandler: options.NetworkTelemetryHandler,
|
||||
ResumeTokenProvider: options.ResumeTokenProvider,
|
||||
}
|
||||
err := proto.DRPCRegisterTailnet(mux, drpcService)
|
||||
if err != nil {
|
||||
@@ -127,6 +129,7 @@ type DRPCService struct {
|
||||
DerpMapUpdateFrequency time.Duration
|
||||
DerpMapFn func() *tailcfg.DERPMap
|
||||
NetworkTelemetryHandler func(batch []*proto.TelemetryEvent)
|
||||
ResumeTokenProvider ResumeTokenProvider
|
||||
}
|
||||
|
||||
func (s *DRPCService) PostTelemetry(_ context.Context, req *proto.TelemetryRequest) (*proto.TelemetryResponse, error) {
|
||||
@@ -167,6 +170,19 @@ func (s *DRPCService) StreamDERPMaps(_ *proto.StreamDERPMapsRequest, stream prot
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DRPCService) RefreshResumeToken(ctx context.Context, _ *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) {
|
||||
streamID, ok := ctx.Value(streamIDContextKey{}).(StreamID)
|
||||
if !ok {
|
||||
return nil, xerrors.New("no Stream ID")
|
||||
}
|
||||
|
||||
res, err := s.ResumeTokenProvider.GenerateResumeToken(streamID.ID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("generate resume token: %w", err)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *DRPCService) Coordinate(stream proto.DRPCTailnet_CoordinateStream) error {
|
||||
ctx := stream.Context()
|
||||
streamID, ok := ctx.Value(streamIDContextKey{}).(StreamID)
|
||||
|
||||
@@ -40,6 +40,7 @@ func TestClientService_ServeClient_V2(t *testing.T) {
|
||||
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {
|
||||
telemetryEvents <- batch
|
||||
},
|
||||
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -144,6 +145,7 @@ func TestClientService_ServeClient_V1(t *testing.T) {
|
||||
DERPMapUpdateFrequency: 0,
|
||||
DERPMapFn: nil,
|
||||
NetworkTelemetryHandler: nil,
|
||||
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -181,6 +181,7 @@ func (o SimpleServerOptions) Router(t *testing.T, logger slog.Logger) *chi.Mux {
|
||||
}
|
||||
},
|
||||
NetworkTelemetryHandler: func(batch []*tailnetproto.TelemetryEvent) {},
|
||||
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -23,6 +23,9 @@ func RequireRecvCtx[A any](ctx context.Context, t testing.TB, c <-chan A) (a A)
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: no AssertRecvCtx because it'd be bad if we returned a default value on
|
||||
// the cases it times out.
|
||||
|
||||
func RequireSendCtx[A any](ctx context.Context, t testing.TB, c chan<- A, a A) {
|
||||
t.Helper()
|
||||
select {
|
||||
@@ -32,3 +35,13 @@ func RequireSendCtx[A any](ctx context.Context, t testing.TB, c chan<- A, a A) {
|
||||
// OK!
|
||||
}
|
||||
}
|
||||
|
||||
func AssertSendCtx[A any](ctx context.Context, t testing.TB, c chan<- A, a A) {
|
||||
t.Helper()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("timeout")
|
||||
case c <- a:
|
||||
// OK!
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user