Compare commits

...

31 Commits

Author SHA1 Message Date
gcp-cherry-pick-bot[bot] f895f94702 fix: stop extending API key access if OIDC refresh is available (cherry-pick #17878) (#17962)
Cherry-picked fix: stop extending API key access if OIDC refresh is
available (#17878)

fixes #17070

Cleans up our handling of APIKey expiration and OIDC to keep them
separate concepts. For an OIDC-login APIKey, both the APIKey and OIDC
link must be valid to login. If the OIDC link is expired and we have a
refresh token, we will attempt to refresh.

OIDC refreshes do not have any effect on APIKey expiry.

https://github.com/coder/coder/issues/17070#issuecomment-2886183613
explains why this is the correct behavior.

Co-authored-by: Spike Curtis <spike@coder.com>
2025-05-21 11:21:48 +04:00
gcp-cherry-pick-bot[bot] 186d9b09fb chore: update alpine 3.21.2 => 3.21.3 (cherry-pick #17773) (#17801)
Co-authored-by: Charlie Voiselle <464492+angrycub@users.noreply.github.com>
2025-05-14 13:01:55 +05:00
Dean Sheather 8329bea8b7 chore: apply Dockerfile architecture fix (#17602) 2025-04-29 09:38:13 -05:00
gcp-cherry-pick-bot[bot] 8fe94e8c47 fix(scripts/release): handle cherry-pick bot titles in check commit metadata (cherry-pick #17535) (#17538)
Co-authored-by: Mathias Fredriksson <mafredri@gmail.com>
2025-04-23 17:30:32 +05:00
gcp-cherry-pick-bot[bot] 7670820b42 test: add unit test to excercise bug when idp sync hits deleted orgs (cherry-pick #17405) (#17423)
Co-authored-by: Steven Masley <Emyrk@users.noreply.github.com>
2025-04-22 19:28:51 +05:00
gcp-cherry-pick-bot[bot] b56ffe01b9 chore: prevent null loading sync settings (cherry-pick #17430) (#17487)
Co-authored-by: Steven Masley <Emyrk@users.noreply.github.com>
2025-04-22 18:40:40 +05:00
Michael Suchacz fc8454b475 feat: extend request logs with auth & DB info (#17498) 2025-04-22 09:56:03 +02:00
gcp-cherry-pick-bot[bot] 550b81f1ea feat: add path & method labels to prometheus metrics for current requests (cherry-pick #17362) (#17493)
Cherry-picked feat: add path & method labels to prometheus metrics for
current requests (#17362)

Closes: #17212

Co-authored-by: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com>
2025-04-22 09:17:54 +02:00
gcp-cherry-pick-bot[bot] 7d0e909ca2 fix: reduce excessive logging when database is unreachable (cherry-pick #17363) (#17420)
Cherry-picked fix: reduce excessive logging when database is unreachable
(#17363)

Fixes #17045

---------

Signed-off-by: Danny Kopping <dannykopping@gmail.com>

Signed-off-by: Danny Kopping <dannykopping@gmail.com>
Co-authored-by: Danny Kopping <danny@coder.com>
2025-04-16 15:54:18 +02:00
gcp-cherry-pick-bot[bot] 828f149b88 chore: fix gpg forwarding test (cherry-pick #17355) (#17415)
Cherry-picked chore: fix gpg forwarding test (#17355)

Co-authored-by: Dean Sheather <dean@deansheather.com>
2025-04-16 15:03:25 +02:00
gcp-cherry-pick-bot[bot] 8a3ebaca01 fix: fix deployment settings navigation issues (cherry-pick #16780) (#16790)
Cherry-picked fix: fix deployment settings navigation issues (#16780)

Co-authored-by: ケイラ <mckayla@hey.com>
2025-03-17 16:00:39 -05:00
gcp-cherry-pick-bot[bot] 3569e56cf9 fix: hide deleted users from org members query (cherry-pick #16830) (#16832)
Cherry-picked fix: hide deleted users from org members query (#16830)

Co-authored-by: ケイラ <mckayla@hey.com>
2025-03-17 11:13:40 -05:00
gcp-cherry-pick-bot[bot] dba881fc7d fix: fix audit log search (cherry-pick #16944) (#16945)
Cherry-picked fix: fix audit log search (#16944)

Co-authored-by: ケイラ <mckayla@hey.com>
2025-03-14 20:37:38 -05:00
gcp-cherry-pick-bot[bot] b615a35d43 chore: use org-scoped roles for organization user e2e tests (cherry-pick #16691) (#16793)
Cherry-picked chore: use org-scoped roles for organization groups and
members e2e tests (#16691)

Co-authored-by: ケイラ <mckayla@hey.com>
2025-03-14 20:18:13 -05:00
gcp-cherry-pick-bot[bot] bdd7794e85 fix: remove provisioners from deployment sidebar (cherry-pick #16717) (#16927)
Cherry-picked fix: remove provisioners from deployment sidebar (#16717)

Provisioners should be only under orgs. This is a left over from a
previous provisioner refactoring.

closes #16921

Co-authored-by: Bruno Quaresma <bruno@coder.com>
2025-03-14 21:50:16 +05:00
gcp-cherry-pick-bot[bot] 03b5012846 feat: update default audit log avatar (cherry-pick #16774) (#16805)
Cherry-picked feat: update default audit log avatar (#16774)

After update:


![image](https://github.com/user-attachments/assets/2ac6707f-2a56-45ec-a88f-651826776744)

Co-authored-by: Bruno Quaresma <bruno@coder.com>
2025-03-05 00:21:19 +05:00
gcp-cherry-pick-bot[bot] a5eb06e3f4 fix: add org role read perm to site template admins and auditors (cherry-pick #16733) (#16787)
Cherry-picked fix: add org role read permissions to site wide template
admins and auditors (#16733)

resolves coder/internal#388

Since site-wide admins and auditors are able to access the members page
of any org, they should have read access to org roles

Co-authored-by: Jaayden Halko <jaayden.halko@gmail.com>
2025-03-03 18:43:01 -06:00
gcp-cherry-pick-bot[bot] 8aec4f2c21 chore: create collapsible summary component (cherry-pick #16705) (#16794)
Cherry-picked chore: create collapsible summary component (#16705)

This is based on the Figma designs here:

https://www.figma.com/design/WfqIgsTFXN2BscBSSyXWF8/Coder-kit?node-id=507-1525&m=dev

---------

Co-authored-by: Steven Masley <stevenmasley@gmail.com>

Co-authored-by: Jaayden Halko <jaayden.halko@gmail.com>
Co-authored-by: Steven Masley <stevenmasley@gmail.com>
2025-03-03 18:41:30 -06:00
gcp-cherry-pick-bot[bot] e54e31e9f4 chore: add an unassign action for roles (cherry-pick #16728) (#16791)
Cherry-picked chore: add an unassign action for roles (#16728)

Co-authored-by: ケイラ <mckayla@hey.com>
2025-03-03 18:38:54 -06:00
gcp-cherry-pick-bot[bot] 32dc903d77 fix: allow viewOrgRoles for custom roles page (cherry-pick #16722) (#16789)
Cherry-picked fix: allow viewOrgRoles for custom roles page (#16722)

Users with viewOrgRoles should be able to see customs roles page as this
matches the left sidebar permissions.

Co-authored-by: Jaayden Halko <jaayden.halko@gmail.com>
2025-03-03 18:31:13 -06:00
gcp-cherry-pick-bot[bot] 7381f9a6c4 chore: warn user without permissions to view org members (cherry-pick #16721) (#16788)
Cherry-picked chore: warn user without permissions to view org members
(#16721)

resolves coder/internal#392

In situations where a user accesses the org members without any
permissions beyond that of a normal member, they will only be able to
see themselves in the list of members.

This PR shows a warning to users who arrive at the members page in this
situation.

<img width="1145" alt="Screenshot 2025-02-26 at 18 36 59"

src="https://github.com/user-attachments/assets/16ad6ce1-2aa9-4719-bdae-914aff0fcd52"
/>

Co-authored-by: Jaayden Halko <jaayden.halko@gmail.com>
2025-03-03 18:20:22 -06:00
gcp-cherry-pick-bot[bot] 4633658d59 feat: implement WorkspaceCreationBan org role (cherry-pick #16686) (#16786)
Cherry-picked feat: implement WorkspaceCreationBan org role (#16686)

Using negative permissions, this role prevents a user's ability to
create & delete a workspace within a given organization.

Workspaces are uniquely owned by an org and a user, so the org has to
supercede the user permission with a negative permission.

# Use case

Organizations must be able to restrict a member's ability to create a
workspace. This permission is implicitly granted (see
https://github.com/coder/coder/issues/16546#issuecomment-2655437860).

To revoke this permission, the solution chosen was to use negative
permissions in a built in role called `WorkspaceCreationBan`.

# Rational

Using negative permissions is new territory, and not ideal. However,
workspaces are in a unique position.

Workspaces have 2 owners. The organization and the user. To prevent
users from creating a workspace in another organization, an [implied
negative

permission](https://github.com/coder/coder/blob/36d9f5ddb3d98029fee07d004709e1e51022e979/coderd/rbac/policy.rego#L172-L192)
is used. So the truth table looks like: _how to read this table

[here](https://github.com/coder/coder/blob/36d9f5ddb3d98029fee07d004709e1e51022e979/coderd/rbac/README.md#roles)_

| Role (example)  | Site | Org  | User | Result |
|-----------------|------|------|------|--------|
| non-org-member  | \_   | N    | YN\_ | N      |
| user            | \_   | \_   | Y    | Y      |
| WorkspaceBan    | \_   | N    | Y    | Y      |
| unauthenticated | \_   | \_   | \_   | N      |


This new role, `WorkspaceCreationBan` is the same truth table condition
as if the user was not a member of the organization (when doing a
workspace create/delete). So this behavior **is not entirely new**.

<details>

<summary>How to do it without a negative permission</summary>

The alternate approach would be to remove the implied permission, and
grant it via and organization role. However this would add new behavior
that an organizational role has the ability to grant a user permissions
on their own resources?

It does not make sense for an org role to prevent user from changing
their profile information for example. So the only option is to create a
new truth table column for resources that are owned by both an
organization and a user.

| Role (example)  | Site | Org  |User+Org| User | Result |
|-----------------|------|------|--------|------|--------|
| non-org-member  | \_   | N    |  \_    | \_   | N      |
| user            | \_   | \_   |  \_    | \_   | N      |
| WorkspaceAllow  | \_   | \_   |   Y    | \_   | Y      |
| unauthenticated | \_   | \_   |  \_    | \_   | N      |

Now a user has no opinion on if they can create a workspace, which feels
a little wrong. A user should have the authority over what is theres.

There is fundamental _philosophical_ question of "Who does a workspace
belong to?". The user has some set of autonomy, yet it is the
organization that controls it's existence. A head scratcher 🤔

</details>

## Will we need more negative built in roles?

There are few resources that have shared ownership. Only
`ResourceOrganizationMember` and `ResourceGroupMember`. Since negative
permissions is intended to revoke access to a shared resource, then
**no.** **This is the only one we need**.

Classic resources like `ResourceTemplate` are entirely controlled by the
Organization permissions. And resources entirely in the user control
(like user profile) are only controlled by `User` permissions.


![Uploading Screenshot 2025-02-26 at 22.26.52.png…]()

---------

Co-authored-by: Jaayden Halko <jaayden.halko@gmail.com>
Co-authored-by: ケイラ <mckayla@hey.com>

Co-authored-by: Steven Masley <Emyrk@users.noreply.github.com>
Co-authored-by: Jaayden Halko <jaayden.halko@gmail.com>
Co-authored-by: ケイラ <mckayla@hey.com>
2025-03-03 18:17:34 -06:00
gcp-cherry-pick-bot[bot] 6da3c9d48c fix: allow orgs with default github provider (cherry-pick #16755) (#16784)
Cherry-picked fix: allow orgs with default github provider (#16755)

This PR fixes 2 bugs:

## Problem 1

The server would fail to start when the default github provider was
configured and the flag `--oauth2-github-allowed-orgs` was set. The
error was

```
error: configure github oauth2: allow everyone and allowed orgs cannot be used together
```

This PR fixes it by enabling "allow everone" with the default provider
only if "allowed orgs" isn't set.

## Problem 2

The default github provider uses the device flow to authorize users, and
that's handled differently by our web UI than the standard oauth flow.
In particular, the web UI only handles JSON responses rather than HTTP
redirects. There were 2 code paths that returned redirects, and the PR
changes them to return JSON messages instead if the device flow is
configured.

Co-authored-by: Hugo Dutka <hugo@coder.com>
2025-03-03 17:49:35 -06:00
gcp-cherry-pick-bot[bot] 99a5d72a8d docs: suggest disabling the default GitHub OAuth2 provider on k8s (cherry-pick #16758) (#16783)
Cherry-picked docs: suggest disabling the default GitHub OAuth2 provider
on k8s (#16758)

For production deployments we recommend disabling the default GitHub
OAuth2 app managed by Coder. This PR mentions it in k8s installation
docs and the helm README so users can stumble upon it more easily.

Co-authored-by: Hugo Dutka <hugo@coder.com>
2025-03-03 17:48:55 -06:00
gcp-cherry-pick-bot[bot] fc0db40791 docs: document default GitHub OAuth2 configuration and device flow (2.20) (#16782)
Cherry-picked docs: document default GitHub OAuth2 configuration and
device flow (#16663)

Document the changes made in https://github.com/coder/coder/pull/16629
and https://github.com/coder/coder/pull/16585.

Co-authored-by: Hugo Dutka <hugo@coder.com>
2025-03-03 17:48:24 -06:00
gcp-cherry-pick-bot[bot] b7ea479de3 chore: track workspace resource monitors in telemetry (cherry-pick #16776) (#16779)
Cherry-picked chore: track workspace resource monitors in telemetry
(#16776)

Addresses https://github.com/coder/nexus/issues/195. Specifically, just
the "tracking templates" requirement:

> ## Tracking in templates
> To enable resource alerts, a user must add the resource_monitoring
block to a template's coder_agent resource. We'd like to track if
customers have any resource monitoring enabled on a per-deployment
basis. Even better, we could identify which templates are using resource
monitoring.

Co-authored-by: Hugo Dutka <hugo@coder.com>
2025-03-03 14:44:13 -06:00
gcp-cherry-pick-bot[bot] 735dc5d794 feat(agent): add second SSH listener on port 22 (cherry-pick #16627) (#16763)
Cherry-picked feat(agent): add second SSH listener on port 22 (#16627)

Fixes: https://github.com/coder/internal/issues/377

Added an additional SSH listener on port 22, so the agent now listens on both, port one and port 22.

---
Change-Id: Ifd986b260f8ac317e37d65111cd4e0bd1dc38af8
Signed-off-by: Thomas Kosiewski <tk@coder.com>
2025-03-03 08:57:47 +01:00
gcp-cherry-pick-bot[bot] 114cf57580 fix: handle undefined job while updating build progress (cherry-pick #16732) (#16740)
Cherry-picked fix: handle undefined job while updating build progress
(#16732)

Fixes: https://github.com/coder/coder/issues/15444

Co-authored-by: Marcin Tojek <mtojek@users.noreply.github.com>
2025-02-28 15:08:19 +05:00
M Atif Ali 36186bbb78 feat: include winres metadata in Windows binaries (cherry-pick #16706) (#16742)
cherry picks #16706 to `release/2.20`

---------

Co-authored-by: Dean Sheather <dean@deansheather.com>
2025-02-28 13:47:04 +05:00
gcp-cherry-pick-bot[bot] 780b2714ff fix(vpn): fail early if wintun.dll is not present (cherry-pick #16707) (#16738)
Cherry-picked fix(vpn): fail early if wintun.dll is not present (#16707)

Prevents the VPN startup from hanging for 5 minutes due to a startup
backoff if `wintun.dll` cannot be loaded.

Because the `wintun` package doesn't expose an easy `Load() error`
method for us, the only way for us to force it to load (without unwanted
side effects) is through `wintun.Version()` which doesn't return an
error message.

So, we call that function so the `wintun` package loads the DLL and
configures the logging properly, then we try to load the DLL ourselves.
`LoadLibraryEx` will not load the library multiple times and returns a
reference to the existing library.

Closes https://github.com/coder/coder-desktop-windows/issues/24

Co-authored-by: Dean Sheather <dean@deansheather.com>
2025-02-28 12:47:59 +05:00
gcp-cherry-pick-bot[bot] 34740bc242 chore: update tailscale (cherry-pick #16737) (#16739)
Cherry-picked chore: update tailscale (#16737)

Co-authored-by: Dean Sheather <dean@deansheather.com>
2025-02-28 12:47:44 +05:00
129 changed files with 3426 additions and 982 deletions
+50 -3
View File
@@ -1021,7 +1021,10 @@ jobs:
if: github.ref == 'refs/heads/main' && needs.changes.outputs.docs-only == 'false' && !github.event.pull_request.head.repo.fork
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-22.04' }}
permissions:
packages: write # Needed to push images to ghcr.io
# Necessary to push docker images to ghcr.io.
packages: write
# Necessary for GCP authentication (https://github.com/google-github-actions/setup-gcloud#usage)
id-token: write
env:
DOCKER_CLI_EXPERIMENTAL: "enabled"
outputs:
@@ -1050,12 +1053,44 @@ jobs:
- name: Setup Go
uses: ./.github/actions/setup-go
# Necessary for signing Windows binaries.
- name: Setup Java
uses: actions/setup-java@3a4f6e1af504cf6a31855fa899c6aa5355ba6c12 # v4.7.0
with:
distribution: "zulu"
java-version: "11.0"
- name: Install go-winres
run: go install github.com/tc-hib/go-winres@d743268d7ea168077ddd443c4240562d4f5e8c3e # v0.3.3
- name: Install nfpm
run: go install github.com/goreleaser/nfpm/v2/cmd/nfpm@v2.35.1
- name: Install zstd
run: sudo apt-get install -y zstd
- name: Setup Windows EV Signing Certificate
run: |
set -euo pipefail
touch /tmp/ev_cert.pem
chmod 600 /tmp/ev_cert.pem
echo "$EV_SIGNING_CERT" > /tmp/ev_cert.pem
wget https://github.com/ebourg/jsign/releases/download/6.0/jsign-6.0.jar -O /tmp/jsign-6.0.jar
env:
EV_SIGNING_CERT: ${{ secrets.EV_SIGNING_CERT }}
# Setup GCloud for signing Windows binaries.
- name: Authenticate to Google Cloud
id: gcloud_auth
uses: google-github-actions/auth@71f986410dfbc7added4569d411d040a91dc6935 # v2.1.8
with:
workload_identity_provider: ${{ secrets.GCP_CODE_SIGNING_WORKLOAD_ID_PROVIDER }}
service_account: ${{ secrets.GCP_CODE_SIGNING_SERVICE_ACCOUNT }}
token_format: "access_token"
- name: Setup GCloud SDK
uses: google-github-actions/setup-gcloud@77e7a554d41e2ee56fc945c52dfd3f33d12def9a # v2.1.4
- name: Download dylibs
uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
with:
@@ -1082,6 +1117,18 @@ jobs:
build/coder_linux_{amd64,arm64,armv7} \
build/coder_"$version"_windows_amd64.zip \
build/coder_"$version"_linux_amd64.{tar.gz,deb}
env:
# The Windows slim binary must be signed for Coder Desktop to accept
# it. The darwin executables don't need to be signed, but the dylibs
# do (see above).
CODER_SIGN_WINDOWS: "1"
CODER_WINDOWS_RESOURCES: "1"
EV_KEY: ${{ secrets.EV_KEY }}
EV_KEYSTORE: ${{ secrets.EV_KEYSTORE }}
EV_TSA_URL: ${{ secrets.EV_TSA_URL }}
EV_CERTIFICATE_PATH: /tmp/ev_cert.pem
GCLOUD_ACCESS_TOKEN: ${{ steps.gcloud_auth.outputs.access_token }}
JSIGN_PATH: /tmp/jsign-6.0.jar
- name: Build Linux Docker images
id: build-docker
@@ -1183,10 +1230,10 @@ jobs:
uses: google-github-actions/setup-gcloud@77e7a554d41e2ee56fc945c52dfd3f33d12def9a # v2.1.4
- name: Set up Flux CLI
uses: fluxcd/flux2/action@af67405ee43a6cd66e0b73f4b3802e8583f9d961 # v2.5.0
uses: fluxcd/flux2/action@8d5f40dca5aa5d3c0fc3414457dda15a0ac92fa4 # v2.5.1
with:
# Keep this and the github action up to date with the version of flux installed in dogfood cluster
version: "2.2.1"
version: "2.5.1"
- name: Get Cluster Credentials
uses: google-github-actions/get-gke-credentials@7a108e64ed8546fe38316b4086e91da13f4785e1 # v2.3.1
+16 -12
View File
@@ -223,21 +223,12 @@ jobs:
distribution: "zulu"
java-version: "11.0"
- name: Install go-winres
run: go install github.com/tc-hib/go-winres@d743268d7ea168077ddd443c4240562d4f5e8c3e # v0.3.3
- name: Install nsis and zstd
run: sudo apt-get install -y nsis zstd
- name: Download dylibs
uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
with:
name: dylibs
path: ./build
- name: Insert dylibs
run: |
mv ./build/*amd64.dylib ./site/out/bin/coder-vpn-darwin-amd64.dylib
mv ./build/*arm64.dylib ./site/out/bin/coder-vpn-darwin-arm64.dylib
mv ./build/*arm64.h ./site/out/bin/coder-vpn-darwin-dylib.h
- name: Install nfpm
run: |
set -euo pipefail
@@ -294,6 +285,18 @@ jobs:
- name: Setup GCloud SDK
uses: google-github-actions/setup-gcloud@77e7a554d41e2ee56fc945c52dfd3f33d12def9a # v2.1.4
- name: Download dylibs
uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
with:
name: dylibs
path: ./build
- name: Insert dylibs
run: |
mv ./build/*amd64.dylib ./site/out/bin/coder-vpn-darwin-amd64.dylib
mv ./build/*arm64.dylib ./site/out/bin/coder-vpn-darwin-arm64.dylib
mv ./build/*arm64.h ./site/out/bin/coder-vpn-darwin-dylib.h
- name: Build binaries
run: |
set -euo pipefail
@@ -310,6 +313,7 @@ jobs:
env:
CODER_SIGN_WINDOWS: "1"
CODER_SIGN_DARWIN: "1"
CODER_WINDOWS_RESOURCES: "1"
AC_CERTIFICATE_FILE: /tmp/apple_cert.p12
AC_CERTIFICATE_PASSWORD_FILE: /tmp/apple_cert_password.txt
AC_APIKEY_ISSUER_ID: ${{ secrets.AC_APIKEY_ISSUER_ID }}
+3
View File
@@ -78,3 +78,6 @@ result
# Zed
.zed_server
# dlv debug binaries for go tests
__debug_bin*
+6 -1
View File
@@ -564,7 +564,8 @@ GEN_FILES := \
examples/examples.gen.json \
$(TAILNETTEST_MOCKS) \
coderd/database/pubsub/psmock/psmock.go \
agent/agentcontainers/acmock/acmock.go
agent/agentcontainers/acmock/acmock.go \
coderd/httpmw/loggermw/loggermock/loggermock.go
# all gen targets should be added here and to gen/mark-fresh
@@ -600,6 +601,7 @@ gen/mark-fresh:
$(TAILNETTEST_MOCKS) \
coderd/database/pubsub/psmock/psmock.go \
agent/agentcontainers/acmock/acmock.go \
coderd/httpmw/loggermw/loggermock/loggermock.go
"
for file in $$files; do
@@ -634,6 +636,9 @@ coderd/database/pubsub/psmock/psmock.go: coderd/database/pubsub/pubsub.go
agent/agentcontainers/acmock/acmock.go: agent/agentcontainers/containers.go
go generate ./agent/agentcontainers/acmock/
coderd/httpmw/loggermw/loggermock/loggermock.go: coderd/httpmw/loggermw/logger.go
go generate ./coderd/httpmw/loggermw/loggermock/
$(TAILNETTEST_MOCKS): tailnet/coordinator.go tailnet/service.go
go generate ./tailnet/tailnettest/
+14 -11
View File
@@ -1193,19 +1193,22 @@ func (a *agent) createTailnet(
return nil, xerrors.Errorf("update host signer: %w", err)
}
sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentSSHPort))
if err != nil {
return nil, xerrors.Errorf("listen on the ssh port: %w", err)
}
defer func() {
for _, port := range []int{workspacesdk.AgentSSHPort, workspacesdk.AgentStandardSSHPort} {
sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(port))
if err != nil {
_ = sshListener.Close()
return nil, xerrors.Errorf("listen on the ssh port (%v): %w", port, err)
}
// nolint:revive // We do want to run the deferred functions when createTailnet returns.
defer func() {
if err != nil {
_ = sshListener.Close()
}
}()
if err = a.trackGoroutine(func() {
_ = a.sshServer.Serve(sshListener)
}); err != nil {
return nil, err
}
}()
if err = a.trackGoroutine(func() {
_ = a.sshServer.Serve(sshListener)
}); err != nil {
return nil, err
}
reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentReconnectingPTYPort))
+120 -79
View File
@@ -61,38 +61,48 @@ func TestMain(m *testing.M) {
goleak.VerifyTestMain(m, testutil.GoleakOptions...)
}
var sshPorts = []uint16{workspacesdk.AgentSSHPort, workspacesdk.AgentStandardSSHPort}
// NOTE: These tests only work when your default shell is bash for some reason.
func TestAgent_Stats_SSH(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
for _, port := range sshPorts {
port := port
t.Run(fmt.Sprintf("(:%d)", port), func(t *testing.T) {
t.Parallel()
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
stdin, err := session.StdinPipe()
require.NoError(t, err)
err = session.Shell()
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
var s *proto.Stats
require.Eventuallyf(t, func() bool {
var ok bool
s, ok = <-stats
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSsh == 1
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats: %+v", s,
)
_ = stdin.Close()
err = session.Wait()
require.NoError(t, err)
//nolint:dogsled
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClientOnPort(ctx, port)
require.NoError(t, err)
defer sshClient.Close()
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
stdin, err := session.StdinPipe()
require.NoError(t, err)
err = session.Shell()
require.NoError(t, err)
var s *proto.Stats
require.Eventuallyf(t, func() bool {
var ok bool
s, ok = <-stats
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSsh == 1
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats: %+v", s,
)
_ = stdin.Close()
err = session.Wait()
require.NoError(t, err)
})
}
}
func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
@@ -266,15 +276,23 @@ func TestAgent_Stats_Magic(t *testing.T) {
func TestAgent_SessionExec(t *testing.T) {
t.Parallel()
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
command := "echo test"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo test"
for _, port := range sshPorts {
port := port
t.Run(fmt.Sprintf("(:%d)", port), func(t *testing.T) {
t.Parallel()
session := setupSSHSessionOnPort(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil, port)
command := "echo test"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo test"
}
output, err := session.Output(command)
require.NoError(t, err)
require.Equal(t, "test", strings.TrimSpace(string(output)))
})
}
output, err := session.Output(command)
require.NoError(t, err)
require.Equal(t, "test", strings.TrimSpace(string(output)))
}
//nolint:tparallel // Sub tests need to run sequentially.
@@ -384,25 +402,33 @@ func TestAgent_SessionTTYShell(t *testing.T) {
// it seems like it could be either.
t.Skip("ConPTY appears to be inconsistent on Windows.")
}
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
command := "sh"
if runtime.GOOS == "windows" {
command = "cmd.exe"
for _, port := range sshPorts {
port := port
t.Run(fmt.Sprintf("(%d)", port), func(t *testing.T) {
t.Parallel()
session := setupSSHSessionOnPort(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil, port)
command := "sh"
if runtime.GOOS == "windows" {
command = "cmd.exe"
}
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
require.NoError(t, err)
ptty := ptytest.New(t)
session.Stdout = ptty.Output()
session.Stderr = ptty.Output()
session.Stdin = ptty.Input()
err = session.Start(command)
require.NoError(t, err)
_ = ptty.Peek(ctx, 1) // wait for the prompt
ptty.WriteLine("echo test")
ptty.ExpectMatch("test")
ptty.WriteLine("exit")
err = session.Wait()
require.NoError(t, err)
})
}
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
require.NoError(t, err)
ptty := ptytest.New(t)
session.Stdout = ptty.Output()
session.Stderr = ptty.Output()
session.Stdin = ptty.Input()
err = session.Start(command)
require.NoError(t, err)
_ = ptty.Peek(ctx, 1) // wait for the prompt
ptty.WriteLine("echo test")
ptty.ExpectMatch("test")
ptty.WriteLine("exit")
err = session.Wait()
require.NoError(t, err)
}
func TestAgent_SessionTTYExitCode(t *testing.T) {
@@ -596,37 +622,41 @@ func TestAgent_Session_TTY_MOTD_Update(t *testing.T) {
//nolint:dogsled // Allow the blank identifiers.
conn, client, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, setSBInterval)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
t.Cleanup(func() {
_ = sshClient.Close()
})
//nolint:paralleltest // These tests need to swap the banner func.
for i, test := range tests {
test := test
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
// Set new banner func and wait for the agent to call it to update the
// banner.
ready := make(chan struct{}, 2)
client.SetAnnouncementBannersFunc(func() ([]codersdk.BannerConfig, error) {
select {
case ready <- struct{}{}:
default:
}
return []codersdk.BannerConfig{test.banner}, nil
})
<-ready
<-ready // Wait for two updates to ensure the value has propagated.
for _, port := range sshPorts {
port := port
session, err := sshClient.NewSession()
require.NoError(t, err)
t.Cleanup(func() {
_ = session.Close()
})
testSessionOutput(t, session, test.expected, test.unexpected, nil)
sshClient, err := conn.SSHClientOnPort(ctx, port)
require.NoError(t, err)
t.Cleanup(func() {
_ = sshClient.Close()
})
for i, test := range tests {
test := test
t.Run(fmt.Sprintf("(:%d)/%d", port, i), func(t *testing.T) {
// Set new banner func and wait for the agent to call it to update the
// banner.
ready := make(chan struct{}, 2)
client.SetAnnouncementBannersFunc(func() ([]codersdk.BannerConfig, error) {
select {
case ready <- struct{}{}:
default:
}
return []codersdk.BannerConfig{test.banner}, nil
})
<-ready
<-ready // Wait for two updates to ensure the value has propagated.
session, err := sshClient.NewSession()
require.NoError(t, err)
t.Cleanup(func() {
_ = session.Close()
})
testSessionOutput(t, session, test.expected, test.unexpected, nil)
})
}
}
}
@@ -2313,6 +2343,17 @@ func setupSSHSession(
banner codersdk.BannerConfig,
prepareFS func(fs afero.Fs),
opts ...func(*agenttest.Client, *agent.Options),
) *ssh.Session {
return setupSSHSessionOnPort(t, manifest, banner, prepareFS, workspacesdk.AgentSSHPort, opts...)
}
func setupSSHSessionOnPort(
t *testing.T,
manifest agentsdk.Manifest,
banner codersdk.BannerConfig,
prepareFS func(fs afero.Fs),
port uint16,
opts ...func(*agenttest.Client, *agent.Options),
) *ssh.Session {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
@@ -2326,7 +2367,7 @@ func setupSSHSession(
if prepareFS != nil {
prepareFS(fs)
}
sshClient, err := conn.SSHClient(ctx)
sshClient, err := conn.SSHClientOnPort(ctx, port)
require.NoError(t, err)
t.Cleanup(func() {
_ = sshClient.Close()
+1 -1
View File
@@ -17,7 +17,7 @@ func Get(username string) (string, error) {
return "", xerrors.Errorf("username is nonlocal path: %s", username)
}
//nolint: gosec // input checked above
out, _ := exec.Command("dscl", ".", "-read", filepath.Join("/Users", username), "UserShell").Output()
out, _ := exec.Command("dscl", ".", "-read", filepath.Join("/Users", username), "UserShell").Output() //nolint:gocritic
s, ok := strings.CutPrefix(string(out), "UserShell: ")
if ok {
return strings.TrimSpace(s), nil
+1
View File
@@ -0,0 +1 @@
*.syso
+8
View File
@@ -0,0 +1,8 @@
// This package is used for embedding .syso resource files into the binary
// during build and does not contain any code. During build, .syso files will be
// dropped in this directory and then removed after the build completes.
//
// This package must be imported by all binaries for this to work.
//
// See build_go.sh for more details.
package resources
+3 -1
View File
@@ -1911,8 +1911,10 @@ func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *c
}
params.clientID = GithubOAuth2DefaultProviderClientID
params.allowEveryone = GithubOAuth2DefaultProviderAllowEveryone
params.deviceFlow = GithubOAuth2DefaultProviderDeviceFlow
if len(params.allowOrgs) == 0 {
params.allowEveryone = GithubOAuth2DefaultProviderAllowEveryone
}
return &params, nil
}
+10 -1
View File
@@ -314,6 +314,7 @@ func TestServer(t *testing.T) {
githubDefaultProviderEnabled string
githubClientID string
githubClientSecret string
allowedOrg string
expectGithubEnabled bool
expectGithubDefaultProviderConfigured bool
createUserPreStart bool
@@ -355,7 +356,9 @@ func TestServer(t *testing.T) {
if tc.githubDefaultProviderEnabled != "" {
args = append(args, fmt.Sprintf("--oauth2-github-default-provider-enable=%s", tc.githubDefaultProviderEnabled))
}
if tc.allowedOrg != "" {
args = append(args, fmt.Sprintf("--oauth2-github-allowed-orgs=%s", tc.allowedOrg))
}
inv, cfg := clitest.New(t, args...)
errChan := make(chan error, 1)
go func() {
@@ -439,6 +442,12 @@ func TestServer(t *testing.T) {
expectGithubEnabled: true,
expectGithubDefaultProviderConfigured: false,
},
{
name: "AllowedOrg",
allowedOrg: "coder",
expectGithubEnabled: true,
expectGithubDefaultProviderConfigured: true,
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
+3 -1
View File
@@ -1908,7 +1908,9 @@ Expire-Date: 0
tpty.WriteLine("gpg --list-keys && echo gpg-''-listkeys-command-done")
listKeysOutput := tpty.ExpectMatch("gpg--listkeys-command-done")
require.Contains(t, listKeysOutput, "[ultimate] Coder Test <test@coder.com>")
require.Contains(t, listKeysOutput, "[ultimate] Dean Sheather (work key) <dean@coder.com>")
// It's fine that this key is expired. We're just testing that the key trust
// gets synced properly.
require.Contains(t, listKeysOutput, "[ expired] Dean Sheather (work key) <dean@coder.com>")
// Try to sign something. This demonstrates that the forwarding is
// working as expected, since the workspace doesn't have access to the
+1
View File
@@ -8,6 +8,7 @@ import (
tea "github.com/charmbracelet/bubbletea"
"github.com/coder/coder/v2/agent/agentexec"
_ "github.com/coder/coder/v2/buildinfo/resources"
"github.com/coder/coder/v2/cli"
)
+2
View File
@@ -13699,6 +13699,7 @@ const docTemplate = `{
"read",
"read_personal",
"ssh",
"unassign",
"update",
"update_personal",
"use",
@@ -13714,6 +13715,7 @@ const docTemplate = `{
"ActionRead",
"ActionReadPersonal",
"ActionSSH",
"ActionUnassign",
"ActionUpdate",
"ActionUpdatePersonal",
"ActionUse",
+2
View File
@@ -12388,6 +12388,7 @@
"read",
"read_personal",
"ssh",
"unassign",
"update",
"update_personal",
"use",
@@ -12403,6 +12404,7 @@
"ActionRead",
"ActionReadPersonal",
"ActionSSH",
"ActionUnassign",
"ActionUpdate",
"ActionUpdatePersonal",
"ActionUse",
+7 -5
View File
@@ -63,6 +63,7 @@ import (
"github.com/coder/coder/v2/coderd/healthcheck/derphealth"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
"github.com/coder/coder/v2/coderd/metricscache"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/portsharing"
@@ -652,10 +653,11 @@ func New(options *Options) *API {
api.Auditor.Store(&options.Auditor)
api.TailnetCoordinator.Store(&options.TailnetCoordinator)
dialer := &InmemTailnetDialer{
CoordPtr: &api.TailnetCoordinator,
DERPFn: api.DERPMap,
Logger: options.Logger,
ClientID: uuid.New(),
CoordPtr: &api.TailnetCoordinator,
DERPFn: api.DERPMap,
Logger: options.Logger,
ClientID: uuid.New(),
DatabaseHealthCheck: api.Database,
}
stn, err := NewServerTailnet(api.ctx,
options.Logger,
@@ -787,7 +789,7 @@ func New(options *Options) *API {
tracing.Middleware(api.TracerProvider),
httpmw.AttachRequestID,
httpmw.ExtractRealIP(api.RealIPConfig),
httpmw.Logger(api.Logger),
loggermw.Logger(api.Logger),
singleSlashMW,
rolestore.CustomRoleMW,
prometheusMW,
+4 -1
View File
@@ -215,7 +215,7 @@ func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values
// WithLogging is optional, but will log some HTTP calls made to the IDP.
func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
return func(f *FakeIDP) {
f.logger = slogtest.Make(t, options)
f.logger = slogtest.Make(t, options).Named("fakeidp")
}
}
@@ -700,6 +700,7 @@ func (f *FakeIDP) newToken(t testing.TB, email string, expires time.Time) string
func (f *FakeIDP) newRefreshTokens(email string) string {
refreshToken := uuid.NewString()
f.refreshTokens.Store(refreshToken, email)
f.logger.Info(context.Background(), "new refresh token", slog.F("email", email), slog.F("token", refreshToken))
return refreshToken
}
@@ -909,6 +910,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
return
}
f.logger.Info(r.Context(), "http idp call refresh_token", slog.F("token", refreshToken))
_, ok := f.refreshTokens.Load(refreshToken)
if !assert.True(t, ok, "invalid refresh_token") {
http.Error(rw, "invalid refresh_token", http.StatusBadRequest)
@@ -932,6 +934,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
f.refreshTokensUsed.Store(refreshToken, true)
// Always invalidate the refresh token after it is used.
f.refreshTokens.Delete(refreshToken)
f.logger.Info(r.Context(), "refresh token invalidated", slog.F("token", refreshToken))
case "urn:ietf:params:oauth:grant-type:device_code":
// Device flow
var resp externalauth.ExchangeDeviceCodeResponse
+53 -69
View File
@@ -34,11 +34,12 @@ func TestInsertCustomRoles(t *testing.T) {
}
}
canAssignRole := rbac.Role{
canCreateCustomRole := rbac.Role{
Identifier: rbac.RoleIdentifier{Name: "can-assign"},
DisplayName: "",
Site: rbac.Permissions(map[string][]policy.Action{
rbac.ResourceAssignRole.Type: {policy.ActionRead, policy.ActionCreate},
rbac.ResourceAssignRole.Type: {policy.ActionRead},
rbac.ResourceAssignOrgRole.Type: {policy.ActionRead, policy.ActionCreate},
}),
}
@@ -61,17 +62,15 @@ func TestInsertCustomRoles(t *testing.T) {
return all
}
orgID := uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
}
orgID := uuid.New()
testCases := []struct {
name string
subject rbac.ExpandableRoles
// Perms to create on new custom role
organizationID uuid.NullUUID
organizationID uuid.UUID
site []codersdk.Permission
org []codersdk.Permission
user []codersdk.Permission
@@ -79,19 +78,21 @@ func TestInsertCustomRoles(t *testing.T) {
}{
{
// No roles, so no assign role
name: "no-roles",
subject: rbac.RoleIdentifiers{},
errorContains: "forbidden",
name: "no-roles",
organizationID: orgID,
subject: rbac.RoleIdentifiers{},
errorContains: "forbidden",
},
{
// This works because the new role has 0 perms
name: "empty",
subject: merge(canAssignRole),
name: "empty",
organizationID: orgID,
subject: merge(canCreateCustomRole),
},
{
name: "mixed-scopes",
subject: merge(canAssignRole, rbac.RoleOwner()),
organizationID: orgID,
subject: merge(canCreateCustomRole, rbac.RoleOwner()),
site: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
codersdk.ResourceWorkspace: {codersdk.ActionRead},
}),
@@ -101,27 +102,30 @@ func TestInsertCustomRoles(t *testing.T) {
errorContains: "organization roles specify site or user permissions",
},
{
name: "invalid-action",
subject: merge(canAssignRole, rbac.RoleOwner()),
site: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
name: "invalid-action",
organizationID: orgID,
subject: merge(canCreateCustomRole, rbac.RoleOwner()),
org: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
// Action does not go with resource
codersdk.ResourceWorkspace: {codersdk.ActionViewInsights},
}),
errorContains: "invalid action",
},
{
name: "invalid-resource",
subject: merge(canAssignRole, rbac.RoleOwner()),
site: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
name: "invalid-resource",
organizationID: orgID,
subject: merge(canCreateCustomRole, rbac.RoleOwner()),
org: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
"foobar": {codersdk.ActionViewInsights},
}),
errorContains: "invalid resource",
},
{
// Not allowing these at this time.
name: "negative-permission",
subject: merge(canAssignRole, rbac.RoleOwner()),
site: []codersdk.Permission{
name: "negative-permission",
organizationID: orgID,
subject: merge(canCreateCustomRole, rbac.RoleOwner()),
org: []codersdk.Permission{
{
Negate: true,
ResourceType: codersdk.ResourceWorkspace,
@@ -131,89 +135,69 @@ func TestInsertCustomRoles(t *testing.T) {
errorContains: "no negative permissions",
},
{
name: "wildcard", // not allowed
subject: merge(canAssignRole, rbac.RoleOwner()),
site: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
name: "wildcard", // not allowed
organizationID: orgID,
subject: merge(canCreateCustomRole, rbac.RoleOwner()),
org: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
codersdk.ResourceWorkspace: {"*"},
}),
errorContains: "no wildcard symbols",
},
// escalation checks
{
name: "read-workspace-escalation",
subject: merge(canAssignRole),
site: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
name: "read-workspace-escalation",
organizationID: orgID,
subject: merge(canCreateCustomRole),
org: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
codersdk.ResourceWorkspace: {codersdk.ActionRead},
}),
errorContains: "not allowed to grant this permission",
},
{
name: "read-workspace-outside-org",
organizationID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
subject: merge(canAssignRole, rbac.ScopedRoleOrgAdmin(orgID.UUID)),
name: "read-workspace-outside-org",
organizationID: uuid.New(),
subject: merge(canCreateCustomRole, rbac.ScopedRoleOrgAdmin(orgID)),
org: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
codersdk.ResourceWorkspace: {codersdk.ActionRead},
}),
errorContains: "forbidden",
errorContains: "not allowed to grant this permission",
},
{
name: "user-escalation",
// These roles do not grant user perms
subject: merge(canAssignRole, rbac.ScopedRoleOrgAdmin(orgID.UUID)),
organizationID: orgID,
subject: merge(canCreateCustomRole, rbac.ScopedRoleOrgAdmin(orgID)),
user: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
codersdk.ResourceWorkspace: {codersdk.ActionRead},
}),
errorContains: "not allowed to grant this permission",
errorContains: "organization roles specify site or user permissions",
},
{
name: "template-admin-escalation",
subject: merge(canAssignRole, rbac.RoleTemplateAdmin()),
name: "site-escalation",
organizationID: orgID,
subject: merge(canCreateCustomRole, rbac.RoleTemplateAdmin()),
site: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
codersdk.ResourceWorkspace: {codersdk.ActionRead}, // ok!
codersdk.ResourceDeploymentConfig: {codersdk.ActionUpdate}, // not ok!
}),
user: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
codersdk.ResourceWorkspace: {codersdk.ActionRead}, // ok!
}),
errorContains: "deployment_config",
errorContains: "organization roles specify site or user permissions",
},
// ok!
{
name: "read-workspace-template-admin",
subject: merge(canAssignRole, rbac.RoleTemplateAdmin()),
site: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
name: "read-workspace-template-admin",
organizationID: orgID,
subject: merge(canCreateCustomRole, rbac.RoleTemplateAdmin()),
org: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
codersdk.ResourceWorkspace: {codersdk.ActionRead},
}),
},
{
name: "read-workspace-in-org",
subject: merge(canAssignRole, rbac.ScopedRoleOrgAdmin(orgID.UUID)),
organizationID: orgID,
subject: merge(canCreateCustomRole, rbac.ScopedRoleOrgAdmin(orgID)),
org: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
codersdk.ResourceWorkspace: {codersdk.ActionRead},
}),
},
{
name: "user-perms",
// This is weird, but is ok
subject: merge(canAssignRole, rbac.RoleMember()),
user: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
codersdk.ResourceWorkspace: {codersdk.ActionRead},
}),
},
{
name: "site+user-perms",
subject: merge(canAssignRole, rbac.RoleMember(), rbac.RoleTemplateAdmin()),
site: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
codersdk.ResourceWorkspace: {codersdk.ActionRead},
}),
user: codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
codersdk.ResourceWorkspace: {codersdk.ActionRead},
}),
},
}
for _, tc := range testCases {
@@ -234,7 +218,7 @@ func TestInsertCustomRoles(t *testing.T) {
_, err := az.InsertCustomRole(ctx, database.InsertCustomRoleParams{
Name: "test-role",
DisplayName: "",
OrganizationID: tc.organizationID,
OrganizationID: uuid.NullUUID{UUID: tc.organizationID, Valid: true},
SitePermissions: db2sdk.List(tc.site, convertSDKPerm),
OrgPermissions: db2sdk.List(tc.org, convertSDKPerm),
UserPermissions: db2sdk.List(tc.user, convertSDKPerm),
@@ -249,11 +233,11 @@ func TestInsertCustomRoles(t *testing.T) {
LookupRoles: []database.NameOrganizationPair{
{
Name: "test-role",
OrganizationID: tc.organizationID.UUID,
OrganizationID: tc.organizationID,
},
},
ExcludeOrgRoles: false,
OrganizationID: uuid.UUID{},
OrganizationID: uuid.Nil,
})
require.NoError(t, err)
require.Len(t, roles, 1)
+78 -46
View File
@@ -24,6 +24,7 @@ import (
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi/httpapiconstraints"
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/provisionersdk"
@@ -162,6 +163,7 @@ func ActorFromContext(ctx context.Context) (rbac.Subject, bool) {
var (
subjectProvisionerd = rbac.Subject{
Type: rbac.SubjectTypeProvisionerd,
FriendlyName: "Provisioner Daemon",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
@@ -195,6 +197,7 @@ var (
}.WithCachedASTValue()
subjectAutostart = rbac.Subject{
Type: rbac.SubjectTypeAutostart,
FriendlyName: "Autostart",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
@@ -218,6 +221,7 @@ var (
// See unhanger package.
subjectHangDetector = rbac.Subject{
Type: rbac.SubjectTypeHangDetector,
FriendlyName: "Hang Detector",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
@@ -238,6 +242,7 @@ var (
// See cryptokeys package.
subjectCryptoKeyRotator = rbac.Subject{
Type: rbac.SubjectTypeCryptoKeyRotator,
FriendlyName: "Crypto Key Rotator",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
@@ -256,6 +261,7 @@ var (
// See cryptokeys package.
subjectCryptoKeyReader = rbac.Subject{
Type: rbac.SubjectTypeCryptoKeyReader,
FriendlyName: "Crypto Key Reader",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
@@ -273,6 +279,7 @@ var (
}.WithCachedASTValue()
subjectNotifier = rbac.Subject{
Type: rbac.SubjectTypeNotifier,
FriendlyName: "Notifier",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
@@ -290,6 +297,7 @@ var (
}.WithCachedASTValue()
subjectResourceMonitor = rbac.Subject{
Type: rbac.SubjectTypeResourceMonitor,
FriendlyName: "Resource Monitor",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
@@ -308,6 +316,7 @@ var (
}.WithCachedASTValue()
subjectSystemRestricted = rbac.Subject{
Type: rbac.SubjectTypeSystemRestricted,
FriendlyName: "System",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
@@ -342,6 +351,7 @@ var (
}.WithCachedASTValue()
subjectSystemReadProvisionerDaemons = rbac.Subject{
Type: rbac.SubjectTypeSystemReadProvisionerDaemons,
FriendlyName: "Provisioner Daemons Reader",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
@@ -362,53 +372,53 @@ var (
// AsProvisionerd returns a context with an actor that has permissions required
// for provisionerd to function.
func AsProvisionerd(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, subjectProvisionerd)
return As(ctx, subjectProvisionerd)
}
// AsAutostart returns a context with an actor that has permissions required
// for autostart to function.
func AsAutostart(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, subjectAutostart)
return As(ctx, subjectAutostart)
}
// AsHangDetector returns a context with an actor that has permissions required
// for unhanger.Detector to function.
func AsHangDetector(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, subjectHangDetector)
return As(ctx, subjectHangDetector)
}
// AsKeyRotator returns a context with an actor that has permissions required for rotating crypto keys.
func AsKeyRotator(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, subjectCryptoKeyRotator)
return As(ctx, subjectCryptoKeyRotator)
}
// AsKeyReader returns a context with an actor that has permissions required for reading crypto keys.
func AsKeyReader(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, subjectCryptoKeyReader)
return As(ctx, subjectCryptoKeyReader)
}
// AsNotifier returns a context with an actor that has permissions required for
// creating/reading/updating/deleting notifications.
func AsNotifier(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, subjectNotifier)
return As(ctx, subjectNotifier)
}
// AsResourceMonitor returns a context with an actor that has permissions required for
// updating resource monitors.
func AsResourceMonitor(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, subjectResourceMonitor)
return As(ctx, subjectResourceMonitor)
}
// AsSystemRestricted returns a context with an actor that has permissions
// required for various system operations (login, logout, metrics cache).
func AsSystemRestricted(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, subjectSystemRestricted)
return As(ctx, subjectSystemRestricted)
}
// AsSystemReadProvisionerDaemons returns a context with an actor that has permissions
// to read provisioner daemons.
func AsSystemReadProvisionerDaemons(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, subjectSystemReadProvisionerDaemons)
return As(ctx, subjectSystemReadProvisionerDaemons)
}
var AsRemoveActor = rbac.Subject{
@@ -426,6 +436,9 @@ func As(ctx context.Context, actor rbac.Subject) context.Context {
// should be removed from the context.
return context.WithValue(ctx, authContextKey{}, nil)
}
if rlogger := loggermw.RequestLoggerFromContext(ctx); rlogger != nil {
rlogger.WithAuthContext(actor)
}
return context.WithValue(ctx, authContextKey{}, actor)
}
@@ -747,7 +760,7 @@ func (*querier) convertToDeploymentRoles(names []string) []rbac.RoleIdentifier {
}
// canAssignRoles handles assigning built in and custom roles.
func (q *querier) canAssignRoles(ctx context.Context, orgID *uuid.UUID, added, removed []rbac.RoleIdentifier) error {
func (q *querier) canAssignRoles(ctx context.Context, orgID uuid.UUID, added, removed []rbac.RoleIdentifier) error {
actor, ok := ActorFromContext(ctx)
if !ok {
return NoActorError
@@ -755,12 +768,14 @@ func (q *querier) canAssignRoles(ctx context.Context, orgID *uuid.UUID, added, r
roleAssign := rbac.ResourceAssignRole
shouldBeOrgRoles := false
if orgID != nil {
roleAssign = rbac.ResourceAssignOrgRole.InOrg(*orgID)
if orgID != uuid.Nil {
roleAssign = rbac.ResourceAssignOrgRole.InOrg(orgID)
shouldBeOrgRoles = true
}
grantedRoles := append(added, removed...)
grantedRoles := make([]rbac.RoleIdentifier, 0, len(added)+len(removed))
grantedRoles = append(grantedRoles, added...)
grantedRoles = append(grantedRoles, removed...)
customRoles := make([]rbac.RoleIdentifier, 0)
// Validate that the roles being assigned are valid.
for _, r := range grantedRoles {
@@ -774,11 +789,11 @@ func (q *querier) canAssignRoles(ctx context.Context, orgID *uuid.UUID, added, r
}
if shouldBeOrgRoles {
if orgID == nil {
if orgID == uuid.Nil {
return xerrors.Errorf("should never happen, orgID is nil, but trying to assign an organization role")
}
if r.OrganizationID != *orgID {
if r.OrganizationID != orgID {
return xerrors.Errorf("attempted to assign role from a different org, role %q to %q", r, orgID.String())
}
}
@@ -824,7 +839,7 @@ func (q *querier) canAssignRoles(ctx context.Context, orgID *uuid.UUID, added, r
}
if len(removed) > 0 {
if err := q.authorizeContext(ctx, policy.ActionDelete, roleAssign); err != nil {
if err := q.authorizeContext(ctx, policy.ActionUnassign, roleAssign); err != nil {
return err
}
}
@@ -1124,11 +1139,15 @@ func (q *querier) CleanTailnetTunnels(ctx context.Context) error {
return q.db.CleanTailnetTunnels(ctx)
}
// TODO: Handle org scoped lookups
func (q *querier) CustomRoles(ctx context.Context, arg database.CustomRolesParams) ([]database.CustomRole, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAssignRole); err != nil {
roleObject := rbac.ResourceAssignRole
if arg.OrganizationID != uuid.Nil {
roleObject = rbac.ResourceAssignOrgRole.InOrg(arg.OrganizationID)
}
if err := q.authorizeContext(ctx, policy.ActionRead, roleObject); err != nil {
return nil, err
}
return q.db.CustomRoles(ctx, arg)
}
@@ -1185,14 +1204,11 @@ func (q *querier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCrypto
}
func (q *querier) DeleteCustomRole(ctx context.Context, arg database.DeleteCustomRoleParams) error {
if arg.OrganizationID.UUID != uuid.Nil {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceAssignOrgRole.InOrg(arg.OrganizationID.UUID)); err != nil {
return err
}
} else {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceAssignRole); err != nil {
return err
}
if !arg.OrganizationID.Valid || arg.OrganizationID.UUID == uuid.Nil {
return NotAuthorizedError{Err: xerrors.New("custom roles must belong to an organization")}
}
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceAssignOrgRole.InOrg(arg.OrganizationID.UUID)); err != nil {
return err
}
return q.db.DeleteCustomRole(ctx, arg)
@@ -1426,6 +1442,17 @@ func (q *querier) FetchMemoryResourceMonitorsByAgentID(ctx context.Context, agen
return q.db.FetchMemoryResourceMonitorsByAgentID(ctx, agentID)
}
func (q *querier) FetchMemoryResourceMonitorsUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]database.WorkspaceAgentMemoryResourceMonitor, error) {
// Ideally, we would return a list of monitors that the user has access to. However, that check would need to
// be implemented similarly to GetWorkspaces, which is more complex than what we're doing here. Since this query
// was introduced for telemetry, we perform a simpler check.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspaceAgentResourceMonitor); err != nil {
return nil, err
}
return q.db.FetchMemoryResourceMonitorsUpdatedAfter(ctx, updatedAt)
}
func (q *querier) FetchNewMessageMetadata(ctx context.Context, arg database.FetchNewMessageMetadataParams) (database.FetchNewMessageMetadataRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceNotificationMessage); err != nil {
return database.FetchNewMessageMetadataRow{}, err
@@ -1447,6 +1474,17 @@ func (q *querier) FetchVolumesResourceMonitorsByAgentID(ctx context.Context, age
return q.db.FetchVolumesResourceMonitorsByAgentID(ctx, agentID)
}
func (q *querier) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]database.WorkspaceAgentVolumeResourceMonitor, error) {
// Ideally, we would return a list of monitors that the user has access to. However, that check would need to
// be implemented similarly to GetWorkspaces, which is more complex than what we're doing here. Since this query
// was introduced for telemetry, we perform a simpler check.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspaceAgentResourceMonitor); err != nil {
return nil, err
}
return q.db.FetchVolumesResourceMonitorsUpdatedAfter(ctx, updatedAt)
}
func (q *querier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) {
return fetch(q.log, q.auth, q.db.GetAPIKeyByID)(ctx, id)
}
@@ -3009,14 +3047,11 @@ func (q *querier) InsertCryptoKey(ctx context.Context, arg database.InsertCrypto
func (q *querier) InsertCustomRole(ctx context.Context, arg database.InsertCustomRoleParams) (database.CustomRole, error) {
// Org and site role upsert share the same query. So switch the assertion based on the org uuid.
if arg.OrganizationID.UUID != uuid.Nil {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceAssignOrgRole.InOrg(arg.OrganizationID.UUID)); err != nil {
return database.CustomRole{}, err
}
} else {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceAssignRole); err != nil {
return database.CustomRole{}, err
}
if !arg.OrganizationID.Valid || arg.OrganizationID.UUID == uuid.Nil {
return database.CustomRole{}, NotAuthorizedError{Err: xerrors.New("custom roles must belong to an organization")}
}
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceAssignOrgRole.InOrg(arg.OrganizationID.UUID)); err != nil {
return database.CustomRole{}, err
}
if err := q.customRoleCheck(ctx, database.CustomRole{
@@ -3146,7 +3181,7 @@ func (q *querier) InsertOrganizationMember(ctx context.Context, arg database.Ins
// All roles are added roles. Org member is always implied.
addedRoles := append(orgRoles, rbac.ScopedRoleOrgMember(arg.OrganizationID))
err = q.canAssignRoles(ctx, &arg.OrganizationID, addedRoles, []rbac.RoleIdentifier{})
err = q.canAssignRoles(ctx, arg.OrganizationID, addedRoles, []rbac.RoleIdentifier{})
if err != nil {
return database.OrganizationMember{}, err
}
@@ -3270,7 +3305,7 @@ func (q *querier) InsertTemplateVersionWorkspaceTag(ctx context.Context, arg dat
func (q *querier) InsertUser(ctx context.Context, arg database.InsertUserParams) (database.User, error) {
// Always check if the assigned roles can actually be assigned by this actor.
impliedRoles := append([]rbac.RoleIdentifier{rbac.RoleMember()}, q.convertToDeploymentRoles(arg.RBACRoles)...)
err := q.canAssignRoles(ctx, nil, impliedRoles, []rbac.RoleIdentifier{})
err := q.canAssignRoles(ctx, uuid.Nil, impliedRoles, []rbac.RoleIdentifier{})
if err != nil {
return database.User{}, err
}
@@ -3608,14 +3643,11 @@ func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.Upd
}
func (q *querier) UpdateCustomRole(ctx context.Context, arg database.UpdateCustomRoleParams) (database.CustomRole, error) {
if arg.OrganizationID.UUID != uuid.Nil {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceAssignOrgRole.InOrg(arg.OrganizationID.UUID)); err != nil {
return database.CustomRole{}, err
}
} else {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceAssignRole); err != nil {
return database.CustomRole{}, err
}
if !arg.OrganizationID.Valid || arg.OrganizationID.UUID == uuid.Nil {
return database.CustomRole{}, NotAuthorizedError{Err: xerrors.New("custom roles must belong to an organization")}
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceAssignOrgRole.InOrg(arg.OrganizationID.UUID)); err != nil {
return database.CustomRole{}, err
}
if err := q.customRoleCheck(ctx, database.CustomRole{
@@ -3695,7 +3727,7 @@ func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemb
impliedTypes := append(scopedGranted, rbac.ScopedRoleOrgMember(arg.OrgID))
added, removed := rbac.ChangeRoleSet(originalRoles, impliedTypes)
err = q.canAssignRoles(ctx, &arg.OrgID, added, removed)
err = q.canAssignRoles(ctx, arg.OrgID, added, removed)
if err != nil {
return database.OrganizationMember{}, err
}
@@ -4102,7 +4134,7 @@ func (q *querier) UpdateUserRoles(ctx context.Context, arg database.UpdateUserRo
impliedTypes := append(q.convertToDeploymentRoles(arg.GrantedRoles), rbac.RoleMember())
// If the changeset is nothing, less rbac checks need to be done.
added, removed := rbac.ChangeRoleSet(q.convertToDeploymentRoles(user.RBACRoles), impliedTypes)
err = q.canAssignRoles(ctx, nil, added, removed)
err = q.canAssignRoles(ctx, uuid.Nil, added, removed)
if err != nil {
return database.User{}, err
}
+28 -39
View File
@@ -839,7 +839,7 @@ func (s *MethodTestSuite) TestOrganization() {
_ = dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID, OrganizationID: a.ID})
b := dbgen.Organization(s.T(), db, database.Organization{})
_ = dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID, OrganizationID: b.ID})
check.Args(database.GetOrganizationsByUserIDParams{UserID: u.ID, Deleted: false}).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
check.Args(database.GetOrganizationsByUserIDParams{UserID: u.ID, Deleted: sql.NullBool{Valid: true, Bool: false}}).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
}))
s.Run("InsertOrganization", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertOrganizationParams{
@@ -948,8 +948,7 @@ func (s *MethodTestSuite) TestOrganization() {
member, policy.ActionRead,
member, policy.ActionDelete).
WithNotAuthorized("no rows").
WithCancelled(cancelledErr).
ErrorsWithInMemDB(sql.ErrNoRows)
WithCancelled(cancelledErr)
}))
s.Run("UpdateOrganization", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{
@@ -1011,7 +1010,7 @@ func (s *MethodTestSuite) TestOrganization() {
Asserts(
mem, policy.ActionRead,
rbac.ResourceAssignOrgRole.InOrg(o.ID), policy.ActionAssign, // org-mem
rbac.ResourceAssignOrgRole.InOrg(o.ID), policy.ActionDelete, // org-admin
rbac.ResourceAssignOrgRole.InOrg(o.ID), policy.ActionUnassign, // org-admin
).Returns(out)
}))
}
@@ -1619,7 +1618,7 @@ func (s *MethodTestSuite) TestUser() {
}).Asserts(
u, policy.ActionRead,
rbac.ResourceAssignRole, policy.ActionAssign,
rbac.ResourceAssignRole, policy.ActionDelete,
rbac.ResourceAssignRole, policy.ActionUnassign,
).Returns(o)
}))
s.Run("AllUserIDs", s.Subtest(func(db database.Store, check *expects) {
@@ -1653,30 +1652,28 @@ func (s *MethodTestSuite) TestUser() {
check.Args(database.DeleteCustomRoleParams{
Name: customRole.Name,
}).Asserts(
rbac.ResourceAssignRole, policy.ActionDelete)
// fails immediately, missing organization id
).Errors(dbauthz.NotAuthorizedError{Err: xerrors.New("custom roles must belong to an organization")})
}))
s.Run("Blank/UpdateCustomRole", s.Subtest(func(db database.Store, check *expects) {
dbtestutil.DisableForeignKeysAndTriggers(s.T(), db)
customRole := dbgen.CustomRole(s.T(), db, database.CustomRole{})
customRole := dbgen.CustomRole(s.T(), db, database.CustomRole{
OrganizationID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
})
// Blank is no perms in the role
check.Args(database.UpdateCustomRoleParams{
Name: customRole.Name,
DisplayName: "Test Name",
OrganizationID: customRole.OrganizationID,
SitePermissions: nil,
OrgPermissions: nil,
UserPermissions: nil,
}).Asserts(rbac.ResourceAssignRole, policy.ActionUpdate).ErrorsWithPG(sql.ErrNoRows)
}).Asserts(rbac.ResourceAssignOrgRole.InOrg(customRole.OrganizationID.UUID), policy.ActionUpdate)
}))
s.Run("SitePermissions/UpdateCustomRole", s.Subtest(func(db database.Store, check *expects) {
customRole := dbgen.CustomRole(s.T(), db, database.CustomRole{
OrganizationID: uuid.NullUUID{
UUID: uuid.Nil,
Valid: false,
},
})
check.Args(database.UpdateCustomRoleParams{
Name: customRole.Name,
OrganizationID: customRole.OrganizationID,
Name: "",
OrganizationID: uuid.NullUUID{UUID: uuid.Nil, Valid: false},
DisplayName: "Test Name",
SitePermissions: db2sdk.List(codersdk.CreatePermissions(map[codersdk.RBACResource][]codersdk.RBACAction{
codersdk.ResourceTemplate: {codersdk.ActionCreate, codersdk.ActionRead, codersdk.ActionUpdate, codersdk.ActionDelete, codersdk.ActionViewInsights},
@@ -1686,17 +1683,8 @@ func (s *MethodTestSuite) TestUser() {
codersdk.ResourceWorkspace: {codersdk.ActionRead},
}), convertSDKPerm),
}).Asserts(
// First check
rbac.ResourceAssignRole, policy.ActionUpdate,
// Escalation checks
rbac.ResourceTemplate, policy.ActionCreate,
rbac.ResourceTemplate, policy.ActionRead,
rbac.ResourceTemplate, policy.ActionUpdate,
rbac.ResourceTemplate, policy.ActionDelete,
rbac.ResourceTemplate, policy.ActionViewInsights,
rbac.ResourceWorkspace.WithOwner(testActorID.String()), policy.ActionRead,
).ErrorsWithPG(sql.ErrNoRows)
// fails immediately, missing organization id
).Errors(dbauthz.NotAuthorizedError{Err: xerrors.New("custom roles must belong to an organization")})
}))
s.Run("OrgPermissions/UpdateCustomRole", s.Subtest(func(db database.Store, check *expects) {
orgID := uuid.New()
@@ -1726,13 +1714,15 @@ func (s *MethodTestSuite) TestUser() {
}))
s.Run("Blank/InsertCustomRole", s.Subtest(func(db database.Store, check *expects) {
// Blank is no perms in the role
orgID := uuid.New()
check.Args(database.InsertCustomRoleParams{
Name: "test",
DisplayName: "Test Name",
OrganizationID: uuid.NullUUID{UUID: orgID, Valid: true},
SitePermissions: nil,
OrgPermissions: nil,
UserPermissions: nil,
}).Asserts(rbac.ResourceAssignRole, policy.ActionCreate)
}).Asserts(rbac.ResourceAssignOrgRole.InOrg(orgID), policy.ActionCreate)
}))
s.Run("SitePermissions/InsertCustomRole", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertCustomRoleParams{
@@ -1746,17 +1736,8 @@ func (s *MethodTestSuite) TestUser() {
codersdk.ResourceWorkspace: {codersdk.ActionRead},
}), convertSDKPerm),
}).Asserts(
// First check
rbac.ResourceAssignRole, policy.ActionCreate,
// Escalation checks
rbac.ResourceTemplate, policy.ActionCreate,
rbac.ResourceTemplate, policy.ActionRead,
rbac.ResourceTemplate, policy.ActionUpdate,
rbac.ResourceTemplate, policy.ActionDelete,
rbac.ResourceTemplate, policy.ActionViewInsights,
rbac.ResourceWorkspace.WithOwner(testActorID.String()), policy.ActionRead,
)
// fails immediately, missing organization id
).Errors(dbauthz.NotAuthorizedError{Err: xerrors.New("custom roles must belong to an organization")})
}))
s.Run("OrgPermissions/InsertCustomRole", s.Subtest(func(db database.Store, check *expects) {
orgID := uuid.New()
@@ -4802,6 +4783,14 @@ func (s *MethodTestSuite) TestResourcesMonitor() {
}).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionUpdate)
}))
s.Run("FetchMemoryResourceMonitorsUpdatedAfter", s.Subtest(func(db database.Store, check *expects) {
check.Args(dbtime.Now()).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionRead)
}))
s.Run("FetchVolumesResourceMonitorsUpdatedAfter", s.Subtest(func(db database.Store, check *expects) {
check.Args(dbtime.Now()).Asserts(rbac.ResourceWorkspaceAgentResourceMonitor, policy.ActionRead)
}))
s.Run("FetchMemoryResourceMonitorsByAgentID", s.Subtest(func(db database.Store, check *expects) {
agt, w := createAgent(s.T(), db)
+18
View File
@@ -17,6 +17,7 @@ type OrganizationBuilder struct {
t *testing.T
db database.Store
seed database.Organization
delete bool
allUsersAllowance int32
members []uuid.UUID
groups map[database.Group][]uuid.UUID
@@ -44,6 +45,12 @@ func (b OrganizationBuilder) EveryoneAllowance(allowance int) OrganizationBuilde
return b
}
func (b OrganizationBuilder) Deleted(deleted bool) OrganizationBuilder {
//nolint: revive // returns modified struct
b.delete = deleted
return b
}
func (b OrganizationBuilder) Seed(seed database.Organization) OrganizationBuilder {
//nolint: revive // returns modified struct
b.seed = seed
@@ -118,6 +125,17 @@ func (b OrganizationBuilder) Do() OrganizationResponse {
}
}
if b.delete {
now := dbtime.Now()
err = b.db.UpdateOrganizationDeletedByID(ctx, database.UpdateOrganizationDeletedByIDParams{
UpdatedAt: now,
ID: org.ID,
})
require.NoError(b.t, err)
org.Deleted = true
org.UpdatedAt = now
}
return OrganizationResponse{
Org: org,
AllUsersGroup: everyone,
+40 -4
View File
@@ -2167,10 +2167,13 @@ func (q *FakeQuerier) DeleteOrganizationMember(ctx context.Context, arg database
q.mutex.Lock()
defer q.mutex.Unlock()
deleted := slices.DeleteFunc(q.data.organizationMembers, func(member database.OrganizationMember) bool {
return member.OrganizationID == arg.OrganizationID && member.UserID == arg.UserID
deleted := false
q.data.organizationMembers = slices.DeleteFunc(q.data.organizationMembers, func(member database.OrganizationMember) bool {
match := member.OrganizationID == arg.OrganizationID && member.UserID == arg.UserID
deleted = deleted || match
return match
})
if len(deleted) == 0 {
if !deleted {
return sql.ErrNoRows
}
@@ -2361,6 +2364,19 @@ func (q *FakeQuerier) FetchMemoryResourceMonitorsByAgentID(_ context.Context, ag
return database.WorkspaceAgentMemoryResourceMonitor{}, sql.ErrNoRows
}
func (q *FakeQuerier) FetchMemoryResourceMonitorsUpdatedAfter(_ context.Context, updatedAt time.Time) ([]database.WorkspaceAgentMemoryResourceMonitor, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
monitors := []database.WorkspaceAgentMemoryResourceMonitor{}
for _, monitor := range q.workspaceAgentMemoryResourceMonitors {
if monitor.UpdatedAt.After(updatedAt) {
monitors = append(monitors, monitor)
}
}
return monitors, nil
}
func (q *FakeQuerier) FetchNewMessageMetadata(_ context.Context, arg database.FetchNewMessageMetadataParams) (database.FetchNewMessageMetadataRow, error) {
err := validateDatabaseType(arg)
if err != nil {
@@ -2405,6 +2421,19 @@ func (q *FakeQuerier) FetchVolumesResourceMonitorsByAgentID(_ context.Context, a
return monitors, nil
}
func (q *FakeQuerier) FetchVolumesResourceMonitorsUpdatedAfter(_ context.Context, updatedAt time.Time) ([]database.WorkspaceAgentVolumeResourceMonitor, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
monitors := []database.WorkspaceAgentVolumeResourceMonitor{}
for _, monitor := range q.workspaceAgentVolumeResourceMonitors {
if monitor.UpdatedAt.After(updatedAt) {
monitors = append(monitors, monitor)
}
}
return monitors, nil
}
func (q *FakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@@ -3732,6 +3761,9 @@ func (q *FakeQuerier) GetOrganizations(_ context.Context, args database.GetOrgan
if args.Name != "" && !strings.EqualFold(org.Name, args.Name) {
continue
}
if args.Deleted != org.Deleted {
continue
}
tmp = append(tmp, org)
}
@@ -3748,7 +3780,11 @@ func (q *FakeQuerier) GetOrganizationsByUserID(_ context.Context, arg database.G
continue
}
for _, organization := range q.organizations {
if organization.ID != organizationMember.OrganizationID || organization.Deleted != arg.Deleted {
if organization.ID != organizationMember.OrganizationID {
continue
}
if arg.Deleted.Valid && organization.Deleted != arg.Deleted.Bool {
continue
}
organizations = append(organizations, organization)
+14
View File
@@ -444,6 +444,13 @@ func (m queryMetricsStore) FetchMemoryResourceMonitorsByAgentID(ctx context.Cont
return r0, r1
}
func (m queryMetricsStore) FetchMemoryResourceMonitorsUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]database.WorkspaceAgentMemoryResourceMonitor, error) {
start := time.Now()
r0, r1 := m.s.FetchMemoryResourceMonitorsUpdatedAfter(ctx, updatedAt)
m.queryLatencies.WithLabelValues("FetchMemoryResourceMonitorsUpdatedAfter").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) FetchNewMessageMetadata(ctx context.Context, arg database.FetchNewMessageMetadataParams) (database.FetchNewMessageMetadataRow, error) {
start := time.Now()
r0, r1 := m.s.FetchNewMessageMetadata(ctx, arg)
@@ -458,6 +465,13 @@ func (m queryMetricsStore) FetchVolumesResourceMonitorsByAgentID(ctx context.Con
return r0, r1
}
func (m queryMetricsStore) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]database.WorkspaceAgentVolumeResourceMonitor, error) {
start := time.Now()
r0, r1 := m.s.FetchVolumesResourceMonitorsUpdatedAfter(ctx, updatedAt)
m.queryLatencies.WithLabelValues("FetchVolumesResourceMonitorsUpdatedAfter").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) {
start := time.Now()
apiKey, err := m.s.GetAPIKeyByID(ctx, id)
+30
View File
@@ -772,6 +772,21 @@ func (mr *MockStoreMockRecorder) FetchMemoryResourceMonitorsByAgentID(ctx, agent
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchMemoryResourceMonitorsByAgentID", reflect.TypeOf((*MockStore)(nil).FetchMemoryResourceMonitorsByAgentID), ctx, agentID)
}
// FetchMemoryResourceMonitorsUpdatedAfter mocks base method.
func (m *MockStore) FetchMemoryResourceMonitorsUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]database.WorkspaceAgentMemoryResourceMonitor, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FetchMemoryResourceMonitorsUpdatedAfter", ctx, updatedAt)
ret0, _ := ret[0].([]database.WorkspaceAgentMemoryResourceMonitor)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FetchMemoryResourceMonitorsUpdatedAfter indicates an expected call of FetchMemoryResourceMonitorsUpdatedAfter.
func (mr *MockStoreMockRecorder) FetchMemoryResourceMonitorsUpdatedAfter(ctx, updatedAt any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchMemoryResourceMonitorsUpdatedAfter", reflect.TypeOf((*MockStore)(nil).FetchMemoryResourceMonitorsUpdatedAfter), ctx, updatedAt)
}
// FetchNewMessageMetadata mocks base method.
func (m *MockStore) FetchNewMessageMetadata(ctx context.Context, arg database.FetchNewMessageMetadataParams) (database.FetchNewMessageMetadataRow, error) {
m.ctrl.T.Helper()
@@ -802,6 +817,21 @@ func (mr *MockStoreMockRecorder) FetchVolumesResourceMonitorsByAgentID(ctx, agen
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchVolumesResourceMonitorsByAgentID", reflect.TypeOf((*MockStore)(nil).FetchVolumesResourceMonitorsByAgentID), ctx, agentID)
}
// FetchVolumesResourceMonitorsUpdatedAfter mocks base method.
func (m *MockStore) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]database.WorkspaceAgentVolumeResourceMonitor, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FetchVolumesResourceMonitorsUpdatedAfter", ctx, updatedAt)
ret0, _ := ret[0].([]database.WorkspaceAgentVolumeResourceMonitor)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FetchVolumesResourceMonitorsUpdatedAfter indicates an expected call of FetchVolumesResourceMonitorsUpdatedAfter.
func (mr *MockStoreMockRecorder) FetchVolumesResourceMonitorsUpdatedAfter(ctx, updatedAt any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchVolumesResourceMonitorsUpdatedAfter", reflect.TypeOf((*MockStore)(nil).FetchVolumesResourceMonitorsUpdatedAfter), ctx, updatedAt)
}
// GetAPIKeyByID mocks base method.
func (m *MockStore) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) {
m.ctrl.T.Helper()
+2
View File
@@ -112,9 +112,11 @@ type sqlcQuerier interface {
EnqueueNotificationMessage(ctx context.Context, arg EnqueueNotificationMessageParams) error
FavoriteWorkspace(ctx context.Context, id uuid.UUID) error
FetchMemoryResourceMonitorsByAgentID(ctx context.Context, agentID uuid.UUID) (WorkspaceAgentMemoryResourceMonitor, error)
FetchMemoryResourceMonitorsUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]WorkspaceAgentMemoryResourceMonitor, error)
// This is used to build up the notification_message's JSON payload.
FetchNewMessageMetadata(ctx context.Context, arg FetchNewMessageMetadataParams) (FetchNewMessageMetadataRow, error)
FetchVolumesResourceMonitorsByAgentID(ctx context.Context, agentID uuid.UUID) ([]WorkspaceAgentVolumeResourceMonitor, error)
FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]WorkspaceAgentVolumeResourceMonitor, error)
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
// there is no unique constraint on empty token names
GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNameParams) (APIKey, error)
+123 -35
View File
@@ -4967,7 +4967,7 @@ SELECT
FROM
organization_members
INNER JOIN
users ON organization_members.user_id = users.id
users ON organization_members.user_id = users.id AND users.deleted = false
WHERE
-- Filter by organization id
CASE
@@ -5221,8 +5221,13 @@ SELECT
FROM
organizations
WHERE
-- Optionally include deleted organizations
deleted = $2 AND
-- Optionally provide a filter for deleted organizations.
CASE WHEN
$2 :: boolean IS NULL THEN
true
ELSE
deleted = $2
END AND
id = ANY(
SELECT
organization_id
@@ -5234,8 +5239,8 @@ WHERE
`
type GetOrganizationsByUserIDParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
Deleted bool `db:"deleted" json:"deleted"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Deleted sql.NullBool `db:"deleted" json:"deleted"`
}
func (q *sqlQuerier) GetOrganizationsByUserID(ctx context.Context, arg GetOrganizationsByUserIDParams) ([]Organization, error) {
@@ -7775,25 +7780,25 @@ SELECT
FROM
custom_roles
WHERE
true
-- @lookup_roles will filter for exact (role_name, org_id) pairs
-- To do this manually in SQL, you can construct an array and cast it:
-- cast(ARRAY[('customrole','ece79dac-926e-44ca-9790-2ff7c5eb6e0c')] AS name_organization_pair[])
AND CASE WHEN array_length($1 :: name_organization_pair[], 1) > 0 THEN
-- Using 'coalesce' to avoid troubles with null literals being an empty string.
(name, coalesce(organization_id, '00000000-0000-0000-0000-000000000000' ::uuid)) = ANY ($1::name_organization_pair[])
ELSE true
END
-- This allows fetching all roles, or just site wide roles
AND CASE WHEN $2 :: boolean THEN
organization_id IS null
true
-- @lookup_roles will filter for exact (role_name, org_id) pairs
-- To do this manually in SQL, you can construct an array and cast it:
-- cast(ARRAY[('customrole','ece79dac-926e-44ca-9790-2ff7c5eb6e0c')] AS name_organization_pair[])
AND CASE WHEN array_length($1 :: name_organization_pair[], 1) > 0 THEN
-- Using 'coalesce' to avoid troubles with null literals being an empty string.
(name, coalesce(organization_id, '00000000-0000-0000-0000-000000000000' ::uuid)) = ANY ($1::name_organization_pair[])
ELSE true
END
-- Allows fetching all roles to a particular organization
AND CASE WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
organization_id = $3
ELSE true
END
END
-- This allows fetching all roles, or just site wide roles
AND CASE WHEN $2 :: boolean THEN
organization_id IS null
ELSE true
END
-- Allows fetching all roles to a particular organization
AND CASE WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
organization_id = $3
ELSE true
END
`
type CustomRolesParams struct {
@@ -7866,16 +7871,16 @@ INSERT INTO
updated_at
)
VALUES (
-- Always force lowercase names
lower($1),
$2,
$3,
$4,
$5,
$6,
now(),
now()
)
-- Always force lowercase names
lower($1),
$2,
$3,
$4,
$5,
$6,
now(),
now()
)
RETURNING name, display_name, site_permissions, org_permissions, user_permissions, created_at, updated_at, organization_id, id
`
@@ -11033,10 +11038,10 @@ func (q *sqlQuerier) GetActiveUserCount(ctx context.Context) (int64, error) {
const getAuthorizationUserRoles = `-- name: GetAuthorizationUserRoles :one
SELECT
-- username is returned just to help for logging purposes
-- username and email are returned just to help for logging purposes
-- status is used to enforce 'suspended' users, as all roles are ignored
-- when suspended.
id, username, status,
id, username, status, email,
-- All user roles, including their org roles.
array_cat(
-- All users are members
@@ -11077,6 +11082,7 @@ type GetAuthorizationUserRolesRow struct {
ID uuid.UUID `db:"id" json:"id"`
Username string `db:"username" json:"username"`
Status UserStatus `db:"status" json:"status"`
Email string `db:"email" json:"email"`
Roles []string `db:"roles" json:"roles"`
Groups []string `db:"groups" json:"groups"`
}
@@ -11090,6 +11096,7 @@ func (q *sqlQuerier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.
&i.ID,
&i.Username,
&i.Status,
&i.Email,
pq.Array(&i.Roles),
pq.Array(&i.Groups),
)
@@ -12135,6 +12142,46 @@ func (q *sqlQuerier) FetchMemoryResourceMonitorsByAgentID(ctx context.Context, a
return i, err
}
const fetchMemoryResourceMonitorsUpdatedAfter = `-- name: FetchMemoryResourceMonitorsUpdatedAfter :many
SELECT
agent_id, enabled, threshold, created_at, updated_at, state, debounced_until
FROM
workspace_agent_memory_resource_monitors
WHERE
updated_at > $1
`
func (q *sqlQuerier) FetchMemoryResourceMonitorsUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]WorkspaceAgentMemoryResourceMonitor, error) {
rows, err := q.db.QueryContext(ctx, fetchMemoryResourceMonitorsUpdatedAfter, updatedAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []WorkspaceAgentMemoryResourceMonitor
for rows.Next() {
var i WorkspaceAgentMemoryResourceMonitor
if err := rows.Scan(
&i.AgentID,
&i.Enabled,
&i.Threshold,
&i.CreatedAt,
&i.UpdatedAt,
&i.State,
&i.DebouncedUntil,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const fetchVolumesResourceMonitorsByAgentID = `-- name: FetchVolumesResourceMonitorsByAgentID :many
SELECT
agent_id, enabled, threshold, path, created_at, updated_at, state, debounced_until
@@ -12176,6 +12223,47 @@ func (q *sqlQuerier) FetchVolumesResourceMonitorsByAgentID(ctx context.Context,
return items, nil
}
const fetchVolumesResourceMonitorsUpdatedAfter = `-- name: FetchVolumesResourceMonitorsUpdatedAfter :many
SELECT
agent_id, enabled, threshold, path, created_at, updated_at, state, debounced_until
FROM
workspace_agent_volume_resource_monitors
WHERE
updated_at > $1
`
func (q *sqlQuerier) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]WorkspaceAgentVolumeResourceMonitor, error) {
rows, err := q.db.QueryContext(ctx, fetchVolumesResourceMonitorsUpdatedAfter, updatedAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []WorkspaceAgentVolumeResourceMonitor
for rows.Next() {
var i WorkspaceAgentVolumeResourceMonitor
if err := rows.Scan(
&i.AgentID,
&i.Enabled,
&i.Threshold,
&i.Path,
&i.CreatedAt,
&i.UpdatedAt,
&i.State,
&i.DebouncedUntil,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertMemoryResourceMonitor = `-- name: InsertMemoryResourceMonitor :one
INSERT INTO
workspace_agent_memory_resource_monitors (
@@ -9,7 +9,7 @@ SELECT
FROM
organization_members
INNER JOIN
users ON organization_members.user_id = users.id
users ON organization_members.user_id = users.id AND users.deleted = false
WHERE
-- Filter by organization id
CASE
+7 -2
View File
@@ -55,8 +55,13 @@ SELECT
FROM
organizations
WHERE
-- Optionally include deleted organizations
deleted = @deleted AND
-- Optionally provide a filter for deleted organizations.
CASE WHEN
sqlc.narg('deleted') :: boolean IS NULL THEN
true
ELSE
deleted = sqlc.narg('deleted')
END AND
id = ANY(
SELECT
organization_id
+28 -28
View File
@@ -4,25 +4,25 @@ SELECT
FROM
custom_roles
WHERE
true
-- @lookup_roles will filter for exact (role_name, org_id) pairs
-- To do this manually in SQL, you can construct an array and cast it:
-- cast(ARRAY[('customrole','ece79dac-926e-44ca-9790-2ff7c5eb6e0c')] AS name_organization_pair[])
AND CASE WHEN array_length(@lookup_roles :: name_organization_pair[], 1) > 0 THEN
-- Using 'coalesce' to avoid troubles with null literals being an empty string.
(name, coalesce(organization_id, '00000000-0000-0000-0000-000000000000' ::uuid)) = ANY (@lookup_roles::name_organization_pair[])
ELSE true
END
-- This allows fetching all roles, or just site wide roles
AND CASE WHEN @exclude_org_roles :: boolean THEN
organization_id IS null
true
-- @lookup_roles will filter for exact (role_name, org_id) pairs
-- To do this manually in SQL, you can construct an array and cast it:
-- cast(ARRAY[('customrole','ece79dac-926e-44ca-9790-2ff7c5eb6e0c')] AS name_organization_pair[])
AND CASE WHEN array_length(@lookup_roles :: name_organization_pair[], 1) > 0 THEN
-- Using 'coalesce' to avoid troubles with null literals being an empty string.
(name, coalesce(organization_id, '00000000-0000-0000-0000-000000000000' ::uuid)) = ANY (@lookup_roles::name_organization_pair[])
ELSE true
END
-- Allows fetching all roles to a particular organization
AND CASE WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
organization_id = @organization_id
ELSE true
END
END
-- This allows fetching all roles, or just site wide roles
AND CASE WHEN @exclude_org_roles :: boolean THEN
organization_id IS null
ELSE true
END
-- Allows fetching all roles to a particular organization
AND CASE WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
organization_id = @organization_id
ELSE true
END
;
-- name: DeleteCustomRole :exec
@@ -46,16 +46,16 @@ INSERT INTO
updated_at
)
VALUES (
-- Always force lowercase names
lower(@name),
@display_name,
@organization_id,
@site_permissions,
@org_permissions,
@user_permissions,
now(),
now()
)
-- Always force lowercase names
lower(@name),
@display_name,
@organization_id,
@site_permissions,
@org_permissions,
@user_permissions,
now(),
now()
)
RETURNING *;
-- name: UpdateCustomRole :one
+2 -2
View File
@@ -244,10 +244,10 @@ WHERE
-- This function returns roles for authorization purposes. Implied member roles
-- are included.
SELECT
-- username is returned just to help for logging purposes
-- username and email are returned just to help for logging purposes
-- status is used to enforce 'suspended' users, as all roles are ignored
-- when suspended.
id, username, status,
id, username, status, email,
-- All user roles, including their org roles.
array_cat(
-- All users are members
@@ -1,3 +1,19 @@
-- name: FetchVolumesResourceMonitorsUpdatedAfter :many
SELECT
*
FROM
workspace_agent_volume_resource_monitors
WHERE
updated_at > $1;
-- name: FetchMemoryResourceMonitorsUpdatedAfter :many
SELECT
*
FROM
workspace_agent_memory_resource_monitors
WHERE
updated_at > $1;
-- name: FetchMemoryResourceMonitorsByAgentID :one
SELECT
*
+6 -4
View File
@@ -151,11 +151,13 @@ func ResourceNotFound(rw http.ResponseWriter) {
Write(context.Background(), rw, http.StatusNotFound, ResourceNotFoundResponse)
}
var ResourceForbiddenResponse = codersdk.Response{
Message: "Forbidden.",
Detail: "You don't have permission to view this content. If you believe this is a mistake, please contact your administrator or try signing in with different credentials.",
}
func Forbidden(rw http.ResponseWriter) {
Write(context.Background(), rw, http.StatusForbidden, codersdk.Response{
Message: "Forbidden.",
Detail: "You don't have permission to view this content. If you believe this is a mistake, please contact your administrator or try signing in with different credentials.",
})
Write(context.Background(), rw, http.StatusForbidden, ResourceForbiddenResponse)
}
func InternalServerError(rw http.ResponseWriter, err error) {
+50 -46
View File
@@ -232,16 +232,21 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
return optionalWrite(http.StatusUnauthorized, resp)
}
var (
link database.UserLink
now = dbtime.Now()
// Tracks if the API key has properties updated
changed = false
)
now := dbtime.Now()
if key.ExpiresAt.Before(now) {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
})
}
// We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor
// really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly
// refreshing the OIDC token.
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
var err error
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
link, err = cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
UserID: key.UserID,
LoginType: key.LoginType,
})
@@ -258,7 +263,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
})
}
// Check if the OAuth token is expired
if link.OAuthExpiry.Before(now) && !link.OAuthExpiry.IsZero() && link.OAuthRefreshToken != "" {
if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) {
if cfg.OAuth2Configs.IsZero() {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
@@ -267,12 +272,15 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
})
}
var friendlyName string
var oauthConfig promoauth.OAuth2Config
switch key.LoginType {
case database.LoginTypeGithub:
oauthConfig = cfg.OAuth2Configs.Github
friendlyName = "GitHub"
case database.LoginTypeOIDC:
oauthConfig = cfg.OAuth2Configs.OIDC
friendlyName = "OpenID Connect"
default:
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
@@ -292,7 +300,13 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
})
}
// If it is, let's refresh it from the provided config
if link.OAuthRefreshToken == "" {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()),
})
}
// We have a refresh token, so let's try it
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
AccessToken: link.OAuthAccessToken,
RefreshToken: link.OAuthRefreshToken,
@@ -300,28 +314,39 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
}).Token()
if err != nil {
return write(http.StatusUnauthorized, codersdk.Response{
Message: "Could not refresh expired Oauth token. Try re-authenticating to resolve this issue.",
Detail: err.Error(),
Message: fmt.Sprintf(
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.",
friendlyName),
Detail: err.Error(),
})
}
link.OAuthAccessToken = token.AccessToken
link.OAuthRefreshToken = token.RefreshToken
link.OAuthExpiry = token.Expiry
key.ExpiresAt = token.Expiry
changed = true
//nolint:gocritic // system needs to update user link
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
UserID: link.UserID,
LoginType: link.LoginType,
OAuthAccessToken: link.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: link.OAuthExpiry,
// Refresh should keep the same debug context because we use
// the original claims for the group/role sync.
Claims: link.Claims,
})
if err != nil {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
})
}
}
}
// Checking if the key is expired.
// NOTE: The `RequireAuth` React component depends on this `Detail` to detect when
// the users token has expired. If you change the text here, make sure to update it
// in site/src/components/RequireAuth/RequireAuth.tsx as well.
if key.ExpiresAt.Before(now) {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
})
}
// Tracks if the API key has properties updated
changed := false
// Only update LastUsed once an hour to prevent database spam.
if now.Sub(key.LastUsed) > time.Hour {
@@ -363,29 +388,6 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
})
}
// If the API Key is associated with a user_link (e.g. Github/OIDC)
// then we want to update the relevant oauth fields.
if link.UserID != uuid.Nil {
//nolint:gocritic // system needs to update user link
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
UserID: link.UserID,
LoginType: link.LoginType,
OAuthAccessToken: link.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: link.OAuthExpiry,
// Refresh should keep the same debug context because we use
// the original claims for the group/role sync.
Claims: link.Claims,
})
if err != nil {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
})
}
}
// We only want to update this occasionally to reduce DB write
// load. We update alongside the UserLink and APIKey since it's
@@ -465,7 +467,9 @@ func UserRBACSubject(ctx context.Context, db database.Store, userID uuid.UUID, s
}
actor := rbac.Subject{
Type: rbac.SubjectTypeUser,
FriendlyName: roles.Username,
Email: roles.Email,
ID: userID.String(),
Roles: rbacRoles,
Groups: roles.Groups,
+157 -1
View File
@@ -508,6 +508,102 @@ func TestAPIKey(t *testing.T) {
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("APIKeyExpiredOAuthExpired", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
LastUsed: dbtime.Now().AddDate(0, 0, -1),
ExpiresAt: dbtime.Now().AddDate(0, 0, -1),
LoginType: database.LoginTypeOIDC,
})
_ = dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID,
LoginType: database.LoginTypeOIDC,
OAuthExpiry: dbtime.Now().AddDate(0, 0, -1),
})
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.Header.Set(codersdk.SessionTokenHeader, token)
// Include a valid oauth token for refreshing. If this token is invalid,
// it is difficult to tell an auth failure from an expired api key, or
// an expired oauth key.
oauthToken := &oauth2.Token{
AccessToken: "wow",
RefreshToken: "moo",
Expiry: dbtime.Now().AddDate(0, 0, 1),
}
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
OAuth2Configs: &httpmw.OAuth2Configs{
OIDC: &testutil.OAuth2Config{
Token: oauthToken,
},
},
RedirectToLogin: false,
})(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("APIKeyExpiredOAuthNotExpired", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
LastUsed: dbtime.Now().AddDate(0, 0, -1),
ExpiresAt: dbtime.Now().AddDate(0, 0, -1),
LoginType: database.LoginTypeOIDC,
})
_ = dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID,
LoginType: database.LoginTypeOIDC,
})
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.Header.Set(codersdk.SessionTokenHeader, token)
oauthToken := &oauth2.Token{
AccessToken: "wow",
RefreshToken: "moo",
Expiry: dbtime.Now().AddDate(0, 0, 1),
}
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
OAuth2Configs: &httpmw.OAuth2Configs{
OIDC: &testutil.OAuth2Config{
Token: oauthToken,
},
},
RedirectToLogin: false,
})(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("OAuthRefresh", func(t *testing.T) {
t.Parallel()
var (
@@ -553,7 +649,67 @@ func TestAPIKey(t *testing.T) {
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, oauthToken.Expiry, gotAPIKey.ExpiresAt)
// Note that OAuth expiry is independent of APIKey expiry, so an OIDC refresh DOES NOT affect the expiry of the
// APIKey
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
gotLink, err := db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{
UserID: user.ID,
LoginType: database.LoginTypeGithub,
})
require.NoError(t, err)
require.Equal(t, gotLink.OAuthRefreshToken, "moo")
})
t.Run("OAuthExpiredNoRefresh", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
db = dbmem.New()
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
LastUsed: dbtime.Now(),
ExpiresAt: dbtime.Now().AddDate(0, 0, 1),
LoginType: database.LoginTypeGithub,
})
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
_, err := db.InsertUserLink(ctx, database.InsertUserLinkParams{
UserID: user.ID,
LoginType: database.LoginTypeGithub,
OAuthExpiry: dbtime.Now().AddDate(0, 0, -1),
OAuthAccessToken: "letmein",
})
require.NoError(t, err)
r.Header.Set(codersdk.SessionTokenHeader, token)
oauthToken := &oauth2.Token{
AccessToken: "wow",
RefreshToken: "moo",
Expiry: dbtime.Now().AddDate(0, 0, 1),
}
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
OAuth2Configs: &httpmw.OAuth2Configs{
Github: &testutil.OAuth2Config{
Token: oauthToken,
},
},
RedirectToLogin: false,
})(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("RemoteIPUpdates", func(t *testing.T) {
-76
View File
@@ -1,76 +0,0 @@
package httpmw
import (
"context"
"fmt"
"net/http"
"time"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/tracing"
)
func Logger(log slog.Logger) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
start := time.Now()
sw, ok := rw.(*tracing.StatusWriter)
if !ok {
panic(fmt.Sprintf("ResponseWriter not a *tracing.StatusWriter; got %T", rw))
}
httplog := log.With(
slog.F("host", httpapi.RequestHost(r)),
slog.F("path", r.URL.Path),
slog.F("proto", r.Proto),
slog.F("remote_addr", r.RemoteAddr),
// Include the start timestamp in the log so that we have the
// source of truth. There is at least a theoretical chance that
// there can be a delay between `next.ServeHTTP` ending and us
// actually logging the request. This can also be useful when
// filtering logs that started at a certain time (compared to
// trying to compute the value).
slog.F("start", start),
)
next.ServeHTTP(sw, r)
end := time.Now()
// Don't log successful health check requests.
if r.URL.Path == "/api/v2" && sw.Status == http.StatusOK {
return
}
httplog = httplog.With(
slog.F("took", end.Sub(start)),
slog.F("status_code", sw.Status),
slog.F("latency_ms", float64(end.Sub(start)/time.Millisecond)),
)
// For status codes 400 and higher we
// want to log the response body.
if sw.Status >= http.StatusInternalServerError {
httplog = httplog.With(
slog.F("response_body", string(sw.ResponseBody())),
)
}
// We should not log at level ERROR for 5xx status codes because 5xx
// includes proxy errors etc. It also causes slogtest to fail
// instantly without an error message by default.
logLevelFn := httplog.Debug
if sw.Status >= http.StatusInternalServerError {
logLevelFn = httplog.Warn
}
// We already capture most of this information in the span (minus
// the response body which we don't want to capture anyways).
tracing.RunWithoutSpan(r.Context(), func(ctx context.Context) {
logLevelFn(ctx, r.Method)
})
})
}
}
+203
View File
@@ -0,0 +1,203 @@
package loggermw
import (
"context"
"fmt"
"net/http"
"sync"
"time"
"github.com/go-chi/chi/v5"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/tracing"
)
func Logger(log slog.Logger) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
start := time.Now()
sw, ok := rw.(*tracing.StatusWriter)
if !ok {
panic(fmt.Sprintf("ResponseWriter not a *tracing.StatusWriter; got %T", rw))
}
httplog := log.With(
slog.F("host", httpapi.RequestHost(r)),
slog.F("path", r.URL.Path),
slog.F("proto", r.Proto),
slog.F("remote_addr", r.RemoteAddr),
// Include the start timestamp in the log so that we have the
// source of truth. There is at least a theoretical chance that
// there can be a delay between `next.ServeHTTP` ending and us
// actually logging the request. This can also be useful when
// filtering logs that started at a certain time (compared to
// trying to compute the value).
slog.F("start", start),
)
logContext := NewRequestLogger(httplog, r.Method, start)
ctx := WithRequestLogger(r.Context(), logContext)
next.ServeHTTP(sw, r.WithContext(ctx))
// Don't log successful health check requests.
if r.URL.Path == "/api/v2" && sw.Status == http.StatusOK {
return
}
// For status codes 500 and higher we
// want to log the response body.
if sw.Status >= http.StatusInternalServerError {
logContext.WithFields(
slog.F("response_body", string(sw.ResponseBody())),
)
}
logContext.WriteLog(r.Context(), sw.Status)
})
}
}
type RequestLogger interface {
WithFields(fields ...slog.Field)
WriteLog(ctx context.Context, status int)
WithAuthContext(actor rbac.Subject)
}
type SlogRequestLogger struct {
log slog.Logger
written bool
message string
start time.Time
// Protects actors map for concurrent writes.
mu sync.RWMutex
actors map[rbac.SubjectType]rbac.Subject
}
var _ RequestLogger = &SlogRequestLogger{}
func NewRequestLogger(log slog.Logger, message string, start time.Time) RequestLogger {
return &SlogRequestLogger{
log: log,
written: false,
message: message,
start: start,
actors: make(map[rbac.SubjectType]rbac.Subject),
}
}
func (c *SlogRequestLogger) WithFields(fields ...slog.Field) {
c.log = c.log.With(fields...)
}
func (c *SlogRequestLogger) WithAuthContext(actor rbac.Subject) {
c.mu.Lock()
defer c.mu.Unlock()
c.actors[actor.Type] = actor
}
func (c *SlogRequestLogger) addAuthContextFields() {
c.mu.RLock()
defer c.mu.RUnlock()
usr, ok := c.actors[rbac.SubjectTypeUser]
if ok {
c.log = c.log.With(
slog.F("requestor_id", usr.ID),
slog.F("requestor_name", usr.FriendlyName),
slog.F("requestor_email", usr.Email),
)
} else {
// If there is no user, we log the requestor name for the first
// actor in a defined order.
for _, v := range actorLogOrder {
subj, ok := c.actors[v]
if !ok {
continue
}
c.log = c.log.With(
slog.F("requestor_name", subj.FriendlyName),
)
break
}
}
}
var actorLogOrder = []rbac.SubjectType{
rbac.SubjectTypeAutostart,
rbac.SubjectTypeCryptoKeyReader,
rbac.SubjectTypeCryptoKeyRotator,
rbac.SubjectTypeHangDetector,
rbac.SubjectTypeNotifier,
rbac.SubjectTypePrebuildsOrchestrator,
rbac.SubjectTypeProvisionerd,
rbac.SubjectTypeResourceMonitor,
rbac.SubjectTypeSystemReadProvisionerDaemons,
rbac.SubjectTypeSystemRestricted,
}
func (c *SlogRequestLogger) WriteLog(ctx context.Context, status int) {
if c.written {
return
}
c.written = true
end := time.Now()
// Right before we write the log, we try to find the user in the actors
// and add the fields to the log.
c.addAuthContextFields()
logger := c.log.With(
slog.F("took", end.Sub(c.start)),
slog.F("status_code", status),
slog.F("latency_ms", float64(end.Sub(c.start)/time.Millisecond)),
)
// If the request is routed, add the route parameters to the log.
if chiCtx := chi.RouteContext(ctx); chiCtx != nil {
urlParams := chiCtx.URLParams
routeParamsFields := make([]slog.Field, 0, len(urlParams.Keys))
for k, v := range urlParams.Keys {
if urlParams.Values[k] != "" {
routeParamsFields = append(routeParamsFields, slog.F("params_"+v, urlParams.Values[k]))
}
}
if len(routeParamsFields) > 0 {
logger = logger.With(routeParamsFields...)
}
}
// We already capture most of this information in the span (minus
// the response body which we don't want to capture anyways).
tracing.RunWithoutSpan(ctx, func(ctx context.Context) {
// We should not log at level ERROR for 5xx status codes because 5xx
// includes proxy errors etc. It also causes slogtest to fail
// instantly without an error message by default.
if status >= http.StatusInternalServerError {
logger.Warn(ctx, c.message)
} else {
logger.Debug(ctx, c.message)
}
})
}
type logContextKey struct{}
func WithRequestLogger(ctx context.Context, rl RequestLogger) context.Context {
return context.WithValue(ctx, logContextKey{}, rl)
}
func RequestLoggerFromContext(ctx context.Context) RequestLogger {
val := ctx.Value(logContextKey{})
if logCtx, ok := val.(RequestLogger); ok {
return logCtx
}
return nil
}
@@ -0,0 +1,311 @@
package loggermw
import (
"context"
"net/http"
"net/http/httptest"
"slices"
"strings"
"sync"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/testutil"
"github.com/coder/websocket"
)
func TestRequestLogger_WriteLog(t *testing.T) {
t.Parallel()
ctx := context.Background()
sink := &fakeSink{}
logger := slog.Make(sink)
logger = logger.Leveled(slog.LevelDebug)
logCtx := NewRequestLogger(logger, "GET", time.Now())
// Add custom fields
logCtx.WithFields(
slog.F("custom_field", "custom_value"),
)
// Write log for 200 status
logCtx.WriteLog(ctx, http.StatusOK)
require.Len(t, sink.entries, 1, "log was written twice")
require.Equal(t, sink.entries[0].Message, "GET")
require.Equal(t, sink.entries[0].Fields[0].Value, "custom_value")
// Attempt to write again (should be skipped).
logCtx.WriteLog(ctx, http.StatusInternalServerError)
require.Len(t, sink.entries, 1, "log was written twice")
}
func TestLoggerMiddleware_SingleRequest(t *testing.T) {
t.Parallel()
sink := &fakeSink{}
logger := slog.Make(sink)
logger = logger.Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
// Create a test handler to simulate an HTTP request
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write([]byte("OK"))
})
// Wrap the test handler with the Logger middleware
loggerMiddleware := Logger(logger)
wrappedHandler := loggerMiddleware(testHandler)
// Create a test HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/test-path", nil)
require.NoError(t, err, "failed to create request")
sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
// Serve the request
wrappedHandler.ServeHTTP(sw, req)
require.Len(t, sink.entries, 1, "log was written twice")
require.Equal(t, sink.entries[0].Message, "GET")
fieldsMap := make(map[string]any)
for _, field := range sink.entries[0].Fields {
fieldsMap[field.Name] = field.Value
}
// Check that the log contains the expected fields
requiredFields := []string{"host", "path", "proto", "remote_addr", "start", "took", "status_code", "latency_ms"}
for _, field := range requiredFields {
_, exists := fieldsMap[field]
require.True(t, exists, "field %q is missing in log fields", field)
}
require.Len(t, sink.entries[0].Fields, len(requiredFields), "log should contain only the required fields")
// Check value of the status code
require.Equal(t, fieldsMap["status_code"], http.StatusOK)
}
func TestLoggerMiddleware_WebSocket(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
sink := &fakeSink{
newEntries: make(chan slog.SinkEntry, 2),
}
logger := slog.Make(sink)
logger = logger.Leveled(slog.LevelDebug)
done := make(chan struct{})
wg := sync.WaitGroup{}
// Create a test handler to simulate a WebSocket connection
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
conn, err := websocket.Accept(rw, r, nil)
if !assert.NoError(t, err, "failed to accept websocket") {
return
}
defer conn.Close(websocket.StatusGoingAway, "")
requestLgr := RequestLoggerFromContext(r.Context())
requestLgr.WriteLog(r.Context(), http.StatusSwitchingProtocols)
// Block so we can be sure the end of the middleware isn't being called.
wg.Wait()
})
// Wrap the test handler with the Logger middleware
loggerMiddleware := Logger(logger)
wrappedHandler := loggerMiddleware(testHandler)
// RequestLogger expects the ResponseWriter to be *tracing.StatusWriter
customHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
defer close(done)
sw := &tracing.StatusWriter{ResponseWriter: rw}
wrappedHandler.ServeHTTP(sw, r)
})
srv := httptest.NewServer(customHandler)
defer srv.Close()
wg.Add(1)
// nolint: bodyclose
conn, _, err := websocket.Dial(ctx, srv.URL, nil)
require.NoError(t, err, "failed to dial WebSocket")
defer conn.Close(websocket.StatusNormalClosure, "")
// Wait for the log from within the handler
newEntry := testutil.RequireRecvCtx(ctx, t, sink.newEntries)
require.Equal(t, newEntry.Message, "GET")
// Signal the websocket handler to return (and read to handle the close frame)
wg.Done()
_, _, err = conn.Read(ctx)
require.ErrorAs(t, err, &websocket.CloseError{}, "websocket read should fail with close error")
// Wait for the request to finish completely and verify we only logged once
_ = testutil.RequireRecvCtx(ctx, t, done)
require.Len(t, sink.entries, 1, "log was written twice")
}
func TestRequestLogger_HTTPRouteParams(t *testing.T) {
t.Parallel()
sink := &fakeSink{}
logger := slog.Make(sink)
logger = logger.Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("workspace", "test-workspace")
chiCtx.URLParams.Add("agent", "test-agent")
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
// Create a test handler to simulate an HTTP request
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write([]byte("OK"))
})
// Wrap the test handler with the Logger middleware
loggerMiddleware := Logger(logger)
wrappedHandler := loggerMiddleware(testHandler)
// Create a test HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/test-path/}", nil)
require.NoError(t, err, "failed to create request")
sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
// Serve the request
wrappedHandler.ServeHTTP(sw, req)
fieldsMap := make(map[string]any)
for _, field := range sink.entries[0].Fields {
fieldsMap[field.Name] = field.Value
}
// Check that the log contains the expected fields
requiredFields := []string{"workspace", "agent"}
for _, field := range requiredFields {
_, exists := fieldsMap["params_"+field]
require.True(t, exists, "field %q is missing in log fields", field)
}
}
func TestRequestLogger_RouteParamsLogging(t *testing.T) {
t.Parallel()
tests := []struct {
name string
params map[string]string
expectedFields []string
}{
{
name: "EmptyParams",
params: map[string]string{},
expectedFields: []string{},
},
{
name: "SingleParam",
params: map[string]string{
"workspace": "test-workspace",
},
expectedFields: []string{"params_workspace"},
},
{
name: "MultipleParams",
params: map[string]string{
"workspace": "test-workspace",
"agent": "test-agent",
"user": "test-user",
},
expectedFields: []string{"params_workspace", "params_agent", "params_user"},
},
{
name: "EmptyValueParam",
params: map[string]string{
"workspace": "test-workspace",
"agent": "",
},
expectedFields: []string{"params_workspace"},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
sink := &fakeSink{}
logger := slog.Make(sink)
logger = logger.Leveled(slog.LevelDebug)
// Create a route context with the test parameters
chiCtx := chi.NewRouteContext()
for key, value := range tt.params {
chiCtx.URLParams.Add(key, value)
}
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
logCtx := NewRequestLogger(logger, "GET", time.Now())
// Write the log
logCtx.WriteLog(ctx, http.StatusOK)
require.Len(t, sink.entries, 1, "expected exactly one log entry")
// Convert fields to map for easier checking
fieldsMap := make(map[string]any)
for _, field := range sink.entries[0].Fields {
fieldsMap[field.Name] = field.Value
}
// Verify expected fields are present
for _, field := range tt.expectedFields {
value, exists := fieldsMap[field]
require.True(t, exists, "field %q should be present in log", field)
require.Equal(t, tt.params[strings.TrimPrefix(field, "params_")], value, "field %q has incorrect value", field)
}
// Verify no unexpected fields are present
for field := range fieldsMap {
if field == "took" || field == "status_code" || field == "latency_ms" {
continue // Skip standard fields
}
require.True(t, slices.Contains(tt.expectedFields, field), "unexpected field %q in log", field)
}
})
}
}
type fakeSink struct {
entries []slog.SinkEntry
newEntries chan slog.SinkEntry
}
func (s *fakeSink) LogEntry(_ context.Context, e slog.SinkEntry) {
s.entries = append(s.entries, e)
if s.newEntries != nil {
select {
case s.newEntries <- e:
default:
}
}
}
func (*fakeSink) Sync() {}
@@ -0,0 +1,83 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/coder/coder/v2/coderd/httpmw/loggermw (interfaces: RequestLogger)
//
// Generated by this command:
//
// mockgen -destination=loggermock/loggermock.go -package=loggermock . RequestLogger
//
// Package loggermock is a generated GoMock package.
package loggermock
import (
context "context"
reflect "reflect"
slog "cdr.dev/slog"
rbac "github.com/coder/coder/v2/coderd/rbac"
gomock "go.uber.org/mock/gomock"
)
// MockRequestLogger is a mock of RequestLogger interface.
type MockRequestLogger struct {
ctrl *gomock.Controller
recorder *MockRequestLoggerMockRecorder
isgomock struct{}
}
// MockRequestLoggerMockRecorder is the mock recorder for MockRequestLogger.
type MockRequestLoggerMockRecorder struct {
mock *MockRequestLogger
}
// NewMockRequestLogger creates a new mock instance.
func NewMockRequestLogger(ctrl *gomock.Controller) *MockRequestLogger {
mock := &MockRequestLogger{ctrl: ctrl}
mock.recorder = &MockRequestLoggerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRequestLogger) EXPECT() *MockRequestLoggerMockRecorder {
return m.recorder
}
// WithAuthContext mocks base method.
func (m *MockRequestLogger) WithAuthContext(actor rbac.Subject) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "WithAuthContext", actor)
}
// WithAuthContext indicates an expected call of WithAuthContext.
func (mr *MockRequestLoggerMockRecorder) WithAuthContext(actor any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithAuthContext", reflect.TypeOf((*MockRequestLogger)(nil).WithAuthContext), actor)
}
// WithFields mocks base method.
func (m *MockRequestLogger) WithFields(fields ...slog.Field) {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range fields {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "WithFields", varargs...)
}
// WithFields indicates an expected call of WithFields.
func (mr *MockRequestLoggerMockRecorder) WithFields(fields ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithFields", reflect.TypeOf((*MockRequestLogger)(nil).WithFields), fields...)
}
// WriteLog mocks base method.
func (m *MockRequestLogger) WriteLog(ctx context.Context, status int) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "WriteLog", ctx, status)
}
// WriteLog indicates an expected call of WriteLog.
func (mr *MockRequestLoggerMockRecorder) WriteLog(ctx, status any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteLog", reflect.TypeOf((*MockRequestLogger)(nil).WriteLog), ctx, status)
}
+42 -10
View File
@@ -3,6 +3,7 @@ package httpmw
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/go-chi/chi/v5"
@@ -22,18 +23,18 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
Name: "requests_processed_total",
Help: "The total number of processed API requests",
}, []string{"code", "method", "path"})
requestsConcurrent := factory.NewGauge(prometheus.GaugeOpts{
requestsConcurrent := factory.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "coderd",
Subsystem: "api",
Name: "concurrent_requests",
Help: "The number of concurrent API requests.",
})
websocketsConcurrent := factory.NewGauge(prometheus.GaugeOpts{
}, []string{"method", "path"})
websocketsConcurrent := factory.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "coderd",
Subsystem: "api",
Name: "concurrent_websockets",
Help: "The total number of concurrent API websockets.",
})
}, []string{"path"})
websocketsDist := factory.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "coderd",
Subsystem: "api",
@@ -61,7 +62,6 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
var (
start = time.Now()
method = r.Method
rctx = chi.RouteContext(r.Context())
)
sw, ok := w.(*tracing.StatusWriter)
@@ -72,16 +72,18 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
var (
dist *prometheus.HistogramVec
distOpts []string
path = getRoutePattern(r)
)
// We want to count WebSockets separately.
if httpapi.IsWebsocketUpgrade(r) {
websocketsConcurrent.Inc()
defer websocketsConcurrent.Dec()
websocketsConcurrent.WithLabelValues(path).Inc()
defer websocketsConcurrent.WithLabelValues(path).Dec()
dist = websocketsDist
} else {
requestsConcurrent.Inc()
defer requestsConcurrent.Dec()
requestsConcurrent.WithLabelValues(method, path).Inc()
defer requestsConcurrent.WithLabelValues(method, path).Dec()
dist = requestsDist
distOpts = []string{method}
@@ -89,7 +91,6 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
next.ServeHTTP(w, r)
path := rctx.RoutePattern()
distOpts = append(distOpts, path)
statusStr := strconv.Itoa(sw.Status)
@@ -98,3 +99,34 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
})
}
}
func getRoutePattern(r *http.Request) string {
rctx := chi.RouteContext(r.Context())
if rctx == nil {
return ""
}
if pattern := rctx.RoutePattern(); pattern != "" {
// Pattern is already available
return pattern
}
routePath := r.URL.Path
if r.URL.RawPath != "" {
routePath = r.URL.RawPath
}
tctx := chi.NewRouteContext()
routes := rctx.Routes
if routes != nil && !routes.Match(tctx, r.Method, routePath) {
// No matching pattern. /api/* requests will be matched as "UNKNOWN"
// All other ones will be matched as "STATIC".
if strings.HasPrefix(routePath, "/api/") {
return "UNKNOWN"
}
return "STATIC"
}
// tctx has the updated pattern, since Match mutates it
return tctx.RoutePattern()
}
+91
View File
@@ -8,14 +8,19 @@ import (
"github.com/go-chi/chi/v5"
"github.com/prometheus/client_golang/prometheus"
cm "github.com/prometheus/client_model/go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/testutil"
"github.com/coder/websocket"
)
func TestPrometheus(t *testing.T) {
t.Parallel()
t.Run("All", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/", nil)
@@ -29,4 +34,90 @@ func TestPrometheus(t *testing.T) {
require.NoError(t, err)
require.Greater(t, len(metrics), 0)
})
t.Run("Concurrent", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
reg := prometheus.NewRegistry()
promMW := httpmw.Prometheus(reg)
// Create a test handler to simulate a WebSocket connection
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
conn, err := websocket.Accept(rw, r, nil)
if !assert.NoError(t, err, "failed to accept websocket") {
return
}
defer conn.Close(websocket.StatusGoingAway, "")
})
wrappedHandler := promMW(testHandler)
r := chi.NewRouter()
r.Use(tracing.StatusWriterMiddleware, promMW)
r.Get("/api/v2/build/{build}/logs", func(rw http.ResponseWriter, r *http.Request) {
wrappedHandler.ServeHTTP(rw, r)
})
srv := httptest.NewServer(r)
defer srv.Close()
// nolint: bodyclose
conn, _, err := websocket.Dial(ctx, srv.URL+"/api/v2/build/1/logs", nil)
require.NoError(t, err, "failed to dial WebSocket")
defer conn.Close(websocket.StatusNormalClosure, "")
metrics, err := reg.Gather()
require.NoError(t, err)
require.Greater(t, len(metrics), 0)
metricLabels := getMetricLabels(metrics)
concurrentWebsockets, ok := metricLabels["coderd_api_concurrent_websockets"]
require.True(t, ok, "coderd_api_concurrent_websockets metric not found")
require.Equal(t, "/api/v2/build/{build}/logs", concurrentWebsockets["path"])
})
t.Run("UserRoute", func(t *testing.T) {
t.Parallel()
reg := prometheus.NewRegistry()
promMW := httpmw.Prometheus(reg)
r := chi.NewRouter()
r.With(promMW).Get("/api/v2/users/{user}", func(w http.ResponseWriter, r *http.Request) {})
req := httptest.NewRequest("GET", "/api/v2/users/john", nil)
sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
r.ServeHTTP(sw, req)
metrics, err := reg.Gather()
require.NoError(t, err)
require.Greater(t, len(metrics), 0)
metricLabels := getMetricLabels(metrics)
reqProcessed, ok := metricLabels["coderd_api_requests_processed_total"]
require.True(t, ok, "coderd_api_requests_processed_total metric not found")
require.Equal(t, "/api/v2/users/{user}", reqProcessed["path"])
require.Equal(t, "GET", reqProcessed["method"])
concurrentRequests, ok := metricLabels["coderd_api_concurrent_requests"]
require.True(t, ok, "coderd_api_concurrent_requests metric not found")
require.Equal(t, "/api/v2/users/{user}", concurrentRequests["path"])
require.Equal(t, "GET", concurrentRequests["method"])
})
}
func getMetricLabels(metrics []*cm.MetricFamily) map[string]map[string]string {
metricLabels := map[string]map[string]string{}
for _, metricFamily := range metrics {
metricName := metricFamily.GetName()
metricLabels[metricName] = map[string]string{}
for _, metric := range metricFamily.GetMetric() {
for _, labelPair := range metric.GetLabel() {
metricLabels[metricName][labelPair.GetName()] = labelPair.GetValue()
}
}
}
return metricLabels
}
+11
View File
@@ -6,8 +6,11 @@ import (
"github.com/go-chi/chi/v5"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
"github.com/coder/coder/v2/codersdk"
)
@@ -81,6 +84,14 @@ func ExtractWorkspaceAgentParam(db database.Store) func(http.Handler) http.Handl
ctx = context.WithValue(ctx, workspaceAgentParamContextKey{}, agent)
chi.RouteContext(ctx).URLParams.Add("workspace", build.WorkspaceID.String())
if rlogger := loggermw.RequestLoggerFromContext(ctx); rlogger != nil {
rlogger.WithFields(
slog.F("workspace_name", resource.Name),
slog.F("agent_name", agent.Name),
)
}
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
+15
View File
@@ -9,8 +9,11 @@ import (
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
"github.com/coder/coder/v2/codersdk"
)
@@ -48,6 +51,11 @@ func ExtractWorkspaceParam(db database.Store) func(http.Handler) http.Handler {
}
ctx = context.WithValue(ctx, workspaceParamContextKey{}, workspace)
if rlogger := loggermw.RequestLoggerFromContext(ctx); rlogger != nil {
rlogger.WithFields(slog.F("workspace_name", workspace.Name))
}
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
@@ -154,6 +162,13 @@ func ExtractWorkspaceAndAgentParam(db database.Store) func(http.Handler) http.Ha
ctx = context.WithValue(ctx, workspaceParamContextKey{}, workspace)
ctx = context.WithValue(ctx, workspaceAgentParamContextKey{}, agent)
if rlogger := loggermw.RequestLoggerFromContext(ctx); rlogger != nil {
rlogger.WithFields(
slog.F("workspace_name", workspace.Name),
slog.F("agent_name", agent.Name),
)
}
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
+3
View File
@@ -268,6 +268,9 @@ func (s *GroupSyncSettings) Set(v string) error {
}
func (s *GroupSyncSettings) String() string {
if s.Mapping == nil {
s.Mapping = make(map[string][]uuid.UUID)
}
return runtimeconfig.JSONString(s)
}
+56 -4
View File
@@ -92,14 +92,16 @@ func (s AGPLIDPSync) SyncOrganizations(ctx context.Context, tx database.Store, u
return nil // No sync configured, nothing to do
}
expectedOrgs, err := orgSettings.ParseClaims(ctx, tx, params.MergedClaims)
expectedOrgIDs, err := orgSettings.ParseClaims(ctx, tx, params.MergedClaims)
if err != nil {
return xerrors.Errorf("organization claims: %w", err)
}
// Fetch all organizations, even deleted ones. This is to remove a user
// from any deleted organizations they may be in.
existingOrgs, err := tx.GetOrganizationsByUserID(ctx, database.GetOrganizationsByUserIDParams{
UserID: user.ID,
Deleted: false,
Deleted: sql.NullBool{},
})
if err != nil {
return xerrors.Errorf("failed to get user organizations: %w", err)
@@ -109,10 +111,35 @@ func (s AGPLIDPSync) SyncOrganizations(ctx context.Context, tx database.Store, u
return org.ID
})
// finalExpected is the final set of org ids the user is expected to be in.
// Deleted orgs are omitted from this set.
finalExpected := expectedOrgIDs
if len(expectedOrgIDs) > 0 {
// If you pass in an empty slice to the db arg, you get all orgs. So the slice
// has to be non-empty to get the expected set. Logically it also does not make
// sense to fetch an empty set from the db.
expectedOrganizations, err := tx.GetOrganizations(ctx, database.GetOrganizationsParams{
IDs: expectedOrgIDs,
// Do not include deleted organizations. Omitting deleted orgs will remove the
// user from any deleted organizations they are a member of.
Deleted: false,
})
if err != nil {
return xerrors.Errorf("failed to get expected organizations: %w", err)
}
finalExpected = db2sdk.List(expectedOrganizations, func(org database.Organization) uuid.UUID {
return org.ID
})
}
// Find the difference in the expected and the existing orgs, and
// correct the set of orgs the user is a member of.
add, remove := slice.SymmetricDifference(existingOrgIDs, expectedOrgs)
notExists := make([]uuid.UUID, 0)
add, remove := slice.SymmetricDifference(existingOrgIDs, finalExpected)
// notExists is purely for debugging. It logs when the settings want
// a user in an organization, but the organization does not exist.
notExists := slice.DifferenceFunc(expectedOrgIDs, finalExpected, func(a, b uuid.UUID) bool {
return a == b
})
for _, orgID := range add {
_, err := tx.InsertOrganizationMember(ctx, database.InsertOrganizationMemberParams{
OrganizationID: orgID,
@@ -123,9 +150,30 @@ func (s AGPLIDPSync) SyncOrganizations(ctx context.Context, tx database.Store, u
})
if err != nil {
if xerrors.Is(err, sql.ErrNoRows) {
// This should not happen because we check the org existence
// beforehand.
notExists = append(notExists, orgID)
continue
}
if database.IsUniqueViolation(err, database.UniqueOrganizationMembersPkey) {
// If we hit this error we have a bug. The user already exists in the
// organization, but was not detected to be at the start of this function.
// Instead of failing the function, an error will be logged. This is to not bring
// down the entire syncing behavior from a single failed org. Failing this can
// prevent user logins, so only fatal non-recoverable errors should be returned.
//
// Inserting a user is privilege escalation. So skipping this instead of failing
// leaves the user with fewer permissions. So this is safe from a security
// perspective to continue.
s.Logger.Error(ctx, "syncing user to organization failed as they are already a member, please report this failure to Coder",
slog.F("user_id", user.ID),
slog.F("username", user.Username),
slog.F("organization_id", orgID),
slog.Error(err),
)
continue
}
return xerrors.Errorf("add user to organization: %w", err)
}
}
@@ -141,6 +189,7 @@ func (s AGPLIDPSync) SyncOrganizations(ctx context.Context, tx database.Store, u
}
if len(notExists) > 0 {
notExists = slice.Unique(notExists) // Remove duplicates
s.Logger.Debug(ctx, "organizations do not exist but attempted to use in org sync",
slog.F("not_found", notExists),
slog.F("user_id", user.ID),
@@ -168,6 +217,9 @@ func (s *OrganizationSyncSettings) Set(v string) error {
}
func (s *OrganizationSyncSettings) String() string {
if s.Mapping == nil {
s.Mapping = make(map[string][]uuid.UUID)
}
return runtimeconfig.JSONString(s)
}
+111
View File
@@ -1,6 +1,7 @@
package idpsync_test
import (
"database/sql"
"testing"
"github.com/golang-jwt/jwt/v4"
@@ -8,6 +9,11 @@ import (
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/idpsync"
"github.com/coder/coder/v2/coderd/runtimeconfig"
"github.com/coder/coder/v2/testutil"
@@ -38,3 +44,108 @@ func TestParseOrganizationClaims(t *testing.T) {
require.False(t, params.SyncEntitled)
})
}
func TestSyncOrganizations(t *testing.T) {
t.Parallel()
// This test creates some deleted organizations and checks the behavior is
// correct.
t.Run("SyncUserToDeletedOrg", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
// Create orgs for:
// - stays = User is a member, and stays
// - leaves = User is a member, and leaves
// - joins = User is not a member, and joins
// For deleted orgs, the user **should not** be a member of afterwards.
// - deletedStays = User is a member of deleted org, and wants to stay
// - deletedLeaves = User is a member of deleted org, and wants to leave
// - deletedJoins = User is not a member of deleted org, and wants to join
stays := dbfake.Organization(t, db).Members(user).Do()
leaves := dbfake.Organization(t, db).Members(user).Do()
joins := dbfake.Organization(t, db).Do()
deletedStays := dbfake.Organization(t, db).Members(user).Deleted(true).Do()
deletedLeaves := dbfake.Organization(t, db).Members(user).Deleted(true).Do()
deletedJoins := dbfake.Organization(t, db).Deleted(true).Do()
// Now sync the user to the deleted organization
s := idpsync.NewAGPLSync(
slogtest.Make(t, &slogtest.Options{}),
runtimeconfig.NewManager(),
idpsync.DeploymentSyncSettings{
OrganizationField: "orgs",
OrganizationMapping: map[string][]uuid.UUID{
"stay": {stays.Org.ID, deletedStays.Org.ID},
"leave": {leaves.Org.ID, deletedLeaves.Org.ID},
"join": {joins.Org.ID, deletedJoins.Org.ID},
},
OrganizationAssignDefault: false,
},
)
err := s.SyncOrganizations(ctx, db, user, idpsync.OrganizationParams{
SyncEntitled: true,
MergedClaims: map[string]interface{}{
"orgs": []string{"stay", "join"},
},
})
require.NoError(t, err)
orgs, err := db.GetOrganizationsByUserID(ctx, database.GetOrganizationsByUserIDParams{
UserID: user.ID,
Deleted: sql.NullBool{},
})
require.NoError(t, err)
require.Len(t, orgs, 2)
// Verify the user only exists in 2 orgs. The one they stayed, and the one they
// joined.
inIDs := db2sdk.List(orgs, func(org database.Organization) uuid.UUID {
return org.ID
})
require.ElementsMatch(t, []uuid.UUID{stays.Org.ID, joins.Org.ID}, inIDs)
})
t.Run("UserToZeroOrgs", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
deletedLeaves := dbfake.Organization(t, db).Members(user).Deleted(true).Do()
// Now sync the user to the deleted organization
s := idpsync.NewAGPLSync(
slogtest.Make(t, &slogtest.Options{}),
runtimeconfig.NewManager(),
idpsync.DeploymentSyncSettings{
OrganizationField: "orgs",
OrganizationMapping: map[string][]uuid.UUID{
"leave": {deletedLeaves.Org.ID},
},
OrganizationAssignDefault: false,
},
)
err := s.SyncOrganizations(ctx, db, user, idpsync.OrganizationParams{
SyncEntitled: true,
MergedClaims: map[string]interface{}{
"orgs": []string{},
},
})
require.NoError(t, err)
orgs, err := db.GetOrganizationsByUserID(ctx, database.GetOrganizationsByUserIDParams{
UserID: user.ID,
Deleted: sql.NullBool{},
})
require.NoError(t, err)
require.Len(t, orgs, 0)
})
}
+3
View File
@@ -284,5 +284,8 @@ func (s *RoleSyncSettings) Set(v string) error {
}
func (s *RoleSyncSettings) String() string {
if s.Mapping == nil {
s.Mapping = make(map[string][]string)
}
return runtimeconfig.JSONString(s)
}
+1 -1
View File
@@ -323,7 +323,7 @@ func convertOrganizationMembers(ctx context.Context, db database.Store, mems []d
customRoles, err := db.CustomRoles(ctx, database.CustomRolesParams{
LookupRoles: roleLookup,
ExcludeOrgRoles: false,
OrganizationID: uuid.UUID{},
OrganizationID: uuid.Nil,
})
if err != nil {
// We are missing the display names, but that is not absolutely required. So just
+1
View File
@@ -144,6 +144,7 @@ func TestAzureAKPKIWithCoderd(t *testing.T) {
return values, nil
}),
oidctest.WithServing(),
oidctest.WithLogging(t, nil),
)
cfg := fake.OIDCConfig(t, scopes, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
+4
View File
@@ -20,6 +20,7 @@ import (
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/util/slice"
@@ -554,6 +555,9 @@ func (f *logFollower) follow() {
return
}
// Log the request immediately instead of after it completes.
loggermw.RequestLoggerFromContext(f.ctx).WriteLog(f.ctx, http.StatusAccepted)
// no need to wait if the job is done
if f.complete {
return
+7
View File
@@ -19,6 +19,8 @@ import (
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
"github.com/coder/coder/v2/coderd/httpmw/loggermw/loggermock"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisionersdk"
"github.com/coder/coder/v2/testutil"
@@ -305,11 +307,16 @@ func Test_logFollower_EndOfLogs(t *testing.T) {
JobStatus: database.ProvisionerJobStatusRunning,
}
mockLogger := loggermock.NewMockRequestLogger(ctrl)
mockLogger.EXPECT().WriteLog(gomock.Any(), http.StatusAccepted).Times(1)
ctx = loggermw.WithRequestLogger(ctx, mockLogger)
// we need an HTTP server to get a websocket
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
uut := newLogFollower(ctx, logger, mDB, ps, rw, r, job, 0)
uut.follow()
}))
defer srv.Close()
// job was incomplete when we create the logFollower, and still incomplete when it queries
+25
View File
@@ -57,6 +57,23 @@ func hashAuthorizeCall(actor Subject, action policy.Action, object Object) [32]b
return hashOut
}
// SubjectType represents the type of subject in the RBAC system.
type SubjectType string
const (
SubjectTypeUser SubjectType = "user"
SubjectTypeProvisionerd SubjectType = "provisionerd"
SubjectTypeAutostart SubjectType = "autostart"
SubjectTypeHangDetector SubjectType = "hang_detector"
SubjectTypeResourceMonitor SubjectType = "resource_monitor"
SubjectTypeCryptoKeyRotator SubjectType = "crypto_key_rotator"
SubjectTypeCryptoKeyReader SubjectType = "crypto_key_reader"
SubjectTypePrebuildsOrchestrator SubjectType = "prebuilds_orchestrator"
SubjectTypeSystemReadProvisionerDaemons SubjectType = "system_read_provisioner_daemons"
SubjectTypeSystemRestricted SubjectType = "system_restricted"
SubjectTypeNotifier SubjectType = "notifier"
)
// Subject is a struct that contains all the elements of a subject in an rbac
// authorize.
type Subject struct {
@@ -66,6 +83,14 @@ type Subject struct {
// external workspace proxy or other service type actor.
FriendlyName string
// Email is entirely optional and is used for logging and debugging
// It is not used in any functional way.
Email string
// Type indicates what kind of subject this is (user, system, provisioner, etc.)
// It is not used in any functional way, only for logging.
Type SubjectType
ID string
Roles ExpandableRoles
Groups []string
+9 -9
View File
@@ -27,22 +27,21 @@ var (
// ResourceAssignOrgRole
// Valid Actions
// - "ActionAssign" :: ability to assign org scoped roles
// - "ActionCreate" :: ability to create/delete custom roles within an organization
// - "ActionDelete" :: ability to delete org scoped roles
// - "ActionRead" :: view what roles are assignable
// - "ActionUpdate" :: ability to edit custom roles within an organization
// - "ActionAssign" :: assign org scoped roles
// - "ActionCreate" :: create/delete custom roles within an organization
// - "ActionDelete" :: delete roles within an organization
// - "ActionRead" :: view what roles are assignable within an organization
// - "ActionUnassign" :: unassign org scoped roles
// - "ActionUpdate" :: edit custom roles within an organization
ResourceAssignOrgRole = Object{
Type: "assign_org_role",
}
// ResourceAssignRole
// Valid Actions
// - "ActionAssign" :: ability to assign roles
// - "ActionCreate" :: ability to create/delete/edit custom roles
// - "ActionDelete" :: ability to unassign roles
// - "ActionAssign" :: assign user roles
// - "ActionRead" :: view what roles are assignable
// - "ActionUpdate" :: ability to edit custom roles
// - "ActionUnassign" :: unassign user roles
ResourceAssignRole = Object{
Type: "assign_role",
}
@@ -367,6 +366,7 @@ func AllActions() []policy.Action {
policy.ActionRead,
policy.ActionReadPersonal,
policy.ActionSSH,
policy.ActionUnassign,
policy.ActionUpdate,
policy.ActionUpdatePersonal,
policy.ActionUse,
+11 -11
View File
@@ -19,7 +19,8 @@ const (
ActionWorkspaceStart Action = "start"
ActionWorkspaceStop Action = "stop"
ActionAssign Action = "assign"
ActionAssign Action = "assign"
ActionUnassign Action = "unassign"
ActionReadPersonal Action = "read_personal"
ActionUpdatePersonal Action = "update_personal"
@@ -221,20 +222,19 @@ var RBACPermissions = map[string]PermissionDefinition{
},
"assign_role": {
Actions: map[Action]ActionDefinition{
ActionAssign: actDef("ability to assign roles"),
ActionRead: actDef("view what roles are assignable"),
ActionDelete: actDef("ability to unassign roles"),
ActionCreate: actDef("ability to create/delete/edit custom roles"),
ActionUpdate: actDef("ability to edit custom roles"),
ActionAssign: actDef("assign user roles"),
ActionUnassign: actDef("unassign user roles"),
ActionRead: actDef("view what roles are assignable"),
},
},
"assign_org_role": {
Actions: map[Action]ActionDefinition{
ActionAssign: actDef("ability to assign org scoped roles"),
ActionRead: actDef("view what roles are assignable"),
ActionDelete: actDef("ability to delete org scoped roles"),
ActionCreate: actDef("ability to create/delete custom roles within an organization"),
ActionUpdate: actDef("ability to edit custom roles within an organization"),
ActionAssign: actDef("assign org scoped roles"),
ActionUnassign: actDef("unassign org scoped roles"),
ActionCreate: actDef("create/delete custom roles within an organization"),
ActionRead: actDef("view what roles are assignable within an organization"),
ActionUpdate: actDef("edit custom roles within an organization"),
ActionDelete: actDef("delete roles within an organization"),
},
},
"oauth2_app": {
+79 -40
View File
@@ -27,11 +27,12 @@ const (
customSiteRole string = "custom-site-role"
customOrganizationRole string = "custom-organization-role"
orgAdmin string = "organization-admin"
orgMember string = "organization-member"
orgAuditor string = "organization-auditor"
orgUserAdmin string = "organization-user-admin"
orgTemplateAdmin string = "organization-template-admin"
orgAdmin string = "organization-admin"
orgMember string = "organization-member"
orgAuditor string = "organization-auditor"
orgUserAdmin string = "organization-user-admin"
orgTemplateAdmin string = "organization-template-admin"
orgWorkspaceCreationBan string = "organization-workspace-creation-ban"
)
func init() {
@@ -159,6 +160,10 @@ func RoleOrgTemplateAdmin() string {
return orgTemplateAdmin
}
func RoleOrgWorkspaceCreationBan() string {
return orgWorkspaceCreationBan
}
// ScopedRoleOrgAdmin is the org role with the organization ID
func ScopedRoleOrgAdmin(organizationID uuid.UUID) RoleIdentifier {
return RoleIdentifier{Name: RoleOrgAdmin(), OrganizationID: organizationID}
@@ -181,6 +186,10 @@ func ScopedRoleOrgTemplateAdmin(organizationID uuid.UUID) RoleIdentifier {
return RoleIdentifier{Name: RoleOrgTemplateAdmin(), OrganizationID: organizationID}
}
func ScopedRoleOrgWorkspaceCreationBan(organizationID uuid.UUID) RoleIdentifier {
return RoleIdentifier{Name: RoleOrgWorkspaceCreationBan(), OrganizationID: organizationID}
}
func allPermsExcept(excepts ...Objecter) []Permission {
resources := AllResources()
var perms []Permission
@@ -298,7 +307,8 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
Identifier: RoleAuditor(),
DisplayName: "Auditor",
Site: Permissions(map[string][]policy.Action{
ResourceAuditLog.Type: {policy.ActionRead},
ResourceAssignOrgRole.Type: {policy.ActionRead},
ResourceAuditLog.Type: {policy.ActionRead},
// Allow auditors to see the resources that audit logs reflect.
ResourceTemplate.Type: {policy.ActionRead, policy.ActionViewInsights},
ResourceUser.Type: {policy.ActionRead},
@@ -318,7 +328,8 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
Identifier: RoleTemplateAdmin(),
DisplayName: "Template Admin",
Site: Permissions(map[string][]policy.Action{
ResourceTemplate.Type: ResourceTemplate.AvailableActions(),
ResourceAssignOrgRole.Type: {policy.ActionRead},
ResourceTemplate.Type: ResourceTemplate.AvailableActions(),
// CRUD all files, even those they did not upload.
ResourceFile.Type: {policy.ActionCreate, policy.ActionRead},
ResourceWorkspace.Type: {policy.ActionRead},
@@ -339,10 +350,10 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
Identifier: RoleUserAdmin(),
DisplayName: "User Admin",
Site: Permissions(map[string][]policy.Action{
ResourceAssignRole.Type: {policy.ActionAssign, policy.ActionDelete, policy.ActionRead},
ResourceAssignRole.Type: {policy.ActionAssign, policy.ActionUnassign, policy.ActionRead},
// Need organization assign as well to create users. At present, creating a user
// will always assign them to some organization.
ResourceAssignOrgRole.Type: {policy.ActionAssign, policy.ActionDelete, policy.ActionRead},
ResourceAssignOrgRole.Type: {policy.ActionAssign, policy.ActionUnassign, policy.ActionRead},
ResourceUser.Type: {
policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete,
policy.ActionUpdatePersonal, policy.ActionReadPersonal,
@@ -459,7 +470,7 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
Org: map[string][]Permission{
organizationID.String(): Permissions(map[string][]policy.Action{
// Assign, remove, and read roles in the organization.
ResourceAssignOrgRole.Type: {policy.ActionAssign, policy.ActionDelete, policy.ActionRead},
ResourceAssignOrgRole.Type: {policy.ActionAssign, policy.ActionUnassign, policy.ActionRead},
ResourceOrganization.Type: {policy.ActionRead},
ResourceOrganizationMember.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
ResourceGroup.Type: ResourceGroup.AvailableActions(),
@@ -496,6 +507,31 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
User: []Permission{},
}
},
// orgWorkspaceCreationBan prevents creating & deleting workspaces. This
// overrides any permissions granted by the org or user level. It accomplishes
// this by using negative permissions.
orgWorkspaceCreationBan: func(organizationID uuid.UUID) Role {
return Role{
Identifier: RoleIdentifier{Name: orgWorkspaceCreationBan, OrganizationID: organizationID},
DisplayName: "Organization Workspace Creation Ban",
Site: []Permission{},
Org: map[string][]Permission{
organizationID.String(): {
{
Negate: true,
ResourceType: ResourceWorkspace.Type,
Action: policy.ActionCreate,
},
{
Negate: true,
ResourceType: ResourceWorkspace.Type,
Action: policy.ActionDelete,
},
},
},
User: []Permission{},
}
},
}
}
@@ -506,44 +542,47 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
// map[actor_role][assign_role]<can_assign>
var assignRoles = map[string]map[string]bool{
"system": {
owner: true,
auditor: true,
member: true,
orgAdmin: true,
orgMember: true,
orgAuditor: true,
orgUserAdmin: true,
orgTemplateAdmin: true,
templateAdmin: true,
userAdmin: true,
customSiteRole: true,
customOrganizationRole: true,
owner: true,
auditor: true,
member: true,
orgAdmin: true,
orgMember: true,
orgAuditor: true,
orgUserAdmin: true,
orgTemplateAdmin: true,
orgWorkspaceCreationBan: true,
templateAdmin: true,
userAdmin: true,
customSiteRole: true,
customOrganizationRole: true,
},
owner: {
owner: true,
auditor: true,
member: true,
orgAdmin: true,
orgMember: true,
orgAuditor: true,
orgUserAdmin: true,
orgTemplateAdmin: true,
templateAdmin: true,
userAdmin: true,
customSiteRole: true,
customOrganizationRole: true,
owner: true,
auditor: true,
member: true,
orgAdmin: true,
orgMember: true,
orgAuditor: true,
orgUserAdmin: true,
orgTemplateAdmin: true,
orgWorkspaceCreationBan: true,
templateAdmin: true,
userAdmin: true,
customSiteRole: true,
customOrganizationRole: true,
},
userAdmin: {
member: true,
orgMember: true,
},
orgAdmin: {
orgAdmin: true,
orgMember: true,
orgAuditor: true,
orgUserAdmin: true,
orgTemplateAdmin: true,
customOrganizationRole: true,
orgAdmin: true,
orgMember: true,
orgAuditor: true,
orgUserAdmin: true,
orgTemplateAdmin: true,
orgWorkspaceCreationBan: true,
customOrganizationRole: true,
},
orgUserAdmin: {
orgMember: true,
+22 -10
View File
@@ -112,6 +112,7 @@ func TestRolePermissions(t *testing.T) {
// Subjects to user
memberMe := authSubject{Name: "member_me", Actor: rbac.Subject{ID: currentUser.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember()}}}
orgMemberMe := authSubject{Name: "org_member_me", Actor: rbac.Subject{ID: currentUser.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgMember(orgID)}}}
orgMemberMeBanWorkspace := authSubject{Name: "org_member_me_workspace_ban", Actor: rbac.Subject{ID: currentUser.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgMember(orgID), rbac.ScopedRoleOrgWorkspaceCreationBan(orgID)}}}
groupMemberMe := authSubject{Name: "group_member_me", Actor: rbac.Subject{ID: currentUser.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.ScopedRoleOrgMember(orgID)}, Groups: []string{groupID.String()}}}
owner := authSubject{Name: "owner", Actor: rbac.Subject{ID: adminID.String(), Roles: rbac.RoleIdentifiers{rbac.RoleMember(), rbac.RoleOwner()}}}
@@ -181,20 +182,30 @@ func TestRolePermissions(t *testing.T) {
Actions: []policy.Action{policy.ActionRead},
Resource: rbac.ResourceWorkspace.WithID(workspaceID).InOrg(orgID).WithOwner(currentUser.String()),
AuthorizeMap: map[bool][]hasAuthSubjects{
true: {owner, orgMemberMe, orgAdmin, templateAdmin, orgTemplateAdmin},
true: {owner, orgMemberMe, orgAdmin, templateAdmin, orgTemplateAdmin, orgMemberMeBanWorkspace},
false: {setOtherOrg, memberMe, userAdmin, orgAuditor, orgUserAdmin},
},
},
{
Name: "C_RDMyWorkspaceInOrg",
Name: "UpdateMyWorkspaceInOrg",
// When creating the WithID won't be set, but it does not change the result.
Actions: []policy.Action{policy.ActionCreate, policy.ActionUpdate, policy.ActionDelete},
Actions: []policy.Action{policy.ActionUpdate},
Resource: rbac.ResourceWorkspace.WithID(workspaceID).InOrg(orgID).WithOwner(currentUser.String()),
AuthorizeMap: map[bool][]hasAuthSubjects{
true: {owner, orgMemberMe, orgAdmin},
false: {setOtherOrg, memberMe, userAdmin, templateAdmin, orgTemplateAdmin, orgUserAdmin, orgAuditor},
},
},
{
Name: "CreateDeleteMyWorkspaceInOrg",
// When creating the WithID won't be set, but it does not change the result.
Actions: []policy.Action{policy.ActionCreate, policy.ActionDelete},
Resource: rbac.ResourceWorkspace.WithID(workspaceID).InOrg(orgID).WithOwner(currentUser.String()),
AuthorizeMap: map[bool][]hasAuthSubjects{
true: {owner, orgMemberMe, orgAdmin},
false: {setOtherOrg, memberMe, userAdmin, templateAdmin, orgTemplateAdmin, orgUserAdmin, orgAuditor, orgMemberMeBanWorkspace},
},
},
{
Name: "MyWorkspaceInOrgExecution",
// When creating the WithID won't be set, but it does not change the result.
@@ -292,9 +303,9 @@ func TestRolePermissions(t *testing.T) {
},
},
{
Name: "CreateCustomRole",
Actions: []policy.Action{policy.ActionCreate, policy.ActionUpdate},
Resource: rbac.ResourceAssignRole,
Name: "CreateUpdateDeleteCustomRole",
Actions: []policy.Action{policy.ActionCreate, policy.ActionUpdate, policy.ActionDelete},
Resource: rbac.ResourceAssignOrgRole,
AuthorizeMap: map[bool][]hasAuthSubjects{
true: {owner},
false: {setOtherOrg, setOrgNotMe, userAdmin, orgMemberMe, memberMe, templateAdmin},
@@ -302,7 +313,7 @@ func TestRolePermissions(t *testing.T) {
},
{
Name: "RoleAssignment",
Actions: []policy.Action{policy.ActionAssign, policy.ActionDelete},
Actions: []policy.Action{policy.ActionAssign, policy.ActionUnassign},
Resource: rbac.ResourceAssignRole,
AuthorizeMap: map[bool][]hasAuthSubjects{
true: {owner, userAdmin},
@@ -320,7 +331,7 @@ func TestRolePermissions(t *testing.T) {
},
{
Name: "OrgRoleAssignment",
Actions: []policy.Action{policy.ActionAssign, policy.ActionDelete},
Actions: []policy.Action{policy.ActionAssign, policy.ActionUnassign},
Resource: rbac.ResourceAssignOrgRole.InOrg(orgID),
AuthorizeMap: map[bool][]hasAuthSubjects{
true: {owner, orgAdmin, userAdmin, orgUserAdmin},
@@ -341,8 +352,8 @@ func TestRolePermissions(t *testing.T) {
Actions: []policy.Action{policy.ActionRead},
Resource: rbac.ResourceAssignOrgRole.InOrg(orgID),
AuthorizeMap: map[bool][]hasAuthSubjects{
true: {owner, setOrgNotMe, orgMemberMe, userAdmin},
false: {setOtherOrg, memberMe, templateAdmin},
true: {owner, setOrgNotMe, orgMemberMe, userAdmin, templateAdmin},
false: {setOtherOrg, memberMe},
},
},
{
@@ -942,6 +953,7 @@ func TestListRoles(t *testing.T) {
fmt.Sprintf("organization-auditor:%s", orgID.String()),
fmt.Sprintf("organization-user-admin:%s", orgID.String()),
fmt.Sprintf("organization-template-admin:%s", orgID.String()),
fmt.Sprintf("organization-workspace-creation-ban:%s", orgID.String()),
},
orgRoleNames)
}
+15 -1
View File
@@ -24,9 +24,11 @@ import (
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/site"
"github.com/coder/coder/v2/tailnet"
@@ -534,6 +536,10 @@ func NewMultiAgentController(ctx context.Context, logger slog.Logger, tracer tra
return m
}
type Pinger interface {
Ping(context.Context) (time.Duration, error)
}
// InmemTailnetDialer is a tailnet.ControlProtocolDialer that connects to a Coordinator and DERPMap
// service running in the same memory space.
type InmemTailnetDialer struct {
@@ -541,9 +547,17 @@ type InmemTailnetDialer struct {
DERPFn func() *tailcfg.DERPMap
Logger slog.Logger
ClientID uuid.UUID
// DatabaseHealthCheck is used to validate that the store is reachable.
DatabaseHealthCheck Pinger
}
func (a *InmemTailnetDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) {
func (a *InmemTailnetDialer) Dial(ctx context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) {
if a.DatabaseHealthCheck != nil {
if _, err := a.DatabaseHealthCheck.Ping(ctx); err != nil {
return tailnet.ControlProtocolClients{}, xerrors.Errorf("%w: %v", codersdk.ErrDatabaseNotReachable, err)
}
}
coord := a.CoordPtr.Load()
if coord == nil {
return tailnet.ControlProtocolClients{}, xerrors.Errorf("tailnet coordinator not initialized")
+41
View File
@@ -11,6 +11,7 @@ import (
"strconv"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
@@ -18,6 +19,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
"github.com/coder/coder/v2/agent"
@@ -25,6 +27,7 @@ import (
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/tailnet"
@@ -365,6 +368,44 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
})
}
func TestDialFailure(t *testing.T) {
t.Parallel()
// Setup.
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)
// Given: a tailnet coordinator.
coord := tailnet.NewCoordinator(logger)
t.Cleanup(func() {
_ = coord.Close()
})
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
// Given: a fake DB healthchecker which will always fail.
fch := &failingHealthcheck{}
// When: dialing the in-memory coordinator.
dialer := &coderd.InmemTailnetDialer{
CoordPtr: &coordPtr,
Logger: logger,
ClientID: uuid.UUID{5},
DatabaseHealthCheck: fch,
}
_, err := dialer.Dial(ctx, nil)
// Then: the error returned reflects the database has failed its healthcheck.
require.ErrorIs(t, err, codersdk.ErrDatabaseNotReachable)
}
type failingHealthcheck struct{}
func (failingHealthcheck) Ping(context.Context) (time.Duration, error) {
// Simulate a database connection error.
return 0, xerrors.New("oops")
}
type wrappedListener struct {
net.Listener
dials int32
+82 -22
View File
@@ -624,6 +624,28 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
}
return nil
})
eg.Go(func() error {
memoryMonitors, err := r.options.Database.FetchMemoryResourceMonitorsUpdatedAfter(ctx, createdAfter)
if err != nil {
return xerrors.Errorf("get memory resource monitors: %w", err)
}
snapshot.WorkspaceAgentMemoryResourceMonitors = make([]WorkspaceAgentMemoryResourceMonitor, 0, len(memoryMonitors))
for _, monitor := range memoryMonitors {
snapshot.WorkspaceAgentMemoryResourceMonitors = append(snapshot.WorkspaceAgentMemoryResourceMonitors, ConvertWorkspaceAgentMemoryResourceMonitor(monitor))
}
return nil
})
eg.Go(func() error {
volumeMonitors, err := r.options.Database.FetchVolumesResourceMonitorsUpdatedAfter(ctx, createdAfter)
if err != nil {
return xerrors.Errorf("get volume resource monitors: %w", err)
}
snapshot.WorkspaceAgentVolumeResourceMonitors = make([]WorkspaceAgentVolumeResourceMonitor, 0, len(volumeMonitors))
for _, monitor := range volumeMonitors {
snapshot.WorkspaceAgentVolumeResourceMonitors = append(snapshot.WorkspaceAgentVolumeResourceMonitors, ConvertWorkspaceAgentVolumeResourceMonitor(monitor))
}
return nil
})
eg.Go(func() error {
proxies, err := r.options.Database.GetWorkspaceProxies(ctx)
if err != nil {
@@ -765,6 +787,26 @@ func ConvertWorkspaceAgent(agent database.WorkspaceAgent) WorkspaceAgent {
return snapAgent
}
func ConvertWorkspaceAgentMemoryResourceMonitor(monitor database.WorkspaceAgentMemoryResourceMonitor) WorkspaceAgentMemoryResourceMonitor {
return WorkspaceAgentMemoryResourceMonitor{
AgentID: monitor.AgentID,
Enabled: monitor.Enabled,
Threshold: monitor.Threshold,
CreatedAt: monitor.CreatedAt,
UpdatedAt: monitor.UpdatedAt,
}
}
func ConvertWorkspaceAgentVolumeResourceMonitor(monitor database.WorkspaceAgentVolumeResourceMonitor) WorkspaceAgentVolumeResourceMonitor {
return WorkspaceAgentVolumeResourceMonitor{
AgentID: monitor.AgentID,
Enabled: monitor.Enabled,
Threshold: monitor.Threshold,
CreatedAt: monitor.CreatedAt,
UpdatedAt: monitor.UpdatedAt,
}
}
// ConvertWorkspaceAgentStat anonymizes a workspace agent stat.
func ConvertWorkspaceAgentStat(stat database.GetWorkspaceAgentStatsRow) WorkspaceAgentStat {
return WorkspaceAgentStat{
@@ -1083,28 +1125,30 @@ func ConvertTelemetryItem(item database.TelemetryItem) TelemetryItem {
type Snapshot struct {
DeploymentID string `json:"deployment_id"`
APIKeys []APIKey `json:"api_keys"`
CLIInvocations []clitelemetry.Invocation `json:"cli_invocations"`
ExternalProvisioners []ExternalProvisioner `json:"external_provisioners"`
Licenses []License `json:"licenses"`
ProvisionerJobs []ProvisionerJob `json:"provisioner_jobs"`
TemplateVersions []TemplateVersion `json:"template_versions"`
Templates []Template `json:"templates"`
Users []User `json:"users"`
Groups []Group `json:"groups"`
GroupMembers []GroupMember `json:"group_members"`
WorkspaceAgentStats []WorkspaceAgentStat `json:"workspace_agent_stats"`
WorkspaceAgents []WorkspaceAgent `json:"workspace_agents"`
WorkspaceApps []WorkspaceApp `json:"workspace_apps"`
WorkspaceBuilds []WorkspaceBuild `json:"workspace_build"`
WorkspaceProxies []WorkspaceProxy `json:"workspace_proxies"`
WorkspaceResourceMetadata []WorkspaceResourceMetadata `json:"workspace_resource_metadata"`
WorkspaceResources []WorkspaceResource `json:"workspace_resources"`
WorkspaceModules []WorkspaceModule `json:"workspace_modules"`
Workspaces []Workspace `json:"workspaces"`
NetworkEvents []NetworkEvent `json:"network_events"`
Organizations []Organization `json:"organizations"`
TelemetryItems []TelemetryItem `json:"telemetry_items"`
APIKeys []APIKey `json:"api_keys"`
CLIInvocations []clitelemetry.Invocation `json:"cli_invocations"`
ExternalProvisioners []ExternalProvisioner `json:"external_provisioners"`
Licenses []License `json:"licenses"`
ProvisionerJobs []ProvisionerJob `json:"provisioner_jobs"`
TemplateVersions []TemplateVersion `json:"template_versions"`
Templates []Template `json:"templates"`
Users []User `json:"users"`
Groups []Group `json:"groups"`
GroupMembers []GroupMember `json:"group_members"`
WorkspaceAgentStats []WorkspaceAgentStat `json:"workspace_agent_stats"`
WorkspaceAgents []WorkspaceAgent `json:"workspace_agents"`
WorkspaceApps []WorkspaceApp `json:"workspace_apps"`
WorkspaceBuilds []WorkspaceBuild `json:"workspace_build"`
WorkspaceProxies []WorkspaceProxy `json:"workspace_proxies"`
WorkspaceResourceMetadata []WorkspaceResourceMetadata `json:"workspace_resource_metadata"`
WorkspaceResources []WorkspaceResource `json:"workspace_resources"`
WorkspaceAgentMemoryResourceMonitors []WorkspaceAgentMemoryResourceMonitor `json:"workspace_agent_memory_resource_monitors"`
WorkspaceAgentVolumeResourceMonitors []WorkspaceAgentVolumeResourceMonitor `json:"workspace_agent_volume_resource_monitors"`
WorkspaceModules []WorkspaceModule `json:"workspace_modules"`
Workspaces []Workspace `json:"workspaces"`
NetworkEvents []NetworkEvent `json:"network_events"`
Organizations []Organization `json:"organizations"`
TelemetryItems []TelemetryItem `json:"telemetry_items"`
}
// Deployment contains information about the host running Coder.
@@ -1232,6 +1276,22 @@ type WorkspaceAgentStat struct {
SessionCountSSH int64 `json:"session_count_ssh"`
}
type WorkspaceAgentMemoryResourceMonitor struct {
AgentID uuid.UUID `json:"agent_id"`
Enabled bool `json:"enabled"`
Threshold int32 `json:"threshold"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type WorkspaceAgentVolumeResourceMonitor struct {
AgentID uuid.UUID `json:"agent_id"`
Enabled bool `json:"enabled"`
Threshold int32 `json:"threshold"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type WorkspaceApp struct {
ID uuid.UUID `json:"id"`
CreatedAt time.Time `json:"created_at"`
+4
View File
@@ -112,6 +112,8 @@ func TestTelemetry(t *testing.T) {
_, _ = dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{})
_ = dbgen.WorkspaceModule(t, db, database.WorkspaceModule{})
_ = dbgen.WorkspaceAgentMemoryResourceMonitor(t, db, database.WorkspaceAgentMemoryResourceMonitor{})
_ = dbgen.WorkspaceAgentVolumeResourceMonitor(t, db, database.WorkspaceAgentVolumeResourceMonitor{})
_, snapshot := collectSnapshot(t, db, nil)
require.Len(t, snapshot.ProvisionerJobs, 1)
@@ -133,6 +135,8 @@ func TestTelemetry(t *testing.T) {
require.Len(t, snapshot.Organizations, 1)
// We create one item manually above. The other is TelemetryEnabled, created by the snapshotter.
require.Len(t, snapshot.TelemetryItems, 2)
require.Len(t, snapshot.WorkspaceAgentMemoryResourceMonitors, 1)
require.Len(t, snapshot.WorkspaceAgentVolumeResourceMonitors, 1)
wsa := snapshot.WorkspaceAgents[0]
require.Len(t, wsa.Subsystems, 2)
require.Equal(t, string(database.WorkspaceAgentSubsystemEnvbox), wsa.Subsystems[0])
+22 -2
View File
@@ -922,7 +922,17 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
}
}
if len(selectedMemberships) == 0 {
httpmw.CustomRedirectToLogin(rw, r, redirect, "You aren't a member of the authorized Github organizations!", http.StatusUnauthorized)
status := http.StatusUnauthorized
msg := "You aren't a member of the authorized Github organizations!"
if api.GithubOAuth2Config.DeviceFlowEnabled {
// In the device flow, the error is rendered client-side.
httpapi.Write(ctx, rw, status, codersdk.Response{
Message: "Unauthorized",
Detail: msg,
})
} else {
httpmw.CustomRedirectToLogin(rw, r, redirect, msg, status)
}
return
}
}
@@ -959,7 +969,17 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
}
}
if allowedTeam == nil {
httpmw.CustomRedirectToLogin(rw, r, redirect, fmt.Sprintf("You aren't a member of an authorized team in the %v Github organization(s)!", organizationNames), http.StatusUnauthorized)
msg := fmt.Sprintf("You aren't a member of an authorized team in the %v Github organization(s)!", organizationNames)
status := http.StatusUnauthorized
if api.GithubOAuth2Config.DeviceFlowEnabled {
// In the device flow, the error is rendered client-side.
httpapi.Write(ctx, rw, status, codersdk.Response{
Message: "Unauthorized",
Detail: msg,
})
} else {
httpmw.CustomRedirectToLogin(rw, r, redirect, msg, status)
}
return
}
}
+1 -1
View File
@@ -1274,7 +1274,7 @@ func (api *API) organizationsByUser(rw http.ResponseWriter, r *http.Request) {
organizations, err := api.Database.GetOrganizationsByUserID(ctx, database.GetOrganizationsByUserIDParams{
UserID: user.ID,
Deleted: false,
Deleted: sql.NullBool{Bool: false, Valid: true},
})
if errors.Is(err, sql.ErrNoRows) {
err = nil
+20
View File
@@ -31,6 +31,7 @@ import (
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
@@ -463,6 +464,9 @@ func (api *API) workspaceAgentLogs(rw http.ResponseWriter, r *http.Request) {
t := time.NewTicker(recheckInterval)
defer t.Stop()
// Log the request immediately instead of after it completes.
loggermw.RequestLoggerFromContext(ctx).WriteLog(ctx, http.StatusAccepted)
go func() {
defer func() {
logger.Debug(ctx, "end log streaming loop")
@@ -836,6 +840,9 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) {
encoder := wsjson.NewEncoder[*tailcfg.DERPMap](ws, websocket.MessageBinary)
defer encoder.Close(websocket.StatusGoingAway)
// Log the request immediately instead of after it completes.
loggermw.RequestLoggerFromContext(ctx).WriteLog(ctx, http.StatusAccepted)
go func(ctx context.Context) {
// TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout?
t := time.NewTicker(api.AgentConnectionUpdateFrequency)
@@ -897,6 +904,16 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) {
func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Ensure the database is reachable before proceeding.
_, err := api.Database.Ping(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: codersdk.DatabaseNotReachable,
Detail: err.Error(),
})
return
}
// This route accepts user API key auth and workspace proxy auth. The moon actor has
// full permissions so should be able to pass this authz check.
workspace := httpmw.WorkspaceParam(r)
@@ -1200,6 +1217,9 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
sendTicker := time.NewTicker(sendInterval)
defer sendTicker.Stop()
// Log the request immediately instead of after it completes.
loggermw.RequestLoggerFromContext(ctx).WriteLog(ctx, http.StatusAccepted)
// Send initial metadata.
sendMetadata()
+48
View File
@@ -375,6 +375,54 @@ func TestWorkspace(t *testing.T) {
require.Error(t, err, "create workspace with archived version")
require.ErrorContains(t, err, "Archived template versions cannot")
})
t.Run("WorkspaceBan", func(t *testing.T) {
t.Parallel()
owner, _, _ := coderdtest.NewWithAPI(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
first := coderdtest.CreateFirstUser(t, owner)
version := coderdtest.CreateTemplateVersion(t, owner, first.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJobCompleted(t, owner, version.ID)
template := coderdtest.CreateTemplate(t, owner, first.OrganizationID, version.ID)
goodClient, _ := coderdtest.CreateAnotherUser(t, owner, first.OrganizationID)
// When a user with workspace-creation-ban
client, user := coderdtest.CreateAnotherUser(t, owner, first.OrganizationID, rbac.ScopedRoleOrgWorkspaceCreationBan(first.OrganizationID))
// Ensure a similar user can create a workspace
coderdtest.CreateWorkspace(t, goodClient, template.ID)
ctx := testutil.Context(t, testutil.WaitLong)
// Then: Cannot create a workspace
_, err := client.CreateUserWorkspace(ctx, codersdk.Me, codersdk.CreateWorkspaceRequest{
TemplateID: template.ID,
TemplateVersionID: uuid.UUID{},
Name: "random",
})
require.Error(t, err)
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
// When: workspace-ban use has a workspace
wrk, err := owner.CreateUserWorkspace(ctx, user.ID.String(), codersdk.CreateWorkspaceRequest{
TemplateID: template.ID,
TemplateVersionID: uuid.UUID{},
Name: "random",
})
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, wrk.LatestBuild.ID)
// Then: They cannot delete said workspace
_, err = client.CreateWorkspaceBuild(ctx, wrk.ID, codersdk.CreateWorkspaceBuildRequest{
Transition: codersdk.WorkspaceTransitionDelete,
ProvisionerState: []byte{},
})
require.Error(t, err)
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
})
}
func TestResolveAutostart(t *testing.T) {
+9
View File
@@ -790,6 +790,15 @@ func (b *Builder) authorize(authFunc func(action policy.Action, object rbac.Obje
return BuildError{http.StatusBadRequest, msg, xerrors.New(msg)}
}
if !authFunc(action, b.workspace) {
if authFunc(policy.ActionRead, b.workspace) {
// If the user can read the workspace, but not delete/create/update. Show
// a more helpful error. They are allowed to know the workspace exists.
return BuildError{
Status: http.StatusForbidden,
Message: fmt.Sprintf("You do not have permission to %s this workspace.", action),
Wrapped: xerrors.New(httpapi.ResourceForbiddenResponse.Detail),
}
}
// We use the same wording as the httpapi to avoid leaking the existence of the workspace
return BuildError{http.StatusNotFound, httpapi.ResourceNotFoundResponse.Message, xerrors.New(httpapi.ResourceNotFoundResponse.Message)}
}
+7
View File
@@ -0,0 +1,7 @@
package codersdk
import "golang.org/x/xerrors"
const DatabaseNotReachable = "database not reachable"
var ErrDatabaseNotReachable = xerrors.New(DatabaseNotReachable)
+3 -2
View File
@@ -49,6 +49,7 @@ const (
ActionRead RBACAction = "read"
ActionReadPersonal RBACAction = "read_personal"
ActionSSH RBACAction = "ssh"
ActionUnassign RBACAction = "unassign"
ActionUpdate RBACAction = "update"
ActionUpdatePersonal RBACAction = "update_personal"
ActionUse RBACAction = "use"
@@ -62,8 +63,8 @@ const (
var RBACResourceActions = map[RBACResource][]RBACAction{
ResourceWildcard: {},
ResourceApiKey: {ActionCreate, ActionDelete, ActionRead, ActionUpdate},
ResourceAssignOrgRole: {ActionAssign, ActionCreate, ActionDelete, ActionRead, ActionUpdate},
ResourceAssignRole: {ActionAssign, ActionCreate, ActionDelete, ActionRead, ActionUpdate},
ResourceAssignOrgRole: {ActionAssign, ActionCreate, ActionDelete, ActionRead, ActionUnassign, ActionUpdate},
ResourceAssignRole: {ActionAssign, ActionRead, ActionUnassign},
ResourceAuditLog: {ActionCreate, ActionRead},
ResourceCryptoKey: {ActionCreate, ActionDelete, ActionRead, ActionUpdate},
ResourceDebugInfo: {ActionRead},
+6 -5
View File
@@ -8,9 +8,10 @@ const (
RoleUserAdmin string = "user-admin"
RoleAuditor string = "auditor"
RoleOrganizationAdmin string = "organization-admin"
RoleOrganizationMember string = "organization-member"
RoleOrganizationAuditor string = "organization-auditor"
RoleOrganizationTemplateAdmin string = "organization-template-admin"
RoleOrganizationUserAdmin string = "organization-user-admin"
RoleOrganizationAdmin string = "organization-admin"
RoleOrganizationMember string = "organization-member"
RoleOrganizationAuditor string = "organization-auditor"
RoleOrganizationTemplateAdmin string = "organization-template-admin"
RoleOrganizationUserAdmin string = "organization-user-admin"
RoleOrganizationWorkspaceCreationBan string = "organization-workspace-creation-ban"
)
+15 -3
View File
@@ -143,6 +143,12 @@ func (c *AgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, w
// SSH pipes the SSH protocol over the returned net.Conn.
// This connects to the built-in SSH server in the workspace agent.
func (c *AgentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) {
return c.SSHOnPort(ctx, AgentSSHPort)
}
// SSHOnPort pipes the SSH protocol over the returned net.Conn.
// This connects to the built-in SSH server in the workspace agent on the specified port.
func (c *AgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
@@ -150,17 +156,23 @@ func (c *AgentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) {
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
}
c.Conn.SendConnectedTelemetry(c.agentAddress(), tailnet.TelemetryApplicationSSH)
return c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), AgentSSHPort))
c.SendConnectedTelemetry(c.agentAddress(), tailnet.TelemetryApplicationSSH)
return c.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), port))
}
// SSHClient calls SSH to create a client that uses a weak cipher
// to improve throughput.
func (c *AgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) {
return c.SSHClientOnPort(ctx, AgentSSHPort)
}
// SSHClientOnPort calls SSH to create a client on a specific port
// that uses a weak cipher to improve throughput.
func (c *AgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
netConn, err := c.SSH(ctx)
netConn, err := c.SSHOnPort(ctx, port)
if err != nil {
return nil, xerrors.Errorf("ssh: %w", err)
}
+11 -4
View File
@@ -11,17 +11,19 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/websocket"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/websocket"
)
var permanentErrorStatuses = []int{
http.StatusConflict, // returned if client/agent connections disabled (browser only)
http.StatusBadRequest, // returned if API mismatch
http.StatusNotFound, // returned if user doesn't have permission or agent doesn't exist
http.StatusConflict, // returned if client/agent connections disabled (browser only)
http.StatusBadRequest, // returned if API mismatch
http.StatusNotFound, // returned if user doesn't have permission or agent doesn't exist
http.StatusInternalServerError, // returned if database is not reachable,
}
type WebsocketDialer struct {
@@ -89,6 +91,11 @@ func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenControl
"Ensure your client release version (%s, different than the API version) matches the server release version",
buildinfo.Version())
}
if sdkErr.Message == codersdk.DatabaseNotReachable &&
sdkErr.StatusCode() == http.StatusInternalServerError {
err = xerrors.Errorf("%w: %v", codersdk.ErrDatabaseNotReachable, err)
}
}
w.connected <- err
return tailnet.ControlProtocolClients{}, err
+1
View File
@@ -29,6 +29,7 @@ var ErrSkipClose = xerrors.New("skip tailnet close")
const (
AgentSSHPort = tailnet.WorkspaceAgentSSHPort
AgentStandardSSHPort = tailnet.WorkspaceAgentStandardSSHPort
AgentReconnectingPTYPort = tailnet.WorkspaceAgentReconnectingPTYPort
AgentSpeedtestPort = tailnet.WorkspaceAgentSpeedtestPort
// AgentHTTPAPIServerPort serves a HTTP server with endpoints for e.g.
@@ -1,13 +1,21 @@
package workspacesdk_test
import (
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
"github.com/coder/websocket"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/testutil"
)
func TestWorkspaceRewriteDERPMap(t *testing.T) {
@@ -37,3 +45,30 @@ func TestWorkspaceRewriteDERPMap(t *testing.T) {
require.Equal(t, "coconuts.org", node.HostName)
require.Equal(t, 44558, node.DERPPort)
}
func TestWorkspaceDialerFailure(t *testing.T) {
t.Parallel()
// Setup.
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)
// Given: a mock HTTP server which mimicks an unreachable database when calling the coordination endpoint.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, w, http.StatusInternalServerError, codersdk.Response{
Message: codersdk.DatabaseNotReachable,
Detail: "oops",
})
}))
t.Cleanup(srv.Close)
u, err := url.Parse(srv.URL)
require.NoError(t, err)
// When: calling the coordination endpoint.
dialer := workspacesdk.NewWebsocketDialer(logger, u, &websocket.DialOptions{})
_, err = dialer.Dial(ctx, nil)
// Then: an error indicating a database issue is returned, to conditionalize the behavior of the caller.
require.ErrorIs(t, err, codersdk.ErrDatabaseNotReachable)
}
+36
View File
@@ -1,5 +1,28 @@
# GitHub
## Default Configuration
By default, new Coder deployments use a Coder-managed GitHub app to authenticate
users. We provide it for convenience, allowing you to experiment with Coder
without setting up your own GitHub OAuth app. Once you authenticate with it, you
grant Coder server read access to:
- Your GitHub user email
- Your GitHub organization membership
- Other metadata listed during the authentication flow
This access is necessary for the Coder server to complete the authentication
process. To the best of our knowledge, Coder, the company, does not gain access
to this data by administering the GitHub app.
For production deployments, we recommend configuring your own GitHub OAuth app
as outlined below. The default is automatically disabled if you configure your
own app or set:
```env
CODER_OAUTH2_GITHUB_DEFAULT_PROVIDER_ENABLE=false
```
## Step 1: Configure the OAuth application in GitHub
First,
@@ -82,3 +105,16 @@ helm upgrade <release-name> coder-v2/coder -n <namespace> -f values.yaml
> We recommend requiring and auditing MFA usage for all users in your GitHub
> organizations. This can be enforced from the organization settings page in the
> "Authentication security" sidebar tab.
## Device Flow
Coder supports
[device flow](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps#device-flow)
for GitHub OAuth. To enable it, set:
```env
CODER_OAUTH2_GITHUB_DEVICE_FLOW=true
```
This is optional. We recommend using the standard OAuth flow instead, as it is
more convenient for end users.
+4
View File
@@ -101,6 +101,10 @@ coder:
# postgres://coder:password@postgres:5432/coder?sslmode=disable
name: coder-db-url
key: url
# For production deployments, we recommend configuring your own GitHub
# OAuth2 provider and disabling the default one.
- name: CODER_OAUTH2_GITHUB_DEFAULT_PROVIDER_ENABLE
value: "false"
# (Optional) For production deployments the access URL should be set.
# If you're just trying Coder, access the dashboard via the service IP.
+5
View File
@@ -173,6 +173,7 @@ Status Code **200**
| `action` | `read` |
| `action` | `read_personal` |
| `action` | `ssh` |
| `action` | `unassign` |
| `action` | `update` |
| `action` | `update_personal` |
| `action` | `use` |
@@ -335,6 +336,7 @@ Status Code **200**
| `action` | `read` |
| `action` | `read_personal` |
| `action` | `ssh` |
| `action` | `unassign` |
| `action` | `update` |
| `action` | `update_personal` |
| `action` | `use` |
@@ -497,6 +499,7 @@ Status Code **200**
| `action` | `read` |
| `action` | `read_personal` |
| `action` | `ssh` |
| `action` | `unassign` |
| `action` | `update` |
| `action` | `update_personal` |
| `action` | `use` |
@@ -628,6 +631,7 @@ Status Code **200**
| `action` | `read` |
| `action` | `read_personal` |
| `action` | `ssh` |
| `action` | `unassign` |
| `action` | `update` |
| `action` | `update_personal` |
| `action` | `use` |
@@ -891,6 +895,7 @@ Status Code **200**
| `action` | `read` |
| `action` | `read_personal` |
| `action` | `ssh` |
| `action` | `unassign` |
| `action` | `update` |
| `action` | `update_personal` |
| `action` | `use` |
+1
View File
@@ -5104,6 +5104,7 @@ Git clone makes use of this by parsing the URL from: 'Username for "https://gith
| `read` |
| `read_personal` |
| `ssh` |
| `unassign` |
| `update` |
| `update_personal` |
| `use` |
+1
View File
@@ -8,6 +8,7 @@ import (
tea "github.com/charmbracelet/bubbletea"
"github.com/coder/coder/v2/agent/agentexec"
_ "github.com/coder/coder/v2/buildinfo/resources"
entcli "github.com/coder/coder/v2/enterprise/cli"
)
@@ -14,6 +14,7 @@ import (
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/entitlements"
@@ -89,7 +90,8 @@ func TestOrganizationSync(t *testing.T) {
Name: "SingleOrgDeployment",
Case: func(t *testing.T, db database.Store) OrganizationSyncTestCase {
def, _ := db.GetDefaultOrganization(context.Background())
other := dbgen.Organization(t, db, database.Organization{})
other := dbfake.Organization(t, db).Do()
deleted := dbfake.Organization(t, db).Deleted(true).Do()
return OrganizationSyncTestCase{
Entitlements: entitled,
Settings: idpsync.DeploymentSyncSettings{
@@ -123,11 +125,19 @@ func TestOrganizationSync(t *testing.T) {
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: other.ID,
OrganizationID: other.Org.ID,
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: deleted.Org.ID,
})
},
Sync: ExpectedUser{
Organizations: []uuid.UUID{def.ID, other.ID},
Organizations: []uuid.UUID{
def.ID, other.Org.ID,
// The user remains in the deleted org because no idp sync happens.
deleted.Org.ID,
},
},
},
},
@@ -138,17 +148,19 @@ func TestOrganizationSync(t *testing.T) {
Name: "MultiOrgWithDefault",
Case: func(t *testing.T, db database.Store) OrganizationSyncTestCase {
def, _ := db.GetDefaultOrganization(context.Background())
one := dbgen.Organization(t, db, database.Organization{})
two := dbgen.Organization(t, db, database.Organization{})
three := dbgen.Organization(t, db, database.Organization{})
one := dbfake.Organization(t, db).Do()
two := dbfake.Organization(t, db).Do()
three := dbfake.Organization(t, db).Do()
deleted := dbfake.Organization(t, db).Deleted(true).Do()
return OrganizationSyncTestCase{
Entitlements: entitled,
Settings: idpsync.DeploymentSyncSettings{
OrganizationField: "organizations",
OrganizationMapping: map[string][]uuid.UUID{
"first": {one.ID},
"second": {two.ID},
"third": {three.ID},
"first": {one.Org.ID},
"second": {two.Org.ID},
"third": {three.Org.ID},
"deleted": {deleted.Org.ID},
},
OrganizationAssignDefault: true,
},
@@ -167,7 +179,7 @@ func TestOrganizationSync(t *testing.T) {
{
Name: "AlreadyInOrgs",
Claims: jwt.MapClaims{
"organizations": []string{"second", "extra"},
"organizations": []string{"second", "extra", "deleted"},
},
ExpectedParams: idpsync.OrganizationParams{
SyncEntitled: true,
@@ -180,18 +192,18 @@ func TestOrganizationSync(t *testing.T) {
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: one.ID,
OrganizationID: one.Org.ID,
})
},
Sync: ExpectedUser{
Organizations: []uuid.UUID{def.ID, two.ID},
Organizations: []uuid.UUID{def.ID, two.Org.ID},
},
},
{
Name: "ManyClaims",
Claims: jwt.MapClaims{
// Add some repeats
"organizations": []string{"second", "extra", "first", "third", "second", "second"},
"organizations": []string{"second", "extra", "first", "third", "second", "second", "deleted"},
},
ExpectedParams: idpsync.OrganizationParams{
SyncEntitled: true,
@@ -204,11 +216,11 @@ func TestOrganizationSync(t *testing.T) {
})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: one.ID,
OrganizationID: one.Org.ID,
})
},
Sync: ExpectedUser{
Organizations: []uuid.UUID{def.ID, one.ID, two.ID, three.ID},
Organizations: []uuid.UUID{def.ID, one.Org.ID, two.Org.ID, three.Org.ID},
},
},
},
+5
View File
@@ -24,6 +24,7 @@ import (
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
"github.com/coder/coder/v2/coderd/provisionerdserver"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
@@ -381,6 +382,10 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
logger.Debug(ctx, "drpc server error", slog.Error(err))
},
})
// Log the request immediately instead of after it completes.
loggermw.RequestLoggerFromContext(ctx).WriteLog(ctx, http.StatusAccepted)
err = server.Serve(ctx, session)
srvCancel()
logger.Info(ctx, "provisioner daemon disconnected", slog.Error(err))
+1 -2
View File
@@ -127,8 +127,7 @@ func (api *API) putOrgRoles(rw http.ResponseWriter, r *http.Request) {
},
},
ExcludeOrgRoles: false,
// Linter requires all fields to be set. This field is not actually required.
OrganizationID: organization.ID,
OrganizationID: organization.ID,
})
// If it is a 404 (not found) error, ignore it.
if err != nil && !httpapi.Is404Error(err) {
+15 -12
View File
@@ -441,10 +441,11 @@ func TestListRoles(t *testing.T) {
return member.ListOrganizationRoles(ctx, owner.OrganizationID)
},
ExpectedRoles: convertRoles(map[rbac.RoleIdentifier]bool{
{Name: codersdk.RoleOrganizationAdmin, OrganizationID: owner.OrganizationID}: false,
{Name: codersdk.RoleOrganizationAuditor, OrganizationID: owner.OrganizationID}: false,
{Name: codersdk.RoleOrganizationTemplateAdmin, OrganizationID: owner.OrganizationID}: false,
{Name: codersdk.RoleOrganizationUserAdmin, OrganizationID: owner.OrganizationID}: false,
{Name: codersdk.RoleOrganizationAdmin, OrganizationID: owner.OrganizationID}: false,
{Name: codersdk.RoleOrganizationAuditor, OrganizationID: owner.OrganizationID}: false,
{Name: codersdk.RoleOrganizationTemplateAdmin, OrganizationID: owner.OrganizationID}: false,
{Name: codersdk.RoleOrganizationUserAdmin, OrganizationID: owner.OrganizationID}: false,
{Name: codersdk.RoleOrganizationWorkspaceCreationBan, OrganizationID: owner.OrganizationID}: false,
}),
},
{
@@ -473,10 +474,11 @@ func TestListRoles(t *testing.T) {
return orgAdmin.ListOrganizationRoles(ctx, owner.OrganizationID)
},
ExpectedRoles: convertRoles(map[rbac.RoleIdentifier]bool{
{Name: codersdk.RoleOrganizationAdmin, OrganizationID: owner.OrganizationID}: true,
{Name: codersdk.RoleOrganizationAuditor, OrganizationID: owner.OrganizationID}: true,
{Name: codersdk.RoleOrganizationTemplateAdmin, OrganizationID: owner.OrganizationID}: true,
{Name: codersdk.RoleOrganizationUserAdmin, OrganizationID: owner.OrganizationID}: true,
{Name: codersdk.RoleOrganizationAdmin, OrganizationID: owner.OrganizationID}: true,
{Name: codersdk.RoleOrganizationAuditor, OrganizationID: owner.OrganizationID}: true,
{Name: codersdk.RoleOrganizationTemplateAdmin, OrganizationID: owner.OrganizationID}: true,
{Name: codersdk.RoleOrganizationUserAdmin, OrganizationID: owner.OrganizationID}: true,
{Name: codersdk.RoleOrganizationWorkspaceCreationBan, OrganizationID: owner.OrganizationID}: true,
}),
},
{
@@ -505,10 +507,11 @@ func TestListRoles(t *testing.T) {
return client.ListOrganizationRoles(ctx, owner.OrganizationID)
},
ExpectedRoles: convertRoles(map[rbac.RoleIdentifier]bool{
{Name: codersdk.RoleOrganizationAdmin, OrganizationID: owner.OrganizationID}: true,
{Name: codersdk.RoleOrganizationAuditor, OrganizationID: owner.OrganizationID}: true,
{Name: codersdk.RoleOrganizationTemplateAdmin, OrganizationID: owner.OrganizationID}: true,
{Name: codersdk.RoleOrganizationUserAdmin, OrganizationID: owner.OrganizationID}: true,
{Name: codersdk.RoleOrganizationAdmin, OrganizationID: owner.OrganizationID}: true,
{Name: codersdk.RoleOrganizationAuditor, OrganizationID: owner.OrganizationID}: true,
{Name: codersdk.RoleOrganizationTemplateAdmin, OrganizationID: owner.OrganizationID}: true,
{Name: codersdk.RoleOrganizationUserAdmin, OrganizationID: owner.OrganizationID}: true,
{Name: codersdk.RoleOrganizationWorkspaceCreationBan, OrganizationID: owner.OrganizationID}: true,
}),
},
}
+2 -1
View File
@@ -32,6 +32,7 @@ import (
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/codersdk"
@@ -336,7 +337,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
tracing.Middleware(s.TracerProvider),
httpmw.AttachRequestID,
httpmw.ExtractRealIP(s.Options.RealIPConfig),
httpmw.Logger(s.Logger),
loggermw.Logger(s.Logger),
prometheusMW,
corsMW,
+1 -1
View File
@@ -36,7 +36,7 @@ replace github.com/tcnksm/go-httpstat => github.com/coder/go-httpstat v0.0.0-202
// There are a few minor changes we make to Tailscale that we're slowly upstreaming. Compare here:
// https://github.com/tailscale/tailscale/compare/main...coder:tailscale:main
replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20250129014916-8086c871eae6
replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20250227024825-c9983534152a
// This is replaced to include
// 1. a fix for a data race: c.f. https://github.com/tailscale/wireguard-go/pull/25
+2 -2
View File
@@ -236,8 +236,8 @@ github.com/coder/serpent v0.10.0 h1:ofVk9FJXSek+SmL3yVE3GoArP83M+1tX+H7S4t8BSuM=
github.com/coder/serpent v0.10.0/go.mod h1:cZFW6/fP+kE9nd/oRkEHJpG6sXCtQ+AX7WMMEHv0Y3Q=
github.com/coder/ssh v0.0.0-20231128192721-70855dedb788 h1:YoUSJ19E8AtuUFVYBpXuOD6a/zVP3rcxezNsoDseTUw=
github.com/coder/ssh v0.0.0-20231128192721-70855dedb788/go.mod h1:aGQbuCLyhRLMzZF067xc84Lh7JDs1FKwCmF1Crl9dxQ=
github.com/coder/tailscale v1.1.1-0.20250129014916-8086c871eae6 h1:prDIwUcsSEKbs1Rc5FfdvtSfz2XGpW3FnJtWR+Mc7MY=
github.com/coder/tailscale v1.1.1-0.20250129014916-8086c871eae6/go.mod h1:1ggFFdHTRjPRu9Yc1yA7nVHBYB50w9Ce7VIXNqcW6Ko=
github.com/coder/tailscale v1.1.1-0.20250227024825-c9983534152a h1:18TQ03KlYrkW8hOohTQaDnlmkY1H9pDPGbZwOnUUmm8=
github.com/coder/tailscale v1.1.1-0.20250227024825-c9983534152a/go.mod h1:1ggFFdHTRjPRu9Yc1yA7nVHBYB50w9Ce7VIXNqcW6Ko=
github.com/coder/terraform-config-inspect v0.0.0-20250107175719-6d06d90c630e h1:JNLPDi2P73laR1oAclY6jWzAbucf70ASAvf5mh2cME0=
github.com/coder/terraform-config-inspect v0.0.0-20250107175719-6d06d90c630e/go.mod h1:Gz/z9Hbn+4KSp8A2FBtNszfLSdT2Tn/uAKGuVqqWmDI=
github.com/coder/terraform-provider-coder/v2 v2.1.3 h1:zB7ObGsiOGBHcJUUMmcSauEPlTWRIYmMYieF05LxHSc=
+4
View File
@@ -47,6 +47,10 @@ coder:
# This env enables the Prometheus metrics endpoint.
- name: CODER_PROMETHEUS_ADDRESS
value: "0.0.0.0:2112"
# For production deployments, we recommend configuring your own GitHub
# OAuth2 provider and disabling the default one.
- name: CODER_OAUTH2_GITHUB_DEFAULT_PROVIDER_ENABLE
value: "false"
tls:
secretNames:
- my-tls-secret-name
+21 -9
View File
@@ -20,12 +20,13 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/retry"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisionerd/proto"
"github.com/coder/coder/v2/provisionerd/runner"
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/retry"
)
// Dialer represents the function to create a daemon client connection.
@@ -290,7 +291,7 @@ func (p *Server) acquireLoop() {
defer p.wg.Done()
defer func() { close(p.acquireDoneCh) }()
ctx := p.closeContext
for {
for retrier := retry.New(10*time.Millisecond, 1*time.Second); retrier.Wait(ctx); {
if p.acquireExit() {
return
}
@@ -299,7 +300,17 @@ func (p *Server) acquireLoop() {
p.opts.Logger.Debug(ctx, "shut down before client (re) connected")
return
}
p.acquireAndRunOne(client)
err := p.acquireAndRunOne(client)
if err != nil && ctx.Err() == nil { // Only log if context is not done.
// Short-circuit: don't wait for the retry delay to exit, if required.
if p.acquireExit() {
return
}
p.opts.Logger.Warn(ctx, "failed to acquire job, retrying", slog.F("delay", fmt.Sprintf("%vms", retrier.Delay.Milliseconds())), slog.Error(err))
} else {
// Reset the retrier after each successful acquisition.
retrier.Reset()
}
}
}
@@ -318,7 +329,7 @@ func (p *Server) acquireExit() bool {
return false
}
func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) {
func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) error {
ctx := p.closeContext
p.opts.Logger.Debug(ctx, "start of acquireAndRunOne")
job, err := p.acquireGraceful(client)
@@ -327,15 +338,15 @@ func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) {
if errors.Is(err, context.Canceled) ||
errors.Is(err, yamux.ErrSessionShutdown) ||
errors.Is(err, fasthttputil.ErrInmemoryListenerClosed) {
return
return err
}
p.opts.Logger.Warn(ctx, "provisionerd was unable to acquire job", slog.Error(err))
return
return xerrors.Errorf("failed to acquire job: %w", err)
}
if job.JobId == "" {
p.opts.Logger.Debug(ctx, "acquire job successfully canceled")
return
return nil
}
if len(job.TraceMetadata) > 0 {
@@ -390,9 +401,9 @@ func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) {
Error: fmt.Sprintf("failed to connect to provisioner: %s", resp.Error),
})
if err != nil {
p.opts.Logger.Error(ctx, "provisioner job failed", slog.F("job_id", job.JobId), slog.Error(err))
p.opts.Logger.Error(ctx, "failed to report provisioner job failed", slog.F("job_id", job.JobId), slog.Error(err))
}
return
return xerrors.Errorf("failed to report provisioner job failed: %w", err)
}
p.mutex.Lock()
@@ -416,6 +427,7 @@ func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) {
p.mutex.Lock()
p.activeJob = nil
p.mutex.Unlock()
return nil
}
// acquireGraceful attempts to acquire a job from the server, handling canceling the acquisition if we gracefully shut
+2 -2
View File
@@ -1,7 +1,7 @@
# This is the base image used for Coder images. It's a multi-arch image that is
# built in depot.dev for all supported architectures. Since it's built on real
# hardware and not cross-compiled, it can have "RUN" commands.
FROM alpine:3.21.2
FROM alpine:3.21.3
# We use a single RUN command to reduce the number of layers in the image.
# NOTE: Keep the Terraform version in sync with minTerraformVersion and
@@ -26,7 +26,7 @@ RUN apk add --no-cache \
# Terraform was disabled in the edge repo due to a build issue.
# https://gitlab.alpinelinux.org/alpine/aports/-/commit/f3e263d94cfac02d594bef83790c280e045eba35
# Using wget for now. Note that busybox unzip doesn't support streaming.
RUN ARCH="$(arch)"; if [ "${ARCH}" == "x86_64" ]; then ARCH="amd64"; elif [ "${ARCH}" == "aarch64" ]; then ARCH="arm64"; fi; wget -O /tmp/terraform.zip "https://releases.hashicorp.com/terraform/1.10.5/terraform_1.10.5_linux_${ARCH}.zip" && \
RUN ARCH="$(arch)"; if [ "${ARCH}" == "x86_64" ]; then ARCH="amd64"; elif [ "${ARCH}" == "aarch64" ]; then ARCH="arm64"; elif [ "${ARCH}" == "armv7l" ]; then ARCH="arm"; fi; wget -O /tmp/terraform.zip "https://releases.hashicorp.com/terraform/1.10.5/terraform_1.10.5_linux_${ARCH}.zip" && \
busybox unzip /tmp/terraform.zip -d /usr/local/bin && \
rm -f /tmp/terraform.zip && \
chmod +x /usr/local/bin/terraform && \
+108 -6
View File
@@ -36,17 +36,19 @@ source "$(dirname "${BASH_SOURCE[0]}")/lib.sh"
version=""
os="${GOOS:-linux}"
arch="${GOARCH:-amd64}"
output_path=""
slim="${CODER_SLIM_BUILD:-0}"
agpl="${CODER_BUILD_AGPL:-0}"
sign_darwin="${CODER_SIGN_DARWIN:-0}"
sign_windows="${CODER_SIGN_WINDOWS:-0}"
bin_ident="com.coder.cli"
output_path=""
agpl="${CODER_BUILD_AGPL:-0}"
boringcrypto=${CODER_BUILD_BORINGCRYPTO:-0}
debug=0
dylib=0
windows_resources="${CODER_WINDOWS_RESOURCES:-0}"
debug=0
args="$(getopt -o "" -l version:,os:,arch:,output:,slim,agpl,sign-darwin,boringcrypto,dylib,debug -- "$@")"
bin_ident="com.coder.cli"
args="$(getopt -o "" -l version:,os:,arch:,output:,slim,agpl,sign-darwin,sign-windows,boringcrypto,dylib,windows-resources,debug -- "$@")"
eval set -- "$args"
while true; do
case "$1" in
@@ -79,6 +81,10 @@ while true; do
sign_darwin=1
shift
;;
--sign-windows)
sign_windows=1
shift
;;
--boringcrypto)
boringcrypto=1
shift
@@ -87,6 +93,10 @@ while true; do
dylib=1
shift
;;
--windows-resources)
windows_resources=1
shift
;;
--debug)
debug=1
shift
@@ -115,11 +125,13 @@ if [[ "$sign_darwin" == 1 ]]; then
dependencies rcodesign
requiredenvs AC_CERTIFICATE_FILE AC_CERTIFICATE_PASSWORD_FILE
fi
if [[ "$sign_windows" == 1 ]]; then
dependencies java
requiredenvs JSIGN_PATH EV_KEYSTORE EV_KEY EV_CERTIFICATE_PATH EV_TSA_URL GCLOUD_ACCESS_TOKEN
fi
if [[ "$windows_resources" == 1 ]]; then
dependencies go-winres
fi
ldflags=(
-X "'github.com/coder/coder/v2/buildinfo.tag=$version'"
@@ -204,10 +216,100 @@ if [[ "$boringcrypto" == 1 ]]; then
goexp="boringcrypto"
fi
# On Windows, we use go-winres to embed the resources into the binary.
if [[ "$windows_resources" == 1 ]] && [[ "$os" == "windows" ]]; then
# Convert the version to a format that Windows understands.
# Remove any trailing data after a "+" or "-".
version_windows=$version
version_windows="${version_windows%+*}"
version_windows="${version_windows%-*}"
# If there wasn't any extra data, add a .0 to the version. Otherwise, add
# a .1 to the version to signify that this is not a release build so it can
# be distinguished from a release build.
non_release_build=0
if [[ "$version_windows" == "$version" ]]; then
version_windows+=".0"
else
version_windows+=".1"
non_release_build=1
fi
if [[ ! "$version_windows" =~ ^[0-9]+\.[0-9]+\.[0-9]+\.[0-1]$ ]]; then
error "Computed invalid windows version format: $version_windows"
fi
# File description changes based on slimness, AGPL status, and architecture.
file_description="Coder"
if [[ "$agpl" == 1 ]]; then
file_description+=" AGPL"
fi
if [[ "$slim" == 1 ]]; then
file_description+=" CLI"
fi
if [[ "$non_release_build" == 1 ]]; then
file_description+=" (development build)"
fi
# Because this writes to a file with the OS and arch in the filename, we
# don't support concurrent builds for the same OS and arch (irregardless of
# slimness or AGPL status).
#
# This is fine since we only embed resources during dogfood and release
# builds, which use make (which will build all slim targets in parallel,
# then all non-slim targets in parallel).
expected_rsrc_file="./buildinfo/resources/resources_windows_${arch}.syso"
if [[ -f "$expected_rsrc_file" ]]; then
rm "$expected_rsrc_file"
fi
touch "$expected_rsrc_file"
pushd ./buildinfo/resources
GOARCH="$arch" go-winres simply \
--arch "$arch" \
--out "resources" \
--product-version "$version_windows" \
--file-version "$version_windows" \
--manifest "cli" \
--file-description "$file_description" \
--product-name "Coder" \
--copyright "Copyright $(date +%Y) Coder Technologies Inc." \
--original-filename "coder.exe" \
--icon ../../scripts/win-installer/coder.ico
popd
if [[ ! -f "$expected_rsrc_file" ]]; then
error "Failed to generate $expected_rsrc_file"
fi
fi
set +e
GOEXPERIMENT="$goexp" CGO_ENABLED="$cgo" GOOS="$os" GOARCH="$arch" GOARM="$arm_version" \
go build \
"${build_args[@]}" \
"$cmd_path" 1>&2
exit_code=$?
set -e
# Clean up the resources file if it was generated.
if [[ "$windows_resources" == 1 ]] && [[ "$os" == "windows" ]]; then
rm "$expected_rsrc_file"
fi
if [[ "$exit_code" != 0 ]]; then
exit "$exit_code"
fi
# If we did embed resources, verify that they were included.
if [[ "$windows_resources" == 1 ]] && [[ "$os" == "windows" ]]; then
winres_dir=$(mktemp -d)
if ! go-winres extract --dir "$winres_dir" "$output_path" 1>&2; then
rm -rf "$winres_dir"
error "Compiled binary does not contain embedded resources"
fi
# If go-winres didn't return an error, it means it did find embedded
# resources.
rm -rf "$winres_dir"
fi
if [[ "$sign_darwin" == 1 ]] && [[ "$os" == "darwin" ]]; then
execrelative ./sign_darwin.sh "$output_path" "$bin_ident" 1>&2
+17
View File
@@ -118,6 +118,23 @@ main() {
title2=${parts2[*]:2}
fi
# Handle cherry-pick bot, it turns "chore: foo bar (#42)" to
# "chore: foo bar (cherry-pick #42) (#43)".
if [[ ${title1} == *"(cherry-pick #"* ]]; then
title1=${title1%" ("*}
pr=${title1##*#}
pr=${pr%)}
title1=${title1%" ("*}
title1="${title1} (#${pr})"$'\n'
fi
if [[ ${title2} == *"(cherry-pick #"* ]]; then
title2=${title2%" ("*}
pr=${title2##*#}
pr=${pr%)}
title2=${title2%" ("*}
title2="${title2} (#${pr})"$'\n'
fi
if [[ ${title1} != "${title2}" ]]; then
log "Invariant failed, cherry-picked commits have different titles: \"${title1%$'\n'}\" != \"${title2%$'\n'}\", attempting to check commit body for cherry-pick information..."
+29 -3
View File
@@ -3,8 +3,8 @@ import { expect } from "@playwright/test";
import { API, type DeploymentConfig } from "api/api";
import type { SerpentOption } from "api/typesGenerated";
import { formatDuration, intervalToDuration } from "date-fns";
import { coderPort } from "./constants";
import { findSessionToken, randomName } from "./helpers";
import { coderPort, defaultPassword } from "./constants";
import { type LoginOptions, findSessionToken, randomName } from "./helpers";
let currentOrgId: string;
@@ -29,14 +29,40 @@ export const createUser = async (...orgIds: string[]) => {
email: `${name}@coder.com`,
username: name,
name: name,
password: "s3cure&password!",
password: defaultPassword,
login_type: "password",
organization_ids: orgIds,
user_status: null,
});
return user;
};
export const createOrganizationMember = async (
orgRoles: Record<string, string[]>,
): Promise<LoginOptions> => {
const name = randomName();
const user = await API.createUser({
email: `${name}@coder.com`,
username: name,
name: name,
password: defaultPassword,
login_type: "password",
organization_ids: Object.keys(orgRoles),
user_status: null,
});
for (const [org, roles] of Object.entries(orgRoles)) {
API.updateOrganizationMemberRoles(org, user.id, roles);
}
return {
username: user.username,
email: user.email,
password: defaultPassword,
};
};
export const createGroup = async (orgId: string) => {
const name = randomName();
const group = await API.createGroup(orgId, {
+7
View File
@@ -15,6 +15,7 @@ export const coderdPProfPort = 6062;
// The name of the organization that should be used by default when needed.
export const defaultOrganizationName = "coder";
export const defaultOrganizationId = "00000000-0000-0000-0000-000000000000";
export const defaultPassword = "SomeSecurePassword!";
// Credentials for users
@@ -30,6 +31,12 @@ export const users = {
email: "templateadmin@coder.com",
roles: ["Template Admin"],
},
userAdmin: {
username: "user-admin",
password: defaultPassword,
email: "useradmin@coder.com",
roles: ["User Admin"],
},
auditor: {
username: "auditor",
password: defaultPassword,
+28 -1
View File
@@ -61,7 +61,7 @@ export function requireTerraformProvisioner() {
test.skip(!requireTerraformTests);
}
type LoginOptions = {
export type LoginOptions = {
username: string;
email: string;
password: string;
@@ -1127,3 +1127,30 @@ export async function createOrganization(page: Page): Promise<{
return { name, displayName, description };
}
/**
* @param organization organization name
* @param user user email or username
*/
export async function addUserToOrganization(
page: Page,
organization: string,
user: string,
roles: string[] = [],
): Promise<void> {
await page.goto(`/organizations/${organization}`, {
waitUntil: "domcontentloaded",
});
await page.getByPlaceholder("User email or username").fill(user);
await page.getByRole("option", { name: user }).click();
await page.getByRole("button", { name: "Add user" }).click();
const addedRow = page.locator("tr", { hasText: user });
await expect(addedRow).toBeVisible();
await addedRow.getByLabel("Edit user roles").click();
for (const role of roles) {
await page.getByText(role).click();
}
await page.mouse.click(10, 10); // close the popover by clicking outside of it
}
+12 -3
View File
@@ -2,10 +2,11 @@ import { expect, test } from "@playwright/test";
import {
createGroup,
createOrganization,
createOrganizationMember,
createUser,
setupApiCalls,
} from "../api";
import { defaultOrganizationName } from "../constants";
import { defaultOrganizationId, defaultOrganizationName } from "../constants";
import { expectUrl } from "../expectUrl";
import { login, randomName, requiresLicense } from "../helpers";
import { beforeCoderTest } from "../hooks";
@@ -32,6 +33,11 @@ test("create group", async ({ page }) => {
// Create a new organization
const org = await createOrganization();
const orgUserAdmin = await createOrganizationMember({
[org.id]: ["organization-user-admin"],
});
await login(page, orgUserAdmin);
await page.goto(`/organizations/${org.name}`);
// Navigate to groups page
@@ -64,8 +70,7 @@ test("create group", async ({ page }) => {
await expect(addedRow).toBeVisible();
// Ensure we can't add a user who isn't in the org
const otherOrg = await createOrganization();
const personToReject = await createUser(otherOrg.id);
const personToReject = await createUser(defaultOrganizationId);
await page
.getByPlaceholder("User email or username")
.fill(personToReject.email);
@@ -93,8 +98,12 @@ test("change quota settings", async ({ page }) => {
// Create a new organization and group
const org = await createOrganization();
const group = await createGroup(org.id);
const orgUserAdmin = await createOrganizationMember({
[org.id]: ["organization-user-admin"],
});
// Go to settings
await login(page, orgUserAdmin);
await page.goto(`/organizations/${org.name}/groups/${group.name}`);
await page.getByRole("button", { name: "Settings", exact: true }).click();
expectUrl(page).toHavePathName(
+9 -11
View File
@@ -1,6 +1,7 @@
import { expect, test } from "@playwright/test";
import { setupApiCalls } from "../api";
import {
addUserToOrganization,
createOrganization,
createUser,
login,
@@ -18,7 +19,7 @@ test("add and remove organization member", async ({ page }) => {
requiresLicense();
// Create a new organization
const { displayName } = await createOrganization(page);
const { name: orgName, displayName } = await createOrganization(page);
// Navigate to members page
await page.getByRole("link", { name: "Members" }).click();
@@ -26,17 +27,14 @@ test("add and remove organization member", async ({ page }) => {
// Add a user to the org
const personToAdd = await createUser(page);
await page.getByPlaceholder("User email or username").fill(personToAdd.email);
await page.getByRole("option", { name: personToAdd.email }).click();
await page.getByRole("button", { name: "Add user" }).click();
const addedRow = page.locator("tr", { hasText: personToAdd.email });
await expect(addedRow).toBeVisible();
// This must be done as an admin, because you can't assign a role that has more
// permissions than you, even if you have the ability to assign roles.
await addUserToOrganization(page, orgName, personToAdd.email, [
"Organization User Admin",
"Organization Template Admin",
]);
// Give them a role
await addedRow.getByLabel("Edit user roles").click();
await page.getByText("Organization User Admin").click();
await page.getByText("Organization Template Admin").click();
await page.mouse.click(10, 10); // close the popover by clicking outside of it
const addedRow = page.locator("tr", { hasText: personToAdd.email });
await expect(addedRow.getByText("Organization User Admin")).toBeVisible();
await expect(addedRow.getByText("+1 more")).toBeVisible();
+157
View File
@@ -0,0 +1,157 @@
import { type Page, expect, test } from "@playwright/test";
import {
createOrganization,
createOrganizationMember,
setupApiCalls,
} from "../api";
import { license, users } from "../constants";
import { login, requiresLicense } from "../helpers";
import { beforeCoderTest } from "../hooks";
test.beforeEach(async ({ page }) => {
beforeCoderTest(page);
});
type AdminSetting = (typeof adminSettings)[number];
const adminSettings = [
"Deployment",
"Organizations",
"Healthcheck",
"Audit Logs",
] as const;
async function hasAccessToAdminSettings(page: Page, settings: AdminSetting[]) {
// Organizations and Audit Logs both require a license to be visible
const visibleSettings = license
? settings
: settings.filter((it) => it !== "Organizations" && it !== "Audit Logs");
const adminSettingsButton = page.getByRole("button", {
name: "Admin settings",
});
if (visibleSettings.length < 1) {
await expect(adminSettingsButton).not.toBeVisible();
return;
}
await adminSettingsButton.click();
for (const name of visibleSettings) {
await expect(page.getByText(name, { exact: true })).toBeVisible();
}
const hiddenSettings = adminSettings.filter(
(it) => !visibleSettings.includes(it),
);
for (const name of hiddenSettings) {
await expect(page.getByText(name, { exact: true })).not.toBeVisible();
}
}
test.describe("roles admin settings access", () => {
test("member cannot see admin settings", async ({ page }) => {
await login(page, users.member);
await page.goto("/", { waitUntil: "domcontentloaded" });
// None, "Admin settings" button should not be visible
await hasAccessToAdminSettings(page, []);
});
test("template admin can see admin settings", async ({ page }) => {
await login(page, users.templateAdmin);
await page.goto("/", { waitUntil: "domcontentloaded" });
await hasAccessToAdminSettings(page, ["Deployment", "Organizations"]);
});
test("user admin can see admin settings", async ({ page }) => {
await login(page, users.userAdmin);
await page.goto("/", { waitUntil: "domcontentloaded" });
await hasAccessToAdminSettings(page, ["Deployment", "Organizations"]);
});
test("auditor can see admin settings", async ({ page }) => {
await login(page, users.auditor);
await page.goto("/", { waitUntil: "domcontentloaded" });
await hasAccessToAdminSettings(page, [
"Deployment",
"Organizations",
"Audit Logs",
]);
});
test("admin can see admin settings", async ({ page }) => {
await login(page, users.admin);
await page.goto("/", { waitUntil: "domcontentloaded" });
await hasAccessToAdminSettings(page, [
"Deployment",
"Organizations",
"Healthcheck",
"Audit Logs",
]);
});
});
test.describe("org-scoped roles admin settings access", () => {
requiresLicense();
test.beforeEach(async ({ page }) => {
await login(page);
await setupApiCalls(page);
});
test("org template admin can see admin settings", async ({ page }) => {
const org = await createOrganization();
const orgTemplateAdmin = await createOrganizationMember({
[org.id]: ["organization-template-admin"],
});
await login(page, orgTemplateAdmin);
await page.goto("/", { waitUntil: "domcontentloaded" });
await hasAccessToAdminSettings(page, ["Organizations"]);
});
test("org user admin can see admin settings", async ({ page }) => {
const org = await createOrganization();
const orgUserAdmin = await createOrganizationMember({
[org.id]: ["organization-user-admin"],
});
await login(page, orgUserAdmin);
await page.goto("/", { waitUntil: "domcontentloaded" });
await hasAccessToAdminSettings(page, ["Deployment", "Organizations"]);
});
test("org auditor can see admin settings", async ({ page }) => {
const org = await createOrganization();
const orgAuditor = await createOrganizationMember({
[org.id]: ["organization-auditor"],
});
await login(page, orgAuditor);
await page.goto("/", { waitUntil: "domcontentloaded" });
await hasAccessToAdminSettings(page, ["Organizations", "Audit Logs"]);
});
test("org admin can see admin settings", async ({ page }) => {
const org = await createOrganization();
const orgAdmin = await createOrganizationMember({
[org.id]: ["organization-admin"],
});
await login(page, orgAdmin);
await page.goto("/", { waitUntil: "domcontentloaded" });
await hasAccessToAdminSettings(page, [
"Deployment",
"Organizations",
"Audit Logs",
]);
});
});
-17
View File
@@ -6,10 +6,8 @@ import type {
UpdateOrganizationRequest,
} from "api/typesGenerated";
import {
type AnyOrganizationPermissions,
type OrganizationPermissionName,
type OrganizationPermissions,
anyOrganizationPermissionChecks,
organizationPermissionChecks,
} from "modules/management/organizationPermissions";
import type { QueryClient } from "react-query";
@@ -266,21 +264,6 @@ export const organizationsPermissions = (
};
};
export const anyOrganizationPermissionsKey = [
"authorization",
"anyOrganization",
];
export const anyOrganizationPermissions = () => {
return {
queryKey: anyOrganizationPermissionsKey,
queryFn: () =>
API.checkAuthorization({
checks: anyOrganizationPermissionChecks,
}) as Promise<AnyOrganizationPermissions>,
};
};
export const getOrganizationIdpSyncClaimFieldValuesKey = (
organization: string,
field: string,
+8 -9
View File
@@ -15,18 +15,17 @@ export const RBACResourceActions: Partial<
update: "update an api key, eg expires",
},
assign_org_role: {
assign: "ability to assign org scoped roles",
create: "ability to create/delete custom roles within an organization",
delete: "ability to delete org scoped roles",
read: "view what roles are assignable",
update: "ability to edit custom roles within an organization",
assign: "assign org scoped roles",
create: "create/delete custom roles within an organization",
delete: "delete roles within an organization",
read: "view what roles are assignable within an organization",
unassign: "unassign org scoped roles",
update: "edit custom roles within an organization",
},
assign_role: {
assign: "ability to assign roles",
create: "ability to create/delete/edit custom roles",
delete: "ability to unassign roles",
assign: "assign user roles",
read: "view what roles are assignable",
update: "ability to edit custom roles",
unassign: "unassign user roles",
},
audit_log: {
create: "create new audit log entries",
+9
View File
@@ -585,6 +585,9 @@ export interface DangerousConfig {
readonly allow_all_cors: boolean;
}
// From codersdk/database.go
export const DatabaseNotReachable = "database not reachable";
// From healthsdk/healthsdk.go
export interface DatabaseReport extends BaseReport {
readonly healthy: boolean;
@@ -1856,6 +1859,7 @@ export type RBACAction =
| "read"
| "read_personal"
| "ssh"
| "unassign"
| "update"
| "update_personal"
| "use"
@@ -1871,6 +1875,7 @@ export const RBACActions: RBACAction[] = [
"read",
"read_personal",
"ssh",
"unassign",
"update",
"update_personal",
"use",
@@ -2101,6 +2106,10 @@ export const RoleOrganizationTemplateAdmin = "organization-template-admin";
// From codersdk/rbacroles.go
export const RoleOrganizationUserAdmin = "organization-user-admin";
// From codersdk/rbacroles.go
export const RoleOrganizationWorkspaceCreationBan =
"organization-workspace-creation-ban";
// From codersdk/rbacroles.go
export const RoleOwner = "owner";
-1
View File
@@ -57,7 +57,6 @@ const avatarVariants = cva(
export type AvatarProps = AvatarPrimitive.AvatarProps &
VariantProps<typeof avatarVariants> & {
src?: string;
fallback?: string;
};
@@ -0,0 +1,120 @@
import type { Meta, StoryObj } from "@storybook/react";
import { Button } from "../Button/Button";
import { CollapsibleSummary } from "./CollapsibleSummary";
const meta: Meta<typeof CollapsibleSummary> = {
title: "components/CollapsibleSummary",
component: CollapsibleSummary,
args: {
label: "Advanced options",
children: (
<>
<div className="p-2 border border-border rounded-md border-solid">
Option 1
</div>
<div className="p-2 border border-border rounded-md border-solid">
Option 2
</div>
<div className="p-2 border border-border rounded-md border-solid">
Option 3
</div>
</>
),
},
};
export default meta;
type Story = StoryObj<typeof CollapsibleSummary>;
export const Default: Story = {};
export const DefaultOpen: Story = {
args: {
defaultOpen: true,
},
};
export const MediumSize: Story = {
args: {
size: "md",
},
};
export const SmallSize: Story = {
args: {
size: "sm",
},
};
export const CustomClassName: Story = {
args: {
className: "text-blue-500 font-bold",
},
};
export const ManyChildren: Story = {
args: {
defaultOpen: true,
children: (
<>
{Array.from({ length: 10 }).map((_, i) => (
<div
key={`option-${i + 1}`}
className="p-2 border border-border rounded-md border-solid"
>
Option {i + 1}
</div>
))}
</>
),
},
};
export const NestedCollapsible: Story = {
args: {
defaultOpen: true,
children: (
<>
<div className="p-2 border border-border rounded-md border-solid">
Option 1
</div>
<CollapsibleSummary label="Nested options" size="sm">
<div className="p-2 border border-border rounded-md border-solid">
Nested Option 1
</div>
<div className="p-2 border border-border rounded-md border-solid">
Nested Option 2
</div>
</CollapsibleSummary>
<div className="p-2 border border-border rounded-md border-solid">
Option 3
</div>
</>
),
},
};
export const ComplexContent: Story = {
args: {
defaultOpen: true,
children: (
<div className="p-4 border border-border rounded-md bg-surface-secondary">
<h3 className="text-lg font-bold mb-2">Complex Content</h3>
<p className="mb-4">
This is a more complex content example with various elements.
</p>
<div className="flex gap-2">
<Button>Action 1</Button>
<Button>Action 2</Button>
</div>
</div>
),
},
};
export const LongLabel: Story = {
args: {
label:
"This is a very long label that might wrap or cause layout issues if not handled properly",
},
};
@@ -0,0 +1,91 @@
import { type VariantProps, cva } from "class-variance-authority";
import { ChevronRightIcon } from "lucide-react";
import { type FC, type ReactNode, useState } from "react";
import { cn } from "utils/cn";
const collapsibleSummaryVariants = cva(
`flex items-center gap-1 p-0 bg-transparent border-0 text-inherit cursor-pointer
transition-colors text-content-secondary hover:text-content-primary font-medium
whitespace-nowrap`,
{
variants: {
size: {
md: "text-sm",
sm: "text-xs",
},
},
defaultVariants: {
size: "md",
},
},
);
export interface CollapsibleSummaryProps
extends VariantProps<typeof collapsibleSummaryVariants> {
/**
* The label to display for the collapsible section
*/
label: string;
/**
* The content to show when expanded
*/
children: ReactNode;
/**
* Whether the section is initially expanded
*/
defaultOpen?: boolean;
/**
* Optional className for the button
*/
className?: string;
/**
* The size of the component
*/
size?: "md" | "sm";
}
export const CollapsibleSummary: FC<CollapsibleSummaryProps> = ({
label,
children,
defaultOpen = false,
className,
size,
}) => {
const [isOpen, setIsOpen] = useState(defaultOpen);
return (
<div className="flex flex-col gap-4">
<button
className={cn(
collapsibleSummaryVariants({ size }),
isOpen && "text-content-primary",
className,
)}
type="button"
onClick={() => {
setIsOpen((v) => !v);
}}
>
<div
className={cn(
"flex items-center justify-center transition-transform duration-200",
isOpen ? "rotate-90" : "rotate-0",
)}
>
<ChevronRightIcon
className={cn(
"p-0.5",
size === "sm" ? "size-icon-xs" : "size-icon-sm",
)}
/>
</div>
<span className="sr-only">
({isOpen ? "Hide" : "Show"}) {label}
</span>
<span className="[&:first-letter]:uppercase">{label}</span>
</button>
{isOpen && <div className="flex flex-col gap-4">{children}</div>}
</div>
);
};

Some files were not shown because too many files have changed in this diff Show More